model_manager.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from typing import Optional, Union, Generator, cast, List, IO
  2. from core.entities.provider_configuration import ProviderModelBundle
  3. from core.errors.error import ProviderTokenNotInitError
  4. from core.model_runtime.callbacks.base_callback import Callback
  5. from core.model_runtime.entities.llm_entities import LLMResult
  6. from core.model_runtime.entities.message_entities import PromptMessageTool, PromptMessage
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.entities.rerank_entities import RerankResult
  9. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  10. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  11. from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
  12. from core.model_runtime.model_providers.__base.rerank_model import RerankModel
  13. from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
  14. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  15. from core.provider_manager import ProviderManager
  16. class ModelInstance:
  17. """
  18. Model instance class
  19. """
  20. def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
  21. self._provider_model_bundle = provider_model_bundle
  22. self.model = model
  23. self.provider = provider_model_bundle.configuration.provider.provider
  24. self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
  25. self.model_type_instance = self._provider_model_bundle.model_type_instance
  26. def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
  27. """
  28. Fetch credentials from provider model bundle
  29. :param provider_model_bundle: provider model bundle
  30. :param model: model name
  31. :return:
  32. """
  33. credentials = provider_model_bundle.configuration.get_current_credentials(
  34. model_type=provider_model_bundle.model_type_instance.model_type,
  35. model=model
  36. )
  37. if credentials is None:
  38. raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
  39. return credentials
  40. def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
  41. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  42. stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
  43. -> Union[LLMResult, Generator]:
  44. """
  45. Invoke large language model
  46. :param prompt_messages: prompt messages
  47. :param model_parameters: model parameters
  48. :param tools: tools for tool calling
  49. :param stop: stop words
  50. :param stream: is stream response
  51. :param user: unique user id
  52. :param callbacks: callbacks
  53. :return: full response or stream response chunk generator result
  54. """
  55. if not isinstance(self.model_type_instance, LargeLanguageModel):
  56. raise Exception(f"Model type instance is not LargeLanguageModel")
  57. self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
  58. return self.model_type_instance.invoke(
  59. model=self.model,
  60. credentials=self.credentials,
  61. prompt_messages=prompt_messages,
  62. model_parameters=model_parameters,
  63. tools=tools,
  64. stop=stop,
  65. stream=stream,
  66. user=user,
  67. callbacks=callbacks
  68. )
  69. def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
  70. -> TextEmbeddingResult:
  71. """
  72. Invoke large language model
  73. :param texts: texts to embed
  74. :param user: unique user id
  75. :return: embeddings result
  76. """
  77. if not isinstance(self.model_type_instance, TextEmbeddingModel):
  78. raise Exception(f"Model type instance is not TextEmbeddingModel")
  79. self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
  80. return self.model_type_instance.invoke(
  81. model=self.model,
  82. credentials=self.credentials,
  83. texts=texts,
  84. user=user
  85. )
  86. def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
  87. user: Optional[str] = None) \
  88. -> RerankResult:
  89. """
  90. Invoke rerank model
  91. :param query: search query
  92. :param docs: docs for reranking
  93. :param score_threshold: score threshold
  94. :param top_n: top n
  95. :param user: unique user id
  96. :return: rerank result
  97. """
  98. if not isinstance(self.model_type_instance, RerankModel):
  99. raise Exception(f"Model type instance is not RerankModel")
  100. self.model_type_instance = cast(RerankModel, self.model_type_instance)
  101. return self.model_type_instance.invoke(
  102. model=self.model,
  103. credentials=self.credentials,
  104. query=query,
  105. docs=docs,
  106. score_threshold=score_threshold,
  107. top_n=top_n,
  108. user=user
  109. )
  110. def invoke_moderation(self, text: str, user: Optional[str] = None) \
  111. -> bool:
  112. """
  113. Invoke moderation model
  114. :param text: text to moderate
  115. :param user: unique user id
  116. :return: false if text is safe, true otherwise
  117. """
  118. if not isinstance(self.model_type_instance, ModerationModel):
  119. raise Exception(f"Model type instance is not ModerationModel")
  120. self.model_type_instance = cast(ModerationModel, self.model_type_instance)
  121. return self.model_type_instance.invoke(
  122. model=self.model,
  123. credentials=self.credentials,
  124. text=text,
  125. user=user
  126. )
  127. def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
  128. -> str:
  129. """
  130. Invoke large language model
  131. :param file: audio file
  132. :param user: unique user id
  133. :return: text for given audio file
  134. """
  135. if not isinstance(self.model_type_instance, Speech2TextModel):
  136. raise Exception(f"Model type instance is not Speech2TextModel")
  137. self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
  138. return self.model_type_instance.invoke(
  139. model=self.model,
  140. credentials=self.credentials,
  141. file=file,
  142. user=user
  143. )
  144. class ModelManager:
  145. def __init__(self) -> None:
  146. self._provider_manager = ProviderManager()
  147. def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
  148. """
  149. Get model instance
  150. :param tenant_id: tenant id
  151. :param provider: provider name
  152. :param model_type: model type
  153. :param model: model name
  154. :return:
  155. """
  156. provider_model_bundle = self._provider_manager.get_provider_model_bundle(
  157. tenant_id=tenant_id,
  158. provider=provider,
  159. model_type=model_type
  160. )
  161. return ModelInstance(provider_model_bundle, model)
  162. def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
  163. """
  164. Get default model instance
  165. :param tenant_id: tenant id
  166. :param model_type: model type
  167. :return:
  168. """
  169. default_model_entity = self._provider_manager.get_default_model(
  170. tenant_id=tenant_id,
  171. model_type=model_type
  172. )
  173. if not default_model_entity:
  174. raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
  175. return self.get_model_instance(
  176. tenant_id=tenant_id,
  177. provider=default_model_entity.provider.provider,
  178. model_type=model_type,
  179. model=default_model_entity.model
  180. )