application_manager.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738
  1. import json
  2. import logging
  3. import threading
  4. import uuid
  5. from typing import Any, Generator, Optional, Tuple, Union, cast
  6. from core.app_runner.assistant_app_runner import AssistantApplicationRunner
  7. from core.app_runner.basic_app_runner import BasicApplicationRunner
  8. from core.app_runner.generate_task_pipeline import GenerateTaskPipeline
  9. from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
  10. from core.entities.application_entities import (AdvancedChatPromptTemplateEntity,
  11. AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentPromptEntity,
  12. AgentToolEntity, ApplicationGenerateEntity,
  13. AppOrchestrationConfigEntity, DatasetEntity,
  14. DatasetRetrieveConfigEntity, ExternalDataVariableEntity,
  15. FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity,
  16. SensitiveWordAvoidanceEntity)
  17. from core.entities.model_entities import ModelStatus
  18. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  19. from core.file.file_obj import FileObj
  20. from core.model_runtime.entities.message_entities import PromptMessageRole
  21. from core.model_runtime.entities.model_entities import ModelType
  22. from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
  23. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  24. from core.prompt.prompt_template import PromptTemplateParser
  25. from core.provider_manager import ProviderManager
  26. from core.tools.prompt.template import REACT_PROMPT_TEMPLATES
  27. from extensions.ext_database import db
  28. from flask import Flask, current_app
  29. from models.account import Account
  30. from models.model import App, Conversation, EndUser, Message, MessageFile
  31. from pydantic import ValidationError
  32. logger = logging.getLogger(__name__)
  33. class ApplicationManager:
  34. """
  35. This class is responsible for managing application
  36. """
  37. def generate(self, tenant_id: str,
  38. app_id: str,
  39. app_model_config_id: str,
  40. app_model_config_dict: dict,
  41. app_model_config_override: bool,
  42. user: Union[Account, EndUser],
  43. invoke_from: InvokeFrom,
  44. inputs: dict[str, str],
  45. query: Optional[str] = None,
  46. files: Optional[list[FileObj]] = None,
  47. conversation: Optional[Conversation] = None,
  48. stream: bool = False,
  49. extras: Optional[dict[str, Any]] = None) \
  50. -> Union[dict, Generator]:
  51. """
  52. Generate App response.
  53. :param tenant_id: workspace ID
  54. :param app_id: app ID
  55. :param app_model_config_id: app model config id
  56. :param app_model_config_dict: app model config dict
  57. :param app_model_config_override: app model config override
  58. :param user: account or end user
  59. :param invoke_from: invoke from source
  60. :param inputs: inputs
  61. :param query: query
  62. :param files: file obj list
  63. :param conversation: conversation
  64. :param stream: is stream
  65. :param extras: extras
  66. """
  67. # init task id
  68. task_id = str(uuid.uuid4())
  69. # init application generate entity
  70. application_generate_entity = ApplicationGenerateEntity(
  71. task_id=task_id,
  72. tenant_id=tenant_id,
  73. app_id=app_id,
  74. app_model_config_id=app_model_config_id,
  75. app_model_config_dict=app_model_config_dict,
  76. app_orchestration_config_entity=self._convert_from_app_model_config_dict(
  77. tenant_id=tenant_id,
  78. app_model_config_dict=app_model_config_dict
  79. ),
  80. app_model_config_override=app_model_config_override,
  81. conversation_id=conversation.id if conversation else None,
  82. inputs=conversation.inputs if conversation else inputs,
  83. query=query.replace('\x00', '') if query else None,
  84. files=files if files else [],
  85. user_id=user.id,
  86. stream=stream,
  87. invoke_from=invoke_from,
  88. extras=extras
  89. )
  90. if not stream and application_generate_entity.app_orchestration_config_entity.agent:
  91. raise ValueError("Agent app is not supported in blocking mode.")
  92. # init generate records
  93. (
  94. conversation,
  95. message
  96. ) = self._init_generate_records(application_generate_entity)
  97. # init queue manager
  98. queue_manager = ApplicationQueueManager(
  99. task_id=application_generate_entity.task_id,
  100. user_id=application_generate_entity.user_id,
  101. invoke_from=application_generate_entity.invoke_from,
  102. conversation_id=conversation.id,
  103. app_mode=conversation.mode,
  104. message_id=message.id
  105. )
  106. # new thread
  107. worker_thread = threading.Thread(target=self._generate_worker, kwargs={
  108. 'flask_app': current_app._get_current_object(),
  109. 'application_generate_entity': application_generate_entity,
  110. 'queue_manager': queue_manager,
  111. 'conversation_id': conversation.id,
  112. 'message_id': message.id,
  113. })
  114. worker_thread.start()
  115. # return response or stream generator
  116. return self._handle_response(
  117. application_generate_entity=application_generate_entity,
  118. queue_manager=queue_manager,
  119. conversation=conversation,
  120. message=message,
  121. stream=stream
  122. )
  123. def _generate_worker(self, flask_app: Flask,
  124. application_generate_entity: ApplicationGenerateEntity,
  125. queue_manager: ApplicationQueueManager,
  126. conversation_id: str,
  127. message_id: str) -> None:
  128. """
  129. Generate worker in a new thread.
  130. :param flask_app: Flask app
  131. :param application_generate_entity: application generate entity
  132. :param queue_manager: queue manager
  133. :param conversation_id: conversation ID
  134. :param message_id: message ID
  135. :return:
  136. """
  137. with flask_app.app_context():
  138. try:
  139. # get conversation and message
  140. conversation = self._get_conversation(conversation_id)
  141. message = self._get_message(message_id)
  142. if application_generate_entity.app_orchestration_config_entity.agent:
  143. # agent app
  144. runner = AssistantApplicationRunner()
  145. runner.run(
  146. application_generate_entity=application_generate_entity,
  147. queue_manager=queue_manager,
  148. conversation=conversation,
  149. message=message
  150. )
  151. else:
  152. # basic app
  153. runner = BasicApplicationRunner()
  154. runner.run(
  155. application_generate_entity=application_generate_entity,
  156. queue_manager=queue_manager,
  157. conversation=conversation,
  158. message=message
  159. )
  160. except ConversationTaskStoppedException:
  161. pass
  162. except InvokeAuthorizationError:
  163. queue_manager.publish_error(
  164. InvokeAuthorizationError('Incorrect API key provided'),
  165. PublishFrom.APPLICATION_MANAGER
  166. )
  167. except ValidationError as e:
  168. logger.exception("Validation Error when generating")
  169. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  170. except (ValueError, InvokeError) as e:
  171. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  172. except Exception as e:
  173. logger.exception("Unknown Error when generating")
  174. queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
  175. finally:
  176. db.session.remove()
  177. def _handle_response(self, application_generate_entity: ApplicationGenerateEntity,
  178. queue_manager: ApplicationQueueManager,
  179. conversation: Conversation,
  180. message: Message,
  181. stream: bool = False) -> Union[dict, Generator]:
  182. """
  183. Handle response.
  184. :param application_generate_entity: application generate entity
  185. :param queue_manager: queue manager
  186. :param conversation: conversation
  187. :param message: message
  188. :param stream: is stream
  189. :return:
  190. """
  191. # init generate task pipeline
  192. generate_task_pipeline = GenerateTaskPipeline(
  193. application_generate_entity=application_generate_entity,
  194. queue_manager=queue_manager,
  195. conversation=conversation,
  196. message=message
  197. )
  198. try:
  199. return generate_task_pipeline.process(stream=stream)
  200. except ValueError as e:
  201. if e.args[0] == "I/O operation on closed file.": # ignore this error
  202. raise ConversationTaskStoppedException()
  203. else:
  204. logger.exception(e)
  205. raise e
  206. finally:
  207. db.session.remove()
  208. def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \
  209. -> AppOrchestrationConfigEntity:
  210. """
  211. Convert app model config dict to entity.
  212. :param tenant_id: tenant ID
  213. :param app_model_config_dict: app model config dict
  214. :raises ProviderTokenNotInitError: provider token not init error
  215. :return: app orchestration config entity
  216. """
  217. properties = {}
  218. copy_app_model_config_dict = app_model_config_dict.copy()
  219. provider_manager = ProviderManager()
  220. provider_model_bundle = provider_manager.get_provider_model_bundle(
  221. tenant_id=tenant_id,
  222. provider=copy_app_model_config_dict['model']['provider'],
  223. model_type=ModelType.LLM
  224. )
  225. provider_name = provider_model_bundle.configuration.provider.provider
  226. model_name = copy_app_model_config_dict['model']['name']
  227. model_type_instance = provider_model_bundle.model_type_instance
  228. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  229. # check model credentials
  230. model_credentials = provider_model_bundle.configuration.get_current_credentials(
  231. model_type=ModelType.LLM,
  232. model=copy_app_model_config_dict['model']['name']
  233. )
  234. if model_credentials is None:
  235. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  236. # check model
  237. provider_model = provider_model_bundle.configuration.get_provider_model(
  238. model=copy_app_model_config_dict['model']['name'],
  239. model_type=ModelType.LLM
  240. )
  241. if provider_model is None:
  242. model_name = copy_app_model_config_dict['model']['name']
  243. raise ValueError(f"Model {model_name} not exist.")
  244. if provider_model.status == ModelStatus.NO_CONFIGURE:
  245. raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
  246. elif provider_model.status == ModelStatus.NO_PERMISSION:
  247. raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.")
  248. elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
  249. raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")
  250. # model config
  251. completion_params = copy_app_model_config_dict['model'].get('completion_params')
  252. stop = []
  253. if 'stop' in completion_params:
  254. stop = completion_params['stop']
  255. del completion_params['stop']
  256. # get model mode
  257. model_mode = copy_app_model_config_dict['model'].get('mode')
  258. if not model_mode:
  259. mode_enum = model_type_instance.get_model_mode(
  260. model=copy_app_model_config_dict['model']['name'],
  261. credentials=model_credentials
  262. )
  263. model_mode = mode_enum.value
  264. model_schema = model_type_instance.get_model_schema(
  265. copy_app_model_config_dict['model']['name'],
  266. model_credentials
  267. )
  268. if not model_schema:
  269. raise ValueError(f"Model {model_name} not exist.")
  270. properties['model_config'] = ModelConfigEntity(
  271. provider=copy_app_model_config_dict['model']['provider'],
  272. model=copy_app_model_config_dict['model']['name'],
  273. model_schema=model_schema,
  274. mode=model_mode,
  275. provider_model_bundle=provider_model_bundle,
  276. credentials=model_credentials,
  277. parameters=completion_params,
  278. stop=stop,
  279. )
  280. # prompt template
  281. prompt_type = PromptTemplateEntity.PromptType.value_of(copy_app_model_config_dict['prompt_type'])
  282. if prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
  283. simple_prompt_template = copy_app_model_config_dict.get("pre_prompt", "")
  284. properties['prompt_template'] = PromptTemplateEntity(
  285. prompt_type=prompt_type,
  286. simple_prompt_template=simple_prompt_template
  287. )
  288. else:
  289. advanced_chat_prompt_template = None
  290. chat_prompt_config = copy_app_model_config_dict.get("chat_prompt_config", {})
  291. if chat_prompt_config:
  292. chat_prompt_messages = []
  293. for message in chat_prompt_config.get("prompt", []):
  294. chat_prompt_messages.append({
  295. "text": message["text"],
  296. "role": PromptMessageRole.value_of(message["role"])
  297. })
  298. advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(
  299. messages=chat_prompt_messages
  300. )
  301. advanced_completion_prompt_template = None
  302. completion_prompt_config = copy_app_model_config_dict.get("completion_prompt_config", {})
  303. if completion_prompt_config:
  304. completion_prompt_template_params = {
  305. 'prompt': completion_prompt_config['prompt']['text'],
  306. }
  307. if 'conversation_histories_role' in completion_prompt_config:
  308. completion_prompt_template_params['role_prefix'] = {
  309. 'user': completion_prompt_config['conversation_histories_role']['user_prefix'],
  310. 'assistant': completion_prompt_config['conversation_histories_role']['assistant_prefix']
  311. }
  312. advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity(
  313. **completion_prompt_template_params
  314. )
  315. properties['prompt_template'] = PromptTemplateEntity(
  316. prompt_type=prompt_type,
  317. advanced_chat_prompt_template=advanced_chat_prompt_template,
  318. advanced_completion_prompt_template=advanced_completion_prompt_template
  319. )
  320. # external data variables
  321. properties['external_data_variables'] = []
  322. # old external_data_tools
  323. external_data_tools = copy_app_model_config_dict.get('external_data_tools', [])
  324. for external_data_tool in external_data_tools:
  325. if 'enabled' not in external_data_tool or not external_data_tool['enabled']:
  326. continue
  327. properties['external_data_variables'].append(
  328. ExternalDataVariableEntity(
  329. variable=external_data_tool['variable'],
  330. type=external_data_tool['type'],
  331. config=external_data_tool['config']
  332. )
  333. )
  334. # current external_data_tools
  335. for variable in copy_app_model_config_dict.get('user_input_form', []):
  336. typ = list(variable.keys())[0]
  337. if typ == 'external_data_tool':
  338. val = variable[typ]
  339. properties['external_data_variables'].append(
  340. ExternalDataVariableEntity(
  341. variable=val['variable'],
  342. type=val['type'],
  343. config=val['config']
  344. )
  345. )
  346. # show retrieve source
  347. show_retrieve_source = False
  348. retriever_resource_dict = copy_app_model_config_dict.get('retriever_resource')
  349. if retriever_resource_dict:
  350. if 'enabled' in retriever_resource_dict and retriever_resource_dict['enabled']:
  351. show_retrieve_source = True
  352. properties['show_retrieve_source'] = show_retrieve_source
  353. dataset_ids = []
  354. if 'datasets' in copy_app_model_config_dict.get('dataset_configs', {}):
  355. datasets = copy_app_model_config_dict.get('dataset_configs', {}).get('datasets', {
  356. 'strategy': 'router',
  357. 'datasets': []
  358. })
  359. for dataset in datasets.get('datasets', []):
  360. keys = list(dataset.keys())
  361. if len(keys) == 0 or keys[0] != 'dataset':
  362. continue
  363. dataset = dataset['dataset']
  364. if 'enabled' not in dataset or not dataset['enabled']:
  365. continue
  366. dataset_id = dataset.get('id', None)
  367. if dataset_id:
  368. dataset_ids.append(dataset_id)
  369. else:
  370. datasets = {'strategy': 'router', 'datasets': []}
  371. if 'agent_mode' in copy_app_model_config_dict and copy_app_model_config_dict['agent_mode'] \
  372. and 'enabled' in copy_app_model_config_dict['agent_mode'] \
  373. and copy_app_model_config_dict['agent_mode']['enabled']:
  374. agent_dict = copy_app_model_config_dict.get('agent_mode', {})
  375. agent_strategy = agent_dict.get('strategy', 'cot')
  376. if agent_strategy == 'function_call':
  377. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  378. elif agent_strategy == 'cot' or agent_strategy == 'react':
  379. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  380. else:
  381. # old configs, try to detect default strategy
  382. if copy_app_model_config_dict['model']['provider'] == 'openai':
  383. strategy = AgentEntity.Strategy.FUNCTION_CALLING
  384. else:
  385. strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
  386. agent_tools = []
  387. for tool in agent_dict.get('tools', []):
  388. keys = tool.keys()
  389. if len(keys) >= 4:
  390. if "enabled" not in tool or not tool["enabled"]:
  391. continue
  392. agent_tool_properties = {
  393. 'provider_type': tool['provider_type'],
  394. 'provider_id': tool['provider_id'],
  395. 'tool_name': tool['tool_name'],
  396. 'tool_parameters': tool['tool_parameters'] if 'tool_parameters' in tool else {}
  397. }
  398. agent_tools.append(AgentToolEntity(**agent_tool_properties))
  399. elif len(keys) == 1:
  400. # old standard
  401. key = list(tool.keys())[0]
  402. if key != 'dataset':
  403. continue
  404. tool_item = tool[key]
  405. if "enabled" not in tool_item or not tool_item["enabled"]:
  406. continue
  407. dataset_id = tool_item['id']
  408. dataset_ids.append(dataset_id)
  409. if 'strategy' in copy_app_model_config_dict['agent_mode'] and \
  410. copy_app_model_config_dict['agent_mode']['strategy'] not in ['react_router', 'router']:
  411. agent_prompt = agent_dict.get('prompt', None) or {}
  412. # check model mode
  413. model_mode = copy_app_model_config_dict.get('model', {}).get('mode', 'completion')
  414. if model_mode == 'completion':
  415. agent_prompt_entity = AgentPromptEntity(
  416. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['completion']['prompt']),
  417. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['completion']['agent_scratchpad']),
  418. )
  419. else:
  420. agent_prompt_entity = AgentPromptEntity(
  421. first_prompt=agent_prompt.get('first_prompt', REACT_PROMPT_TEMPLATES['english']['chat']['prompt']),
  422. next_iteration=agent_prompt.get('next_iteration', REACT_PROMPT_TEMPLATES['english']['chat']['agent_scratchpad']),
  423. )
  424. properties['agent'] = AgentEntity(
  425. provider=properties['model_config'].provider,
  426. model=properties['model_config'].model,
  427. strategy=strategy,
  428. prompt=agent_prompt_entity,
  429. tools=agent_tools,
  430. max_iteration=agent_dict.get('max_iteration', 5)
  431. )
  432. if len(dataset_ids) > 0:
  433. # dataset configs
  434. dataset_configs = copy_app_model_config_dict.get('dataset_configs', {'retrieval_model': 'single'})
  435. query_variable = copy_app_model_config_dict.get('dataset_query_variable')
  436. if dataset_configs['retrieval_model'] == 'single':
  437. properties['dataset'] = DatasetEntity(
  438. dataset_ids=dataset_ids,
  439. retrieve_config=DatasetRetrieveConfigEntity(
  440. query_variable=query_variable,
  441. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  442. dataset_configs['retrieval_model']
  443. ),
  444. single_strategy=datasets.get('strategy', 'router')
  445. )
  446. )
  447. else:
  448. properties['dataset'] = DatasetEntity(
  449. dataset_ids=dataset_ids,
  450. retrieve_config=DatasetRetrieveConfigEntity(
  451. query_variable=query_variable,
  452. retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
  453. dataset_configs['retrieval_model']
  454. ),
  455. top_k=dataset_configs.get('top_k'),
  456. score_threshold=dataset_configs.get('score_threshold'),
  457. reranking_model=dataset_configs.get('reranking_model')
  458. )
  459. )
  460. # file upload
  461. file_upload_dict = copy_app_model_config_dict.get('file_upload')
  462. if file_upload_dict:
  463. if 'image' in file_upload_dict and file_upload_dict['image']:
  464. if 'enabled' in file_upload_dict['image'] and file_upload_dict['image']['enabled']:
  465. properties['file_upload'] = FileUploadEntity(
  466. image_config={
  467. 'number_limits': file_upload_dict['image']['number_limits'],
  468. 'detail': file_upload_dict['image']['detail'],
  469. 'transfer_methods': file_upload_dict['image']['transfer_methods']
  470. }
  471. )
  472. # opening statement
  473. properties['opening_statement'] = copy_app_model_config_dict.get('opening_statement')
  474. # suggested questions after answer
  475. suggested_questions_after_answer_dict = copy_app_model_config_dict.get('suggested_questions_after_answer')
  476. if suggested_questions_after_answer_dict:
  477. if 'enabled' in suggested_questions_after_answer_dict and suggested_questions_after_answer_dict['enabled']:
  478. properties['suggested_questions_after_answer'] = True
  479. # more like this
  480. more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
  481. if more_like_this_dict:
  482. if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
  483. properties['more_like_this'] = True
  484. # speech to text
  485. speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
  486. if speech_to_text_dict:
  487. if 'enabled' in speech_to_text_dict and speech_to_text_dict['enabled']:
  488. properties['speech_to_text'] = True
  489. # text to speech
  490. text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech')
  491. if text_to_speech_dict:
  492. if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']:
  493. properties['text_to_speech'] = True
  494. # sensitive word avoidance
  495. sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance')
  496. if sensitive_word_avoidance_dict:
  497. if 'enabled' in sensitive_word_avoidance_dict and sensitive_word_avoidance_dict['enabled']:
  498. properties['sensitive_word_avoidance'] = SensitiveWordAvoidanceEntity(
  499. type=sensitive_word_avoidance_dict.get('type'),
  500. config=sensitive_word_avoidance_dict.get('config'),
  501. )
  502. return AppOrchestrationConfigEntity(**properties)
  503. def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \
  504. -> Tuple[Conversation, Message]:
  505. """
  506. Initialize generate records
  507. :param application_generate_entity: application generate entity
  508. :return:
  509. """
  510. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  511. model_type_instance = app_orchestration_config_entity.model_config.provider_model_bundle.model_type_instance
  512. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  513. model_schema = model_type_instance.get_model_schema(
  514. model=app_orchestration_config_entity.model_config.model,
  515. credentials=app_orchestration_config_entity.model_config.credentials
  516. )
  517. app_record = (db.session.query(App)
  518. .filter(App.id == application_generate_entity.app_id).first())
  519. app_mode = app_record.mode
  520. # get from source
  521. end_user_id = None
  522. account_id = None
  523. if application_generate_entity.invoke_from in [InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API]:
  524. from_source = 'api'
  525. end_user_id = application_generate_entity.user_id
  526. else:
  527. from_source = 'console'
  528. account_id = application_generate_entity.user_id
  529. override_model_configs = None
  530. if application_generate_entity.app_model_config_override:
  531. override_model_configs = application_generate_entity.app_model_config_dict
  532. introduction = ''
  533. if app_mode == 'chat':
  534. # get conversation introduction
  535. introduction = self._get_conversation_introduction(application_generate_entity)
  536. if not application_generate_entity.conversation_id:
  537. conversation = Conversation(
  538. app_id=app_record.id,
  539. app_model_config_id=application_generate_entity.app_model_config_id,
  540. model_provider=app_orchestration_config_entity.model_config.provider,
  541. model_id=app_orchestration_config_entity.model_config.model,
  542. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  543. mode=app_mode,
  544. name='New conversation',
  545. inputs=application_generate_entity.inputs,
  546. introduction=introduction,
  547. system_instruction="",
  548. system_instruction_tokens=0,
  549. status='normal',
  550. from_source=from_source,
  551. from_end_user_id=end_user_id,
  552. from_account_id=account_id,
  553. )
  554. db.session.add(conversation)
  555. db.session.commit()
  556. else:
  557. conversation = (
  558. db.session.query(Conversation)
  559. .filter(
  560. Conversation.id == application_generate_entity.conversation_id,
  561. Conversation.app_id == app_record.id
  562. ).first()
  563. )
  564. currency = model_schema.pricing.currency if model_schema.pricing else 'USD'
  565. message = Message(
  566. app_id=app_record.id,
  567. model_provider=app_orchestration_config_entity.model_config.provider,
  568. model_id=app_orchestration_config_entity.model_config.model,
  569. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  570. conversation_id=conversation.id,
  571. inputs=application_generate_entity.inputs,
  572. query=application_generate_entity.query or "",
  573. message="",
  574. message_tokens=0,
  575. message_unit_price=0,
  576. message_price_unit=0,
  577. answer="",
  578. answer_tokens=0,
  579. answer_unit_price=0,
  580. answer_price_unit=0,
  581. provider_response_latency=0,
  582. total_price=0,
  583. currency=currency,
  584. from_source=from_source,
  585. from_end_user_id=end_user_id,
  586. from_account_id=account_id,
  587. agent_based=app_orchestration_config_entity.agent is not None
  588. )
  589. db.session.add(message)
  590. db.session.commit()
  591. for file in application_generate_entity.files:
  592. message_file = MessageFile(
  593. message_id=message.id,
  594. type=file.type.value,
  595. transfer_method=file.transfer_method.value,
  596. belongs_to='user',
  597. url=file.url,
  598. upload_file_id=file.upload_file_id,
  599. created_by_role=('account' if account_id else 'end_user'),
  600. created_by=account_id or end_user_id,
  601. )
  602. db.session.add(message_file)
  603. db.session.commit()
  604. return conversation, message
  605. def _get_conversation_introduction(self, application_generate_entity: ApplicationGenerateEntity) -> str:
  606. """
  607. Get conversation introduction
  608. :param application_generate_entity: application generate entity
  609. :return: conversation introduction
  610. """
  611. app_orchestration_config_entity = application_generate_entity.app_orchestration_config_entity
  612. introduction = app_orchestration_config_entity.opening_statement
  613. if introduction:
  614. try:
  615. inputs = application_generate_entity.inputs
  616. prompt_template = PromptTemplateParser(template=introduction)
  617. prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
  618. introduction = prompt_template.format(prompt_inputs)
  619. except KeyError:
  620. pass
  621. return introduction
  622. def _get_conversation(self, conversation_id: str) -> Conversation:
  623. """
  624. Get conversation by conversation id
  625. :param conversation_id: conversation id
  626. :return: conversation
  627. """
  628. conversation = (
  629. db.session.query(Conversation)
  630. .filter(Conversation.id == conversation_id)
  631. .first()
  632. )
  633. return conversation
  634. def _get_message(self, message_id: str) -> Message:
  635. """
  636. Get message by message id
  637. :param message_id: message id
  638. :return: message
  639. """
  640. message = (
  641. db.session.query(Message)
  642. .filter(Message.id == message_id)
  643. .first()
  644. )
  645. return message