base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. import json
  2. import os
  3. import re
  4. import time
  5. from abc import abstractmethod
  6. from typing import List, Optional, Any, Union, Tuple
  7. import decimal
  8. from langchain.callbacks.manager import Callbacks
  9. from langchain.memory.chat_memory import BaseChatMemory
  10. from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
  11. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  12. from core.helper import moderation
  13. from core.model_providers.models.base import BaseProviderModel
  14. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
  15. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  16. from core.model_providers.providers.base import BaseModelProvider
  17. from core.prompt.prompt_builder import PromptBuilder
  18. from core.prompt.prompt_template import JinjaPromptTemplate
  19. from core.third_party.langchain.llms.fake import FakeLLM
  20. import logging
  21. from extensions.ext_database import db
  22. logger = logging.getLogger(__name__)
  23. class BaseLLM(BaseProviderModel):
  24. model_mode: ModelMode = ModelMode.COMPLETION
  25. name: str
  26. model_kwargs: ModelKwargs
  27. credentials: dict
  28. streaming: bool = False
  29. type: ModelType = ModelType.TEXT_GENERATION
  30. deduct_quota: bool = True
  31. def __init__(self, model_provider: BaseModelProvider,
  32. name: str,
  33. model_kwargs: ModelKwargs,
  34. streaming: bool = False,
  35. callbacks: Callbacks = None):
  36. self.name = name
  37. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  38. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  39. max_tokens=None,
  40. temperature=None,
  41. top_p=None,
  42. presence_penalty=None,
  43. frequency_penalty=None
  44. )
  45. self.credentials = model_provider.get_model_credentials(
  46. model_name=name,
  47. model_type=self.type
  48. )
  49. self.streaming = streaming
  50. if streaming:
  51. default_callback = DifyStreamingStdOutCallbackHandler()
  52. else:
  53. default_callback = DifyStdOutCallbackHandler()
  54. if not callbacks:
  55. callbacks = [default_callback]
  56. else:
  57. callbacks.append(default_callback)
  58. self.callbacks = callbacks
  59. client = self._init_client()
  60. super().__init__(model_provider, client)
  61. @abstractmethod
  62. def _init_client(self) -> Any:
  63. raise NotImplementedError
  64. @property
  65. def base_model_name(self) -> str:
  66. """
  67. get llm base model name
  68. :return: str
  69. """
  70. return self.name
  71. @property
  72. def price_config(self) -> dict:
  73. def get_or_default():
  74. default_price_config = {
  75. 'prompt': decimal.Decimal('0'),
  76. 'completion': decimal.Decimal('0'),
  77. 'unit': decimal.Decimal('0'),
  78. 'currency': 'USD'
  79. }
  80. rules = self.model_provider.get_rules()
  81. price_config = rules['price_config'][
  82. self.base_model_name] if 'price_config' in rules else default_price_config
  83. price_config = {
  84. 'prompt': decimal.Decimal(price_config['prompt']),
  85. 'completion': decimal.Decimal(price_config['completion']),
  86. 'unit': decimal.Decimal(price_config['unit']),
  87. 'currency': price_config['currency']
  88. }
  89. return price_config
  90. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  91. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  92. return self._price_config
  93. def run(self, messages: List[PromptMessage],
  94. stop: Optional[List[str]] = None,
  95. callbacks: Callbacks = None,
  96. **kwargs) -> LLMRunResult:
  97. """
  98. run predict by prompt messages and stop words.
  99. :param messages:
  100. :param stop:
  101. :param callbacks:
  102. :return:
  103. """
  104. moderation_result = moderation.check_moderation(
  105. self.model_provider,
  106. "\n".join([message.content for message in messages])
  107. )
  108. if not moderation_result:
  109. kwargs['fake_response'] = "I apologize for any confusion, " \
  110. "but I'm an AI assistant to be helpful, harmless, and honest."
  111. if self.deduct_quota:
  112. self.model_provider.check_quota_over_limit()
  113. db.session.commit()
  114. if not callbacks:
  115. callbacks = self.callbacks
  116. else:
  117. callbacks.extend(self.callbacks)
  118. if 'fake_response' in kwargs and kwargs['fake_response']:
  119. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  120. fake_llm = FakeLLM(
  121. response=kwargs['fake_response'],
  122. num_token_func=self.get_num_tokens,
  123. streaming=self.streaming,
  124. callbacks=callbacks
  125. )
  126. result = fake_llm.generate([prompts])
  127. else:
  128. try:
  129. result = self._run(
  130. messages=messages,
  131. stop=stop,
  132. callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
  133. **kwargs
  134. )
  135. except Exception as ex:
  136. raise self.handle_exceptions(ex)
  137. if isinstance(result.generations[0][0], ChatGeneration):
  138. completion_content = result.generations[0][0].message.content
  139. else:
  140. completion_content = result.generations[0][0].text
  141. if self.streaming and not self.support_streaming:
  142. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  143. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  144. fake_llm = FakeLLM(
  145. response=completion_content,
  146. num_token_func=self.get_num_tokens,
  147. streaming=self.streaming,
  148. callbacks=callbacks
  149. )
  150. fake_llm.generate([prompts])
  151. if result.llm_output and result.llm_output['token_usage']:
  152. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  153. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  154. total_tokens = result.llm_output['token_usage']['total_tokens']
  155. else:
  156. prompt_tokens = self.get_num_tokens(messages)
  157. completion_tokens = self.get_num_tokens(
  158. [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  159. total_tokens = prompt_tokens + completion_tokens
  160. self.model_provider.update_last_used()
  161. if self.deduct_quota:
  162. self.model_provider.deduct_quota(total_tokens)
  163. return LLMRunResult(
  164. content=completion_content,
  165. prompt_tokens=prompt_tokens,
  166. completion_tokens=completion_tokens
  167. )
  168. @abstractmethod
  169. def _run(self, messages: List[PromptMessage],
  170. stop: Optional[List[str]] = None,
  171. callbacks: Callbacks = None,
  172. **kwargs) -> LLMResult:
  173. """
  174. run predict by prompt messages and stop words.
  175. :param messages:
  176. :param stop:
  177. :param callbacks:
  178. :return:
  179. """
  180. raise NotImplementedError
  181. @abstractmethod
  182. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  183. """
  184. get num tokens of prompt messages.
  185. :param messages:
  186. :return:
  187. """
  188. raise NotImplementedError
  189. def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
  190. """
  191. calc tokens total price.
  192. :param tokens:
  193. :param message_type:
  194. :return:
  195. """
  196. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  197. unit_price = self.price_config['prompt']
  198. else:
  199. unit_price = self.price_config['completion']
  200. unit = self.get_price_unit(message_type)
  201. total_price = tokens * unit_price * unit
  202. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  203. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  204. return total_price
  205. def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
  206. """
  207. get token price.
  208. :param message_type:
  209. :return: decimal.Decimal('0.0001')
  210. """
  211. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  212. unit_price = self.price_config['prompt']
  213. else:
  214. unit_price = self.price_config['completion']
  215. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  216. logging.debug(f"unit_price={unit_price}")
  217. return unit_price
  218. def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
  219. """
  220. get price unit.
  221. :param message_type:
  222. :return: decimal.Decimal('0.000001')
  223. """
  224. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  225. price_unit = self.price_config['unit']
  226. else:
  227. price_unit = self.price_config['unit']
  228. price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
  229. logging.debug(f"price_unit={price_unit}")
  230. return price_unit
  231. def get_currency(self) -> str:
  232. """
  233. get token currency.
  234. :return: get from price config, default 'USD'
  235. """
  236. currency = self.price_config['currency']
  237. return currency
  238. def get_model_kwargs(self):
  239. return self.model_kwargs
  240. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  241. self.model_kwargs = model_kwargs
  242. self._set_model_kwargs(model_kwargs)
  243. @abstractmethod
  244. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  245. raise NotImplementedError
  246. @abstractmethod
  247. def handle_exceptions(self, ex: Exception) -> Exception:
  248. """
  249. Handle llm run exceptions.
  250. :param ex:
  251. :return:
  252. """
  253. raise NotImplementedError
  254. def add_callbacks(self, callbacks: Callbacks):
  255. """
  256. Add callbacks to client.
  257. :param callbacks:
  258. :return:
  259. """
  260. if not self.client.callbacks:
  261. self.client.callbacks = callbacks
  262. else:
  263. self.client.callbacks.extend(callbacks)
  264. @property
  265. def support_streaming(self):
  266. return False
  267. def get_prompt(self, mode: str,
  268. pre_prompt: str, inputs: dict,
  269. query: str,
  270. context: Optional[str],
  271. memory: Optional[BaseChatMemory]) -> \
  272. Tuple[List[PromptMessage], Optional[List[str]]]:
  273. prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
  274. prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
  275. return [PromptMessage(content=prompt)], stops
  276. def prompt_file_name(self, mode: str) -> str:
  277. if mode == 'completion':
  278. return 'common_completion'
  279. else:
  280. return 'common_chat'
  281. def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  282. query: str,
  283. context: Optional[str],
  284. memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
  285. context_prompt_content = ''
  286. if context and 'context_prompt' in prompt_rules:
  287. prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
  288. context_prompt_content = prompt_template.format(
  289. context=context
  290. )
  291. pre_prompt_content = ''
  292. if pre_prompt:
  293. prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
  294. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  295. pre_prompt_content = prompt_template.format(
  296. **prompt_inputs
  297. )
  298. prompt = ''
  299. for order in prompt_rules['system_prompt_orders']:
  300. if order == 'context_prompt':
  301. prompt += context_prompt_content
  302. elif order == 'pre_prompt':
  303. prompt += pre_prompt_content
  304. query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
  305. if memory and 'histories_prompt' in prompt_rules:
  306. # append chat histories
  307. tmp_human_message = PromptBuilder.to_human_message(
  308. prompt_content=prompt + query_prompt,
  309. inputs={
  310. 'query': query
  311. }
  312. )
  313. if self.model_rules.max_tokens.max:
  314. curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
  315. max_tokens = self.model_kwargs.max_tokens
  316. rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  317. rest_tokens = max(rest_tokens, 0)
  318. else:
  319. rest_tokens = 2000
  320. memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
  321. memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  322. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  323. prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
  324. histories_prompt_content = prompt_template.format(
  325. histories=histories
  326. )
  327. prompt = ''
  328. for order in prompt_rules['system_prompt_orders']:
  329. if order == 'context_prompt':
  330. prompt += context_prompt_content
  331. elif order == 'pre_prompt':
  332. prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
  333. elif order == 'histories_prompt':
  334. prompt += histories_prompt_content
  335. prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
  336. query_prompt_content = prompt_template.format(
  337. query=query
  338. )
  339. prompt += query_prompt_content
  340. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  341. stops = prompt_rules.get('stops')
  342. if stops is not None and len(stops) == 0:
  343. stops = None
  344. return prompt, stops
  345. def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
  346. # Get the absolute path of the subdirectory
  347. prompt_path = os.path.join(
  348. os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
  349. 'prompt/generate_prompts')
  350. json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
  351. # Open the JSON file and read its content
  352. with open(json_file_path, 'r') as json_file:
  353. return json.load(json_file)
  354. def _get_history_messages_from_memory(self, memory: BaseChatMemory,
  355. max_token_limit: int) -> str:
  356. """Get memory messages."""
  357. memory.max_token_limit = max_token_limit
  358. memory_key = memory.memory_variables[0]
  359. external_context = memory.load_memory_variables({})
  360. return external_context[memory_key]
  361. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  362. model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
  363. if not model_mode:
  364. model_mode = self.model_mode
  365. if model_mode == ModelMode.COMPLETION:
  366. if len(messages) == 0:
  367. return ''
  368. return messages[0].content
  369. else:
  370. if len(messages) == 0:
  371. return []
  372. chat_messages = []
  373. for message in messages:
  374. if message.type == MessageType.HUMAN:
  375. chat_messages.append(HumanMessage(content=message.content))
  376. elif message.type == MessageType.ASSISTANT:
  377. chat_messages.append(AIMessage(content=message.content))
  378. elif message.type == MessageType.SYSTEM:
  379. chat_messages.append(SystemMessage(content=message.content))
  380. return chat_messages
  381. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  382. """
  383. convert model kwargs to provider model kwargs.
  384. :param model_rules:
  385. :param model_kwargs:
  386. :return:
  387. """
  388. model_kwargs_input = {}
  389. for key, value in model_kwargs.dict().items():
  390. rule = getattr(model_rules, key)
  391. if not rule.enabled:
  392. continue
  393. if rule.alias:
  394. key = rule.alias
  395. if rule.default is not None and value is None:
  396. value = rule.default
  397. if rule.min is not None:
  398. value = max(value, rule.min)
  399. if rule.max is not None:
  400. value = min(value, rule.max)
  401. model_kwargs_input[key] = value
  402. return model_kwargs_input