application_manager.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import json
  2. import logging
  3. import threading
  4. import uuid
  5. from typing import Any, Generator, Optional, Tuple, Union, cast
  6. from flask import Flask, current_app
  7. from pydantic import ValidationError
  8. from core.app_runner.assistant_app_runner import AssistantApplicationRunner
  9. from core.app_runner.basic_app_runner import BasicApplicationRunner
  10. from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
  11. from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
  12. from core.entities.application_entities import (
  13. AdvancedChatPromptTemplateEntity,
  14. AdvancedCompletionPromptTemplateEntity,
  15. AgentEntity,
  16. AgentPromptEntity,
  17. AgentToolEntity,
  18. ApplicationGenerateEntity,
  19. AppOrchestrationConfigEntity,
  20. DatasetEntity,
  21. DatasetRetrieveConfigEntity,
  22. ExternalDataVariableEntity,
  23. FileUploadEntity,
  24. InvokeFrom,
  25. ModelConfigEntity,
  26. PromptTemplateEntity,
  27. SensitiveWordAvoidanceEntity,
  28. )
  29. from core.entities.model_entities import ModelStatus
  30. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  31. from core.file.file_obj import FileObj
  32. from core.model_runtime.entities.message_entities import PromptMessageRole
  33. from core.model_runtime.entities.model_entities import ModelType
  34. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  35. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  36. from core.prompt.prompt_template import PromptTemplateParser
  37. from core.provider_manager import ProviderManager
  38. from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
  39. from extensions.ext_database import db
  40. from models.account import Account
  41. from models.model import App, Conversation, EndUser, Message, MessageFile
  42. logger = logging.getLogger(__name__)
  43. class ApplicationManager:
  44. """
  45. This class is responsible for managing application
  46. """
  47. def generate(self, tenant_id: str,
  48. app_id: str,
  49. app_model_config_id: str,
  50. app_model_config_dict: dict,
  51. app_model_config_override: bool,
  52. user: Union[Account, EndUser],
  53. invoke_from: InvokeFrom,
  54. inputs: dict[str, str],
  55. query: Optional[str] = None,
  56. files: Optional[list[FileObj]] = None,
  57. conversation: Optional[Conversation] = None,
  58. stream: bool = False,
  59. extras: Optional[dict[str, Any]] = None) \
  60. -> Union[dict, Generator]:
  61. """
  62. Generate App response.
  63. :param tenant_id: workspace ID
  64. :param app_id: app ID
  65. :param app_model_config_id: app model config id
  66. :param app_model_config_dict: app model config dict
  67. :param app_model_config_override: app model config override
  68. :param user: account or end user
  69. :param invoke_from: invoke from source
  70. :param inputs: inputs
  71. :param query: query
  72. :param files: file obj list
  73. :param conversation: conversation
  74. :param stream: is stream
  75. :param extras: extras
  76. """
  77. # init task id
  78. task_id = str(uuid.uuid4())
  79. # init application generate entity
  80. application_generate_entity = ApplicationGenerateEntity(
  81. task_id=task_id,
  82. tenant_id=tenant_id,
  83. app_id=app_id,
  84. app_model_config_id=app_model_config_id,
  85. app_model_config_dict=app_model_config_dict,
  86. app_orchestration_config_entity=self._convert_from_app_model_config_dict(
  87. tenant_id=tenant_id,
  88. app_model_config_dict=app_model_config_dict
  89. ),
  90. app_model_config_override=app_model_config_override,
  91. conversation_id=conversation.id if conversation else None,
  92. inputs=conversation.inputs if conversation else inputs,
  93. query=query.replace('\x00', '') if query else None,
  94. files=files if files else [],
  95. user_id=user.id,
  96. stream=stream,
  97. invoke_from=invoke_from,
  98. extras=extras
  99. )
  100. if not stream and application_generate_entity.app_orchestration_config_entity.agent:
  101. raise ValueError("Agent app is not supported in blocking mode.")
  102. # init generate records
  103. (
  104. conversation,
  105. message
  106. ) = self._init_generate_records(application_generate_entity)
  107. # init queue manager
  108. queue_manager = ApplicationQueueManager(
  109. task_id=application_generate_entity.task_id,
  110. user_id=application_generate_entity.user_id,
  111. invoke_from=application_generate_entity.invoke_from,
  112. conversation_id=conversation.id,
  113. app_mode=conversation.mode,
  114. message_id=message.id
  115. )
  116. # new thread
  117. worker_thread = threading.Thread(target=self._generate_worker, kwargs={
  118. 'flask_app': current_app._get_current_object(),
  119. 'application_generate_entity': application_generate_entity,
  120. 'queue_manager': queue_manager,
  121. 'conversation_id': conversation.id,
  122. 'message_id': message.id,
  123. })
  124. worker_thread.start()
  125. # return response or stream generator
  126. return self._handle_response(
  127. application_generate_entity=application_generate_entity,
  128. queue_manager=queue_manager,
  129. conversation=conversation,
  130. message=message,
  131. stream=stream
  132. )
  133. def _generate_worker(self, flask_app: Flask,
  134. application_generate_entity: ApplicationGenerateEntity,
  135. queue_manager: ApplicationQueueManager,
  136. conversation_id: str,
  137. message_id: str) -> None:
  138. """
  139. Generate worker in a new thread.
  140. :param flask_app: Flask app
  141. :param application_generate_entity: application generate entity
  142. :param queue_manager: queue manager
  143. :param conversation_id: conversation ID
  144. :param message_id: message ID
  145. :return:
  146. """
  147. with flask_app.app_context():
  148. try:
  149. # get conversation and message
  150. conversation = self._get_conversation(conversation_id)
  151. message = self._get_message(message_id)
  152. if application_generate_entity.app_orchestration_config_entity.agent:
  153. # agent app
  154. runner = AssistantApplicationRunner()
  155. runner.run(
  156. application_generate_entity=application_generate_entity,
  157. queue_manager=queue_manager,
  158. conversation=conversation,
  159. message=message
  160. )
  161. else:
  162. # basic app
  163. runner = BasicApplicationRunner()
  164. runner.run(
  165. application_generate_entity=application_generate_entity,
  166. queue_manager=queue_manager,
  167. conversation=conversation,
  168. message=message
  169. )
  170. except ConversationTaskStoppedException:
  171. pass
  172. except InvokeAuthorizationError:
  173. queue_manager.publish_error(
  174. InvokeAuthorizationError('Incorrect API key provided'),
  175. PublishFrom.APPLICATION_MANAGER
  176. )
  177. except ValidationError as e:
  178. logger.exception("Validation Error when generating")
  179. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  180. except (ValueError, InvokeError) as e:
  181. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  182. except Exception as e:
  183. logger.exception("Unknown Error when generating")
  184. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  185. finally:
  186. db.session.remove()
  187. def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
  188. queue_manager: ApplicationQueueManager,
  189. conversation: Conversation,
  190. message: Message,
  191. stream: bool = False) -> Union[dict, Generator]:
  192. """
  193. Handle response.
  194. :param application_generate_entity: application generate entity
  195. :param queue_manager: queue manager
  196. :param conversation: conversation
  197. :param message: message
  198. :param stream: is stream
  199. :return:
  200. """
  201. # init generate task pipeline
  202. generate_task_pipeline = GenerateTaskPipeline(
  203. application_generate_entity=application_generate_entity,
  204. queue_manager=queue_manager,
  205. conversation=conversation,
  206. message=message
  207. )
  208. try:
  209. return generate_task_pipeline.process(stream=stream)
  210. except ValueError as e:
  211. if e.args[0] == "I/O operation on closed file.": # ignore this error
  212. raise ConversationTaskStoppedException()
  213. else:
  214. logger.exception(e)
  215. raise e
  216. finally:
  217. db.session.remove()
  218. def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
  219. -> AppOrchestrationConfigEntity:
  220. """
  221. Convert app model config dict to entity.
  222. :param tenant_id: tenant ID
  223. :param app_model_config_dict: app model config dict
  224. :raises ProviderTokenNotInitError: provider token not init error
  225. :return: app orchestration config entity
  226. """
  227. properties = {}
  228. copy_app_model_config_dict = app_model_config_dict.copy()
  229. provider_manager = ProviderManager()
  230. provider_model_bundle = provider_manager.get_provider_model_bundle(
  231. tenant_id=tenant_id,
  232. provider=copy_app_model_config_dict['model']['provider'],
  233. model_type=ModelType.LLM
  234. )
  235. provider_name = provider_model_bundle.configuration.provider.provider
  236. model_name = copy_app_model_config_dict['model']['name']
  237. model_type_instance = provider_model_bundle.model_type_instance
  238. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  239. # check model credentials
  240. model_credentials = provider_model_bundle.configuration.get_current_credentials(
  241. model_type=ModelType.LLM,
  242. model=copy_app_model_config_dict['model']['name']
  243. )
  244. if model_credentials is None:
  245. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  246. # check model
  247. provider_model = provider_model_bundle.configuration.get_provider_model(
  248. model=copy_app_model_config_dict['model']['name'],
  249. model_type=ModelType.LLM
  250. )
  251. if provider_model is None:
  252. model_name = copy_app_model_config_dict['model']['name']
  253. raise ValueError(f"Model {model_name} not exist.")
  254. if provider_model.status == ModelStatus.NO_CONFIGURE:
  255. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  256. elif provider_model.status == ModelStatus.NO_PERMISSION:
  257. raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
  258. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  259. raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
  260. # model config
  261. completion_params = copy_app_model_config_dict['model'].get('completion_params')
  262. stop = []
  263. if 'stop' in completion_params:
  264. stop = completion_params['stop']
  265. del completion_params['stop']
  266. # get model mode
  267. model_mode = copy_app_model_config_dict['model'].get('mode')
  268. if not model_mode:
  269. mode_enum = model_type_instance.get_model_mode(
  270. model=copy_app_model_config_dict['model']['name'],
  271. credentials=model_credentials
  272. )
  273. model_mode = mode_enum.value
  274. model_schema = model_type_instance.get_model_schema(
  275. copy_app_model_config_dict['model']['name'],
  276. model_credentials
  277. )
  278. if not model_schema:
  279. raise ValueError(f"Model {model_name} not exist.")
  280. properties['model_config'] = ModelConfigEntity(
  281. provider=copy_app_model_config_dict['model']['provider'],
  282. model=copy_app_model_config_dict['model']['name'],
  283. model_schema=model_schema,
  284. mode=model_mode,
  285. provider_model_bundle=provider_model_bundle,
  286. credentials=model_credentials,
  287. parameters=completion_params,
  288. stop=stop,
  289. )
  290. # prompt template
  291. prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
  292. if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  293. simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
  294. properties['prompt_template'] = PromptTemplateEntity(
  295. prompt_type=prompt_type,
  296. simple_prompt_template=simple_prompt_template
  297. )
  298. else:
  299. advanced_chat_prompt_template = None
  300. chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
  301. if chat_prompt_config:
  302. chat_prompt_messages = []
  303. for message in chat_prompt_config.get("prompt", []):
  304. chat_prompt_messages.append({
  305. "text": message["text"],
  306. "role": PromptMessageRole.value_of(message["role"])
  307. })
  308. advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
  309. messages=chat_prompt_messages
  310. )
  311. advanced_completion_prompt_template = None
  312. completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
  313. if completion_prompt_config:
  314. completion_prompt_template_params = {
  315. 'prompt': completion_prompt_config['prompt']['text'],
  316. }
  317. if 'conversation_histories_role' in completion_prompt_config:
  318. completion_prompt_template_params['role_prefix'] = {
  319. 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
  320. 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
  321. }
  322. advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
  323. **completion_prompt_template_params
  324. )
  325. properties['prompt_template'] = PromptTemplateEntity(
  326. prompt_type=prompt_type,
  327. advanced_chat_prompt_template=advanced_chat_prompt_template,
  328. advanced_completion_prompt_template=advanced_completion_prompt_template
  329. )
  330. # external data variables
  331. properties['external_data_variables'] = []
  332. # old external_data_tools
  333. external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
  334. for external_data_tool in external_data_tools:
  335. if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
  336. continue
  337. properties['external_data_variables'].append(
  338. ExternalDataVariableEntity(
  339. variable=external_data_tool['variable'],
  340. type=external_data_tool['type'],
  341. config=external_data_tool['config']
  342. )
  343. )
  344. # current external_data_tools
  345. for variable in copy_app_model_config_dict.get('user_input_form', []):
  346. typ = list(variable.keys())[0]
  347. if typ == 'external_data_tool':
  348. val = variable[typ]
  349. properties['external_data_variables'].append(
  350. ExternalDataVariableEntity(
  351. variable=val['variable'],
  352. type=val['type'],
  353. config=val['config']
  354. )
  355. )
  356. # show retrieve source
  357. show_retrieve_source = False
  358. retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
  359. if retriever_resource_dict:
  360. if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
  361. show_retrieve_source = True
  362. properties['show_retrieve_source'] = show_retrieve_source
  363. dataset_ids = []
  364. if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
  365. datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
  366. 'strategy': 'router',
  367. 'datasets': []
  368. })
  369. for dataset in datasets.get('datasets', []):
  370. keys = list(dataset.keys())
  371. if len(keys) == 0 or keys[0] != 'dataset':
  372. continue
  373. dataset = dataset['dataset']
  374. if 'enabled' not in dataset or not dataset['enabled']:
  375. continue
  376. dataset_id = dataset.get('id', None)
  377. if dataset_id:
  378. dataset_ids.append(dataset_id)
  379. else:
  380. datasets = {'strategy': 'router', 'datasets': []}
  381. if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
  382. and 'enabled' in copy_app_model_config_dict['agent_mode'] \
  383. and copy_app_model_config_dict['agent_mode']['enabled']:
  384. agent_dict = copy_app_model_config_dict.get('agent_mode', {})
  385. agent_strategy = agent_dict.get('strategy', 'cot')
  386. if agent_strategy == 'function_call':
  387. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  388. elif agent_strategy == 'cot' or agent_strategy == 'react':
  389. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  390. else:
  391. # old configs, try to detect default strategy
  392. if copy_app_model_config_dict['model']['provider'] == 'openai':
  393. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  394. else:
  395. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  396. agent_tools = []
  397. for tool in agent_dict.get('tools', []):
  398. keys = tool.keys()
  399. if len(keys) >= 4:
  400. if "enabled" not in tool or not tool["enabled"]:
  401. continue
  402. agent_tool_properties = {
  403. 'provider_type': tool['provider_type'],
  404. 'provider_id': tool['provider_id'],
  405. 'tool_name': tool['tool_name'],
  406. 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
  407. }
  408. agent_tools.append(AgentToolEntity(**agent_tool_properties))
  409. elif len(keys) == 1:
  410. # old standard
  411. key = list(tool.keys())[0]
  412. if key != 'dataset':
  413. continue
  414. tool_item = tool[key]
  415. if "enabled" not in tool_item or not tool_item["enabled"]:
  416. continue
  417. dataset_id = tool_item['id']
  418. dataset_ids.append(dataset_id)
  419. if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
  420. copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
  421. agent_prompt = agent_dict.get('prompt', None) or {}
  422. # check model mode
  423. model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
  424. if model_mode == 'completion':
  425. agent_prompt_entity = AgentPromptEntity(
  426. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
  427. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
  428. )
  429. else:
  430. agent_prompt_entity = AgentPromptEntity(
  431. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
  432. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
  433. )
  434. properties['agent'] = AgentEntity(
  435. provider=properties['model_config'].provider,
  436. model=properties['model_config'].model,
  437. strategy=strategy,
  438. prompt=agent_prompt_entity,
  439. tools=agent_tools,
  440. max_iteration=agent_dict.get('max_iteration', 5)
  441. )
  442. if len(dataset_ids) > 0:
  443. # dataset configs
  444. dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
  445. query_variable = copy_app_model_config_dict.get('dataset_query_variable')
  446. if dataset_configs['retrieval_model'] == 'single':
  447. properties['dataset'] = DatasetEntity(
  448. dataset_ids=dataset_ids,
  449. retrieve_config=DatasetRetrieveConfigEntity(
  450. query_variable=query_variable,
  451. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  452. dataset_configs['retrieval_model']
  453. ),
  454. single_strategy=datasets.get('strategy', 'router')
  455. )
  456. )
  457. else:
  458. properties['dataset'] = DatasetEntity(
  459. dataset_ids=dataset_ids,
  460. retrieve_config=DatasetRetrieveConfigEntity(
  461. query_variable=query_variable,
  462. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  463. dataset_configs['retrieval_model']
  464. ),
  465. top_k=dataset_configs.get('top_k'),
  466. score_threshold=dataset_configs.get('score_threshold'),
  467. reranking_model=dataset_configs.get('reranking_model')
  468. )
  469. )
  470. # file upload
  471. file_upload_dict = copy_app_model_config_dict.get('file_upload')
  472. if file_upload_dict:
  473. if 'image' in file_upload_dict and file_upload_dict['image']:
  474. if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
  475. properties['file_upload'] = FileUploadEntity(
  476. image_config={
  477. 'number_limits': file_upload_dict['image']['number_limits'],
  478. 'detail': file_upload_dict['image']['detail'],
  479. 'transfer_methods': file_upload_dict['image']['transfer_methods']
  480. }
  481. )
  482. # opening statement
  483. properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
  484. # suggested questions after answer
  485. suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
  486. if suggested_questions_after_answer_dict:
  487. if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
  488. properties['suggested_questions_after_answer'] = True
  489. # more like this
  490. more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
  491. if more_like_this_dict:
  492. if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
  493. properties['more_like_this'] = True
  494. # speech to text
  495. speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
  496. if speech_to_text_dict:
  497. if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
  498. properties['speech_to_text'] = True
  499. # text to speech
  500. text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
  501. if text_to_speech_dict:
  502. if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
  503. properties['text_to_speech'] = True
  504. # sensitive word avoidance
  505. sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
  506. if sensitive_word_avoidance_dict:
  507. if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
  508. properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
  509. type=sensitive_word_avoidance_dict.get('type'),
  510. config=sensitive_word_avoidance_dict.get('config'),
  511. )
  512. return AppOrchestrationConfigEntity(**properties)
  513. def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
  514. -> Tuple[Conversation, Message]:
  515. """
  516. Initialize generate records
  517. :param application_generate_entity: application generate entity
  518. :return:
  519. """
  520. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  521. model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
  522. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  523. model_schema = model_type_instance.get_model_schema(
  524. model=app_orchestration_config_entity.model_config.model,
  525. credentials=app_orchestration_config_entity.model_config.credentials
  526. )
  527. app_record = (db.session.query(App)
  528. .filter(App.id == application_generate_entity.app_id).first())
  529. app_mode = app_record.mode
  530. # get from source
  531. end_user_id = None
  532. account_id = None
  533. if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
  534. from_source = 'api'
  535. end_user_id = application_generate_entity.user_id
  536. else:
  537. from_source = 'console'
  538. account_id = application_generate_entity.user_id
  539. override_model_configs = None
  540. if application_generate_entity.app_model_config_override:
  541. override_model_configs = application_generate_entity.app_model_config_dict
  542. introduction = ''
  543. if app_mode == 'chat':
  544. # get conversation introduction
  545. introduction = self._get_conversation_introduction(application_generate_entity)
  546. if not application_generate_entity.conversation_id:
  547. conversation = Conversation(
  548. app_id=app_record.id,
  549. app_model_config_id=application_generate_entity.app_model_config_id,
  550. model_provider=app_orchestration_config_entity.model_config.provider,
  551. model_id=app_orchestration_config_entity.model_config.model,
  552. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  553. mode=app_mode,
  554. name='New conversation',
  555. inputs=application_generate_entity.inputs,
  556. introduction=introduction,
  557. system_instruction="",
  558. system_instruction_tokens=0,
  559. status='normal',
  560. from_source=from_source,
  561. from_end_user_id=end_user_id,
  562. from_account_id=account_id,
  563. )
  564. db.session.add(conversation)
  565. db.session.commit()
  566. else:
  567. conversation = (
  568. db.session.query(Conversation)
  569. .filter(
  570. Conversation.id == application_generate_entity.conversation_id,
  571. Conversation.app_id == app_record.id
  572. ).first()
  573. )
  574. currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
  575. message = Message(
  576. app_id=app_record.id,
  577. model_provider=app_orchestration_config_entity.model_config.provider,
  578. model_id=app_orchestration_config_entity.model_config.model,
  579. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  580. conversation_id=conversation.id,
  581. inputs=application_generate_entity.inputs,
  582. query=application_generate_entity.query or "",
  583. message="",
  584. message_tokens=0,
  585. message_unit_price=0,
  586. message_price_unit=0,
  587. answer="",
  588. answer_tokens=0,
  589. answer_unit_price=0,
  590. answer_price_unit=0,
  591. provider_response_latency=0,
  592. total_price=0,
  593. currency=currency,
  594. from_source=from_source,
  595. from_end_user_id=end_user_id,
  596. from_account_id=account_id,
  597. agent_based=app_orchestration_config_entity.agent is not None
  598. )
  599. db.session.add(message)
  600. db.session.commit()
  601. for file in application_generate_entity.files:
  602. message_file = MessageFile(
  603. message_id=message.id,
  604. type=file.type.value,
  605. transfer_method=file.transfer_method.value,
  606. belongs_to='user',
  607. url=file.url,
  608. upload_file_id=file.upload_file_id,
  609. created_by_role=('account' if account_id else 'end_user'),
  610. created_by=account_id or end_user_id,
  611. )
  612. db.session.add(message_file)
  613. db.session.commit()
  614. return conversation, message
  615. def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
  616. """
  617. Get conversation introduction
  618. :param application_generate_entity: application generate entity
  619. :return: conversation introduction
  620. """
  621. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  622. introduction = app_orchestration_config_entity.opening_statement
  623. if introduction:
  624. try:
  625. inputs = application_generate_entity.inputs
  626. prompt_template = PromptTemplateParser(template=introduction)
  627. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  628. introduction = prompt_template.format(prompt_inputs)
  629. except KeyError:
  630. pass
  631. return introduction
  632. def _get_conversation(self, conversation_id: str) -> Conversation:
  633. """
  634. Get conversation by conversation id
  635. :param conversation_id: conversation id
  636. :return: conversation
  637. """
  638. conversation = (
  639. db.session.query(Conversation)
  640. .filter(Conversation.id == conversation_id)
  641. .first()
  642. )
  643. return conversation
  644. def _get_message(self, message_id: str) -> Message:
  645. """
  646. Get message by message id
  647. :param message_id: message id
  648. :return: message
  649. """
  650. message = (
  651. db.session.query(Message)
  652. .filter(Message.id == message_id)
  653. .first()
  654. )
  655. return message