application_manager.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753
  1. import json
  2. import logging
  3. import threading
  4. import uuid
  5. from collections.abc import Generator
  6. from typing import Any, Optional, Union, cast
  7. from flask import Flask, current_app
  8. from pydantic import ValidationError
  9. from core.app_runner.assistant_app_runner import AssistantApplicationRunner
  10. from core.app_runner.basic_app_runner import BasicApplicationRunner
  11. from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
  12. from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
  13. from core.entities.application_entities import (
  14. AdvancedChatPromptTemplateEntity,
  15. AdvancedCompletionPromptTemplateEntity,
  16. AgentEntity,
  17. AgentPromptEntity,
  18. AgentToolEntity,
  19. ApplicationGenerateEntity,
  20. AppOrchestrationConfigEntity,
  21. DatasetEntity,
  22. DatasetRetrieveConfigEntity,
  23. ExternalDataVariableEntity,
  24. FileUploadEntity,
  25. InvokeFrom,
  26. ModelConfigEntity,
  27. PromptTemplateEntity,
  28. SensitiveWordAvoidanceEntity,
  29. TextToSpeechEntity,
  30. )
  31. from core.entities.model_entities import ModelStatus
  32. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  33. from core.file.file_obj import FileObj
  34. from core.model_runtime.entities.message_entities import PromptMessageRole
  35. from core.model_runtime.entities.model_entities import ModelType
  36. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  37. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  38. from core.prompt.prompt_template import PromptTemplateParser
  39. from core.provider_manager import ProviderManager
  40. from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
  41. from extensions.ext_database import db
  42. from models.account import Account
  43. from models.model import App, Conversation, EndUser, Message, MessageFile
  44. logger = logging.getLogger(__name__)
  45. class ApplicationManager:
  46. """
  47. This class is responsible for managing application
  48. """
  49. def generate(self, tenant_id: str,
  50. app_id: str,
  51. app_model_config_id: str,
  52. app_model_config_dict: dict,
  53. app_model_config_override: bool,
  54. user: Union[Account, EndUser],
  55. invoke_from: InvokeFrom,
  56. inputs: dict[str, str],
  57. query: Optional[str] = None,
  58. files: Optional[list[FileObj]] = None,
  59. conversation: Optional[Conversation] = None,
  60. stream: bool = False,
  61. extras: Optional[dict[str, Any]] = None) \
  62. -> Union[dict, Generator]:
  63. """
  64. Generate App response.
  65. :param tenant_id: workspace ID
  66. :param app_id: app ID
  67. :param app_model_config_id: app model config id
  68. :param app_model_config_dict: app model config dict
  69. :param app_model_config_override: app model config override
  70. :param user: account or end user
  71. :param invoke_from: invoke from source
  72. :param inputs: inputs
  73. :param query: query
  74. :param files: file obj list
  75. :param conversation: conversation
  76. :param stream: is stream
  77. :param extras: extras
  78. """
  79. # init task id
  80. task_id = str(uuid.uuid4())
  81. # init application generate entity
  82. application_generate_entity = ApplicationGenerateEntity(
  83. task_id=task_id,
  84. tenant_id=tenant_id,
  85. app_id=app_id,
  86. app_model_config_id=app_model_config_id,
  87. app_model_config_dict=app_model_config_dict,
  88. app_orchestration_config_entity=self._convert_from_app_model_config_dict(
  89. tenant_id=tenant_id,
  90. app_model_config_dict=app_model_config_dict
  91. ),
  92. app_model_config_override=app_model_config_override,
  93. conversation_id=conversation.id if conversation else None,
  94. inputs=conversation.inputs if conversation else inputs,
  95. query=query.replace('\x00', '') if query else None,
  96. files=files if files else [],
  97. user_id=user.id,
  98. stream=stream,
  99. invoke_from=invoke_from,
  100. extras=extras
  101. )
  102. if not stream and application_generate_entity.app_orchestration_config_entity.agent:
  103. raise ValueError("Agent app is not supported in blocking mode.")
  104. # init generate records
  105. (
  106. conversation,
  107. message
  108. ) = self._init_generate_records(application_generate_entity)
  109. # init queue manager
  110. queue_manager = ApplicationQueueManager(
  111. task_id=application_generate_entity.task_id,
  112. user_id=application_generate_entity.user_id,
  113. invoke_from=application_generate_entity.invoke_from,
  114. conversation_id=conversation.id,
  115. app_mode=conversation.mode,
  116. message_id=message.id
  117. )
  118. # new thread
  119. worker_thread = threading.Thread(target=self._generate_worker, kwargs={
  120. 'flask_app': current_app._get_current_object(),
  121. 'application_generate_entity': application_generate_entity,
  122. 'queue_manager': queue_manager,
  123. 'conversation_id': conversation.id,
  124. 'message_id': message.id,
  125. })
  126. worker_thread.start()
  127. # return response or stream generator
  128. return self._handle_response(
  129. application_generate_entity=application_generate_entity,
  130. queue_manager=queue_manager,
  131. conversation=conversation,
  132. message=message,
  133. stream=stream
  134. )
  135. def _generate_worker(self, flask_app: Flask,
  136. application_generate_entity: ApplicationGenerateEntity,
  137. queue_manager: ApplicationQueueManager,
  138. conversation_id: str,
  139. message_id: str) -> None:
  140. """
  141. Generate worker in a new thread.
  142. :param flask_app: Flask app
  143. :param application_generate_entity: application generate entity
  144. :param queue_manager: queue manager
  145. :param conversation_id: conversation ID
  146. :param message_id: message ID
  147. :return:
  148. """
  149. with flask_app.app_context():
  150. try:
  151. # get conversation and message
  152. conversation = self._get_conversation(conversation_id)
  153. message = self._get_message(message_id)
  154. if application_generate_entity.app_orchestration_config_entity.agent:
  155. # agent app
  156. runner = AssistantApplicationRunner()
  157. runner.run(
  158. application_generate_entity=application_generate_entity,
  159. queue_manager=queue_manager,
  160. conversation=conversation,
  161. message=message
  162. )
  163. else:
  164. # basic app
  165. runner = BasicApplicationRunner()
  166. runner.run(
  167. application_generate_entity=application_generate_entity,
  168. queue_manager=queue_manager,
  169. conversation=conversation,
  170. message=message
  171. )
  172. except ConversationTaskStoppedException:
  173. pass
  174. except InvokeAuthorizationError:
  175. queue_manager.publish_error(
  176. InvokeAuthorizationError('Incorrect API key provided'),
  177. PublishFrom.APPLICATION_MANAGER
  178. )
  179. except ValidationError as e:
  180. logger.exception("Validation Error when generating")
  181. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  182. except Exception as e:
  183. logger.exception("Unknown Error when generating")
  184. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  185. finally:
  186. db.session.close()
  187. def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
  188. queue_manager: ApplicationQueueManager,
  189. conversation: Conversation,
  190. message: Message,
  191. stream: bool = False) -> Union[dict, Generator]:
  192. """
  193. Handle response.
  194. :param application_generate_entity: application generate entity
  195. :param queue_manager: queue manager
  196. :param conversation: conversation
  197. :param message: message
  198. :param stream: is stream
  199. :return:
  200. """
  201. # init generate task pipeline
  202. generate_task_pipeline = GenerateTaskPipeline(
  203. application_generate_entity=application_generate_entity,
  204. queue_manager=queue_manager,
  205. conversation=conversation,
  206. message=message
  207. )
  208. try:
  209. return generate_task_pipeline.process(stream=stream)
  210. except ValueError as e:
  211. if e.args[0] == "I/O operation on closed file.": # ignore this error
  212. raise ConversationTaskStoppedException()
  213. else:
  214. logger.exception(e)
  215. raise e
  216. def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
  217. -> AppOrchestrationConfigEntity:
  218. """
  219. Convert app model config dict to entity.
  220. :param tenant_id: tenant ID
  221. :param app_model_config_dict: app model config dict
  222. :raises ProviderTokenNotInitError: provider token not init error
  223. :return: app orchestration config entity
  224. """
  225. properties = {}
  226. copy_app_model_config_dict = app_model_config_dict.copy()
  227. provider_manager = ProviderManager()
  228. provider_model_bundle = provider_manager.get_provider_model_bundle(
  229. tenant_id=tenant_id,
  230. provider=copy_app_model_config_dict['model']['provider'],
  231. model_type=ModelType.LLM
  232. )
  233. provider_name = provider_model_bundle.configuration.provider.provider
  234. model_name = copy_app_model_config_dict['model']['name']
  235. model_type_instance = provider_model_bundle.model_type_instance
  236. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  237. # check model credentials
  238. model_credentials = provider_model_bundle.configuration.get_current_credentials(
  239. model_type=ModelType.LLM,
  240. model=copy_app_model_config_dict['model']['name']
  241. )
  242. if model_credentials is None:
  243. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  244. # check model
  245. provider_model = provider_model_bundle.configuration.get_provider_model(
  246. model=copy_app_model_config_dict['model']['name'],
  247. model_type=ModelType.LLM
  248. )
  249. if provider_model is None:
  250. model_name = copy_app_model_config_dict['model']['name']
  251. raise ValueError(f"Model {model_name} not exist.")
  252. if provider_model.status == ModelStatus.NO_CONFIGURE:
  253. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  254. elif provider_model.status == ModelStatus.NO_PERMISSION:
  255. raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
  256. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  257. raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
  258. # model config
  259. completion_params = copy_app_model_config_dict['model'].get('completion_params')
  260. stop = []
  261. if 'stop' in completion_params:
  262. stop = completion_params['stop']
  263. del completion_params['stop']
  264. # get model mode
  265. model_mode = copy_app_model_config_dict['model'].get('mode')
  266. if not model_mode:
  267. mode_enum = model_type_instance.get_model_mode(
  268. model=copy_app_model_config_dict['model']['name'],
  269. credentials=model_credentials
  270. )
  271. model_mode = mode_enum.value
  272. model_schema = model_type_instance.get_model_schema(
  273. copy_app_model_config_dict['model']['name'],
  274. model_credentials
  275. )
  276. if not model_schema:
  277. raise ValueError(f"Model {model_name} not exist.")
  278. properties['model_config'] = ModelConfigEntity(
  279. provider=copy_app_model_config_dict['model']['provider'],
  280. model=copy_app_model_config_dict['model']['name'],
  281. model_schema=model_schema,
  282. mode=model_mode,
  283. provider_model_bundle=provider_model_bundle,
  284. credentials=model_credentials,
  285. parameters=completion_params,
  286. stop=stop,
  287. )
  288. # prompt template
  289. prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
  290. if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  291. simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
  292. properties['prompt_template'] = PromptTemplateEntity(
  293. prompt_type=prompt_type,
  294. simple_prompt_template=simple_prompt_template
  295. )
  296. else:
  297. advanced_chat_prompt_template = None
  298. chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
  299. if chat_prompt_config:
  300. chat_prompt_messages = []
  301. for message in chat_prompt_config.get("prompt", []):
  302. chat_prompt_messages.append({
  303. "text": message["text"],
  304. "role": PromptMessageRole.value_of(message["role"])
  305. })
  306. advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
  307. messages=chat_prompt_messages
  308. )
  309. advanced_completion_prompt_template = None
  310. completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
  311. if completion_prompt_config:
  312. completion_prompt_template_params = {
  313. 'prompt': completion_prompt_config['prompt']['text'],
  314. }
  315. if 'conversation_histories_role' in completion_prompt_config:
  316. completion_prompt_template_params['role_prefix'] = {
  317. 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
  318. 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
  319. }
  320. advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
  321. **completion_prompt_template_params
  322. )
  323. properties['prompt_template'] = PromptTemplateEntity(
  324. prompt_type=prompt_type,
  325. advanced_chat_prompt_template=advanced_chat_prompt_template,
  326. advanced_completion_prompt_template=advanced_completion_prompt_template
  327. )
  328. # external data variables
  329. properties['external_data_variables'] = []
  330. # old external_data_tools
  331. external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
  332. for external_data_tool in external_data_tools:
  333. if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
  334. continue
  335. properties['external_data_variables'].append(
  336. ExternalDataVariableEntity(
  337. variable=external_data_tool['variable'],
  338. type=external_data_tool['type'],
  339. config=external_data_tool['config']
  340. )
  341. )
  342. # current external_data_tools
  343. for variable in copy_app_model_config_dict.get('user_input_form', []):
  344. typ = list(variable.keys())[0]
  345. if typ == 'external_data_tool':
  346. val = variable[typ]
  347. properties['external_data_variables'].append(
  348. ExternalDataVariableEntity(
  349. variable=val['variable'],
  350. type=val['type'],
  351. config=val['config']
  352. )
  353. )
  354. # show retrieve source
  355. show_retrieve_source = False
  356. retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
  357. if retriever_resource_dict:
  358. if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
  359. show_retrieve_source = True
  360. properties['show_retrieve_source'] = show_retrieve_source
  361. dataset_ids = []
  362. if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
  363. datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
  364. 'strategy': 'router',
  365. 'datasets': []
  366. })
  367. for dataset in datasets.get('datasets', []):
  368. keys = list(dataset.keys())
  369. if len(keys) == 0 or keys[0] != 'dataset':
  370. continue
  371. dataset = dataset['dataset']
  372. if 'enabled' not in dataset or not dataset['enabled']:
  373. continue
  374. dataset_id = dataset.get('id', None)
  375. if dataset_id:
  376. dataset_ids.append(dataset_id)
  377. else:
  378. datasets = {'strategy': 'router', 'datasets': []}
  379. if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
  380. and 'enabled' in copy_app_model_config_dict['agent_mode'] \
  381. and copy_app_model_config_dict['agent_mode']['enabled']:
  382. agent_dict = copy_app_model_config_dict.get('agent_mode', {})
  383. agent_strategy = agent_dict.get('strategy', 'cot')
  384. if agent_strategy == 'function_call':
  385. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  386. elif agent_strategy == 'cot' or agent_strategy == 'react':
  387. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  388. else:
  389. # old configs, try to detect default strategy
  390. if copy_app_model_config_dict['model']['provider'] == 'openai':
  391. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  392. else:
  393. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  394. agent_tools = []
  395. for tool in agent_dict.get('tools', []):
  396. keys = tool.keys()
  397. if len(keys) >= 4:
  398. if "enabled" not in tool or not tool["enabled"]:
  399. continue
  400. agent_tool_properties = {
  401. 'provider_type': tool['provider_type'],
  402. 'provider_id': tool['provider_id'],
  403. 'tool_name': tool['tool_name'],
  404. 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
  405. }
  406. agent_tools.append(AgentToolEntity(**agent_tool_properties))
  407. elif len(keys) == 1:
  408. # old standard
  409. key = list(tool.keys())[0]
  410. if key != 'dataset':
  411. continue
  412. tool_item = tool[key]
  413. if "enabled" not in tool_item or not tool_item["enabled"]:
  414. continue
  415. dataset_id = tool_item['id']
  416. dataset_ids.append(dataset_id)
  417. if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
  418. copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
  419. agent_prompt = agent_dict.get('prompt', None) or {}
  420. # check model mode
  421. model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
  422. if model_mode == 'completion':
  423. agent_prompt_entity = AgentPromptEntity(
  424. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
  425. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
  426. )
  427. else:
  428. agent_prompt_entity = AgentPromptEntity(
  429. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
  430. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
  431. )
  432. properties['agent'] = AgentEntity(
  433. provider=properties['model_config'].provider,
  434. model=properties['model_config'].model,
  435. strategy=strategy,
  436. prompt=agent_prompt_entity,
  437. tools=agent_tools,
  438. max_iteration=agent_dict.get('max_iteration', 5)
  439. )
  440. if len(dataset_ids) > 0:
  441. # dataset configs
  442. dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
  443. query_variable = copy_app_model_config_dict.get('dataset_query_variable')
  444. if dataset_configs['retrieval_model'] == 'single':
  445. properties['dataset'] = DatasetEntity(
  446. dataset_ids=dataset_ids,
  447. retrieve_config=DatasetRetrieveConfigEntity(
  448. query_variable=query_variable,
  449. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  450. dataset_configs['retrieval_model']
  451. ),
  452. single_strategy=datasets.get('strategy', 'router')
  453. )
  454. )
  455. else:
  456. properties['dataset'] = DatasetEntity(
  457. dataset_ids=dataset_ids,
  458. retrieve_config=DatasetRetrieveConfigEntity(
  459. query_variable=query_variable,
  460. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  461. dataset_configs['retrieval_model']
  462. ),
  463. top_k=dataset_configs.get('top_k'),
  464. score_threshold=dataset_configs.get('score_threshold'),
  465. reranking_model=dataset_configs.get('reranking_model')
  466. )
  467. )
  468. # file upload
  469. file_upload_dict = copy_app_model_config_dict.get('file_upload')
  470. if file_upload_dict:
  471. if 'image' in file_upload_dict and file_upload_dict['image']:
  472. if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
  473. properties['file_upload'] = FileUploadEntity(
  474. image_config={
  475. 'number_limits': file_upload_dict['image']['number_limits'],
  476. 'detail': file_upload_dict['image']['detail'],
  477. 'transfer_methods': file_upload_dict['image']['transfer_methods']
  478. }
  479. )
  480. # opening statement
  481. properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
  482. # suggested questions after answer
  483. suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
  484. if suggested_questions_after_answer_dict:
  485. if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
  486. properties['suggested_questions_after_answer'] = True
  487. # more like this
  488. more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
  489. if more_like_this_dict:
  490. if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
  491. properties['more_like_this'] = True
  492. # speech to text
  493. speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
  494. if speech_to_text_dict:
  495. if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
  496. properties['speech_to_text'] = True
  497. # text to speech
  498. text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
  499. if text_to_speech_dict:
  500. if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
  501. properties['text_to_speech'] = TextToSpeechEntity(
  502. enabled=text_to_speech_dict.get('enabled'),
  503. voice=text_to_speech_dict.get('voice'),
  504. language=text_to_speech_dict.get('language'),
  505. )
  506. # sensitive word avoidance
  507. sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
  508. if sensitive_word_avoidance_dict:
  509. if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
  510. properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
  511. type=sensitive_word_avoidance_dict.get('type'),
  512. config=sensitive_word_avoidance_dict.get('config'),
  513. )
  514. return AppOrchestrationConfigEntity(**properties)
  515. def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
  516. -> tuple[Conversation, Message]:
  517. """
  518. Initialize generate records
  519. :param application_generate_entity: application generate entity
  520. :return:
  521. """
  522. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  523. model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
  524. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  525. model_schema = model_type_instance.get_model_schema(
  526. model=app_orchestration_config_entity.model_config.model,
  527. credentials=app_orchestration_config_entity.model_config.credentials
  528. )
  529. app_record = (db.session.query(App)
  530. .filter(App.id == application_generate_entity.app_id).first())
  531. app_mode = app_record.mode
  532. # get from source
  533. end_user_id = None
  534. account_id = None
  535. if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
  536. from_source = 'api'
  537. end_user_id = application_generate_entity.user_id
  538. else:
  539. from_source = 'console'
  540. account_id = application_generate_entity.user_id
  541. override_model_configs = None
  542. if application_generate_entity.app_model_config_override:
  543. override_model_configs = application_generate_entity.app_model_config_dict
  544. introduction = ''
  545. if app_mode == 'chat':
  546. # get conversation introduction
  547. introduction = self._get_conversation_introduction(application_generate_entity)
  548. if not application_generate_entity.conversation_id:
  549. conversation = Conversation(
  550. app_id=app_record.id,
  551. app_model_config_id=application_generate_entity.app_model_config_id,
  552. model_provider=app_orchestration_config_entity.model_config.provider,
  553. model_id=app_orchestration_config_entity.model_config.model,
  554. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  555. mode=app_mode,
  556. name='New conversation',
  557. inputs=application_generate_entity.inputs,
  558. introduction=introduction,
  559. system_instruction="",
  560. system_instruction_tokens=0,
  561. status='normal',
  562. from_source=from_source,
  563. from_end_user_id=end_user_id,
  564. from_account_id=account_id,
  565. )
  566. db.session.add(conversation)
  567. db.session.commit()
  568. db.session.refresh(conversation)
  569. else:
  570. conversation = (
  571. db.session.query(Conversation)
  572. .filter(
  573. Conversation.id == application_generate_entity.conversation_id,
  574. Conversation.app_id == app_record.id
  575. ).first()
  576. )
  577. currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
  578. message = Message(
  579. app_id=app_record.id,
  580. model_provider=app_orchestration_config_entity.model_config.provider,
  581. model_id=app_orchestration_config_entity.model_config.model,
  582. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  583. conversation_id=conversation.id,
  584. inputs=application_generate_entity.inputs,
  585. query=application_generate_entity.query or "",
  586. message="",
  587. message_tokens=0,
  588. message_unit_price=0,
  589. message_price_unit=0,
  590. answer="",
  591. answer_tokens=0,
  592. answer_unit_price=0,
  593. answer_price_unit=0,
  594. provider_response_latency=0,
  595. total_price=0,
  596. currency=currency,
  597. from_source=from_source,
  598. from_end_user_id=end_user_id,
  599. from_account_id=account_id,
  600. agent_based=app_orchestration_config_entity.agent is not None
  601. )
  602. db.session.add(message)
  603. db.session.commit()
  604. db.session.refresh(message)
  605. for file in application_generate_entity.files:
  606. message_file = MessageFile(
  607. message_id=message.id,
  608. type=file.type.value,
  609. transfer_method=file.transfer_method.value,
  610. belongs_to='user',
  611. url=file.url,
  612. upload_file_id=file.upload_file_id,
  613. created_by_role=('account' if account_id else 'end_user'),
  614. created_by=account_id or end_user_id,
  615. )
  616. db.session.add(message_file)
  617. db.session.commit()
  618. return conversation, message
  619. def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
  620. """
  621. Get conversation introduction
  622. :param application_generate_entity: application generate entity
  623. :return: conversation introduction
  624. """
  625. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  626. introduction = app_orchestration_config_entity.opening_statement
  627. if introduction:
  628. try:
  629. inputs = application_generate_entity.inputs
  630. prompt_template = PromptTemplateParser(template=introduction)
  631. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  632. introduction = prompt_template.format(prompt_inputs)
  633. except KeyError:
  634. pass
  635. return introduction
  636. def _get_conversation(self, conversation_id: str) -> Conversation:
  637. """
  638. Get conversation by conversation id
  639. :param conversation_id: conversation id
  640. :return: conversation
  641. """
  642. conversation = (
  643. db.session.query(Conversation)
  644. .filter(Conversation.id == conversation_id)
  645. .first()
  646. )
  647. return conversation
  648. def _get_message(self, message_id: str) -> Message:
  649. """
  650. Get message by message id
  651. :param message_id: message id
  652. :return: message
  653. """
  654. message = (
  655. db.session.query(Message)
  656. .filter(Message.id == message_id)
  657. .first()
  658. )
  659. return message