base.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. import json
  2. import os
  3. import re
  4. import time
  5. from abc import abstractmethod
  6. from typing import List, Optional, Any, Union, Tuple
  7. import decimal
  8. from langchain.callbacks.manager import Callbacks
  9. from langchain.memory.chat_memory import BaseChatMemory
  10. from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
  11. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  12. from core.helper import moderation
  13. from core.model_providers.models.base import BaseProviderModel
  14. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
  15. to_lc_messages
  16. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  17. from core.model_providers.providers.base import BaseModelProvider
  18. from core.prompt.prompt_builder import PromptBuilder
  19. from core.prompt.prompt_template import PromptTemplateParser
  20. from core.third_party.langchain.llms.fake import FakeLLM
  21. import logging
  22. from extensions.ext_database import db
  23. logger = logging.getLogger(__name__)
  24. class BaseLLM(BaseProviderModel):
  25. model_mode: ModelMode = ModelMode.COMPLETION
  26. name: str
  27. model_kwargs: ModelKwargs
  28. credentials: dict
  29. streaming: bool = False
  30. type: ModelType = ModelType.TEXT_GENERATION
  31. deduct_quota: bool = True
  32. def __init__(self, model_provider: BaseModelProvider,
  33. name: str,
  34. model_kwargs: ModelKwargs,
  35. streaming: bool = False,
  36. callbacks: Callbacks = None):
  37. self.name = name
  38. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  39. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  40. max_tokens=None,
  41. temperature=None,
  42. top_p=None,
  43. presence_penalty=None,
  44. frequency_penalty=None
  45. )
  46. self.credentials = model_provider.get_model_credentials(
  47. model_name=name,
  48. model_type=self.type
  49. )
  50. self.streaming = streaming
  51. if streaming:
  52. default_callback = DifyStreamingStdOutCallbackHandler()
  53. else:
  54. default_callback = DifyStdOutCallbackHandler()
  55. if not callbacks:
  56. callbacks = [default_callback]
  57. else:
  58. callbacks.append(default_callback)
  59. self.callbacks = callbacks
  60. client = self._init_client()
  61. super().__init__(model_provider, client)
  62. @abstractmethod
  63. def _init_client(self) -> Any:
  64. raise NotImplementedError
  65. @property
  66. def base_model_name(self) -> str:
  67. """
  68. get llm base model name
  69. :return: str
  70. """
  71. return self.name
  72. @property
  73. def price_config(self) -> dict:
  74. def get_or_default():
  75. default_price_config = {
  76. 'prompt': decimal.Decimal('0'),
  77. 'completion': decimal.Decimal('0'),
  78. 'unit': decimal.Decimal('0'),
  79. 'currency': 'USD'
  80. }
  81. rules = self.model_provider.get_rules()
  82. price_config = rules['price_config'][
  83. self.base_model_name] if 'price_config' in rules else default_price_config
  84. price_config = {
  85. 'prompt': decimal.Decimal(price_config['prompt']),
  86. 'completion': decimal.Decimal(price_config['completion']),
  87. 'unit': decimal.Decimal(price_config['unit']),
  88. 'currency': price_config['currency']
  89. }
  90. return price_config
  91. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  92. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  93. return self._price_config
  94. def run(self, messages: List[PromptMessage],
  95. stop: Optional[List[str]] = None,
  96. callbacks: Callbacks = None,
  97. **kwargs) -> LLMRunResult:
  98. """
  99. run predict by prompt messages and stop words.
  100. :param messages:
  101. :param stop:
  102. :param callbacks:
  103. :return:
  104. """
  105. moderation_result = moderation.check_moderation(
  106. self.model_provider,
  107. "\n".join([message.content for message in messages])
  108. )
  109. if not moderation_result:
  110. kwargs['fake_response'] = "I apologize for any confusion, " \
  111. "but I'm an AI assistant to be helpful, harmless, and honest."
  112. if self.deduct_quota:
  113. self.model_provider.check_quota_over_limit()
  114. if not callbacks:
  115. callbacks = self.callbacks
  116. else:
  117. callbacks.extend(self.callbacks)
  118. if 'fake_response' in kwargs and kwargs['fake_response']:
  119. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  120. fake_llm = FakeLLM(
  121. response=kwargs['fake_response'],
  122. num_token_func=self.get_num_tokens,
  123. streaming=self.streaming,
  124. callbacks=callbacks
  125. )
  126. result = fake_llm.generate([prompts])
  127. else:
  128. try:
  129. result = self._run(
  130. messages=messages,
  131. stop=stop,
  132. callbacks=callbacks if not (self.streaming and not self.support_streaming) else None,
  133. **kwargs
  134. )
  135. except Exception as ex:
  136. raise self.handle_exceptions(ex)
  137. function_call = None
  138. if isinstance(result.generations[0][0], ChatGeneration):
  139. completion_content = result.generations[0][0].message.content
  140. if 'function_call' in result.generations[0][0].message.additional_kwargs:
  141. function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
  142. else:
  143. completion_content = result.generations[0][0].text
  144. if self.streaming and not self.support_streaming:
  145. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  146. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  147. fake_llm = FakeLLM(
  148. response=completion_content,
  149. num_token_func=self.get_num_tokens,
  150. streaming=self.streaming,
  151. callbacks=callbacks
  152. )
  153. fake_llm.generate([prompts])
  154. if result.llm_output and result.llm_output['token_usage']:
  155. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  156. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  157. total_tokens = result.llm_output['token_usage']['total_tokens']
  158. else:
  159. prompt_tokens = self.get_num_tokens(messages)
  160. completion_tokens = self.get_num_tokens(
  161. [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  162. total_tokens = prompt_tokens + completion_tokens
  163. self.model_provider.update_last_used()
  164. if self.deduct_quota:
  165. self.model_provider.deduct_quota(total_tokens)
  166. return LLMRunResult(
  167. content=completion_content,
  168. prompt_tokens=prompt_tokens,
  169. completion_tokens=completion_tokens,
  170. function_call=function_call
  171. )
  172. @abstractmethod
  173. def _run(self, messages: List[PromptMessage],
  174. stop: Optional[List[str]] = None,
  175. callbacks: Callbacks = None,
  176. **kwargs) -> LLMResult:
  177. """
  178. run predict by prompt messages and stop words.
  179. :param messages:
  180. :param stop:
  181. :param callbacks:
  182. :return:
  183. """
  184. raise NotImplementedError
  185. @abstractmethod
  186. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  187. """
  188. get num tokens of prompt messages.
  189. :param messages:
  190. :return:
  191. """
  192. raise NotImplementedError
  193. def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
  194. """
  195. calc tokens total price.
  196. :param tokens:
  197. :param message_type:
  198. :return:
  199. """
  200. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  201. unit_price = self.price_config['prompt']
  202. else:
  203. unit_price = self.price_config['completion']
  204. unit = self.get_price_unit(message_type)
  205. total_price = tokens * unit_price * unit
  206. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  207. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  208. return total_price
  209. def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
  210. """
  211. get token price.
  212. :param message_type:
  213. :return: decimal.Decimal('0.0001')
  214. """
  215. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  216. unit_price = self.price_config['prompt']
  217. else:
  218. unit_price = self.price_config['completion']
  219. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  220. logging.debug(f"unit_price={unit_price}")
  221. return unit_price
  222. def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
  223. """
  224. get price unit.
  225. :param message_type:
  226. :return: decimal.Decimal('0.000001')
  227. """
  228. if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
  229. price_unit = self.price_config['unit']
  230. else:
  231. price_unit = self.price_config['unit']
  232. price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
  233. logging.debug(f"price_unit={price_unit}")
  234. return price_unit
  235. def get_currency(self) -> str:
  236. """
  237. get token currency.
  238. :return: get from price config, default 'USD'
  239. """
  240. currency = self.price_config['currency']
  241. return currency
  242. def get_model_kwargs(self):
  243. return self.model_kwargs
  244. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  245. self.model_kwargs = model_kwargs
  246. self._set_model_kwargs(model_kwargs)
  247. @abstractmethod
  248. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  249. raise NotImplementedError
  250. @abstractmethod
  251. def handle_exceptions(self, ex: Exception) -> Exception:
  252. """
  253. Handle llm run exceptions.
  254. :param ex:
  255. :return:
  256. """
  257. raise NotImplementedError
  258. def add_callbacks(self, callbacks: Callbacks):
  259. """
  260. Add callbacks to client.
  261. :param callbacks:
  262. :return:
  263. """
  264. if not self.client.callbacks:
  265. self.client.callbacks = callbacks
  266. else:
  267. self.client.callbacks.extend(callbacks)
  268. @property
  269. def support_streaming(self):
  270. return False
  271. def get_prompt(self, mode: str,
  272. pre_prompt: str, inputs: dict,
  273. query: str,
  274. context: Optional[str],
  275. memory: Optional[BaseChatMemory]) -> \
  276. Tuple[List[PromptMessage], Optional[List[str]]]:
  277. prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
  278. prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
  279. return [PromptMessage(content=prompt)], stops
  280. def get_advanced_prompt(self, app_mode: str,
  281. app_model_config: str, inputs: dict,
  282. query: str,
  283. context: Optional[str],
  284. memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
  285. model_mode = app_model_config.model_dict['mode']
  286. conversation_histories_role = {}
  287. raw_prompt_list = []
  288. prompt_messages = []
  289. if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
  290. prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
  291. raw_prompt_list = [{
  292. 'role': MessageType.USER.value,
  293. 'text': prompt_text
  294. }]
  295. conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
  296. elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
  297. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  298. elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
  299. raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
  300. elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
  301. prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
  302. raw_prompt_list = [{
  303. 'role': MessageType.USER.value,
  304. 'text': prompt_text
  305. }]
  306. else:
  307. raise Exception("app_mode or model_mode not support")
  308. for prompt_item in raw_prompt_list:
  309. prompt = prompt_item['text']
  310. # set prompt template variables
  311. prompt_template = PromptTemplateParser(template=prompt)
  312. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  313. if '#context#' in prompt:
  314. if context:
  315. prompt_inputs['#context#'] = context
  316. else:
  317. prompt_inputs['#context#'] = ''
  318. if '#query#' in prompt:
  319. if query:
  320. prompt_inputs['#query#'] = query
  321. else:
  322. prompt_inputs['#query#'] = ''
  323. if '#histories#' in prompt:
  324. if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
  325. memory.human_prefix = conversation_histories_role['user_prefix']
  326. memory.ai_prefix = conversation_histories_role['assistant_prefix']
  327. histories = self._get_history_messages_from_memory(memory, 2000)
  328. prompt_inputs['#histories#'] = histories
  329. else:
  330. prompt_inputs['#histories#'] = ''
  331. prompt = prompt_template.format(
  332. prompt_inputs
  333. )
  334. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  335. prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
  336. if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
  337. memory.human_prefix = MessageType.USER.value
  338. memory.ai_prefix = MessageType.ASSISTANT.value
  339. histories = self._get_history_messages_list_from_memory(memory, 2000)
  340. prompt_messages.extend(histories)
  341. if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
  342. prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
  343. return prompt_messages
  344. def prompt_file_name(self, mode: str) -> str:
  345. if mode == 'completion':
  346. return 'common_completion'
  347. else:
  348. return 'common_chat'
  349. def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  350. query: str,
  351. context: Optional[str],
  352. memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
  353. context_prompt_content = ''
  354. if context and 'context_prompt' in prompt_rules:
  355. prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
  356. context_prompt_content = prompt_template.format(
  357. {'context': context}
  358. )
  359. pre_prompt_content = ''
  360. if pre_prompt:
  361. prompt_template = PromptTemplateParser(template=pre_prompt)
  362. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  363. pre_prompt_content = prompt_template.format(
  364. prompt_inputs
  365. )
  366. prompt = ''
  367. for order in prompt_rules['system_prompt_orders']:
  368. if order == 'context_prompt':
  369. prompt += context_prompt_content
  370. elif order == 'pre_prompt':
  371. prompt += pre_prompt_content
  372. query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
  373. if memory and 'histories_prompt' in prompt_rules:
  374. # append chat histories
  375. tmp_human_message = PromptBuilder.to_human_message(
  376. prompt_content=prompt + query_prompt,
  377. inputs={
  378. 'query': query
  379. }
  380. )
  381. if self.model_rules.max_tokens.max:
  382. curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
  383. max_tokens = self.model_kwargs.max_tokens
  384. rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  385. rest_tokens = max(rest_tokens, 0)
  386. else:
  387. rest_tokens = 2000
  388. memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
  389. memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  390. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  391. prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
  392. histories_prompt_content = prompt_template.format({'histories': histories})
  393. prompt = ''
  394. for order in prompt_rules['system_prompt_orders']:
  395. if order == 'context_prompt':
  396. prompt += context_prompt_content
  397. elif order == 'pre_prompt':
  398. prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
  399. elif order == 'histories_prompt':
  400. prompt += histories_prompt_content
  401. prompt_template = PromptTemplateParser(template=query_prompt)
  402. query_prompt_content = prompt_template.format({'query': query})
  403. prompt += query_prompt_content
  404. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  405. stops = prompt_rules.get('stops')
  406. if stops is not None and len(stops) == 0:
  407. stops = None
  408. return prompt, stops
  409. def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
  410. # Get the absolute path of the subdirectory
  411. prompt_path = os.path.join(
  412. os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
  413. 'prompt/generate_prompts')
  414. json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
  415. # Open the JSON file and read its content
  416. with open(json_file_path, 'r') as json_file:
  417. return json.load(json_file)
  418. def _get_history_messages_from_memory(self, memory: BaseChatMemory,
  419. max_token_limit: int) -> str:
  420. """Get memory messages."""
  421. memory.max_token_limit = max_token_limit
  422. memory_key = memory.memory_variables[0]
  423. external_context = memory.load_memory_variables({})
  424. return external_context[memory_key]
  425. def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
  426. max_token_limit: int) -> List[PromptMessage]:
  427. """Get memory messages."""
  428. memory.max_token_limit = max_token_limit
  429. memory.return_messages = True
  430. memory_key = memory.memory_variables[0]
  431. external_context = memory.load_memory_variables({})
  432. memory.return_messages = False
  433. return to_prompt_messages(external_context[memory_key])
  434. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  435. model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
  436. if not model_mode:
  437. model_mode = self.model_mode
  438. if model_mode == ModelMode.COMPLETION:
  439. if len(messages) == 0:
  440. return ''
  441. return messages[0].content
  442. else:
  443. if len(messages) == 0:
  444. return []
  445. return to_lc_messages(messages)
  446. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  447. """
  448. convert model kwargs to provider model kwargs.
  449. :param model_rules:
  450. :param model_kwargs:
  451. :return:
  452. """
  453. model_kwargs_input = {}
  454. for key, value in model_kwargs.dict().items():
  455. rule = getattr(model_rules, key)
  456. if not rule.enabled:
  457. continue
  458. if rule.alias:
  459. key = rule.alias
  460. if rule.default is not None and value is None:
  461. value = rule.default
  462. if rule.min is not None:
  463. value = max(value, rule.min)
  464. if rule.max is not None:
  465. value = min(value, rule.max)
  466. model_kwargs_input[key] = value
  467. return model_kwargs_input