import logging
from typing import cast
from core.app_runner.app_runner import AppRunner
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import AgentEntity, ApplicationGenerateEntity, ModelConfigEntity
from core.features.assistant_cot_runner import AssistantCotApplicationRunner
from core.features.assistant_fc_runner import AssistantFunctionCallApplicationRunner
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
class AssistantApplicationRunner(AppRunner):
"""
Assistant Application Runner
"""
def run(self, application_generate_entity: ApplicationGenerateEntity,
queue_manager: ApplicationQueueManager,
conversation: Conversation,
message: Message) -> None:
"""
Run assistant application
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param conversation: conversation
:param message: message
:return:
"""
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
if not app_record:
raise ValueError("App not found")
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
inputs = application_generate_entity.inputs
query = application_generate_entity.query
files = application_generate_entity.files
# Pre-calculate the number of tokens of the prompt messages,
# and return the rest number of tokens by model context token size limit and max token size limit.
# If the rest number of tokens is not enough, raise exception.
# Include: prompt template, inputs, query(optional), files(optional)
# Not Include: memory, external data, dataset context
self.get_pre_calculate_rest_tokens(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query
)
memory = None
if application_generate_entity.conversation_id:
# get memory of conversation (read-only)
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
memory = TokenBufferMemory(
conversation=conversation,
model_instance=model_instance
)
# organize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory
)
# moderation
try:
# process sensitive_word_avoidance
_, inputs, query = self.moderation_for_inputs(
app_id=app_record.id,
tenant_id=application_generate_entity.tenant_id,
app_orchestration_config_entity=app_orchestration_config,
inputs=inputs,
query=query,
)
except ModerationException as e:
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=str(e),
stream=application_generate_entity.stream
)
return
if query:
# annotation reply
annotation_reply = self.query_app_annotations_to_reply(
app_record=app_record,
message=message,
query=query,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from
)
if annotation_reply:
queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
)
self.direct_output(
queue_manager=queue_manager,
app_orchestration_config=app_orchestration_config,
prompt_messages=prompt_messages,
text=annotation_reply.content,
stream=application_generate_entity.stream
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_orchestration_config.external_data_variables
if external_data_tools:
inputs = self.fill_in_inputs_from_external_data_tools(
tenant_id=app_record.tenant_id,
app_id=app_record.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# reorganize all inputs and template to prompt messages
# Include: prompt template, inputs, query(optional), files(optional)
# memory(optional), external data, dataset context(optional)
prompt_messages, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory
)
# check hosting moderation
hosting_moderation_result = self.check_hosting_moderation(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
prompt_messages=prompt_messages
)
if hosting_moderation_result:
return
agent_entity = app_orchestration_config.agent
# load tool variables
tool_conversation_variables = self._load_tool_variables(conversation_id=conversation.id,
user_id=application_generate_entity.user_id,
tenant_id=application_generate_entity.tenant_id)
# convert db variables to tool variables
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
# init model instance
model_instance = ModelInstance(
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
model=app_orchestration_config.model_config.model
)
prompt_message, _ = self.organize_prompt_messages(
app_record=app_record,
model_config=app_orchestration_config.model_config,
prompt_template_entity=app_orchestration_config.prompt_template,
inputs=inputs,
files=files,
query=query,
memory=memory,
)
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
db.session.refresh(conversation)
db.session.refresh(message)
db.session.close()
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
assistant_cot_runner = AssistantCotApplicationRunner(
tenant_id=application_generate_entity.tenant_id,
application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config,
model_config=app_orchestration_config.model_config,
config=agent_entity,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_cot_runner.run(
conversation=conversation,
message=message,
query=query,
inputs=inputs,
)
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
assistant_fc_runner = AssistantFunctionCallApplicationRunner(
tenant_id=application_generate_entity.tenant_id,
application_generate_entity=application_generate_entity,
app_orchestration_config=app_orchestration_config,
model_config=app_orchestration_config.model_config,
config=agent_entity,
queue_manager=queue_manager,
message=message,
user_id=application_generate_entity.user_id,
memory=memory,
prompt_messages=prompt_message,
variables_pool=tool_variables,
db_variables=tool_conversation_variables,
model_instance=model_instance
)
invoke_result = assistant_fc_runner.run(
conversation=conversation,
message=message,
query=query,
)
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True
)
def _load_tool_variables(self, conversation_id: str, user_id: str, tenant_id: str) -> ToolConversationVariables:
"""
load tool variables from database
"""
tool_variables: ToolConversationVariables = db.session.query(ToolConversationVariables).filter(
ToolConversationVariables.conversation_id == conversation_id,
ToolConversationVariables.tenant_id == tenant_id
).first()
if tool_variables:
# save tool variables to session, so that we can update it later
db.session.add(tool_variables)
else:
# create new tool variables
tool_variables = ToolConversationVariables(
conversation_id=conversation_id,
user_id=user_id,
tenant_id=tenant_id,
variables_str='[]',
)
db.session.add(tool_variables)
db.session.commit()
return tool_variables
def _convert_db_variables_to_tool_variables(self, db_variables: ToolConversationVariables) -> ToolRuntimeVariablePool:
"""
convert db variables to tool variables
"""
return ToolRuntimeVariablePool(**{
'conversation_id': db_variables.conversation_id,
'user_id': db_variables.user_id,
'tenant_id': db_variables.tenant_id,
'pool': db_variables.variables
})
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
message: Message) -> LLMUsage:
"""
Get usage of all agent thoughts
:param model_config: model config
:param message: message
:return:
"""
agent_thoughts = (db.session.query(MessageAgentThought)
.filter(MessageAgentThought.message_id == message.id).all())
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
return model_type_instance._calc_response_usage(
model_config.model,
model_config.credentials,
all_message_tokens,
all_answer_tokens
)