application_manager.py 32 KB


  1. import json
  2. import logging
  3. import threading
  4. import uuid
  5. from collections.abc import Generator
  6. from typing import Any, Optional, Union, cast
  7. from flask import Flask, current_app
  8. from pydantic import ValidationError
  9. from core.app_runner.assistant_app_runner import AssistantApplicationRunner
  10. from core.app_runner.basic_app_runner import BasicApplicationRunner
  11. from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
  12. from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
  13. from core.entities.application_entities import (
  14. AdvancedChatPromptTemplateEntity,
  15. AdvancedCompletionPromptTemplateEntity,
  16. AgentEntity,
  17. AgentPromptEntity,
  18. AgentToolEntity,
  19. ApplicationGenerateEntity,
  20. AppOrchestrationConfigEntity,
  21. DatasetEntity,
  22. DatasetRetrieveConfigEntity,
  23. ExternalDataVariableEntity,
  24. FileUploadEntity,
  25. InvokeFrom,
  26. ModelConfigEntity,
  27. PromptTemplateEntity,
  28. SensitiveWordAvoidanceEntity,
  29. TextToSpeechEntity,
  30. )
  31. from core.entities.model_entities import ModelStatus
  32. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  33. from core.file.file_obj import FileObj
  34. from core.model_runtime.entities.message_entities import PromptMessageRole
  35. from core.model_runtime.entities.model_entities import ModelType
  36. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  37. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  38. from core.prompt.prompt_template import PromptTemplateParser
  39. from core.provider_manager import ProviderManager
  40. from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
  41. from extensions.ext_database import db
  42. from models.account import Account
  43. from models.model import App, Conversation, EndUser, Message, MessageFile
  44. logger = logging.getLogger(__name__)
  45. class ApplicationManager:
  46. """
  47. This class is responsible for managing application
  48. """
  49. def generate(self, tenant_id: str,
  50. app_id: str,
  51. app_model_config_id: str,
  52. app_model_config_dict: dict,
  53. app_model_config_override: bool,
  54. user: Union[Account, EndUser],
  55. invoke_from: InvokeFrom,
  56. inputs: dict[str, str],
  57. query: Optional[str] = None,
  58. files: Optional[list[FileObj]] = None,
  59. conversation: Optional[Conversation] = None,
  60. stream: bool = False,
  61. extras: Optional[dict[str, Any]] = None) \
  62. -> Union[dict, Generator]:
  63. """
  64. Generate App response.
  65. :param tenant_id: workspace ID
  66. :param app_id: app ID
  67. :param app_model_config_id: app model config id
  68. :param app_model_config_dict: app model config dict
  69. :param app_model_config_override: app model config override
  70. :param user: account or end user
  71. :param invoke_from: invoke from source
  72. :param inputs: inputs
  73. :param query: query
  74. :param files: file obj list
  75. :param conversation: conversation
  76. :param stream: is stream
  77. :param extras: extras
  78. """
  79. # init task id
  80. task_id = str(uuid.uuid4())
  81. # init application generate entity
  82. application_generate_entity = ApplicationGenerateEntity(
  83. task_id=task_id,
  84. tenant_id=tenant_id,
  85. app_id=app_id,
  86. app_model_config_id=app_model_config_id,
  87. app_model_config_dict=app_model_config_dict,
  88. app_orchestration_config_entity=self._convert_from_app_model_config_dict(
  89. tenant_id=tenant_id,
  90. app_model_config_dict=app_model_config_dict
  91. ),
  92. app_model_config_override=app_model_config_override,
  93. conversation_id=conversation.id if conversation else None,
  94. inputs=conversation.inputs if conversation else inputs,
  95. query=query.replace('\x00', '') if query else None,
  96. files=files if files else [],
  97. user_id=user.id,
  98. stream=stream,
  99. invoke_from=invoke_from,
  100. extras=extras
  101. )
  102. if not stream and application_generate_entity.app_orchestration_config_entity.agent:
  103. raise ValueError("Agent app is not supported in blocking mode.")
  104. # init generate records
  105. (
  106. conversation,
  107. message
  108. ) = self._init_generate_records(application_generate_entity)
  109. # init queue manager
  110. queue_manager = ApplicationQueueManager(
  111. task_id=application_generate_entity.task_id,
  112. user_id=application_generate_entity.user_id,
  113. invoke_from=application_generate_entity.invoke_from,
  114. conversation_id=conversation.id,
  115. app_mode=conversation.mode,
  116. message_id=message.id
  117. )
  118. # new thread
  119. worker_thread = threading.Thread(target=self._generate_worker, kwargs={
  120. 'flask_app': current_app._get_current_object(),
  121. 'application_generate_entity': application_generate_entity,
  122. 'queue_manager': queue_manager,
  123. 'conversation_id': conversation.id,
  124. 'message_id': message.id,
  125. })
  126. worker_thread.start()
  127. # return response or stream generator
  128. return self._handle_response(
  129. application_generate_entity=application_generate_entity,
  130. queue_manager=queue_manager,
  131. conversation=conversation,
  132. message=message,
  133. stream=stream
  134. )
  135. def _generate_worker(self, flask_app: Flask,
  136. application_generate_entity: ApplicationGenerateEntity,
  137. queue_manager: ApplicationQueueManager,
  138. conversation_id: str,
  139. message_id: str) -> None:
  140. """
  141. Generate worker in a new thread.
  142. :param flask_app: Flask app
  143. :param application_generate_entity: application generate entity
  144. :param queue_manager: queue manager
  145. :param conversation_id: conversation ID
  146. :param message_id: message ID
  147. :return:
  148. """
  149. with flask_app.app_context():
  150. try:
  151. # get conversation and message
  152. conversation = self._get_conversation(conversation_id)
  153. message = self._get_message(message_id)
  154. if application_generate_entity.app_orchestration_config_entity.agent:
  155. # agent app
  156. runner = AssistantApplicationRunner()
  157. runner.run(
  158. application_generate_entity=application_generate_entity,
  159. queue_manager=queue_manager,
  160. conversation=conversation,
  161. message=message
  162. )
  163. else:
  164. # basic app
  165. runner = BasicApplicationRunner()
  166. runner.run(
  167. application_generate_entity=application_generate_entity,
  168. queue_manager=queue_manager,
  169. conversation=conversation,
  170. message=message
  171. )
  172. except ConversationTaskStoppedException:
  173. pass
  174. except InvokeAuthorizationError:
  175. queue_manager.publish_error(
  176. InvokeAuthorizationError('Incorrect API key provided'),
  177. PublishFrom.APPLICATION_MANAGER
  178. )
  179. except ValidationError as e:
  180. logger.exception("Validation Error when generating")
  181. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  182. except (ValueError, InvokeError) as e:
  183. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  184. except Exception as e:
  185. logger.exception("Unknown Error when generating")
  186. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  187. finally:
  188. db.session.remove()
  189. def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
  190. queue_manager: ApplicationQueueManager,
  191. conversation: Conversation,
  192. message: Message,
  193. stream: bool = False) -> Union[dict, Generator]:
  194. """
  195. Handle response.
  196. :param application_generate_entity: application generate entity
  197. :param queue_manager: queue manager
  198. :param conversation: conversation
  199. :param message: message
  200. :param stream: is stream
  201. :return:
  202. """
  203. # init generate task pipeline
  204. generate_task_pipeline = GenerateTaskPipeline(
  205. application_generate_entity=application_generate_entity,
  206. queue_manager=queue_manager,
  207. conversation=conversation,
  208. message=message
  209. )
  210. try:
  211. return generate_task_pipeline.process(stream=stream)
  212. except ValueError as e:
  213. if e.args[0] == "I/O operation on closed file.": # ignore this error
  214. raise ConversationTaskStoppedException()
  215. else:
  216. logger.exception(e)
  217. raise e
  218. finally:
  219. db.session.remove()
  220. def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
  221. -> AppOrchestrationConfigEntity:
  222. """
  223. Convert app model config dict to entity.
  224. :param tenant_id: tenant ID
  225. :param app_model_config_dict: app model config dict
  226. :raises ProviderTokenNotInitError: provider token not init error
  227. :return: app orchestration config entity
  228. """
  229. properties = {}
  230. copy_app_model_config_dict = app_model_config_dict.copy()
  231. provider_manager = ProviderManager()
  232. provider_model_bundle = provider_manager.get_provider_model_bundle(
  233. tenant_id=tenant_id,
  234. provider=copy_app_model_config_dict['model']['provider'],
  235. model_type=ModelType.LLM
  236. )
  237. provider_name = provider_model_bundle.configuration.provider.provider
  238. model_name = copy_app_model_config_dict['model']['name']
  239. model_type_instance = provider_model_bundle.model_type_instance
  240. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  241. # check model credentials
  242. model_credentials = provider_model_bundle.configuration.get_current_credentials(
  243. model_type=ModelType.LLM,
  244. model=copy_app_model_config_dict['model']['name']
  245. )
  246. if model_credentials is None:
  247. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  248. # check model
  249. provider_model = provider_model_bundle.configuration.get_provider_model(
  250. model=copy_app_model_config_dict['model']['name'],
  251. model_type=ModelType.LLM
  252. )
  253. if provider_model is None:
  254. model_name = copy_app_model_config_dict['model']['name']
  255. raise ValueError(f"Model {model_name} not exist.")
  256. if provider_model.status == ModelStatus.NO_CONFIGURE:
  257. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  258. elif provider_model.status == ModelStatus.NO_PERMISSION:
  259. raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
  260. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  261. raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
  262. # model config
  263. completion_params = copy_app_model_config_dict['model'].get('completion_params')
  264. stop = []
  265. if 'stop' in completion_params:
  266. stop = completion_params['stop']
  267. del completion_params['stop']
  268. # get model mode
  269. model_mode = copy_app_model_config_dict['model'].get('mode')
  270. if not model_mode:
  271. mode_enum = model_type_instance.get_model_mode(
  272. model=copy_app_model_config_dict['model']['name'],
  273. credentials=model_credentials
  274. )
  275. model_mode = mode_enum.value
  276. model_schema = model_type_instance.get_model_schema(
  277. copy_app_model_config_dict['model']['name'],
  278. model_credentials
  279. )
  280. if not model_schema:
  281. raise ValueError(f"Model {model_name} not exist.")
  282. properties['model_config'] = ModelConfigEntity(
  283. provider=copy_app_model_config_dict['model']['provider'],
  284. model=copy_app_model_config_dict['model']['name'],
  285. model_schema=model_schema,
  286. mode=model_mode,
  287. provider_model_bundle=provider_model_bundle,
  288. credentials=model_credentials,
  289. parameters=completion_params,
  290. stop=stop,
  291. )
  292. # prompt template
  293. prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
  294. if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  295. simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
  296. properties['prompt_template'] = PromptTemplateEntity(
  297. prompt_type=prompt_type,
  298. simple_prompt_template=simple_prompt_template
  299. )
  300. else:
  301. advanced_chat_prompt_template = None
  302. chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
  303. if chat_prompt_config:
  304. chat_prompt_messages = []
  305. for message in chat_prompt_config.get("prompt", []):
  306. chat_prompt_messages.append({
  307. "text": message["text"],
  308. "role": PromptMessageRole.value_of(message["role"])
  309. })
  310. advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
  311. messages=chat_prompt_messages
  312. )
  313. advanced_completion_prompt_template = None
  314. completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
  315. if completion_prompt_config:
  316. completion_prompt_template_params = {
  317. 'prompt': completion_prompt_config['prompt']['text'],
  318. }
  319. if 'conversation_histories_role' in completion_prompt_config:
  320. completion_prompt_template_params['role_prefix'] = {
  321. 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
  322. 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
  323. }
  324. advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
  325. **completion_prompt_template_params
  326. )
  327. properties['prompt_template'] = PromptTemplateEntity(
  328. prompt_type=prompt_type,
  329. advanced_chat_prompt_template=advanced_chat_prompt_template,
  330. advanced_completion_prompt_template=advanced_completion_prompt_template
  331. )
  332. # external data variables
  333. properties['external_data_variables'] = []
  334. # old external_data_tools
  335. external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
  336. for external_data_tool in external_data_tools:
  337. if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
  338. continue
  339. properties['external_data_variables'].append(
  340. ExternalDataVariableEntity(
  341. variable=external_data_tool['variable'],
  342. type=external_data_tool['type'],
  343. config=external_data_tool['config']
  344. )
  345. )
  346. # current external_data_tools
  347. for variable in copy_app_model_config_dict.get('user_input_form', []):
  348. typ = list(variable.keys())[0]
  349. if typ == 'external_data_tool':
  350. val = variable[typ]
  351. properties['external_data_variables'].append(
  352. ExternalDataVariableEntity(
  353. variable=val['variable'],
  354. type=val['type'],
  355. config=val['config']
  356. )
  357. )
  358. # show retrieve source
  359. show_retrieve_source = False
  360. retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
  361. if retriever_resource_dict:
  362. if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
  363. show_retrieve_source = True
  364. properties['show_retrieve_source'] = show_retrieve_source
  365. dataset_ids = []
  366. if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
  367. datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
  368. 'strategy': 'router',
  369. 'datasets': []
  370. })
  371. for dataset in datasets.get('datasets', []):
  372. keys = list(dataset.keys())
  373. if len(keys) == 0 or keys[0] != 'dataset':
  374. continue
  375. dataset = dataset['dataset']
  376. if 'enabled' not in dataset or not dataset['enabled']:
  377. continue
  378. dataset_id = dataset.get('id', None)
  379. if dataset_id:
  380. dataset_ids.append(dataset_id)
  381. else:
  382. datasets = {'strategy': 'router', 'datasets': []}
  383. if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
  384. and 'enabled' in copy_app_model_config_dict['agent_mode'] \
  385. and copy_app_model_config_dict['agent_mode']['enabled']:
  386. agent_dict = copy_app_model_config_dict.get('agent_mode', {})
  387. agent_strategy = agent_dict.get('strategy', 'cot')
  388. if agent_strategy == 'function_call':
  389. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  390. elif agent_strategy == 'cot' or agent_strategy == 'react':
  391. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  392. else:
  393. # old configs, try to detect default strategy
  394. if copy_app_model_config_dict['model']['provider'] == 'openai':
  395. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  396. else:
  397. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  398. agent_tools = []
  399. for tool in agent_dict.get('tools', []):
  400. keys = tool.keys()
  401. if len(keys) >= 4:
  402. if "enabled" not in tool or not tool["enabled"]:
  403. continue
  404. agent_tool_properties = {
  405. 'provider_type': tool['provider_type'],
  406. 'provider_id': tool['provider_id'],
  407. 'tool_name': tool['tool_name'],
  408. 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
  409. }
  410. agent_tools.append(AgentToolEntity(**agent_tool_properties))
  411. elif len(keys) == 1:
  412. # old standard
  413. key = list(tool.keys())[0]
  414. if key != 'dataset':
  415. continue
  416. tool_item = tool[key]
  417. if "enabled" not in tool_item or not tool_item["enabled"]:
  418. continue
  419. dataset_id = tool_item['id']
  420. dataset_ids.append(dataset_id)
  421. if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
  422. copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
  423. agent_prompt = agent_dict.get('prompt', None) or {}
  424. # check model mode
  425. model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
  426. if model_mode == 'completion':
  427. agent_prompt_entity = AgentPromptEntity(
  428. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
  429. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
  430. )
  431. else:
  432. agent_prompt_entity = AgentPromptEntity(
  433. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
  434. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
  435. )
  436. properties['agent'] = AgentEntity(
  437. provider=properties['model_config'].provider,
  438. model=properties['model_config'].model,
  439. strategy=strategy,
  440. prompt=agent_prompt_entity,
  441. tools=agent_tools,
  442. max_iteration=agent_dict.get('max_iteration', 5)
  443. )
  444. if len(dataset_ids) > 0:
  445. # dataset configs
  446. dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
  447. query_variable = copy_app_model_config_dict.get('dataset_query_variable')
  448. if dataset_configs['retrieval_model'] == 'single':
  449. properties['dataset'] = DatasetEntity(
  450. dataset_ids=dataset_ids,
  451. retrieve_config=DatasetRetrieveConfigEntity(
  452. query_variable=query_variable,
  453. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  454. dataset_configs['retrieval_model']
  455. ),
  456. single_strategy=datasets.get('strategy', 'router')
  457. )
  458. )
  459. else:
  460. properties['dataset'] = DatasetEntity(
  461. dataset_ids=dataset_ids,
  462. retrieve_config=DatasetRetrieveConfigEntity(
  463. query_variable=query_variable,
  464. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  465. dataset_configs['retrieval_model']
  466. ),
  467. top_k=dataset_configs.get('top_k'),
  468. score_threshold=dataset_configs.get('score_threshold'),
  469. reranking_model=dataset_configs.get('reranking_model')
  470. )
  471. )
  472. # file upload
  473. file_upload_dict = copy_app_model_config_dict.get('file_upload')
  474. if file_upload_dict:
  475. if 'image' in file_upload_dict and file_upload_dict['image']:
  476. if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
  477. properties['file_upload'] = FileUploadEntity(
  478. image_config={
  479. 'number_limits': file_upload_dict['image']['number_limits'],
  480. 'detail': file_upload_dict['image']['detail'],
  481. 'transfer_methods': file_upload_dict['image']['transfer_methods']
  482. }
  483. )
  484. # opening statement
  485. properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
  486. # suggested questions after answer
  487. suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
  488. if suggested_questions_after_answer_dict:
  489. if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
  490. properties['suggested_questions_after_answer'] = True
  491. # more like this
  492. more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
  493. if more_like_this_dict:
  494. if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
  495. properties['more_like_this'] = True
  496. # speech to text
  497. speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
  498. if speech_to_text_dict:
  499. if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
  500. properties['speech_to_text'] = True
  501. # text to speech
  502. text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
  503. if text_to_speech_dict:
  504. if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
  505. properties['text_to_speech'] = TextToSpeechEntity(
  506. enabled=text_to_speech_dict.get('enabled'),
  507. voice=text_to_speech_dict.get('voice'),
  508. language=text_to_speech_dict.get('language'),
  509. )
  510. # sensitive word avoidance
  511. sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
  512. if sensitive_word_avoidance_dict:
  513. if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
  514. properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
  515. type=sensitive_word_avoidance_dict.get('type'),
  516. config=sensitive_word_avoidance_dict.get('config'),
  517. )
  518. return AppOrchestrationConfigEntity(**properties)
  519. def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
  520. -> tuple[Conversation, Message]:
  521. """
  522. Initialize generate records
  523. :param application_generate_entity: application generate entity
  524. :return:
  525. """
  526. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  527. model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
  528. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  529. model_schema = model_type_instance.get_model_schema(
  530. model=app_orchestration_config_entity.model_config.model,
  531. credentials=app_orchestration_config_entity.model_config.credentials
  532. )
  533. app_record = (db.session.query(App)
  534. .filter(App.id == application_generate_entity.app_id).first())
  535. app_mode = app_record.mode
  536. # get from source
  537. end_user_id = None
  538. account_id = None
  539. if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
  540. from_source = 'api'
  541. end_user_id = application_generate_entity.user_id
  542. else:
  543. from_source = 'console'
  544. account_id = application_generate_entity.user_id
  545. override_model_configs = None
  546. if application_generate_entity.app_model_config_override:
  547. override_model_configs = application_generate_entity.app_model_config_dict
  548. introduction = ''
  549. if app_mode == 'chat':
  550. # get conversation introduction
  551. introduction = self._get_conversation_introduction(application_generate_entity)
  552. if not application_generate_entity.conversation_id:
  553. conversation = Conversation(
  554. app_id=app_record.id,
  555. app_model_config_id=application_generate_entity.app_model_config_id,
  556. model_provider=app_orchestration_config_entity.model_config.provider,
  557. model_id=app_orchestration_config_entity.model_config.model,
  558. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  559. mode=app_mode,
  560. name='New conversation',
  561. inputs=application_generate_entity.inputs,
  562. introduction=introduction,
  563. system_instruction="",
  564. system_instruction_tokens=0,
  565. status='normal',
  566. from_source=from_source,
  567. from_end_user_id=end_user_id,
  568. from_account_id=account_id,
  569. )
  570. db.session.add(conversation)
  571. db.session.commit()
  572. else:
  573. conversation = (
  574. db.session.query(Conversation)
  575. .filter(
  576. Conversation.id == application_generate_entity.conversation_id,
  577. Conversation.app_id == app_record.id
  578. ).first()
  579. )
  580. currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
  581. message = Message(
  582. app_id=app_record.id,
  583. model_provider=app_orchestration_config_entity.model_config.provider,
  584. model_id=app_orchestration_config_entity.model_config.model,
  585. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  586. conversation_id=conversation.id,
  587. inputs=application_generate_entity.inputs,
  588. query=application_generate_entity.query or "",
  589. message="",
  590. message_tokens=0,
  591. message_unit_price=0,
  592. message_price_unit=0,
  593. answer="",
  594. answer_tokens=0,
  595. answer_unit_price=0,
  596. answer_price_unit=0,
  597. provider_response_latency=0,
  598. total_price=0,
  599. currency=currency,
  600. from_source=from_source,
  601. from_end_user_id=end_user_id,
  602. from_account_id=account_id,
  603. agent_based=app_orchestration_config_entity.agent is not None
  604. )
  605. db.session.add(message)
  606. db.session.commit()
  607. for file in application_generate_entity.files:
  608. message_file = MessageFile(
  609. message_id=message.id,
  610. type=file.type.value,
  611. transfer_method=file.transfer_method.value,
  612. belongs_to='user',
  613. url=file.url,
  614. upload_file_id=file.upload_file_id,
  615. created_by_role=('account' if account_id else 'end_user'),
  616. created_by=account_id or end_user_id,
  617. )
  618. db.session.add(message_file)
  619. db.session.commit()
  620. return conversation, message
  621. def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
  622. """
  623. Get conversation introduction
  624. :param application_generate_entity: application generate entity
  625. :return: conversation introduction
  626. """
  627. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  628. introduction = app_orchestration_config_entity.opening_statement
  629. if introduction:
  630. try:
  631. inputs = application_generate_entity.inputs
  632. prompt_template = PromptTemplateParser(template=introduction)
  633. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  634. introduction = prompt_template.format(prompt_inputs)
  635. except KeyError:
  636. pass
  637. return introduction
  638. def _get_conversation(self, conversation_id: str) -> Conversation:
  639. """
  640. Get conversation by conversation id
  641. :param conversation_id: conversation id
  642. :return: conversation
  643. """
  644. conversation = (
  645. db.session.query(Conversation)
  646. .filter(Conversation.id == conversation_id)
  647. .first()
  648. )
  649. return conversation
  650. def _get_message(self, message_id: str) -> Message:
  651. """
  652. Get message by message id
  653. :param message_id: message id
  654. :return: message
  655. """
  656. message = (
  657. db.session.query(Message)
  658. .filter(Message.id == message_id)
  659. .first()
  660. )
  661. return message