model_manager.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from collections.abc import Generator
  2. from typing import IO, Optional, Union, cast
  3. from core.entities.provider_configuration import ProviderModelBundle
  4. from core.errors.error import ProviderTokenNotInitError
  5. from core.model_runtime.callbacks.base_callback import Callback
  6. from core.model_runtime.entities.llm_entities import LLMResult
  7. from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
  8. from core.model_runtime.entities.model_entities import ModelType
  9. from core.model_runtime.entities.rerank_entities import RerankResult
  10. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  11. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  12. from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
  13. from core.model_runtime.model_providers.__base.rerank_model import RerankModel
  14. from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
  15. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  16. from core.model_runtime.model_providers.__base.tts_model import TTSModel
  17. from core.provider_manager import ProviderManager
  18. class ModelInstance:
  19. """
  20. Model instance class
  21. """
  22. def __init__(self, provider_model_bundle: ProviderModelBundle, model: str) -> None:
  23. self._provider_model_bundle = provider_model_bundle
  24. self.model = model
  25. self.provider = provider_model_bundle.configuration.provider.provider
  26. self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
  27. self.model_type_instance = self._provider_model_bundle.model_type_instance
  28. def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBundle, model: str) -> dict:
  29. """
  30. Fetch credentials from provider model bundle
  31. :param provider_model_bundle: provider model bundle
  32. :param model: model name
  33. :return:
  34. """
  35. credentials = provider_model_bundle.configuration.get_current_credentials(
  36. model_type=provider_model_bundle.model_type_instance.model_type,
  37. model=model
  38. )
  39. if credentials is None:
  40. raise ProviderTokenNotInitError(f"Model {model} credentials is not initialized.")
  41. return credentials
  42. def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
  43. tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
  44. stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \
  45. -> Union[LLMResult, Generator]:
  46. """
  47. Invoke large language model
  48. :param prompt_messages: prompt messages
  49. :param model_parameters: model parameters
  50. :param tools: tools for tool calling
  51. :param stop: stop words
  52. :param stream: is stream response
  53. :param user: unique user id
  54. :param callbacks: callbacks
  55. :return: full response or stream response chunk generator result
  56. """
  57. if not isinstance(self.model_type_instance, LargeLanguageModel):
  58. raise Exception("Model type instance is not LargeLanguageModel")
  59. self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
  60. return self.model_type_instance.invoke(
  61. model=self.model,
  62. credentials=self.credentials,
  63. prompt_messages=prompt_messages,
  64. model_parameters=model_parameters,
  65. tools=tools,
  66. stop=stop,
  67. stream=stream,
  68. user=user,
  69. callbacks=callbacks
  70. )
  71. def invoke_text_embedding(self, texts: list[str], user: Optional[str] = None) \
  72. -> TextEmbeddingResult:
  73. """
  74. Invoke large language model
  75. :param texts: texts to embed
  76. :param user: unique user id
  77. :return: embeddings result
  78. """
  79. if not isinstance(self.model_type_instance, TextEmbeddingModel):
  80. raise Exception("Model type instance is not TextEmbeddingModel")
  81. self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
  82. return self.model_type_instance.invoke(
  83. model=self.model,
  84. credentials=self.credentials,
  85. texts=texts,
  86. user=user
  87. )
  88. def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
  89. user: Optional[str] = None) \
  90. -> RerankResult:
  91. """
  92. Invoke rerank model
  93. :param query: search query
  94. :param docs: docs for reranking
  95. :param score_threshold: score threshold
  96. :param top_n: top n
  97. :param user: unique user id
  98. :return: rerank result
  99. """
  100. if not isinstance(self.model_type_instance, RerankModel):
  101. raise Exception("Model type instance is not RerankModel")
  102. self.model_type_instance = cast(RerankModel, self.model_type_instance)
  103. return self.model_type_instance.invoke(
  104. model=self.model,
  105. credentials=self.credentials,
  106. query=query,
  107. docs=docs,
  108. score_threshold=score_threshold,
  109. top_n=top_n,
  110. user=user
  111. )
  112. def invoke_moderation(self, text: str, user: Optional[str] = None) \
  113. -> bool:
  114. """
  115. Invoke moderation model
  116. :param text: text to moderate
  117. :param user: unique user id
  118. :return: false if text is safe, true otherwise
  119. """
  120. if not isinstance(self.model_type_instance, ModerationModel):
  121. raise Exception("Model type instance is not ModerationModel")
  122. self.model_type_instance = cast(ModerationModel, self.model_type_instance)
  123. return self.model_type_instance.invoke(
  124. model=self.model,
  125. credentials=self.credentials,
  126. text=text,
  127. user=user
  128. )
  129. def invoke_speech2text(self, file: IO[bytes], user: Optional[str] = None) \
  130. -> str:
  131. """
  132. Invoke large language model
  133. :param file: audio file
  134. :param user: unique user id
  135. :return: text for given audio file
  136. """
  137. if not isinstance(self.model_type_instance, Speech2TextModel):
  138. raise Exception("Model type instance is not Speech2TextModel")
  139. self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
  140. return self.model_type_instance.invoke(
  141. model=self.model,
  142. credentials=self.credentials,
  143. file=file,
  144. user=user
  145. )
  146. def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \
  147. -> str:
  148. """
  149. Invoke large language model
  150. :param content_text: text content to be translated
  151. :param user: unique user id
  152. :param streaming: output is streaming
  153. :return: text for given audio file
  154. """
  155. if not isinstance(self.model_type_instance, TTSModel):
  156. raise Exception("Model type instance is not TTSModel")
  157. self.model_type_instance = cast(TTSModel, self.model_type_instance)
  158. return self.model_type_instance.invoke(
  159. model=self.model,
  160. credentials=self.credentials,
  161. content_text=content_text,
  162. user=user,
  163. streaming=streaming
  164. )
  165. class ModelManager:
  166. def __init__(self) -> None:
  167. self._provider_manager = ProviderManager()
  168. def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance:
  169. """
  170. Get model instance
  171. :param tenant_id: tenant id
  172. :param provider: provider name
  173. :param model_type: model type
  174. :param model: model name
  175. :return:
  176. """
  177. if not provider:
  178. return self.get_default_model_instance(tenant_id, model_type)
  179. provider_model_bundle = self._provider_manager.get_provider_model_bundle(
  180. tenant_id=tenant_id,
  181. provider=provider,
  182. model_type=model_type
  183. )
  184. return ModelInstance(provider_model_bundle, model)
  185. def get_default_model_instance(self, tenant_id: str, model_type: ModelType) -> ModelInstance:
  186. """
  187. Get default model instance
  188. :param tenant_id: tenant id
  189. :param model_type: model type
  190. :return:
  191. """
  192. default_model_entity = self._provider_manager.get_default_model(
  193. tenant_id=tenant_id,
  194. model_type=model_type
  195. )
  196. if not default_model_entity:
  197. raise ProviderTokenNotInitError(f"Default model not found for {model_type}")
  198. return self.get_model_instance(
  199. tenant_id=tenant_id,
  200. provider=default_model_entity.provider.provider,
  201. model_type=model_type,
  202. model=default_model_entity.model
  203. )