conversation_message_task.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. import decimal
  2. import json
  3. from typing import Optional, Union
  4. from core.callback_handler.entity.agent_loop import AgentLoop
  5. from core.callback_handler.entity.dataset_query import DatasetQueryObj
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.callback_handler.entity.chain_result import ChainResult
  8. from core.model_providers.model_factory import ModelFactory
  9. from core.model_providers.models.entity.message import to_prompt_messages, MessageType
  10. from core.model_providers.models.llm.base import BaseLLM
  11. from core.prompt.prompt_builder import PromptBuilder
  12. from core.prompt.prompt_template import JinjaPromptTemplate
  13. from events.message_event import message_was_created
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.dataset import DatasetQuery
  17. from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
  18. class ConversationMessageTask:
  19. def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
  20. inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
  21. conversation: Optional[Conversation] = None, is_override: bool = False):
  22. self.task_id = task_id
  23. self.app = app
  24. self.tenant_id = app.tenant_id
  25. self.app_model_config = app_model_config
  26. self.is_override = is_override
  27. self.user = user
  28. self.inputs = inputs
  29. self.query = query
  30. self.streaming = streaming
  31. self.conversation = conversation
  32. self.is_new_conversation = False
  33. self.model_instance = model_instance
  34. self.message = None
  35. self.model_dict = self.app_model_config.model_dict
  36. self.provider_name = self.model_dict.get('provider')
  37. self.model_name = self.model_dict.get('name')
  38. self.mode = app.mode
  39. self.init()
  40. self._pub_handler = PubHandler(
  41. user=self.user,
  42. task_id=self.task_id,
  43. message=self.message,
  44. conversation=self.conversation,
  45. chain_pub=False, # disabled currently
  46. agent_thought_pub=True
  47. )
  48. def init(self):
  49. override_model_configs = None
  50. if self.is_override:
  51. override_model_configs = self.app_model_config.to_dict()
  52. introduction = ''
  53. system_instruction = ''
  54. system_instruction_tokens = 0
  55. if self.mode == 'chat':
  56. introduction = self.app_model_config.opening_statement
  57. if introduction:
  58. prompt_template = JinjaPromptTemplate.from_template(template=introduction)
  59. prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
  60. try:
  61. introduction = prompt_template.format(**prompt_inputs)
  62. except KeyError:
  63. pass
  64. if self.app_model_config.pre_prompt:
  65. system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
  66. system_instruction = system_message.content
  67. model_instance = ModelFactory.get_text_generation_model(
  68. tenant_id=self.tenant_id,
  69. model_provider_name=self.provider_name,
  70. model_name=self.model_name
  71. )
  72. system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
  73. if not self.conversation:
  74. self.is_new_conversation = True
  75. self.conversation = Conversation(
  76. app_id=self.app_model_config.app_id,
  77. app_model_config_id=self.app_model_config.id,
  78. model_provider=self.provider_name,
  79. model_id=self.model_name,
  80. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  81. mode=self.mode,
  82. name='',
  83. inputs=self.inputs,
  84. introduction=introduction,
  85. system_instruction=system_instruction,
  86. system_instruction_tokens=system_instruction_tokens,
  87. status='normal',
  88. from_source=('console' if isinstance(self.user, Account) else 'api'),
  89. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  90. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  91. )
  92. db.session.add(self.conversation)
  93. db.session.flush()
  94. self.message = Message(
  95. app_id=self.app_model_config.app_id,
  96. model_provider=self.provider_name,
  97. model_id=self.model_name,
  98. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  99. conversation_id=self.conversation.id,
  100. inputs=self.inputs,
  101. query=self.query,
  102. message="",
  103. message_tokens=0,
  104. message_unit_price=0,
  105. answer="",
  106. answer_tokens=0,
  107. answer_unit_price=0,
  108. provider_response_latency=0,
  109. total_price=0,
  110. currency=self.model_instance.get_currency(),
  111. from_source=('console' if isinstance(self.user, Account) else 'api'),
  112. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  113. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  114. agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
  115. )
  116. db.session.add(self.message)
  117. db.session.flush()
  118. def append_message_text(self, text: str):
  119. self._pub_handler.pub_text(text)
  120. def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
  121. message_tokens = llm_message.prompt_tokens
  122. answer_tokens = llm_message.completion_tokens
  123. message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN)
  124. answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT)
  125. total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
  126. self.message.message = llm_message.prompt
  127. self.message.message_tokens = message_tokens
  128. self.message.message_unit_price = message_unit_price
  129. self.message.answer = PromptBuilder.process_template(llm_message.completion.strip()) if llm_message.completion else ''
  130. self.message.answer_tokens = answer_tokens
  131. self.message.answer_unit_price = answer_unit_price
  132. self.message.provider_response_latency = llm_message.latency
  133. self.message.total_price = total_price
  134. db.session.commit()
  135. message_was_created.send(
  136. self.message,
  137. conversation=self.conversation,
  138. is_first_message=self.is_new_conversation
  139. )
  140. if not by_stopped:
  141. self.end()
  142. def init_chain(self, chain_result: ChainResult):
  143. message_chain = MessageChain(
  144. message_id=self.message.id,
  145. type=chain_result.type,
  146. input=json.dumps(chain_result.prompt),
  147. output=''
  148. )
  149. db.session.add(message_chain)
  150. db.session.flush()
  151. return message_chain
  152. def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
  153. message_chain.output = json.dumps(chain_result.completion)
  154. self._pub_handler.pub_chain(message_chain)
  155. def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
  156. message_agent_thought = MessageAgentThought(
  157. message_id=self.message.id,
  158. message_chain_id=message_chain.id,
  159. position=agent_loop.position,
  160. thought=agent_loop.thought,
  161. tool=agent_loop.tool_name,
  162. tool_input=agent_loop.tool_input,
  163. message=agent_loop.prompt,
  164. answer=agent_loop.completion,
  165. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  166. created_by=self.user.id
  167. )
  168. db.session.add(message_agent_thought)
  169. db.session.flush()
  170. self._pub_handler.pub_agent_thought(message_agent_thought)
  171. return message_agent_thought
  172. def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM,
  173. agent_loop: AgentLoop):
  174. agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN)
  175. agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT)
  176. loop_message_tokens = agent_loop.prompt_tokens
  177. loop_answer_tokens = agent_loop.completion_tokens
  178. loop_total_price = self.calc_total_price(
  179. loop_message_tokens,
  180. agent_message_unit_price,
  181. loop_answer_tokens,
  182. agent_answer_unit_price
  183. )
  184. message_agent_thought.observation = agent_loop.tool_output
  185. message_agent_thought.tool_process_data = '' # currently not support
  186. message_agent_thought.message_token = loop_message_tokens
  187. message_agent_thought.message_unit_price = agent_message_unit_price
  188. message_agent_thought.answer_token = loop_answer_tokens
  189. message_agent_thought.answer_unit_price = agent_answer_unit_price
  190. message_agent_thought.latency = agent_loop.latency
  191. message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
  192. message_agent_thought.total_price = loop_total_price
  193. message_agent_thought.currency = agent_model_instant.get_currency()
  194. db.session.flush()
  195. def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
  196. dataset_query = DatasetQuery(
  197. dataset_id=dataset_query_obj.dataset_id,
  198. content=dataset_query_obj.query,
  199. source='app',
  200. source_app_id=self.app.id,
  201. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  202. created_by=self.user.id
  203. )
  204. db.session.add(dataset_query)
  205. def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
  206. message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
  207. rounding=decimal.ROUND_HALF_UP)
  208. answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
  209. rounding=decimal.ROUND_HALF_UP)
  210. total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
  211. return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  212. def end(self):
  213. self._pub_handler.pub_end()
  214. class PubHandler:
  215. def __init__(self, user: Union[Account | EndUser], task_id: str,
  216. message: Message, conversation: Conversation,
  217. chain_pub: bool = False, agent_thought_pub: bool = False):
  218. self._channel = PubHandler.generate_channel_name(user, task_id)
  219. self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
  220. self._task_id = task_id
  221. self._message = message
  222. self._conversation = conversation
  223. self._chain_pub = chain_pub
  224. self._agent_thought_pub = agent_thought_pub
  225. @classmethod
  226. def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
  227. if not user:
  228. raise ValueError("user is required")
  229. user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
  230. return "generate_result:{}-{}".format(user_str, task_id)
  231. @classmethod
  232. def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
  233. user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
  234. return "generate_result_stopped:{}-{}".format(user_str, task_id)
  235. def pub_text(self, text: str):
  236. content = {
  237. 'event': 'message',
  238. 'data': {
  239. 'task_id': self._task_id,
  240. 'message_id': str(self._message.id),
  241. 'text': text,
  242. 'mode': self._conversation.mode,
  243. 'conversation_id': str(self._conversation.id)
  244. }
  245. }
  246. redis_client.publish(self._channel, json.dumps(content))
  247. if self._is_stopped():
  248. self.pub_end()
  249. raise ConversationTaskStoppedException()
  250. def pub_chain(self, message_chain: MessageChain):
  251. if self._chain_pub:
  252. content = {
  253. 'event': 'chain',
  254. 'data': {
  255. 'task_id': self._task_id,
  256. 'message_id': self._message.id,
  257. 'chain_id': message_chain.id,
  258. 'type': message_chain.type,
  259. 'input': json.loads(message_chain.input),
  260. 'output': json.loads(message_chain.output),
  261. 'mode': self._conversation.mode,
  262. 'conversation_id': self._conversation.id
  263. }
  264. }
  265. redis_client.publish(self._channel, json.dumps(content))
  266. if self._is_stopped():
  267. self.pub_end()
  268. raise ConversationTaskStoppedException()
  269. def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
  270. if self._agent_thought_pub:
  271. content = {
  272. 'event': 'agent_thought',
  273. 'data': {
  274. 'id': message_agent_thought.id,
  275. 'task_id': self._task_id,
  276. 'message_id': self._message.id,
  277. 'chain_id': message_agent_thought.message_chain_id,
  278. 'position': message_agent_thought.position,
  279. 'thought': message_agent_thought.thought,
  280. 'tool': message_agent_thought.tool,
  281. 'tool_input': message_agent_thought.tool_input,
  282. 'mode': self._conversation.mode,
  283. 'conversation_id': self._conversation.id
  284. }
  285. }
  286. redis_client.publish(self._channel, json.dumps(content))
  287. if self._is_stopped():
  288. self.pub_end()
  289. raise ConversationTaskStoppedException()
  290. def pub_end(self):
  291. content = {
  292. 'event': 'end',
  293. }
  294. redis_client.publish(self._channel, json.dumps(content))
  295. @classmethod
  296. def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
  297. content = {
  298. 'error': type(e).__name__,
  299. 'description': e.description if getattr(e, 'description', None) is not None else str(e)
  300. }
  301. channel = cls.generate_channel_name(user, task_id)
  302. redis_client.publish(channel, json.dumps(content))
  303. def _is_stopped(self):
  304. return redis_client.get(self._stopped_cache_key) is not None
  305. @classmethod
  306. def ping(cls, user: Union[Account | EndUser], task_id: str):
  307. content = {
  308. 'event': 'ping'
  309. }
  310. channel = cls.generate_channel_name(user, task_id)
  311. redis_client.publish(channel, json.dumps(content))
  312. @classmethod
  313. def stop(cls, user: Union[Account | EndUser], task_id: str):
  314. stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
  315. redis_client.setex(stopped_cache_key, 600, 1)
  316. class ConversationTaskStoppedException(Exception):
  317. pass