base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. from abc import abstractmethod
  2. from typing import List, Optional, Any, Union
  3. import decimal
  4. from langchain.callbacks.manager import Callbacks
  5. from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
  6. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  7. from core.model_providers.models.base import BaseProviderModel
  8. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
  9. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  10. from core.model_providers.providers.base import BaseModelProvider
  11. from core.third_party.langchain.llms.fake import FakeLLM
  12. import logging
  13. logger = logging.getLogger(__name__)
  14. class BaseLLM(BaseProviderModel):
  15. model_mode: ModelMode = ModelMode.COMPLETION
  16. name: str
  17. model_kwargs: ModelKwargs
  18. credentials: dict
  19. streaming: bool = False
  20. type: ModelType = ModelType.TEXT_GENERATION
  21. deduct_quota: bool = True
  22. def __init__(self, model_provider: BaseModelProvider,
  23. name: str,
  24. model_kwargs: ModelKwargs,
  25. streaming: bool = False,
  26. callbacks: Callbacks = None):
  27. self.name = name
  28. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  29. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  30. max_tokens=None,
  31. temperature=None,
  32. top_p=None,
  33. presence_penalty=None,
  34. frequency_penalty=None
  35. )
  36. self.credentials = model_provider.get_model_credentials(
  37. model_name=name,
  38. model_type=self.type
  39. )
  40. self.streaming = streaming
  41. if streaming:
  42. default_callback = DifyStreamingStdOutCallbackHandler()
  43. else:
  44. default_callback = DifyStdOutCallbackHandler()
  45. if not callbacks:
  46. callbacks = [default_callback]
  47. else:
  48. callbacks.append(default_callback)
  49. self.callbacks = callbacks
  50. client = self._init_client()
  51. super().__init__(model_provider, client)
  52. @abstractmethod
  53. def _init_client(self) -> Any:
  54. raise NotImplementedError
  55. @property
  56. def base_model_name(self) -> str:
  57. """
  58. get llm base model name
  59. :return: str
  60. """
  61. return self.name
  62. @property
  63. def price_config(self) -> dict:
  64. def get_or_default():
  65. default_price_config = {
  66. 'prompt': decimal.Decimal('0'),
  67. 'completion': decimal.Decimal('0'),
  68. 'unit': decimal.Decimal('0'),
  69. 'currency': 'USD'
  70. }
  71. rules = self.model_provider.get_rules()
  72. price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
  73. price_config = {
  74. 'prompt': decimal.Decimal(price_config['prompt']),
  75. 'completion': decimal.Decimal(price_config['completion']),
  76. 'unit': decimal.Decimal(price_config['unit']),
  77. 'currency': price_config['currency']
  78. }
  79. return price_config
  80. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  81. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  82. return self._price_config
  83. def run(self, messages: List[PromptMessage],
  84. stop: Optional[List[str]] = None,
  85. callbacks: Callbacks = None,
  86. **kwargs) -> LLMRunResult:
  87. """
  88. run predict by prompt messages and stop words.
  89. :param messages:
  90. :param stop:
  91. :param callbacks:
  92. :return:
  93. """
  94. if self.deduct_quota:
  95. self.model_provider.check_quota_over_limit()
  96. if not callbacks:
  97. callbacks = self.callbacks
  98. else:
  99. callbacks.extend(self.callbacks)
  100. if 'fake_response' in kwargs and kwargs['fake_response']:
  101. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  102. fake_llm = FakeLLM(
  103. response=kwargs['fake_response'],
  104. num_token_func=self.get_num_tokens,
  105. streaming=self.streaming,
  106. callbacks=callbacks
  107. )
  108. result = fake_llm.generate([prompts])
  109. else:
  110. try:
  111. result = self._run(
  112. messages=messages,
  113. stop=stop,
  114. callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
  115. **kwargs
  116. )
  117. except Exception as ex:
  118. raise self.handle_exceptions(ex)
  119. if isinstance(result.generations[0][0], ChatGeneration):
  120. completion_content = result.generations[0][0].message.content
  121. else:
  122. completion_content = result.generations[0][0].text
  123. if self.streaming and not self.support_streaming():
  124. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  125. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  126. fake_llm = FakeLLM(
  127. response=completion_content,
  128. num_token_func=self.get_num_tokens,
  129. streaming=self.streaming,
  130. callbacks=callbacks
  131. )
  132. fake_llm.generate([prompts])
  133. if result.llm_output and result.llm_output['token_usage']:
  134. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  135. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  136. total_tokens = result.llm_output['token_usage']['total_tokens']
  137. else:
  138. prompt_tokens = self.get_num_tokens(messages)
  139. completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  140. total_tokens = prompt_tokens + completion_tokens
  141. self.model_provider.update_last_used()
  142. if self.deduct_quota:
  143. self.model_provider.deduct_quota(total_tokens)
  144. return LLMRunResult(
  145. content=completion_content,
  146. prompt_tokens=prompt_tokens,
  147. completion_tokens=completion_tokens
  148. )
  149. @abstractmethod
  150. def _run(self, messages: List[PromptMessage],
  151. stop: Optional[List[str]] = None,
  152. callbacks: Callbacks = None,
  153. **kwargs) -> LLMResult:
  154. """
  155. run predict by prompt messages and stop words.
  156. :param messages:
  157. :param stop:
  158. :param callbacks:
  159. :return:
  160. """
  161. raise NotImplementedError
  162. @abstractmethod
  163. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  164. """
  165. get num tokens of prompt messages.
  166. :param messages:
  167. :return:
  168. """
  169. raise NotImplementedError
  170. def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
  171. """
  172. calc tokens total price.
  173. :param tokens:
  174. :param message_type:
  175. :return:
  176. """
  177. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  178. unit_price = self.price_config['prompt']
  179. else:
  180. unit_price = self.price_config['completion']
  181. unit = self.get_price_unit(message_type)
  182. total_price = tokens * unit_price * unit
  183. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  184. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  185. return total_price
  186. def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
  187. """
  188. get token price.
  189. :param message_type:
  190. :return: decimal.Decimal('0.0001')
  191. """
  192. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  193. unit_price = self.price_config['prompt']
  194. else:
  195. unit_price = self.price_config['completion']
  196. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  197. logging.debug(f"unit_price={unit_price}")
  198. return unit_price
  199. def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
  200. """
  201. get price unit.
  202. :param message_type:
  203. :return: decimal.Decimal('0.000001')
  204. """
  205. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  206. price_unit = self.price_config['unit']
  207. else:
  208. price_unit = self.price_config['unit']
  209. price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
  210. logging.debug(f"price_unit={price_unit}")
  211. return price_unit
  212. def get_currency(self) -> str:
  213. """
  214. get token currency.
  215. :return: get from price config, default 'USD'
  216. """
  217. currency = self.price_config['currency']
  218. return currency
  219. def get_model_kwargs(self):
  220. return self.model_kwargs
  221. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  222. self.model_kwargs = model_kwargs
  223. self._set_model_kwargs(model_kwargs)
  224. @abstractmethod
  225. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  226. raise NotImplementedError
  227. @abstractmethod
  228. def handle_exceptions(self, ex: Exception) -> Exception:
  229. """
  230. Handle llm run exceptions.
  231. :param ex:
  232. :return:
  233. """
  234. raise NotImplementedError
  235. def add_callbacks(self, callbacks: Callbacks):
  236. """
  237. Add callbacks to client.
  238. :param callbacks:
  239. :return:
  240. """
  241. if not self.client.callbacks:
  242. self.client.callbacks = callbacks
  243. else:
  244. self.client.callbacks.extend(callbacks)
  245. @classmethod
  246. def support_streaming(cls):
  247. return False
  248. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  249. model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
  250. if not model_mode:
  251. model_mode = self.model_mode
  252. if model_mode == ModelMode.COMPLETION:
  253. if len(messages) == 0:
  254. return ''
  255. return messages[0].content
  256. else:
  257. if len(messages) == 0:
  258. return []
  259. chat_messages = []
  260. for message in messages:
  261. if message.type == MessageType.HUMAN:
  262. chat_messages.append(HumanMessage(content=message.content))
  263. elif message.type == MessageType.ASSISTANT:
  264. chat_messages.append(AIMessage(content=message.content))
  265. elif message.type == MessageType.SYSTEM:
  266. chat_messages.append(SystemMessage(content=message.content))
  267. return chat_messages
  268. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  269. """
  270. convert model kwargs to provider model kwargs.
  271. :param model_rules:
  272. :param model_kwargs:
  273. :return:
  274. """
  275. model_kwargs_input = {}
  276. for key, value in model_kwargs.dict().items():
  277. rule = getattr(model_rules, key)
  278. if not rule.enabled:
  279. continue
  280. if rule.alias:
  281. key = rule.alias
  282. if rule.default is not None and value is None:
  283. value = rule.default
  284. if rule.min is not None:
  285. value = max(value, rule.min)
  286. if rule.max is not None:
  287. value = min(value, rule.max)
  288. model_kwargs_input[key] = value
  289. return model_kwargs_input