tool_model_manager.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. """
  2. For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
  3. Therefore, a model manager is needed to list/invoke/validate models.
  4. """
  5. import json
  6. from typing import List, cast
  7. from core.model_manager import ModelManager
  8. from core.model_runtime.entities.llm_entities import LLMResult
  9. from core.model_runtime.entities.message_entities import PromptMessage
  10. from core.model_runtime.entities.model_entities import ModelType
  11. from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError,
  12. InvokeRateLimitError, InvokeServerUnavailableError)
  13. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
  14. from core.model_runtime.utils.encoders import jsonable_encoder
  15. from core.tools.model.errors import InvokeModelError
  16. from extensions.ext_database import db
  17. from models.tools import ToolModelInvoke
  18. class ToolModelManager:
  19. @staticmethod
  20. def get_max_llm_context_tokens(
  21. tenant_id: str,
  22. ) -> int:
  23. """
  24. get max llm context tokens of the model
  25. """
  26. model_manager = ModelManager()
  27. model_instance = model_manager.get_default_model_instance(
  28. tenant_id=tenant_id, model_type=ModelType.LLM,
  29. )
  30. if not model_instance:
  31. raise InvokeModelError(f'Model not found')
  32. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  33. schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
  34. if not schema:
  35. raise InvokeModelError(f'No model schema found')
  36. max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
  37. if max_tokens is None:
  38. return 2048
  39. return max_tokens
  40. @staticmethod
  41. def calculate_tokens(
  42. tenant_id: str,
  43. prompt_messages: List[PromptMessage]
  44. ) -> int:
  45. """
  46. calculate tokens from prompt messages and model parameters
  47. """
  48. # get model instance
  49. model_manager = ModelManager()
  50. model_instance = model_manager.get_default_model_instance(
  51. tenant_id=tenant_id, model_type=ModelType.LLM
  52. )
  53. if not model_instance:
  54. raise InvokeModelError(f'Model not found')
  55. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  56. # get tokens
  57. tokens = llm_model.get_num_tokens(model_instance.model, model_instance.credentials, prompt_messages)
  58. return tokens
  59. @staticmethod
  60. def invoke(
  61. user_id: str, tenant_id: str,
  62. tool_type: str, tool_name: str,
  63. prompt_messages: List[PromptMessage]
  64. ) -> LLMResult:
  65. """
  66. invoke model with parameters in user's own context
  67. :param user_id: user id
  68. :param tenant_id: tenant id, the tenant id of the creator of the tool
  69. :param tool_provider: tool provider
  70. :param tool_id: tool id
  71. :param tool_name: tool name
  72. :param provider: model provider
  73. :param model: model name
  74. :param model_parameters: model parameters
  75. :param prompt_messages: prompt messages
  76. :return: AssistantPromptMessage
  77. """
  78. # get model manager
  79. model_manager = ModelManager()
  80. # get model instance
  81. model_instance = model_manager.get_default_model_instance(
  82. tenant_id=tenant_id, model_type=ModelType.LLM,
  83. )
  84. llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
  85. # get model credentials
  86. model_credentials = model_instance.credentials
  87. # get prompt tokens
  88. prompt_tokens = llm_model.get_num_tokens(model_instance.model, model_credentials, prompt_messages)
  89. model_parameters = {
  90. 'temperature': 0.8,
  91. 'top_p': 0.8,
  92. }
  93. # create tool model invoke
  94. tool_model_invoke = ToolModelInvoke(
  95. user_id=user_id,
  96. tenant_id=tenant_id,
  97. provider=model_instance.provider,
  98. tool_type=tool_type,
  99. tool_name=tool_name,
  100. model_parameters=json.dumps(model_parameters),
  101. prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
  102. model_response='',
  103. prompt_tokens=prompt_tokens,
  104. answer_tokens=0,
  105. answer_unit_price=0,
  106. answer_price_unit=0,
  107. provider_response_latency=0,
  108. total_price=0,
  109. currency='USD',
  110. )
  111. db.session.add(tool_model_invoke)
  112. db.session.commit()
  113. try:
  114. response: LLMResult = llm_model.invoke(
  115. model=model_instance.model,
  116. credentials=model_credentials,
  117. prompt_messages=prompt_messages,
  118. model_parameters=model_parameters,
  119. tools=[], stop=[], stream=False, user=user_id, callbacks=[]
  120. )
  121. except InvokeRateLimitError as e:
  122. raise InvokeModelError(f'Invoke rate limit error: {e}')
  123. except InvokeBadRequestError as e:
  124. raise InvokeModelError(f'Invoke bad request error: {e}')
  125. except InvokeConnectionError as e:
  126. raise InvokeModelError(f'Invoke connection error: {e}')
  127. except InvokeAuthorizationError as e:
  128. raise InvokeModelError(f'Invoke authorization error')
  129. except InvokeServerUnavailableError as e:
  130. raise InvokeModelError(f'Invoke server unavailable error: {e}')
  131. except Exception as e:
  132. raise InvokeModelError(f'Invoke error: {e}')
  133. # update tool model invoke
  134. tool_model_invoke.model_response = response.message.content
  135. if response.usage:
  136. tool_model_invoke.answer_tokens = response.usage.completion_tokens
  137. tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
  138. tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
  139. tool_model_invoke.provider_response_latency = response.usage.latency
  140. tool_model_invoke.total_price = response.usage.total_price
  141. tool_model_invoke.currency = response.usage.currency
  142. db.session.commit()
  143. return response