model_invocation_utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """
  2. For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
  3. Therefore, a model manager is needed to list/invoke/validate models.
  4. """
  5. import json
  6. from typing import cast
  7. from core.model_manager import ModelManager
  8. from core.model_runtime.entities.llm_entities import LLMResult
  9. from core.model_runtime.entities.message_entities import PromptMessage
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.errors.invoke import (
  12. InvokeAuthorizationError,
  13. InvokeBadRequestError,
  14. InvokeConnectionError,
  15. InvokeRateLimitError,
  16. InvokeServerUnavailableError,
  17. )
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from extensions.ext_database import db
  21. from models.tools import ToolModelInvoke
  22. class InvokeModelError(Exception):
  23. pass
  24. class ModelInvocationUtils:
  25. @staticmethod
  26. def get_max_llm_context_tokens(
  27. tenant_id: str,
  28. ) -> int:
  29. """
  30. get max llm context tokens of the model
  31. """
  32. model_manager = ModelManager()
  33. model_instance = model_manager.get_default_model_instance(
  34. tenant_id=tenant_id, model_type=ModelType.LLM,
  35. )
  36. if not model_instance:
  37. raise InvokeModelError('Model not found')
  38. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  39. schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  40. if not schema:
  41. raise InvokeModelError('No model schema found')
  42. max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
  43. if max_tokens is None:
  44. return 2048
  45. return max_tokens
  46. @staticmethod
  47. def calculate_tokens(
  48. tenant_id: str,
  49. prompt_messages: list[PromptMessage]
  50. ) -> int:
  51. """
  52. calculate tokens from prompt messages and model parameters
  53. """
  54. # get model instance
  55. model_manager = ModelManager()
  56. model_instance = model_manager.get_default_model_instance(
  57. tenant_id=tenant_id, model_type=ModelType.LLM
  58. )
  59. if not model_instance:
  60. raise InvokeModelError('Model not found')
  61. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  62. # get tokens
  63. tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
  64. return tokens
  65. @staticmethod
  66. def invoke(
  67. user_id: str, tenant_id: str,
  68. tool_type: str, tool_name: str,
  69. prompt_messages: list[PromptMessage]
  70. ) -> LLMResult:
  71. """
  72. invoke model with parameters in user's own context
  73. :param user_id: user id
  74. :param tenant_id: tenant id, the tenant id of the creator of the tool
  75. :param tool_provider: tool provider
  76. :param tool_id: tool id
  77. :param tool_name: tool name
  78. :param provider: model provider
  79. :param model: model name
  80. :param model_parameters: model parameters
  81. :param prompt_messages: prompt messages
  82. :return: AssistantPromptMessage
  83. """
  84. # get model manager
  85. model_manager = ModelManager()
  86. # get model instance
  87. model_instance = model_manager.get_default_model_instance(
  88. tenant_id=tenant_id, model_type=ModelType.LLM,
  89. )
  90. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  91. # get model credentials
  92. model_credentials = model_instance.credentials
  93. # get prompt tokens
  94. prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
  95. model_parameters = {
  96. 'temperature': 0.8,
  97. 'top_p': 0.8,
  98. }
  99. # create tool model invoke
  100. tool_model_invoke = ToolModelInvoke(
  101. user_id=user_id,
  102. tenant_id=tenant_id,
  103. provider=model_instance.provider,
  104. tool_type=tool_type,
  105. tool_name=tool_name,
  106. model_parameters=json.dumps(model_parameters),
  107. prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
  108. model_response='',
  109. prompt_tokens=prompt_tokens,
  110. answer_tokens=0,
  111. answer_unit_price=0,
  112. answer_price_unit=0,
  113. provider_response_latency=0,
  114. total_price=0,
  115. currency='USD',
  116. )
  117. db.session.add(tool_model_invoke)
  118. db.session.commit()
  119. try:
  120. response: LLMResult = llm_model.invoke(
  121. model=model_instance.model,
  122. credentials=model_credentials,
  123. prompt_messages=prompt_messages,
  124. model_parameters=model_parameters,
  125. tools=[], stop=[], stream=False, user=user_id, callbacks=[]
  126. )
  127. except InvokeRateLimitError as e:
  128. raise InvokeModelError(f'Invoke rate limit error: {e}')
  129. except InvokeBadRequestError as e:
  130. raise InvokeModelError(f'Invoke bad request error: {e}')
  131. except InvokeConnectionError as e:
  132. raise InvokeModelError(f'Invoke connection error: {e}')
  133. except InvokeAuthorizationError as e:
  134. raise InvokeModelError('Invoke authorization error')
  135. except InvokeServerUnavailableError as e:
  136. raise InvokeModelError(f'Invoke server unavailable error: {e}')
  137. except Exception as e:
  138. raise InvokeModelError(f'Invoke error: {e}')
  139. # update tool model invoke
  140. tool_model_invoke.model_response = response.message.content
  141. if response.usage:
  142. tool_model_invoke.answer_tokens = response.usage.completion_tokens
  143. tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
  144. tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
  145. tool_model_invoke.provider_response_latency = response.usage.latency
  146. tool_model_invoke.total_price = response.usage.total_price
  147. tool_model_invoke.currency = response.usage.currency
  148. db.session.commit()
  149. return response