model_manager.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from typing import IO, Generator, List, Optional, Union, cast
  2. from core.entities.provider_configuration import ProviderModelBundle
  3. from core.errors.error import ProviderTokenNotInitError
  4. from core.model_runtime.callbacks.base_callback import Callback
  5. from core.model_runtime.entities.llm_entities import LLMResult
  6. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  7. from core.model_runtime.entities.model_entities import ModelType
  8. from core.model_runtime.entities.rerank_entities import RerankResult
  9. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  10. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  11. from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
  12. from core.model_runtime.model_providers.__base.rerank_model import RerankModel
  13. from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
  14. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  15. from core.model_runtime.model_providers.__base.tts_model import TTSModel
  16. from core.provider_manager import ProviderManager
  17. class ModelInstance:
  18. """
  19. Model instance class
  20. """
  21. def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
  22. self._provider_model_bundle = provider_model_bundle
  23. self.model = model
  24. self.provider = provider_model_bundle.configuration.provider.provider
  25. self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
  26. self.model_type_instance = self._provider_model_bundle.model_type_instance
  27. def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
  28. """
  29. Fetch credentials from provider model bundle
  30. :param provider_model_bundle: provider model bundle
  31. :param model: model name
  32. :return:
  33. """
  34. credentials = provider_model_bundle.configuration.get_current_credentials(
  35. model_type=provider_model_bundle.model_type_instance.model_type,
  36. model=model
  37. )
  38. if credentials is None:
  39. raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
  40. return credentials
  41. def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
  42. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
  43. stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
  44. -> Union[LLMResult, Generator]:
  45. """
  46. Invoke large language model
  47. :param prompt_messages: prompt messages
  48. :param model_parameters: model parameters
  49. :param tools: tools for tool calling
  50. :param stop: stop words
  51. :param stream: is stream response
  52. :param user: unique user id
  53. :param callbacks: callbacks
  54. :return: full response or stream response chunk generator result
  55. """
  56. if not isinstance(self.model_type_instance, LargeLanguageModel):
  57. raise Exception(f"Model type instance is not LargeLanguageModel")
  58. self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
  59. return self.model_type_instance.invoke(
  60. model=self.model,
  61. credentials=self.credentials,
  62. prompt_messages=prompt_messages,
  63. model_parameters=model_parameters,
  64. tools=tools,
  65. stop=stop,
  66. stream=stream,
  67. user=user,
  68. callbacks=callbacks
  69. )
  70. def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
  71. -> TextEmbeddingResult:
  72. """
  73. Invoke large language model
  74. :param texts: texts to embed
  75. :param user: unique user id
  76. :return: embeddings result
  77. """
  78. if not isinstance(self.model_type_instance, TextEmbeddingModel):
  79. raise Exception(f"Model type instance is not TextEmbeddingModel")
  80. self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
  81. return self.model_type_instance.invoke(
  82. model=self.model,
  83. credentials=self.credentials,
  84. texts=texts,
  85. user=user
  86. )
  87. def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
  88. user: Optional[str] = None) \
  89. -> RerankResult:
  90. """
  91. Invoke rerank model
  92. :param query: search query
  93. :param docs: docs for reranking
  94. :param score_threshold: score threshold
  95. :param top_n: top n
  96. :param user: unique user id
  97. :return: rerank result
  98. """
  99. if not isinstance(self.model_type_instance, RerankModel):
  100. raise Exception(f"Model type instance is not RerankModel")
  101. self.model_type_instance = cast(RerankModel, self.model_type_instance)
  102. return self.model_type_instance.invoke(
  103. model=self.model,
  104. credentials=self.credentials,
  105. query=query,
  106. docs=docs,
  107. score_threshold=score_threshold,
  108. top_n=top_n,
  109. user=user
  110. )
  111. def invoke_moderation(self, text: str, user: Optional[str] = None) \
  112. -> bool:
  113. """
  114. Invoke moderation model
  115. :param text: text to moderate
  116. :param user: unique user id
  117. :return: false if text is safe, true otherwise
  118. """
  119. if not isinstance(self.model_type_instance, ModerationModel):
  120. raise Exception(f"Model type instance is not ModerationModel")
  121. self.model_type_instance = cast(ModerationModel, self.model_type_instance)
  122. return self.model_type_instance.invoke(
  123. model=self.model,
  124. credentials=self.credentials,
  125. text=text,
  126. user=user
  127. )
  128. def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
  129. -> str:
  130. """
  131. Invoke large language model
  132. :param file: audio file
  133. :param user: unique user id
  134. :return: text for given audio file
  135. """
  136. if not isinstance(self.model_type_instance, Speech2TextModel):
  137. raise Exception(f"Model type instance is not Speech2TextModel")
  138. self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
  139. return self.model_type_instance.invoke(
  140. model=self.model,
  141. credentials=self.credentials,
  142. file=file,
  143. user=user
  144. )
  145. def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
  146. -> str:
  147. """
  148. Invoke large language model
  149. :param content_text: text content to be translated
  150. :param user: unique user id
  151. :param streaming: output is streaming
  152. :return: text for given audio file
  153. """
  154. if not isinstance(self.model_type_instance, TTSModel):
  155. raise Exception(f"Model type instance is not TTSModel")
  156. self.model_type_instance = cast(TTSModel, self.model_type_instance)
  157. return self.model_type_instance.invoke(
  158. model=self.model,
  159. credentials=self.credentials,
  160. content_text=content_text,
  161. user=user,
  162. streaming=streaming
  163. )
  164. class ModelManager:
  165. def __init__(self) -> None:
  166. self._provider_manager = ProviderManager()
  167. def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
  168. """
  169. Get model instance
  170. :param tenant_id: tenant id
  171. :param provider: provider name
  172. :param model_type: model type
  173. :param model: model name
  174. :return:
  175. """
  176. if not provider:
  177. return self.get_default_model_instance(tenant_id, model_type)
  178. provider_model_bundle = self._provider_manager.get_provider_model_bundle(
  179. tenant_id=tenant_id,
  180. provider=provider,
  181. model_type=model_type
  182. )
  183. return ModelInstance(provider_model_bundle, model)
  184. def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
  185. """
  186. Get default model instance
  187. :param tenant_id: tenant id
  188. :param model_type: model type
  189. :return:
  190. """
  191. default_model_entity = self._provider_manager.get_default_model(
  192. tenant_id=tenant_id,
  193. model_type=model_type
  194. )
  195. if not default_model_entity:
  196. raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
  197. return self.get_model_instance(
  198. tenant_id=tenant_id,
  199. provider=default_model_entity.provider.provider,
  200. model_type=model_type,
  201. model=default_model_entity.model
  202. )