model_invocation_utils.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. """
  2. For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
  3. Therefore, a model manager is needed to list/invoke/validate models.
  4. """
  5. import json
  6. from typing import cast
  7. from core.model_manager import ModelManager
  8. from core.model_runtime.entities.llm_entities import LLMResult
  9. from core.model_runtime.entities.message_entities import PromptMessage
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.errors.invoke import (
  12. InvokeAuthorizationError,
  13. InvokeBadRequestError,
  14. InvokeConnectionError,
  15. InvokeRateLimitError,
  16. InvokeServerUnavailableError,
  17. )
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from extensions.ext_database import db
  21. from models.tools import ToolModelInvoke
  22. class InvokeModelError(Exception):
  23. pass
  24. class ModelInvocationUtils:
  25. @staticmethod
  26. def get_max_llm_context_tokens(
  27. tenant_id: str,
  28. ) -> int:
  29. """
  30. get max llm context tokens of the model
  31. """
  32. model_manager = ModelManager()
  33. model_instance = model_manager.get_default_model_instance(
  34. tenant_id=tenant_id,
  35. model_type=ModelType.LLM,
  36. )
  37. if not model_instance:
  38. raise InvokeModelError("Model not found")
  39. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  40. schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  41. if not schema:
  42. raise InvokeModelError("No model schema found")
  43. max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
  44. if max_tokens is None:
  45. return 2048
  46. return max_tokens
  47. @staticmethod
  48. def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
  49. """
  50. calculate tokens from prompt messages and model parameters
  51. """
  52. # get model instance
  53. model_manager = ModelManager()
  54. model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
  55. if not model_instance:
  56. raise InvokeModelError("Model not found")
  57. # get tokens
  58. tokens = model_instance.get_llm_num_tokens(prompt_messages)
  59. return tokens
  60. @staticmethod
  61. def invoke(
  62. user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
  63. ) -> LLMResult:
  64. """
  65. invoke model with parameters in user's own context
  66. :param user_id: user id
  67. :param tenant_id: tenant id, the tenant id of the creator of the tool
  68. :param tool_provider: tool provider
  69. :param tool_id: tool id
  70. :param tool_name: tool name
  71. :param provider: model provider
  72. :param model: model name
  73. :param model_parameters: model parameters
  74. :param prompt_messages: prompt messages
  75. :return: AssistantPromptMessage
  76. """
  77. # get model manager
  78. model_manager = ModelManager()
  79. # get model instance
  80. model_instance = model_manager.get_default_model_instance(
  81. tenant_id=tenant_id,
  82. model_type=ModelType.LLM,
  83. )
  84. # get prompt tokens
  85. prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
  86. model_parameters = {
  87. "temperature": 0.8,
  88. "top_p": 0.8,
  89. }
  90. # create tool model invoke
  91. tool_model_invoke = ToolModelInvoke(
  92. user_id=user_id,
  93. tenant_id=tenant_id,
  94. provider=model_instance.provider,
  95. tool_type=tool_type,
  96. tool_name=tool_name,
  97. model_parameters=json.dumps(model_parameters),
  98. prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
  99. model_response="",
  100. prompt_tokens=prompt_tokens,
  101. answer_tokens=0,
  102. answer_unit_price=0,
  103. answer_price_unit=0,
  104. provider_response_latency=0,
  105. total_price=0,
  106. currency="USD",
  107. )
  108. db.session.add(tool_model_invoke)
  109. db.session.commit()
  110. try:
  111. response: LLMResult = model_instance.invoke_llm(
  112. prompt_messages=prompt_messages,
  113. model_parameters=model_parameters,
  114. tools=[],
  115. stop=[],
  116. stream=False,
  117. user=user_id,
  118. callbacks=[],
  119. )
  120. except InvokeRateLimitError as e:
  121. raise InvokeModelError(f"Invoke rate limit error: {e}")
  122. except InvokeBadRequestError as e:
  123. raise InvokeModelError(f"Invoke bad request error: {e}")
  124. except InvokeConnectionError as e:
  125. raise InvokeModelError(f"Invoke connection error: {e}")
  126. except InvokeAuthorizationError as e:
  127. raise InvokeModelError("Invoke authorization error")
  128. except InvokeServerUnavailableError as e:
  129. raise InvokeModelError(f"Invoke server unavailable error: {e}")
  130. except Exception as e:
  131. raise InvokeModelError(f"Invoke error: {e}")
  132. # update tool model invoke
  133. tool_model_invoke.model_response = response.message.content
  134. if response.usage:
  135. tool_model_invoke.answer_tokens = response.usage.completion_tokens
  136. tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
  137. tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
  138. tool_model_invoke.provider_response_latency = response.usage.latency
  139. tool_model_invoke.total_price = response.usage.total_price
  140. tool_model_invoke.currency = response.usage.currency
  141. db.session.commit()
  142. return response