completion.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import concurrent
  2. import json
  3. import logging
  4. from concurrent.futures import ThreadPoolExecutor
  5. from typing import Optional, List, Union, Tuple
  6. from flask import current_app, Flask
  7. from requests.exceptions import ChunkedEncodingError
  8. from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
  9. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  10. from core.callback_handler.llm_callback_handler import LLMCallbackHandler
  11. from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
  12. ConversationTaskInterruptException
  13. from core.embedding.cached_embedding import CacheEmbedding
  14. from core.external_data_tool.factory import ExternalDataToolFactory
  15. from core.file.file_obj import FileObj
  16. from core.index.vector_index.vector_index import VectorIndex
  17. from core.model_providers.error import LLMBadRequestError
  18. from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
  19. ReadOnlyConversationTokenDBBufferSharedMemory
  20. from core.model_providers.model_factory import ModelFactory
  21. from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
  22. from core.model_providers.models.llm.base import BaseLLM
  23. from core.orchestrator_rule_parser import OrchestratorRuleParser
  24. from core.prompt.prompt_template import PromptTemplateParser
  25. from core.prompt.prompt_transform import PromptTransform
  26. from models.dataset import Dataset
  27. from models.model import App, AppModelConfig, Account, Conversation, EndUser
  28. from core.moderation.base import ModerationException, ModerationAction
  29. from core.moderation.factory import ModerationFactory
  30. from services.annotation_service import AppAnnotationService
  31. from services.dataset_service import DatasetCollectionBindingService
  32. class Completion:
  33. @classmethod
  34. def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
  35. files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
  36. streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
  37. auto_generate_name: bool = True, from_source: str = 'console'):
  38. """
  39. errors: ProviderTokenNotInitError
  40. """
  41. query = PromptTemplateParser.remove_template_variables(query)
  42. memory = None
  43. if conversation:
  44. # get memory of conversation (read-only)
  45. memory = cls.get_memory_from_conversation(
  46. tenant_id=app.tenant_id,
  47. app_model_config=app_model_config,
  48. conversation=conversation,
  49. return_messages=False
  50. )
  51. inputs = conversation.inputs
  52. final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  53. tenant_id=app.tenant_id,
  54. model_config=app_model_config.model_dict,
  55. streaming=streaming
  56. )
  57. conversation_message_task = ConversationMessageTask(
  58. task_id=task_id,
  59. app=app,
  60. app_model_config=app_model_config,
  61. user=user,
  62. conversation=conversation,
  63. is_override=is_override,
  64. inputs=inputs,
  65. query=query,
  66. files=files,
  67. streaming=streaming,
  68. model_instance=final_model_instance,
  69. auto_generate_name=auto_generate_name
  70. )
  71. prompt_message_files = [file.prompt_message_file for file in files]
  72. rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
  73. mode=app.mode,
  74. model_instance=final_model_instance,
  75. app_model_config=app_model_config,
  76. query=query,
  77. inputs=inputs,
  78. files=prompt_message_files
  79. )
  80. # init orchestrator rule parser
  81. orchestrator_rule_parser = OrchestratorRuleParser(
  82. tenant_id=app.tenant_id,
  83. app_model_config=app_model_config
  84. )
  85. try:
  86. chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
  87. try:
  88. # process sensitive_word_avoidance
  89. inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
  90. except ModerationException as e:
  91. cls.run_final_llm(
  92. model_instance=final_model_instance,
  93. mode=app.mode,
  94. app_model_config=app_model_config,
  95. query=query,
  96. inputs=inputs,
  97. files=prompt_message_files,
  98. agent_execute_result=None,
  99. conversation_message_task=conversation_message_task,
  100. memory=memory,
  101. fake_response=str(e)
  102. )
  103. return
  104. # check annotation reply
  105. annotation_reply = cls.query_app_annotations_to_reply(conversation_message_task, from_source)
  106. if annotation_reply:
  107. return
  108. # fill in variable inputs from external data tools if exists
  109. external_data_tools = app_model_config.external_data_tools_list
  110. if external_data_tools:
  111. inputs = cls.fill_in_inputs_from_external_data_tools(
  112. tenant_id=app.tenant_id,
  113. app_id=app.id,
  114. external_data_tools=external_data_tools,
  115. inputs=inputs,
  116. query=query
  117. )
  118. # get agent executor
  119. agent_executor = orchestrator_rule_parser.to_agent_executor(
  120. conversation_message_task=conversation_message_task,
  121. memory=memory,
  122. rest_tokens=rest_tokens_for_context_and_memory,
  123. chain_callback=chain_callback,
  124. tenant_id=app.tenant_id,
  125. retriever_from=retriever_from
  126. )
  127. query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)
  128. # run agent executor
  129. agent_execute_result = None
  130. if query_for_agent and agent_executor:
  131. should_use_agent = agent_executor.should_use_agent(query_for_agent)
  132. if should_use_agent:
  133. agent_execute_result = agent_executor.run(query_for_agent)
  134. # When no extra pre prompt is specified,
  135. # the output of the agent can be used directly as the main output content without calling LLM again
  136. fake_response = None
  137. if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
  138. and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,
  139. PlanningStrategy.REACT_ROUTER]:
  140. fake_response = agent_execute_result.output
  141. # run the final llm
  142. cls.run_final_llm(
  143. model_instance=final_model_instance,
  144. mode=app.mode,
  145. app_model_config=app_model_config,
  146. query=query,
  147. inputs=inputs,
  148. files=prompt_message_files,
  149. agent_execute_result=agent_execute_result,
  150. conversation_message_task=conversation_message_task,
  151. memory=memory,
  152. fake_response=fake_response
  153. )
  154. except (ConversationTaskInterruptException, ConversationTaskStoppedException):
  155. return
  156. except ChunkedEncodingError as e:
  157. # Interrupt by LLM (like OpenAI), handle it.
  158. logging.warning(f'ChunkedEncodingError: {e}')
  159. return
  160. @classmethod
  161. def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict,
  162. query: str):
  163. if not app_model_config.sensitive_word_avoidance_dict['enabled']:
  164. return inputs, query
  165. type = app_model_config.sensitive_word_avoidance_dict['type']
  166. moderation = ModerationFactory(type, app_id, tenant_id,
  167. app_model_config.sensitive_word_avoidance_dict['config'])
  168. moderation_result = moderation.moderation_for_inputs(inputs, query)
  169. if not moderation_result.flagged:
  170. return inputs, query
  171. if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
  172. raise ModerationException(moderation_result.preset_response)
  173. elif moderation_result.action == ModerationAction.OVERRIDED:
  174. inputs = moderation_result.inputs
  175. query = moderation_result.query
  176. return inputs, query
  177. @classmethod
  178. def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
  179. inputs: dict, query: str) -> dict:
  180. """
  181. Fill in variable inputs from external data tools if exists.
  182. :param tenant_id: workspace id
  183. :param app_id: app id
  184. :param external_data_tools: external data tools configs
  185. :param inputs: the inputs
  186. :param query: the query
  187. :return: the filled inputs
  188. """
  189. # Group tools by type and config
  190. grouped_tools = {}
  191. for tool in external_data_tools:
  192. if not tool.get("enabled"):
  193. continue
  194. tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
  195. grouped_tools.setdefault(tool_key, []).append(tool)
  196. results = {}
  197. with ThreadPoolExecutor() as executor:
  198. futures = {}
  199. for tool in external_data_tools:
  200. if not tool.get("enabled"):
  201. continue
  202. future = executor.submit(
  203. cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
  204. inputs, query
  205. )
  206. futures[future] = tool
  207. for future in concurrent.futures.as_completed(futures):
  208. tool_variable, result = future.result()
  209. results[tool_variable] = result
  210. inputs.update(results)
  211. return inputs
  212. @classmethod
  213. def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
  214. inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
  215. with flask_app.app_context():
  216. tool_variable = external_data_tool.get("variable")
  217. tool_type = external_data_tool.get("type")
  218. tool_config = external_data_tool.get("config")
  219. external_data_tool_factory = ExternalDataToolFactory(
  220. name=tool_type,
  221. tenant_id=tenant_id,
  222. app_id=app_id,
  223. variable=tool_variable,
  224. config=tool_config
  225. )
  226. # query external data tool
  227. result = external_data_tool_factory.query(
  228. inputs=inputs,
  229. query=query
  230. )
  231. return tool_variable, result
  232. @classmethod
  233. def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
  234. if app.mode != 'completion':
  235. return query
  236. return inputs.get(app_model_config.dataset_query_variable, "")
  237. @classmethod
  238. def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
  239. inputs: dict,
  240. files: List[PromptMessageFile],
  241. agent_execute_result: Optional[AgentExecuteResult],
  242. conversation_message_task: ConversationMessageTask,
  243. memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
  244. fake_response: Optional[str]):
  245. prompt_transform = PromptTransform()
  246. # get llm prompt
  247. if app_model_config.prompt_type == 'simple':
  248. prompt_messages, stop_words = prompt_transform.get_prompt(
  249. app_mode=mode,
  250. pre_prompt=app_model_config.pre_prompt,
  251. inputs=inputs,
  252. query=query,
  253. files=files,
  254. context=agent_execute_result.output if agent_execute_result else None,
  255. memory=memory,
  256. model_instance=model_instance
  257. )
  258. else:
  259. prompt_messages = prompt_transform.get_advanced_prompt(
  260. app_mode=mode,
  261. app_model_config=app_model_config,
  262. inputs=inputs,
  263. query=query,
  264. files=files,
  265. context=agent_execute_result.output if agent_execute_result else None,
  266. memory=memory,
  267. model_instance=model_instance
  268. )
  269. model_config = app_model_config.model_dict
  270. completion_params = model_config.get("completion_params", {})
  271. stop_words = completion_params.get("stop", [])
  272. cls.recale_llm_max_tokens(
  273. model_instance=model_instance,
  274. prompt_messages=prompt_messages,
  275. )
  276. response = model_instance.run(
  277. messages=prompt_messages,
  278. stop=stop_words if stop_words else None,
  279. callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
  280. fake_response=fake_response
  281. )
  282. return response
  283. @classmethod
  284. def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
  285. max_token_limit: int) -> str:
  286. """Get memory messages."""
  287. memory.max_token_limit = max_token_limit
  288. memory_key = memory.memory_variables[0]
  289. external_context = memory.load_memory_variables({})
  290. return external_context[memory_key]
  291. @classmethod
  292. def query_app_annotations_to_reply(cls, conversation_message_task: ConversationMessageTask,
  293. from_source: str) -> bool:
  294. """Get memory messages."""
  295. app_model_config = conversation_message_task.app_model_config
  296. app = conversation_message_task.app
  297. annotation_reply = app_model_config.annotation_reply_dict
  298. if annotation_reply['enabled']:
  299. try:
  300. score_threshold = annotation_reply.get('score_threshold', 1)
  301. embedding_provider_name = annotation_reply['embedding_model']['embedding_provider_name']
  302. embedding_model_name = annotation_reply['embedding_model']['embedding_model_name']
  303. # get embedding model
  304. embedding_model = ModelFactory.get_embedding_model(
  305. tenant_id=app.tenant_id,
  306. model_provider_name=embedding_provider_name,
  307. model_name=embedding_model_name
  308. )
  309. embeddings = CacheEmbedding(embedding_model)
  310. dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
  311. embedding_provider_name,
  312. embedding_model_name,
  313. 'annotation'
  314. )
  315. dataset = Dataset(
  316. id=app.id,
  317. tenant_id=app.tenant_id,
  318. indexing_technique='high_quality',
  319. embedding_model_provider=embedding_provider_name,
  320. embedding_model=embedding_model_name,
  321. collection_binding_id=dataset_collection_binding.id
  322. )
  323. vector_index = VectorIndex(
  324. dataset=dataset,
  325. config=current_app.config,
  326. embeddings=embeddings,
  327. attributes=['doc_id', 'annotation_id', 'app_id']
  328. )
  329. documents = vector_index.search(
  330. conversation_message_task.query,
  331. search_type='similarity_score_threshold',
  332. search_kwargs={
  333. 'k': 1,
  334. 'score_threshold': score_threshold,
  335. 'filter': {
  336. 'group_id': [dataset.id]
  337. }
  338. }
  339. )
  340. if documents:
  341. annotation_id = documents[0].metadata['annotation_id']
  342. score = documents[0].metadata['score']
  343. annotation = AppAnnotationService.get_annotation_by_id(annotation_id)
  344. if annotation:
  345. conversation_message_task.annotation_end(annotation.content, annotation.id, annotation.account.name)
  346. # insert annotation history
  347. AppAnnotationService.add_annotation_history(annotation.id,
  348. app.id,
  349. annotation.question,
  350. annotation.content,
  351. conversation_message_task.query,
  352. conversation_message_task.user.id,
  353. conversation_message_task.message.id,
  354. from_source,
  355. score)
  356. return True
  357. except Exception as e:
  358. logging.warning(f'Query annotation failed, exception: {str(e)}.')
  359. return False
  360. return False
  361. @classmethod
  362. def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,
  363. conversation: Conversation,
  364. **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:
  365. # only for calc token in memory
  366. memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(
  367. tenant_id=tenant_id,
  368. model_config=app_model_config.model_dict
  369. )
  370. # use llm config from conversation
  371. memory = ReadOnlyConversationTokenDBBufferSharedMemory(
  372. conversation=conversation,
  373. model_instance=memory_model_instance,
  374. max_token_limit=kwargs.get("max_token_limit", 2048),
  375. memory_key=kwargs.get("memory_key", "chat_history"),
  376. return_messages=kwargs.get("return_messages", True),
  377. input_key=kwargs.get("input_key", "input"),
  378. output_key=kwargs.get("output_key", "output"),
  379. message_limit=kwargs.get("message_limit", 10),
  380. )
  381. return memory
  382. @classmethod
  383. def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
  384. query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
  385. model_limited_tokens = model_instance.model_rules.max_tokens.max
  386. max_tokens = model_instance.get_model_kwargs().max_tokens
  387. if model_limited_tokens is None:
  388. return -1
  389. if max_tokens is None:
  390. max_tokens = 0
  391. prompt_transform = PromptTransform()
  392. # get prompt without memory and context
  393. if app_model_config.prompt_type == 'simple':
  394. prompt_messages, _ = prompt_transform.get_prompt(
  395. app_mode=mode,
  396. pre_prompt=app_model_config.pre_prompt,
  397. inputs=inputs,
  398. query=query,
  399. files=files,
  400. context=None,
  401. memory=None,
  402. model_instance=model_instance
  403. )
  404. else:
  405. prompt_messages = prompt_transform.get_advanced_prompt(
  406. app_mode=mode,
  407. app_model_config=app_model_config,
  408. inputs=inputs,
  409. query=query,
  410. files=files,
  411. context=None,
  412. memory=None,
  413. model_instance=model_instance
  414. )
  415. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  416. rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
  417. if rest_tokens < 0:
  418. raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
  419. "or shrink the max token, or switch to a llm with a larger token limit size.")
  420. return rest_tokens
  421. @classmethod
  422. def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):
  423. # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
  424. model_limited_tokens = model_instance.model_rules.max_tokens.max
  425. max_tokens = model_instance.get_model_kwargs().max_tokens
  426. if model_limited_tokens is None:
  427. return
  428. if max_tokens is None:
  429. max_tokens = 0
  430. prompt_tokens = model_instance.get_num_tokens(prompt_messages)
  431. if prompt_tokens + max_tokens > model_limited_tokens:
  432. max_tokens = max(model_limited_tokens - prompt_tokens, 16)
  433. # update model instance max tokens
  434. model_kwargs = model_instance.get_model_kwargs()
  435. model_kwargs.max_tokens = max_tokens
  436. model_instance.set_model_kwargs(model_kwargs)