tool_model_manager.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. """
  2. For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
  3. Therefore, a model manager is needed to list/invoke/validate models.
  4. """
  5. import json
  6. from typing import cast
  7. from core.model_manager import ModelManager
  8. from core.model_runtime.entities.llm_entities import LLMResult
  9. from core.model_runtime.entities.message_entities import PromptMessage
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.errors.invoke import (
  12. InvokeAuthorizationError,
  13. InvokeBadRequestError,
  14. InvokeConnectionError,
  15. InvokeRateLimitError,
  16. InvokeServerUnavailableError,
  17. )
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.model.errors import InvokeModelError
  21. from extensions.ext_database import db
  22. from models.tools import ToolModelInvoke
  23. class ToolModelManager:
  24. @staticmethod
  25. def get_max_llm_context_tokens(
  26. tenant_id: str,
  27. ) -> int:
  28. """
  29. get max llm context tokens of the model
  30. """
  31. model_manager = ModelManager()
  32. model_instance = model_manager.get_default_model_instance(
  33. tenant_id=tenant_id, model_type=ModelType.LLM,
  34. )
  35. if not model_instance:
  36. raise InvokeModelError('Model not found')
  37. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  38. schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  39. if not schema:
  40. raise InvokeModelError('No model schema found')
  41. max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
  42. if max_tokens is None:
  43. return 2048
  44. return max_tokens
  45. @staticmethod
  46. def calculate_tokens(
  47. tenant_id: str,
  48. prompt_messages: list[PromptMessage]
  49. ) -> int:
  50. """
  51. calculate tokens from prompt messages and model parameters
  52. """
  53. # get model instance
  54. model_manager = ModelManager()
  55. model_instance = model_manager.get_default_model_instance(
  56. tenant_id=tenant_id, model_type=ModelType.LLM
  57. )
  58. if not model_instance:
  59. raise InvokeModelError('Model not found')
  60. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  61. # get tokens
  62. tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
  63. return tokens
  64. @staticmethod
  65. def invoke(
  66. user_id: str, tenant_id: str,
  67. tool_type: str, tool_name: str,
  68. prompt_messages: list[PromptMessage]
  69. ) -> LLMResult:
  70. """
  71. invoke model with parameters in user's own context
  72. :param user_id: user id
  73. :param tenant_id: tenant id, the tenant id of the creator of the tool
  74. :param tool_provider: tool provider
  75. :param tool_id: tool id
  76. :param tool_name: tool name
  77. :param provider: model provider
  78. :param model: model name
  79. :param model_parameters: model parameters
  80. :param prompt_messages: prompt messages
  81. :return: AssistantPromptMessage
  82. """
  83. # get model manager
  84. model_manager = ModelManager()
  85. # get model instance
  86. model_instance = model_manager.get_default_model_instance(
  87. tenant_id=tenant_id, model_type=ModelType.LLM,
  88. )
  89. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  90. # get model credentials
  91. model_credentials = model_instance.credentials
  92. # get prompt tokens
  93. prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
  94. model_parameters = {
  95. 'temperature': 0.8,
  96. 'top_p': 0.8,
  97. }
  98. # create tool model invoke
  99. tool_model_invoke = ToolModelInvoke(
  100. user_id=user_id,
  101. tenant_id=tenant_id,
  102. provider=model_instance.provider,
  103. tool_type=tool_type,
  104. tool_name=tool_name,
  105. model_parameters=json.dumps(model_parameters),
  106. prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
  107. model_response='',
  108. prompt_tokens=prompt_tokens,
  109. answer_tokens=0,
  110. answer_unit_price=0,
  111. answer_price_unit=0,
  112. provider_response_latency=0,
  113. total_price=0,
  114. currency='USD',
  115. )
  116. db.session.add(tool_model_invoke)
  117. db.session.commit()
  118. try:
  119. response: LLMResult = llm_model.invoke(
  120. model=model_instance.model,
  121. credentials=model_credentials,
  122. prompt_messages=prompt_messages,
  123. model_parameters=model_parameters,
  124. tools=[], stop=[], stream=False, user=user_id, callbacks=[]
  125. )
  126. except InvokeRateLimitError as e:
  127. raise InvokeModelError(f'Invoke rate limit error: {e}')
  128. except InvokeBadRequestError as e:
  129. raise InvokeModelError(f'Invoke bad request error: {e}')
  130. except InvokeConnectionError as e:
  131. raise InvokeModelError(f'Invoke connection error: {e}')
  132. except InvokeAuthorizationError as e:
  133. raise InvokeModelError('Invoke authorization error')
  134. except InvokeServerUnavailableError as e:
  135. raise InvokeModelError(f'Invoke server unavailable error: {e}')
  136. except Exception as e:
  137. raise InvokeModelError(f'Invoke error: {e}')
  138. # update tool model invoke
  139. tool_model_invoke.model_response = response.message.content
  140. if response.usage:
  141. tool_model_invoke.answer_tokens = response.usage.completion_tokens
  142. tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
  143. tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
  144. tool_model_invoke.provider_response_latency = response.usage.latency
  145. tool_model_invoke.total_price = response.usage.total_price
  146. tool_model_invoke.currency = response.usage.currency
  147. db.session.commit()
  148. return response