conversation_message_task.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. import json
  2. import time
  3. from typing import Optional, Union, List
  4. from core.callback_handler.entity.agent_loop import AgentLoop
  5. from core.callback_handler.entity.dataset_query import DatasetQueryObj
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.callback_handler.entity.chain_result import ChainResult
  8. from core.model_providers.model_factory import ModelFactory
  9. from core.model_providers.models.entity.message import to_prompt_messages, MessageType
  10. from core.model_providers.models.llm.base import BaseLLM
  11. from core.prompt.prompt_builder import PromptBuilder
  12. from core.prompt.prompt_template import JinjaPromptTemplate
  13. from events.message_event import message_was_created
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.dataset import DatasetQuery
  17. from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
  18. MessageChain, DatasetRetrieverResource
  19. class ConversationMessageTask:
  20. def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
  21. inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
  22. conversation: Optional[Conversation] = None, is_override: bool = False):
  23. self.start_at = time.perf_counter()
  24. self.task_id = task_id
  25. self.app = app
  26. self.tenant_id = app.tenant_id
  27. self.app_model_config = app_model_config
  28. self.is_override = is_override
  29. self.user = user
  30. self.inputs = inputs
  31. self.query = query
  32. self.streaming = streaming
  33. self.conversation = conversation
  34. self.is_new_conversation = False
  35. self.model_instance = model_instance
  36. self.message = None
  37. self.retriever_resource = None
  38. self.model_dict = self.app_model_config.model_dict
  39. self.provider_name = self.model_dict.get('provider')
  40. self.model_name = self.model_dict.get('name')
  41. self.mode = app.mode
  42. self.init()
  43. self._pub_handler = PubHandler(
  44. user=self.user,
  45. task_id=self.task_id,
  46. message=self.message,
  47. conversation=self.conversation,
  48. chain_pub=False, # disabled currently
  49. agent_thought_pub=True
  50. )
  51. def init(self):
  52. override_model_configs = None
  53. if self.is_override:
  54. override_model_configs = self.app_model_config.to_dict()
  55. introduction = ''
  56. system_instruction = ''
  57. system_instruction_tokens = 0
  58. if self.mode == 'chat':
  59. introduction = self.app_model_config.opening_statement
  60. if introduction:
  61. prompt_template = JinjaPromptTemplate.from_template(template=introduction)
  62. prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
  63. try:
  64. introduction = prompt_template.format(**prompt_inputs)
  65. except KeyError:
  66. pass
  67. if self.app_model_config.pre_prompt:
  68. system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
  69. system_instruction = system_message.content
  70. model_instance = ModelFactory.get_text_generation_model(
  71. tenant_id=self.tenant_id,
  72. model_provider_name=self.provider_name,
  73. model_name=self.model_name
  74. )
  75. system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
  76. if not self.conversation:
  77. self.is_new_conversation = True
  78. self.conversation = Conversation(
  79. app_id=self.app.id,
  80. app_model_config_id=self.app_model_config.id,
  81. model_provider=self.provider_name,
  82. model_id=self.model_name,
  83. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  84. mode=self.mode,
  85. name='',
  86. inputs=self.inputs,
  87. introduction=introduction,
  88. system_instruction=system_instruction,
  89. system_instruction_tokens=system_instruction_tokens,
  90. status='normal',
  91. from_source=('console' if isinstance(self.user, Account) else 'api'),
  92. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  93. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  94. )
  95. db.session.add(self.conversation)
  96. db.session.commit()
  97. self.message = Message(
  98. app_id=self.app.id,
  99. model_provider=self.provider_name,
  100. model_id=self.model_name,
  101. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  102. conversation_id=self.conversation.id,
  103. inputs=self.inputs,
  104. query=self.query,
  105. message="",
  106. message_tokens=0,
  107. message_unit_price=0,
  108. message_price_unit=0,
  109. answer="",
  110. answer_tokens=0,
  111. answer_unit_price=0,
  112. answer_price_unit=0,
  113. provider_response_latency=0,
  114. total_price=0,
  115. currency=self.model_instance.get_currency(),
  116. from_source=('console' if isinstance(self.user, Account) else 'api'),
  117. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  118. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  119. agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
  120. )
  121. db.session.add(self.message)
  122. db.session.commit()
  123. def append_message_text(self, text: str):
  124. if text is not None:
  125. self._pub_handler.pub_text(text)
  126. def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
  127. message_tokens = llm_message.prompt_tokens
  128. answer_tokens = llm_message.completion_tokens
  129. message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
  130. message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
  131. answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
  132. answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
  133. message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
  134. answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
  135. total_price = message_total_price + answer_total_price
  136. self.message.message = llm_message.prompt
  137. self.message.message_tokens = message_tokens
  138. self.message.message_unit_price = message_unit_price
  139. self.message.message_price_unit = message_price_unit
  140. self.message.answer = PromptBuilder.process_template(
  141. llm_message.completion.strip()) if llm_message.completion else ''
  142. self.message.answer_tokens = answer_tokens
  143. self.message.answer_unit_price = answer_unit_price
  144. self.message.answer_price_unit = answer_price_unit
  145. self.message.provider_response_latency = time.perf_counter() - self.start_at
  146. self.message.total_price = total_price
  147. db.session.commit()
  148. message_was_created.send(
  149. self.message,
  150. conversation=self.conversation,
  151. is_first_message=self.is_new_conversation
  152. )
  153. if not by_stopped:
  154. self.end()
  155. def init_chain(self, chain_result: ChainResult):
  156. message_chain = MessageChain(
  157. message_id=self.message.id,
  158. type=chain_result.type,
  159. input=json.dumps(chain_result.prompt),
  160. output=''
  161. )
  162. db.session.add(message_chain)
  163. db.session.commit()
  164. return message_chain
  165. def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
  166. message_chain.output = json.dumps(chain_result.completion)
  167. db.session.commit()
  168. self._pub_handler.pub_chain(message_chain)
  169. def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
  170. message_agent_thought = MessageAgentThought(
  171. message_id=self.message.id,
  172. message_chain_id=message_chain.id,
  173. position=agent_loop.position,
  174. thought=agent_loop.thought,
  175. tool=agent_loop.tool_name,
  176. tool_input=agent_loop.tool_input,
  177. message=agent_loop.prompt,
  178. message_price_unit=0,
  179. answer=agent_loop.completion,
  180. answer_price_unit=0,
  181. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  182. created_by=self.user.id
  183. )
  184. db.session.add(message_agent_thought)
  185. db.session.commit()
  186. self._pub_handler.pub_agent_thought(message_agent_thought)
  187. return message_agent_thought
  188. def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
  189. agent_loop: AgentLoop):
  190. agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
  191. agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
  192. agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
  193. agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
  194. loop_message_tokens = agent_loop.prompt_tokens
  195. loop_answer_tokens = agent_loop.completion_tokens
  196. loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
  197. loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
  198. loop_total_price = loop_message_total_price + loop_answer_total_price
  199. message_agent_thought.observation = agent_loop.tool_output
  200. message_agent_thought.tool_process_data = '' # currently not support
  201. message_agent_thought.message_token = loop_message_tokens
  202. message_agent_thought.message_unit_price = agent_message_unit_price
  203. message_agent_thought.message_price_unit = agent_message_price_unit
  204. message_agent_thought.answer_token = loop_answer_tokens
  205. message_agent_thought.answer_unit_price = agent_answer_unit_price
  206. message_agent_thought.answer_price_unit = agent_answer_price_unit
  207. message_agent_thought.latency = agent_loop.latency
  208. message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
  209. message_agent_thought.total_price = loop_total_price
  210. message_agent_thought.currency = agent_model_instance.get_currency()
  211. db.session.commit()
  212. def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
  213. dataset_query = DatasetQuery(
  214. dataset_id=dataset_query_obj.dataset_id,
  215. content=dataset_query_obj.query,
  216. source='app',
  217. source_app_id=self.app.id,
  218. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  219. created_by=self.user.id
  220. )
  221. db.session.add(dataset_query)
  222. db.session.commit()
  223. def on_dataset_query_finish(self, resource: List):
  224. if resource and len(resource) > 0:
  225. for item in resource:
  226. dataset_retriever_resource = DatasetRetrieverResource(
  227. message_id=self.message.id,
  228. position=item.get('position'),
  229. dataset_id=item.get('dataset_id'),
  230. dataset_name=item.get('dataset_name'),
  231. document_id=item.get('document_id'),
  232. document_name=item.get('document_name'),
  233. data_source_type=item.get('data_source_type'),
  234. segment_id=item.get('segment_id'),
  235. score=item.get('score') if 'score' in item else None,
  236. hit_count=item.get('hit_count') if 'hit_count' else None,
  237. word_count=item.get('word_count') if 'word_count' in item else None,
  238. segment_position=item.get('segment_position') if 'segment_position' in item else None,
  239. index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
  240. content=item.get('content'),
  241. retriever_from=item.get('retriever_from'),
  242. created_by=self.user.id
  243. )
  244. db.session.add(dataset_retriever_resource)
  245. db.session.commit()
  246. self.retriever_resource = resource
  247. def message_end(self):
  248. self._pub_handler.pub_message_end(self.retriever_resource)
  249. def end(self):
  250. self._pub_handler.pub_message_end(self.retriever_resource)
  251. self._pub_handler.pub_end()
  252. class PubHandler:
  253. def __init__(self, user: Union[Account | EndUser], task_id: str,
  254. message: Message, conversation: Conversation,
  255. chain_pub: bool = False, agent_thought_pub: bool = False):
  256. self._channel = PubHandler.generate_channel_name(user, task_id)
  257. self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
  258. self._task_id = task_id
  259. self._message = message
  260. self._conversation = conversation
  261. self._chain_pub = chain_pub
  262. self._agent_thought_pub = agent_thought_pub
  263. @classmethod
  264. def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
  265. if not user:
  266. raise ValueError("user is required")
  267. user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
  268. return "generate_result:{}-{}".format(user_str, task_id)
  269. @classmethod
  270. def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
  271. user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
  272. return "generate_result_stopped:{}-{}".format(user_str, task_id)
  273. def pub_text(self, text: str):
  274. content = {
  275. 'event': 'message',
  276. 'data': {
  277. 'task_id': self._task_id,
  278. 'message_id': str(self._message.id),
  279. 'text': text,
  280. 'mode': self._conversation.mode,
  281. 'conversation_id': str(self._conversation.id)
  282. }
  283. }
  284. redis_client.publish(self._channel, json.dumps(content))
  285. if self._is_stopped():
  286. self.pub_end()
  287. raise ConversationTaskStoppedException()
  288. def pub_chain(self, message_chain: MessageChain):
  289. if self._chain_pub:
  290. content = {
  291. 'event': 'chain',
  292. 'data': {
  293. 'task_id': self._task_id,
  294. 'message_id': self._message.id,
  295. 'chain_id': message_chain.id,
  296. 'type': message_chain.type,
  297. 'input': json.loads(message_chain.input),
  298. 'output': json.loads(message_chain.output),
  299. 'mode': self._conversation.mode,
  300. 'conversation_id': self._conversation.id
  301. }
  302. }
  303. redis_client.publish(self._channel, json.dumps(content))
  304. if self._is_stopped():
  305. self.pub_end()
  306. raise ConversationTaskStoppedException()
  307. def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
  308. if self._agent_thought_pub:
  309. content = {
  310. 'event': 'agent_thought',
  311. 'data': {
  312. 'id': message_agent_thought.id,
  313. 'task_id': self._task_id,
  314. 'message_id': self._message.id,
  315. 'chain_id': message_agent_thought.message_chain_id,
  316. 'position': message_agent_thought.position,
  317. 'thought': message_agent_thought.thought,
  318. 'tool': message_agent_thought.tool,
  319. 'tool_input': message_agent_thought.tool_input,
  320. 'mode': self._conversation.mode,
  321. 'conversation_id': self._conversation.id
  322. }
  323. }
  324. redis_client.publish(self._channel, json.dumps(content))
  325. if self._is_stopped():
  326. self.pub_end()
  327. raise ConversationTaskStoppedException()
  328. def pub_message_end(self, retriever_resource: List):
  329. content = {
  330. 'event': 'message_end',
  331. 'data': {
  332. 'task_id': self._task_id,
  333. 'message_id': self._message.id,
  334. 'mode': self._conversation.mode,
  335. 'conversation_id': self._conversation.id
  336. }
  337. }
  338. if retriever_resource:
  339. content['data']['retriever_resources'] = retriever_resource
  340. redis_client.publish(self._channel, json.dumps(content))
  341. if self._is_stopped():
  342. self.pub_end()
  343. raise ConversationTaskStoppedException()
  344. def pub_end(self):
  345. content = {
  346. 'event': 'end',
  347. }
  348. redis_client.publish(self._channel, json.dumps(content))
  349. @classmethod
  350. def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
  351. content = {
  352. 'error': type(e).__name__,
  353. 'description': e.description if getattr(e, 'description', None) is not None else str(e)
  354. }
  355. channel = cls.generate_channel_name(user, task_id)
  356. redis_client.publish(channel, json.dumps(content))
  357. def _is_stopped(self):
  358. return redis_client.get(self._stopped_cache_key) is not None
  359. @classmethod
  360. def ping(cls, user: Union[Account | EndUser], task_id: str):
  361. content = {
  362. 'event': 'ping'
  363. }
  364. channel = cls.generate_channel_name(user, task_id)
  365. redis_client.publish(channel, json.dumps(content))
  366. @classmethod
  367. def stop(cls, user: Union[Account | EndUser], task_id: str):
  368. stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
  369. redis_client.setex(stopped_cache_key, 600, 1)
  370. class ConversationTaskStoppedException(Exception):
  371. pass