base.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. from abc import abstractmethod
  2. from typing import List, Optional, Any, Union
  3. import decimal
  4. import logging
  5. from langchain.callbacks.manager import Callbacks
  6. from langchain.schema import LLMResult, BaseMessage, ChatGeneration
  7. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  8. from core.helper import moderation
  9. from core.model_providers.models.base import BaseProviderModel
  10. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
  11. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  12. from core.model_providers.providers.base import BaseModelProvider
  13. from core.third_party.langchain.llms.fake import FakeLLM
  14. logger = logging.getLogger(__name__)
  15. class BaseLLM(BaseProviderModel):
  16. model_mode: ModelMode = ModelMode.COMPLETION
  17. name: str
  18. model_kwargs: ModelKwargs
  19. credentials: dict
  20. streaming: bool = False
  21. type: ModelType = ModelType.TEXT_GENERATION
  22. deduct_quota: bool = True
  23. def __init__(self, model_provider: BaseModelProvider,
  24. name: str,
  25. model_kwargs: ModelKwargs,
  26. streaming: bool = False,
  27. callbacks: Callbacks = None):
  28. self.name = name
  29. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  30. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  31. max_tokens=None,
  32. temperature=None,
  33. top_p=None,
  34. presence_penalty=None,
  35. frequency_penalty=None
  36. )
  37. self.credentials = model_provider.get_model_credentials(
  38. model_name=name,
  39. model_type=self.type
  40. )
  41. self.streaming = streaming
  42. if streaming:
  43. default_callback = DifyStreamingStdOutCallbackHandler()
  44. else:
  45. default_callback = DifyStdOutCallbackHandler()
  46. if not callbacks:
  47. callbacks = [default_callback]
  48. else:
  49. callbacks.append(default_callback)
  50. self.callbacks = callbacks
  51. client = self._init_client()
  52. super().__init__(model_provider, client)
  53. @abstractmethod
  54. def _init_client(self) -> Any:
  55. raise NotImplementedError
  56. @property
  57. def base_model_name(self) -> str:
  58. """
  59. get llm base model name
  60. :return: str
  61. """
  62. return self.name
  63. @property
  64. def price_config(self) -> dict:
  65. def get_or_default():
  66. default_price_config = {
  67. 'prompt': decimal.Decimal('0'),
  68. 'completion': decimal.Decimal('0'),
  69. 'unit': decimal.Decimal('0'),
  70. 'currency': 'USD'
  71. }
  72. rules = self.model_provider.get_rules()
  73. price_config = rules['price_config'][
  74. self.base_model_name] if 'price_config' in rules else default_price_config
  75. price_config = {
  76. 'prompt': decimal.Decimal(price_config['prompt']),
  77. 'completion': decimal.Decimal(price_config['completion']),
  78. 'unit': decimal.Decimal(price_config['unit']),
  79. 'currency': price_config['currency']
  80. }
  81. return price_config
  82. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  83. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  84. return self._price_config
  85. def run(self, messages: List[PromptMessage],
  86. stop: Optional[List[str]] = None,
  87. callbacks: Callbacks = None,
  88. **kwargs) -> LLMRunResult:
  89. """
  90. run predict by prompt messages and stop words.
  91. :param messages:
  92. :param stop:
  93. :param callbacks:
  94. :return:
  95. """
  96. moderation_result = moderation.check_moderation(
  97. self.model_provider,
  98. "\n".join([message.content for message in messages])
  99. )
  100. if not moderation_result:
  101. kwargs['fake_response'] = "I apologize for any confusion, " \
  102. "but I'm an AI assistant to be helpful, harmless, and honest."
  103. if self.deduct_quota:
  104. self.model_provider.check_quota_over_limit()
  105. if not callbacks:
  106. callbacks = self.callbacks
  107. else:
  108. callbacks.extend(self.callbacks)
  109. if 'fake_response' in kwargs and kwargs['fake_response']:
  110. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  111. fake_llm = FakeLLM(
  112. response=kwargs['fake_response'],
  113. num_token_func=self.get_num_tokens,
  114. streaming=self.streaming,
  115. callbacks=callbacks
  116. )
  117. result = fake_llm.generate([prompts])
  118. else:
  119. try:
  120. result = self._run(
  121. messages=messages,
  122. stop=stop,
  123. callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
  124. **kwargs
  125. )
  126. except Exception as ex:
  127. raise self.handle_exceptions(ex)
  128. function_call = None
  129. if isinstance(result.generations[0][0], ChatGeneration):
  130. completion_content = result.generations[0][0].message.content
  131. if 'function_call' in result.generations[0][0].message.additional_kwargs:
  132. function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
  133. else:
  134. completion_content = result.generations[0][0].text
  135. if self.streaming and not self.support_streaming:
  136. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  137. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  138. fake_llm = FakeLLM(
  139. response=completion_content,
  140. num_token_func=self.get_num_tokens,
  141. streaming=self.streaming,
  142. callbacks=callbacks
  143. )
  144. fake_llm.generate([prompts])
  145. if result.llm_output and result.llm_output['token_usage']:
  146. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  147. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  148. total_tokens = result.llm_output['token_usage']['total_tokens']
  149. else:
  150. prompt_tokens = self.get_num_tokens(messages)
  151. completion_tokens = self.get_num_tokens(
  152. [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  153. total_tokens = prompt_tokens + completion_tokens
  154. self.model_provider.update_last_used()
  155. if self.deduct_quota:
  156. self.model_provider.deduct_quota(total_tokens)
  157. return LLMRunResult(
  158. content=completion_content,
  159. prompt_tokens=prompt_tokens,
  160. completion_tokens=completion_tokens,
  161. function_call=function_call
  162. )
  163. @abstractmethod
  164. def _run(self, messages: List[PromptMessage],
  165. stop: Optional[List[str]] = None,
  166. callbacks: Callbacks = None,
  167. **kwargs) -> LLMResult:
  168. """
  169. run predict by prompt messages and stop words.
  170. :param messages:
  171. :param stop:
  172. :param callbacks:
  173. :return:
  174. """
  175. raise NotImplementedError
  176. @abstractmethod
  177. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  178. """
  179. get num tokens of prompt messages.
  180. :param messages:
  181. :return:
  182. """
  183. raise NotImplementedError
  184. def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
  185. """
  186. calc tokens total price.
  187. :param tokens:
  188. :param message_type:
  189. :return:
  190. """
  191. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  192. unit_price = self.price_config['prompt']
  193. else:
  194. unit_price = self.price_config['completion']
  195. unit = self.get_price_unit(message_type)
  196. total_price = tokens * unit_price * unit
  197. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  198. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  199. return total_price
  200. def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
  201. """
  202. get token price.
  203. :param message_type:
  204. :return: decimal.Decimal('0.0001')
  205. """
  206. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  207. unit_price = self.price_config['prompt']
  208. else:
  209. unit_price = self.price_config['completion']
  210. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  211. logging.debug(f"unit_price={unit_price}")
  212. return unit_price
  213. def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
  214. """
  215. get price unit.
  216. :param message_type:
  217. :return: decimal.Decimal('0.000001')
  218. """
  219. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  220. price_unit = self.price_config['unit']
  221. else:
  222. price_unit = self.price_config['unit']
  223. price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
  224. logging.debug(f"price_unit={price_unit}")
  225. return price_unit
  226. def get_currency(self) -> str:
  227. """
  228. get token currency.
  229. :return: get from price config, default 'USD'
  230. """
  231. currency = self.price_config['currency']
  232. return currency
  233. def get_model_kwargs(self):
  234. return self.model_kwargs
  235. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  236. self.model_kwargs = model_kwargs
  237. self._set_model_kwargs(model_kwargs)
  238. @abstractmethod
  239. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  240. raise NotImplementedError
  241. @abstractmethod
  242. def handle_exceptions(self, ex: Exception) -> Exception:
  243. """
  244. Handle llm run exceptions.
  245. :param ex:
  246. :return:
  247. """
  248. raise NotImplementedError
  249. def add_callbacks(self, callbacks: Callbacks):
  250. """
  251. Add callbacks to client.
  252. :param callbacks:
  253. :return:
  254. """
  255. if not self.client.callbacks:
  256. self.client.callbacks = callbacks
  257. else:
  258. self.client.callbacks.extend(callbacks)
  259. @property
  260. def support_streaming(self):
  261. return False
  262. @property
  263. def support_function_call(self):
  264. return False
  265. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  266. model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
  267. if not model_mode:
  268. model_mode = self.model_mode
  269. if model_mode == ModelMode.COMPLETION:
  270. if len(messages) == 0:
  271. return ''
  272. return messages[0].content
  273. else:
  274. if len(messages) == 0:
  275. return []
  276. return to_lc_messages(messages)
  277. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  278. """
  279. convert model kwargs to provider model kwargs.
  280. :param model_rules:
  281. :param model_kwargs:
  282. :return:
  283. """
  284. model_kwargs_input = {}
  285. for key, value in model_kwargs.dict().items():
  286. rule = getattr(model_rules, key)
  287. if not rule.enabled:
  288. continue
  289. if rule.alias:
  290. key = rule.alias
  291. if rule.default is not None and value is None:
  292. value = rule.default
  293. if rule.min is not None:
  294. value = max(value, rule.min)
  295. if rule.max is not None:
  296. value = min(value, rule.max)
  297. model_kwargs_input[key] = value
  298. return model_kwargs_input