123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- import json
- import time
- from typing import Optional, Union, List
- from core.callback_handler.entity.agent_loop import AgentLoop
- from core.callback_handler.entity.dataset_query import DatasetQueryObj
- from core.callback_handler.entity.llm_message import LLMMessage
- from core.callback_handler.entity.chain_result import ChainResult
- from core.model_providers.model_factory import ModelFactory
- from core.model_providers.models.entity.message import to_prompt_messages, MessageType
- from core.model_providers.models.llm.base import BaseLLM
- from core.prompt.prompt_builder import PromptBuilder
- from core.prompt.prompt_template import JinjaPromptTemplate
- from events.message_event import message_was_created
- from extensions.ext_database import db
- from extensions.ext_redis import redis_client
- from models.dataset import DatasetQuery
- from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
- MessageChain, DatasetRetrieverResource
- class ConversationMessageTask:
- def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
- inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
- conversation: Optional[Conversation] = None, is_override: bool = False):
- self.start_at = time.perf_counter()
- self.task_id = task_id
- self.app = app
- self.tenant_id = app.tenant_id
- self.app_model_config = app_model_config
- self.is_override = is_override
- self.user = user
- self.inputs = inputs
- self.query = query
- self.streaming = streaming
- self.conversation = conversation
- self.is_new_conversation = False
- self.model_instance = model_instance
- self.message = None
- self.retriever_resource = None
- self.model_dict = self.app_model_config.model_dict
- self.provider_name = self.model_dict.get('provider')
- self.model_name = self.model_dict.get('name')
- self.mode = app.mode
- self.init()
- self._pub_handler = PubHandler(
- user=self.user,
- task_id=self.task_id,
- message=self.message,
- conversation=self.conversation,
- chain_pub=False, # disabled currently
- agent_thought_pub=True
- )
- def init(self):
- override_model_configs = None
- if self.is_override:
- override_model_configs = self.app_model_config.to_dict()
- introduction = ''
- system_instruction = ''
- system_instruction_tokens = 0
- if self.mode == 'chat':
- introduction = self.app_model_config.opening_statement
- if introduction:
- prompt_template = JinjaPromptTemplate.from_template(template=introduction)
- prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
- try:
- introduction = prompt_template.format(**prompt_inputs)
- except KeyError:
- pass
- if self.app_model_config.pre_prompt:
- system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
- system_instruction = system_message.content
- model_instance = ModelFactory.get_text_generation_model(
- tenant_id=self.tenant_id,
- model_provider_name=self.provider_name,
- model_name=self.model_name
- )
- system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message]))
- if not self.conversation:
- self.is_new_conversation = True
- self.conversation = Conversation(
- app_id=self.app_model_config.app_id,
- app_model_config_id=self.app_model_config.id,
- model_provider=self.provider_name,
- model_id=self.model_name,
- override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
- mode=self.mode,
- name='',
- inputs=self.inputs,
- introduction=introduction,
- system_instruction=system_instruction,
- system_instruction_tokens=system_instruction_tokens,
- status='normal',
- from_source=('console' if isinstance(self.user, Account) else 'api'),
- from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
- from_account_id=(self.user.id if isinstance(self.user, Account) else None),
- )
- db.session.add(self.conversation)
- db.session.flush()
- self.message = Message(
- app_id=self.app_model_config.app_id,
- model_provider=self.provider_name,
- model_id=self.model_name,
- override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
- conversation_id=self.conversation.id,
- inputs=self.inputs,
- query=self.query,
- message="",
- message_tokens=0,
- message_unit_price=0,
- message_price_unit=0,
- answer="",
- answer_tokens=0,
- answer_unit_price=0,
- answer_price_unit=0,
- provider_response_latency=0,
- total_price=0,
- currency=self.model_instance.get_currency(),
- from_source=('console' if isinstance(self.user, Account) else 'api'),
- from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None),
- from_account_id=(self.user.id if isinstance(self.user, Account) else None),
- agent_based=self.app_model_config.agent_mode_dict.get('enabled'),
- )
- db.session.add(self.message)
- db.session.flush()
- def append_message_text(self, text: str):
- if text is not None:
- self._pub_handler.pub_text(text)
- def save_message(self, llm_message: LLMMessage, by_stopped: bool = False):
- message_tokens = llm_message.prompt_tokens
- answer_tokens = llm_message.completion_tokens
- message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
- message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
- answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
- answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
- message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
- answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
- total_price = message_total_price + answer_total_price
- self.message.message = llm_message.prompt
- self.message.message_tokens = message_tokens
- self.message.message_unit_price = message_unit_price
- self.message.message_price_unit = message_price_unit
- self.message.answer = PromptBuilder.process_template(
- llm_message.completion.strip()) if llm_message.completion else ''
- self.message.answer_tokens = answer_tokens
- self.message.answer_unit_price = answer_unit_price
- self.message.answer_price_unit = answer_price_unit
- self.message.provider_response_latency = time.perf_counter() - self.start_at
- self.message.total_price = total_price
- db.session.commit()
- message_was_created.send(
- self.message,
- conversation=self.conversation,
- is_first_message=self.is_new_conversation
- )
- if not by_stopped:
- self.end()
- def init_chain(self, chain_result: ChainResult):
- message_chain = MessageChain(
- message_id=self.message.id,
- type=chain_result.type,
- input=json.dumps(chain_result.prompt),
- output=''
- )
- db.session.add(message_chain)
- db.session.flush()
- return message_chain
- def on_chain_end(self, message_chain: MessageChain, chain_result: ChainResult):
- message_chain.output = json.dumps(chain_result.completion)
- self._pub_handler.pub_chain(message_chain)
- def on_agent_start(self, message_chain: MessageChain, agent_loop: AgentLoop) -> MessageAgentThought:
- message_agent_thought = MessageAgentThought(
- message_id=self.message.id,
- message_chain_id=message_chain.id,
- position=agent_loop.position,
- thought=agent_loop.thought,
- tool=agent_loop.tool_name,
- tool_input=agent_loop.tool_input,
- message=agent_loop.prompt,
- message_price_unit=0,
- answer=agent_loop.completion,
- answer_price_unit=0,
- created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
- created_by=self.user.id
- )
- db.session.add(message_agent_thought)
- db.session.flush()
- self._pub_handler.pub_agent_thought(message_agent_thought)
- return message_agent_thought
- def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
- agent_loop: AgentLoop):
- agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
- agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
- agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
- agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
- loop_message_tokens = agent_loop.prompt_tokens
- loop_answer_tokens = agent_loop.completion_tokens
- loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
- loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
- loop_total_price = loop_message_total_price + loop_answer_total_price
- message_agent_thought.observation = agent_loop.tool_output
- message_agent_thought.tool_process_data = '' # currently not support
- message_agent_thought.message_token = loop_message_tokens
- message_agent_thought.message_unit_price = agent_message_unit_price
- message_agent_thought.message_price_unit = agent_message_price_unit
- message_agent_thought.answer_token = loop_answer_tokens
- message_agent_thought.answer_unit_price = agent_answer_unit_price
- message_agent_thought.answer_price_unit = agent_answer_price_unit
- message_agent_thought.latency = agent_loop.latency
- message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens
- message_agent_thought.total_price = loop_total_price
- message_agent_thought.currency = agent_model_instance.get_currency()
- db.session.flush()
- def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj):
- dataset_query = DatasetQuery(
- dataset_id=dataset_query_obj.dataset_id,
- content=dataset_query_obj.query,
- source='app',
- source_app_id=self.app.id,
- created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
- created_by=self.user.id
- )
- db.session.add(dataset_query)
- def on_dataset_query_finish(self, resource: List):
- if resource and len(resource) > 0:
- for item in resource:
- dataset_retriever_resource = DatasetRetrieverResource(
- message_id=self.message.id,
- position=item.get('position'),
- dataset_id=item.get('dataset_id'),
- dataset_name=item.get('dataset_name'),
- document_id=item.get('document_id'),
- document_name=item.get('document_name'),
- data_source_type=item.get('data_source_type'),
- segment_id=item.get('segment_id'),
- score=item.get('score') if 'score' in item else None,
- hit_count=item.get('hit_count') if 'hit_count' else None,
- word_count=item.get('word_count') if 'word_count' in item else None,
- segment_position=item.get('segment_position') if 'segment_position' in item else None,
- index_node_hash=item.get('index_node_hash') if 'index_node_hash' in item else None,
- content=item.get('content'),
- retriever_from=item.get('retriever_from'),
- created_by=self.user.id
- )
- db.session.add(dataset_retriever_resource)
- db.session.flush()
- self.retriever_resource = resource
- def message_end(self):
- self._pub_handler.pub_message_end(self.retriever_resource)
- def end(self):
- self._pub_handler.pub_message_end(self.retriever_resource)
- self._pub_handler.pub_end()
- class PubHandler:
- def __init__(self, user: Union[Account | EndUser], task_id: str,
- message: Message, conversation: Conversation,
- chain_pub: bool = False, agent_thought_pub: bool = False):
- self._channel = PubHandler.generate_channel_name(user, task_id)
- self._stopped_cache_key = PubHandler.generate_stopped_cache_key(user, task_id)
- self._task_id = task_id
- self._message = message
- self._conversation = conversation
- self._chain_pub = chain_pub
- self._agent_thought_pub = agent_thought_pub
- @classmethod
- def generate_channel_name(cls, user: Union[Account | EndUser], task_id: str):
- if not user:
- raise ValueError("user is required")
- user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
- return "generate_result:{}-{}".format(user_str, task_id)
- @classmethod
- def generate_stopped_cache_key(cls, user: Union[Account | EndUser], task_id: str):
- user_str = 'account-' + str(user.id) if isinstance(user, Account) else 'end-user-' + str(user.id)
- return "generate_result_stopped:{}-{}".format(user_str, task_id)
- def pub_text(self, text: str):
- content = {
- 'event': 'message',
- 'data': {
- 'task_id': self._task_id,
- 'message_id': str(self._message.id),
- 'text': text,
- 'mode': self._conversation.mode,
- 'conversation_id': str(self._conversation.id)
- }
- }
- redis_client.publish(self._channel, json.dumps(content))
- if self._is_stopped():
- self.pub_end()
- raise ConversationTaskStoppedException()
- def pub_chain(self, message_chain: MessageChain):
- if self._chain_pub:
- content = {
- 'event': 'chain',
- 'data': {
- 'task_id': self._task_id,
- 'message_id': self._message.id,
- 'chain_id': message_chain.id,
- 'type': message_chain.type,
- 'input': json.loads(message_chain.input),
- 'output': json.loads(message_chain.output),
- 'mode': self._conversation.mode,
- 'conversation_id': self._conversation.id
- }
- }
- redis_client.publish(self._channel, json.dumps(content))
- if self._is_stopped():
- self.pub_end()
- raise ConversationTaskStoppedException()
- def pub_agent_thought(self, message_agent_thought: MessageAgentThought):
- if self._agent_thought_pub:
- content = {
- 'event': 'agent_thought',
- 'data': {
- 'id': message_agent_thought.id,
- 'task_id': self._task_id,
- 'message_id': self._message.id,
- 'chain_id': message_agent_thought.message_chain_id,
- 'position': message_agent_thought.position,
- 'thought': message_agent_thought.thought,
- 'tool': message_agent_thought.tool,
- 'tool_input': message_agent_thought.tool_input,
- 'mode': self._conversation.mode,
- 'conversation_id': self._conversation.id
- }
- }
- redis_client.publish(self._channel, json.dumps(content))
- if self._is_stopped():
- self.pub_end()
- raise ConversationTaskStoppedException()
- def pub_message_end(self, retriever_resource: List):
- content = {
- 'event': 'message_end',
- 'data': {
- 'task_id': self._task_id,
- 'message_id': self._message.id,
- 'mode': self._conversation.mode,
- 'conversation_id': self._conversation.id
- }
- }
- if retriever_resource:
- content['data']['retriever_resources'] = retriever_resource
- redis_client.publish(self._channel, json.dumps(content))
- if self._is_stopped():
- self.pub_end()
- raise ConversationTaskStoppedException()
- def pub_end(self):
- content = {
- 'event': 'end',
- }
- redis_client.publish(self._channel, json.dumps(content))
- @classmethod
- def pub_error(cls, user: Union[Account | EndUser], task_id: str, e):
- content = {
- 'error': type(e).__name__,
- 'description': e.description if getattr(e, 'description', None) is not None else str(e)
- }
- channel = cls.generate_channel_name(user, task_id)
- redis_client.publish(channel, json.dumps(content))
- def _is_stopped(self):
- return redis_client.get(self._stopped_cache_key) is not None
- @classmethod
- def ping(cls, user: Union[Account | EndUser], task_id: str):
- content = {
- 'event': 'ping'
- }
- channel = cls.generate_channel_name(user, task_id)
- redis_client.publish(channel, json.dumps(content))
- @classmethod
- def stop(cls, user: Union[Account | EndUser], task_id: str):
- stopped_cache_key = cls.generate_stopped_cache_key(user, task_id)
- redis_client.setex(stopped_cache_key, 600, 1)
- class ConversationTaskStoppedException(Exception):
- pass
|