base.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from abc import abstractmethod
  2. from typing import List, Optional, Any, Union
  3. from langchain.callbacks.manager import Callbacks
  4. from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
  5. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  6. from core.model_providers.models.base import BaseProviderModel
  7. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
  8. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  9. from core.model_providers.providers.base import BaseModelProvider
  10. from core.third_party.langchain.llms.fake import FakeLLM
  11. class BaseLLM(BaseProviderModel):
  12. model_mode: ModelMode = ModelMode.COMPLETION
  13. name: str
  14. model_kwargs: ModelKwargs
  15. credentials: dict
  16. streaming: bool = False
  17. type: ModelType = ModelType.TEXT_GENERATION
  18. deduct_quota: bool = True
  19. def __init__(self, model_provider: BaseModelProvider,
  20. name: str,
  21. model_kwargs: ModelKwargs,
  22. streaming: bool = False,
  23. callbacks: Callbacks = None):
  24. self.name = name
  25. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  26. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  27. max_tokens=None,
  28. temperature=None,
  29. top_p=None,
  30. presence_penalty=None,
  31. frequency_penalty=None
  32. )
  33. self.credentials = model_provider.get_model_credentials(
  34. model_name=name,
  35. model_type=self.type
  36. )
  37. self.streaming = streaming
  38. if streaming:
  39. default_callback = DifyStreamingStdOutCallbackHandler()
  40. else:
  41. default_callback = DifyStdOutCallbackHandler()
  42. if not callbacks:
  43. callbacks = [default_callback]
  44. else:
  45. callbacks.append(default_callback)
  46. self.callbacks = callbacks
  47. client = self._init_client()
  48. super().__init__(model_provider, client)
  49. @abstractmethod
  50. def _init_client(self) -> Any:
  51. raise NotImplementedError
  52. def run(self, messages: List[PromptMessage],
  53. stop: Optional[List[str]] = None,
  54. callbacks: Callbacks = None,
  55. **kwargs) -> LLMRunResult:
  56. """
  57. run predict by prompt messages and stop words.
  58. :param messages:
  59. :param stop:
  60. :param callbacks:
  61. :return:
  62. """
  63. if self.deduct_quota:
  64. self.model_provider.check_quota_over_limit()
  65. if not callbacks:
  66. callbacks = self.callbacks
  67. else:
  68. callbacks.extend(self.callbacks)
  69. if 'fake_response' in kwargs and kwargs['fake_response']:
  70. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  71. fake_llm = FakeLLM(
  72. response=kwargs['fake_response'],
  73. num_token_func=self.get_num_tokens,
  74. streaming=self.streaming,
  75. callbacks=callbacks
  76. )
  77. result = fake_llm.generate([prompts])
  78. else:
  79. try:
  80. result = self._run(
  81. messages=messages,
  82. stop=stop,
  83. callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
  84. **kwargs
  85. )
  86. except Exception as ex:
  87. raise self.handle_exceptions(ex)
  88. if isinstance(result.generations[0][0], ChatGeneration):
  89. completion_content = result.generations[0][0].message.content
  90. else:
  91. completion_content = result.generations[0][0].text
  92. if self.streaming and not self.support_streaming():
  93. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  94. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  95. fake_llm = FakeLLM(
  96. response=completion_content,
  97. num_token_func=self.get_num_tokens,
  98. streaming=self.streaming,
  99. callbacks=callbacks
  100. )
  101. fake_llm.generate([prompts])
  102. if result.llm_output and result.llm_output['token_usage']:
  103. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  104. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  105. total_tokens = result.llm_output['token_usage']['total_tokens']
  106. else:
  107. prompt_tokens = self.get_num_tokens(messages)
  108. completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  109. total_tokens = prompt_tokens + completion_tokens
  110. self.model_provider.update_last_used()
  111. if self.deduct_quota:
  112. self.model_provider.deduct_quota(total_tokens)
  113. return LLMRunResult(
  114. content=completion_content,
  115. prompt_tokens=prompt_tokens,
  116. completion_tokens=completion_tokens
  117. )
  118. @abstractmethod
  119. def _run(self, messages: List[PromptMessage],
  120. stop: Optional[List[str]] = None,
  121. callbacks: Callbacks = None,
  122. **kwargs) -> LLMResult:
  123. """
  124. run predict by prompt messages and stop words.
  125. :param messages:
  126. :param stop:
  127. :param callbacks:
  128. :return:
  129. """
  130. raise NotImplementedError
  131. @abstractmethod
  132. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  133. """
  134. get num tokens of prompt messages.
  135. :param messages:
  136. :return:
  137. """
  138. raise NotImplementedError
  139. @abstractmethod
  140. def get_token_price(self, tokens: int, message_type: MessageType):
  141. """
  142. get token price.
  143. :param tokens:
  144. :param message_type:
  145. :return:
  146. """
  147. raise NotImplementedError
  148. @abstractmethod
  149. def get_currency(self):
  150. """
  151. get token currency.
  152. :return:
  153. """
  154. raise NotImplementedError
  155. def get_model_kwargs(self):
  156. return self.model_kwargs
  157. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  158. self.model_kwargs = model_kwargs
  159. self._set_model_kwargs(model_kwargs)
  160. @abstractmethod
  161. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  162. raise NotImplementedError
  163. @abstractmethod
  164. def handle_exceptions(self, ex: Exception) -> Exception:
  165. """
  166. Handle llm run exceptions.
  167. :param ex:
  168. :return:
  169. """
  170. raise NotImplementedError
  171. def add_callbacks(self, callbacks: Callbacks):
  172. """
  173. Add callbacks to client.
  174. :param callbacks:
  175. :return:
  176. """
  177. if not self.client.callbacks:
  178. self.client.callbacks = callbacks
  179. else:
  180. self.client.callbacks.extend(callbacks)
  181. @classmethod
  182. def support_streaming(cls):
  183. return False
  184. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  185. model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
  186. if not model_mode:
  187. model_mode = self.model_mode
  188. if model_mode == ModelMode.COMPLETION:
  189. if len(messages) == 0:
  190. return ''
  191. return messages[0].content
  192. else:
  193. if len(messages) == 0:
  194. return []
  195. chat_messages = []
  196. for message in messages:
  197. if message.type == MessageType.HUMAN:
  198. chat_messages.append(HumanMessage(content=message.content))
  199. elif message.type == MessageType.ASSISTANT:
  200. chat_messages.append(AIMessage(content=message.content))
  201. elif message.type == MessageType.SYSTEM:
  202. chat_messages.append(SystemMessage(content=message.content))
  203. return chat_messages
  204. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  205. """
  206. convert model kwargs to provider model kwargs.
  207. :param model_rules:
  208. :param model_kwargs:
  209. :return:
  210. """
  211. model_kwargs_input = {}
  212. for key, value in model_kwargs.dict().items():
  213. rule = getattr(model_rules, key)
  214. if not rule.enabled:
  215. continue
  216. if rule.alias:
  217. key = rule.alias
  218. if rule.default is not None and value is None:
  219. value = rule.default
  220. if rule.min is not None:
  221. value = max(value, rule.min)
  222. if rule.max is not None:
  223. value = min(value, rule.max)
  224. model_kwargs_input[key] = value
  225. return model_kwargs_input