conversation_message_task.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import decimal
  2. import json
  3. from typing import Optional, Union
  4. from core.callback_handler.entity.agent_loop import AgentLoop
  5. from core.callback_handler.entity.dataset_query import DatasetQueryObj
  6. from core.callback_handler.entity.llm_message import LLMMessage
  7. from core.callback_handler.entity.chain_result import ChainResult
  8. from core.constant import llm_constant
  9. from core.llm.llm_builder import LLMBuilder
  10. from core.llm.provider.llm_provider_service import LLMProviderService
  11. from core.prompt.prompt_builder import PromptBuilder
  12. from core.prompt.prompt_template import OutLinePromptTemplate
  13. from events.message_event import message_was_created
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.dataset import DatasetQuery
  17. from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain
  18. from models.provider import ProviderType, Provider
  19. class ConversationMessageTask:
  20. def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
  21. inputs: dict, query: str, streaming: bool,
  22. conversation: Optional[Conversation] = None, is_override: bool = False):
  23. self.task_id = task_id
  24. self.app = app
  25. self.tenant_id = app.tenant_id
  26. self.app_model_config = app_model_config
  27. self.is_override = is_override
  28. self.user = user
  29. self.inputs = inputs
  30. self.query = query
  31. self.streaming = streaming
  32. self.conversation = conversation
  33. self.is_new_conversation = False
  34. self.message = None
  35. self.model_dict = self.app_model_config.model_dict
  36. self.model_name = self.model_dict.get('name')
  37. self.mode = app.mode
  38. self.init()
  39. self._pub_handler = PubHandler(
  40. user=self.user,
  41. task_id=self.task_id,
  42. message=self.message,
  43. conversation=self.conversation,
  44. chain_pub=False, # disabled currently
  45. agent_thought_pub=False # disabled currently
  46. )
  47. def init(self):
  48. override_model_configs = None
  49. if self.is_override:
  50. override_model_configs = {
  51. "model": self.app_model_config.model_dict,
  52. "pre_prompt": self.app_model_config.pre_prompt,
  53. "agent_mode": self.app_model_config.agent_mode_dict,
  54. "opening_statement": self.app_model_config.opening_statement,
  55. "suggested_questions": self.app_model_config.suggested_questions_list,
  56. "suggested_questions_after_answer": self.app_model_config.suggested_questions_after_answer_dict,
  57. "more_like_this": self.app_model_config.more_like_this_dict,
  58. "user_input_form": self.app_model_config.user_input_form_list,
  59. }
  60. introduction = ''
  61. system_instruction = ''
  62. system_instruction_tokens = 0
  63. if self.mode == 'chat':
  64. introduction = self.app_model_config.opening_statement
  65. if introduction:
  66. prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
  67. prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
  68. introduction = prompt_template.format(**prompt_inputs)
  69. if self.app_model_config.pre_prompt:
  70. pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
  71. system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
  72. system_instruction = system_message.content
  73. llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
  74. system_instruction_tokens = llm.get_messages_tokens([system_message])
  75. if not self.conversation:
  76. self.is_new_conversation = True
  77. self.conversation = Conversation(
  78. app_id=self.app_model_config.app_id,
  79. app_model_config_id=self.app_model_config.id,
  80. model_provider=self.model_dict.get('provider'),
  81. model_id=self.model_name,
  82. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  83. mode=self.mode,
  84. name='',
  85. inputs=self.inputs,
  86. introduction=introduction,
  87. system_instruction=system_instruction,
  88. system_instruction_tokens=system_instruction_tokens,
  89. status='normal',
  90. from_source=('console' if isinstance(self.user, Account) else 'api'),
  91. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  92. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  93. )
  94. db.session.add(self.conversation)
  95. db.session.flush()
  96. self.message = Message(
  97. app_id=self.app_model_config.app_id,
  98. model_provider=self.model_dict.get('provider'),
  99. model_id=self.model_name,
  100. override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
  101. conversation_id=self.conversation.id,
  102. inputs=self.inputs,
  103. query=self.query,
  104. message="",
  105. message_tokens=0,
  106. message_unit_price=0,
  107. answer="",
  108. answer_tokens=0,
  109. answer_unit_price=0,
  110. provider_response_latency=0,
  111. total_price=0,
  112. currency=llm_constant.model_currency,
  113. from_source=('console' if isinstance(self.user, Account) else 'api'),
  114. from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
  115. from_account_id=(self.user.id if isinstance(self.user, Account) else None),
  116. agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
  117. )
  118. db.session.add(self.message)
  119. db.session.flush()
  120. def append_message_text(self, text: str):
  121. self._pub_handler.pub_text(text)
  122. def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
  123. model_name = self.app_model_config.model_dict.get('name')
  124. message_tokens = llm_message.prompt_tokens
  125. answer_tokens = llm_message.completion_tokens
  126. message_unit_price = llm_constant.model_prices[model_name]['prompt']
  127. answer_unit_price = llm_constant.model_prices[model_name]['completion']
  128. total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price)
  129. self.message.message = llm_message.prompt
  130. self.message.message_tokens = message_tokens
  131. self.message.message_unit_price = message_unit_price
  132. self.message.answer = llm_message.completion.strip() if llm_message.completion else ''
  133. self.message.answer_tokens = answer_tokens
  134. self.message.answer_unit_price = answer_unit_price
  135. self.message.provider_response_latency = llm_message.latency
  136. self.message.total_price = total_price
  137. self.update_provider_quota()
  138. db.session.commit()
  139. message_was_created.send(
  140. self.message,
  141. conversation=self.conversation,
  142. is_first_message=self.is_new_conversation
  143. )
  144. if not by_stopped:
  145. self._pub_handler.pub_end()
  146. def update_provider_quota(self):
  147. llm_provider_service = LLMProviderService(
  148. tenant_id=self.app.tenant_id,
  149. provider_name=self.message.model_provider,
  150. )
  151. provider = llm_provider_service.get_provider_db_record()
  152. if provider and provider.provider_type == ProviderType.SYSTEM.value:
  153. db.session.query(Provider).filter(
  154. Provider.tenant_id == self.app.tenant_id,
  155. Provider.quota_limit > Provider.quota_used
  156. ).update({'quota_used': Provider.quota_used + 1})
  157. def init_chain(self, chain_result: ChainResult):
  158. message_chain = MessageChain(
  159. message_id=self.message.id,
  160. type=chain_result.type,
  161. input=json.dumps(chain_result.prompt),
  162. output=''
  163. )
  164. db.session.add(message_chain)
  165. db.session.flush()
  166. return message_chain
  167. def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
  168. message_chain.output = json.dumps(chain_result.completion)
  169. self._pub_handler.pub_chain(message_chain)
  170. def on_agent_end(self, message_chain: MessageChain, agent_model_name: str,
  171. agent_loop: AgentLoop):
  172. agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt']
  173. agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion']
  174. loop_message_tokens = agent_loop.prompt_tokens
  175. loop_answer_tokens = agent_loop.completion_tokens
  176. loop_total_price = self.calc_total_price(
  177. loop_message_tokens,
  178. agent_message_unit_price,
  179. loop_answer_tokens,
  180. agent_answer_unit_price
  181. )
  182. message_agent_loop = MessageAgentThought(
  183. message_id=self.message.id,
  184. message_chain_id=message_chain.id,
  185. position=agent_loop.position,
  186. thought=agent_loop.thought,
  187. tool=agent_loop.tool_name,
  188. tool_input=agent_loop.tool_input,
  189. observation=agent_loop.tool_output,
  190. tool_process_data='', # currently not support
  191. message=agent_loop.prompt,
  192. message_token=loop_message_tokens,
  193. message_unit_price=agent_message_unit_price,
  194. answer=agent_loop.completion,
  195. answer_token=loop_answer_tokens,
  196. answer_unit_price=agent_answer_unit_price,
  197. latency=agent_loop.latency,
  198. tokens=agent_loop.prompt_tokens + agent_loop.completion_tokens,
  199. total_price=loop_total_price,
  200. currency=llm_constant.model_currency,
  201. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  202. created_by=self.user.id
  203. )
  204. db.session.add(message_agent_loop)
  205. db.session.flush()
  206. self._pub_handler.pub_agent_thought(message_agent_loop)
  207. def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
  208. dataset_query = DatasetQuery(
  209. dataset_id=dataset_query_obj.dataset_id,
  210. content=dataset_query_obj.query,
  211. source='app',
  212. source_app_id=self.app.id,
  213. created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
  214. created_by=self.user.id
  215. )
  216. db.session.add(dataset_query)
  217. def calc_total_price(self, message_tokens, message_unit_price, answer_tokens, answer_unit_price):
  218. message_tokens_per_1k = (decimal.Decimal(message_tokens) / 1000).quantize(decimal.Decimal('0.001'),
  219. rounding=decimal.ROUND_HALF_UP)
  220. answer_tokens_per_1k = (decimal.Decimal(answer_tokens) / 1000).quantize(decimal.Decimal('0.001'),
  221. rounding=decimal.ROUND_HALF_UP)
  222. total_price = message_tokens_per_1k * message_unit_price + answer_tokens_per_1k * answer_unit_price
  223. return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  224. class PubHandler:
  225. def __init__(self, user: Union[Account | EndUser], task_id: str,
  226. message: Message, conversation: Conversation,
  227. chain_pub: bool = False, agent_thought_pub: bool = False):
  228. self._channel = PubHandler.generate_channel_name(user, task_id)
  229. self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
  230. self._task_id = task_id
  231. self._message = message
  232. self._conversation = conversation
  233. self._chain_pub = chain_pub
  234. self._agent_thought_pub = agent_thought_pub
  235. @classmethod
  236. def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
  237. if not user:
  238. raise ValueError("user is required")
  239. user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
  240. return "generate_result:{}-{}".format(user_str, task_id)
  241. @classmethod
  242. def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
  243. user_str = 'account-' + user.id if isinstance(user, Account) else 'end-user-' + user.id
  244. return "generate_result_stopped:{}-{}".format(user_str, task_id)
  245. def pub_text(self, text: str):
  246. content = {
  247. 'event': 'message',
  248. 'data': {
  249. 'task_id': self._task_id,
  250. 'message_id': self._message.id,
  251. 'text': text,
  252. 'mode': self._conversation.mode,
  253. 'conversation_id': self._conversation.id
  254. }
  255. }
  256. redis_client.publish(self._channel, json.dumps(content))
  257. if self._is_stopped():
  258. self.pub_end()
  259. raise ConversationTaskStoppedException()
  260. def pub_chain(self, message_chain: MessageChain):
  261. if self._chain_pub:
  262. content = {
  263. 'event': 'chain',
  264. 'data': {
  265. 'task_id': self._task_id,
  266. 'message_id': self._message.id,
  267. 'chain_id': message_chain.id,
  268. 'type': message_chain.type,
  269. 'input': json.loads(message_chain.input),
  270. 'output': json.loads(message_chain.output),
  271. 'mode': self._conversation.mode,
  272. 'conversation_id': self._conversation.id
  273. }
  274. }
  275. redis_client.publish(self._channel, json.dumps(content))
  276. if self._is_stopped():
  277. self.pub_end()
  278. raise ConversationTaskStoppedException()
  279. def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
  280. if self._agent_thought_pub:
  281. content = {
  282. 'event': 'agent_thought',
  283. 'data': {
  284. 'task_id': self._task_id,
  285. 'message_id': self._message.id,
  286. 'chain_id': message_agent_thought.message_chain_id,
  287. 'agent_thought_id': message_agent_thought.id,
  288. 'position': message_agent_thought.position,
  289. 'thought': message_agent_thought.thought,
  290. 'tool': message_agent_thought.tool,
  291. 'tool_input': message_agent_thought.tool_input,
  292. 'observation': message_agent_thought.observation,
  293. 'answer': message_agent_thought.answer,
  294. 'mode': self._conversation.mode,
  295. 'conversation_id': self._conversation.id
  296. }
  297. }
  298. redis_client.publish(self._channel, json.dumps(content))
  299. if self._is_stopped():
  300. self.pub_end()
  301. raise ConversationTaskStoppedException()
  302. def pub_end(self):
  303. content = {
  304. 'event': 'end',
  305. }
  306. redis_client.publish(self._channel, json.dumps(content))
  307. @classmethod
  308. def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
  309. content = {
  310. 'error': type(e).__name__,
  311. 'description': e.description if getattr(e, 'description', None) is not None else str(e)
  312. }
  313. channel = cls.generate_channel_name(user, task_id)
  314. redis_client.publish(channel, json.dumps(content))
  315. def _is_stopped(self):
  316. return redis_client.get(self._stopped_cache_key) is not None
  317. @classmethod
  318. def stop(cls, user: Union[Account | EndUser], task_id: str):
  319. stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
  320. redis_client.setex(stopped_cache_key, 600, 1)
  321. class ConversationTaskStoppedException(Exception):
  322. pass