|
@@ -1,4 +1,5 @@
|
|
import logging
|
|
import logging
|
|
|
|
+import re
|
|
from typing import Optional, List, Union, Tuple
|
|
from typing import Optional, List, Union, Tuple
|
|
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.base_language import BaseLanguageModel
|
|
@@ -8,20 +9,21 @@ from langchain.llms import BaseLLM
|
|
from langchain.schema import BaseMessage, HumanMessage
|
|
from langchain.schema import BaseMessage, HumanMessage
|
|
from requests.exceptions import ChunkedEncodingError
|
|
from requests.exceptions import ChunkedEncodingError
|
|
|
|
|
|
|
|
+from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
|
|
|
|
+from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
|
|
from core.constant import llm_constant
|
|
from core.constant import llm_constant
|
|
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
|
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
|
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \
|
|
DifyStdOutCallbackHandler
|
|
DifyStdOutCallbackHandler
|
|
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
|
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
|
|
from core.llm.error import LLMBadRequestError
|
|
from core.llm.error import LLMBadRequestError
|
|
|
|
+from core.llm.fake import FakeLLM
|
|
from core.llm.llm_builder import LLMBuilder
|
|
from core.llm.llm_builder import LLMBuilder
|
|
-from core.chain.main_chain_builder import MainChainBuilder
|
|
|
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
|
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
|
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
|
ReadOnlyConversationTokenDBBufferSharedMemory
|
|
ReadOnlyConversationTokenDBBufferSharedMemory
|
|
-from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
|
|
+from core.orchestrator_rule_parser import OrchestratorRuleParser
|
|
- ReadOnlyConversationTokenDBStringBufferSharedMemory
|
|
|
|
from core.prompt.prompt_builder import PromptBuilder
|
|
from core.prompt.prompt_builder import PromptBuilder
|
|
from core.prompt.prompt_template import JinjaPromptTemplate
|
|
from core.prompt.prompt_template import JinjaPromptTemplate
|
|
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
|
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
|
|
@@ -69,18 +71,33 @@ class Completion:
|
|
streaming=streaming
|
|
streaming=streaming
|
|
)
|
|
)
|
|
|
|
|
|
-
|
|
+ chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
|
|
- main_chain = MainChainBuilder.to_langchain_components(
|
|
+
|
|
|
|
+
|
|
|
|
+ orchestrator_rule_parser = OrchestratorRuleParser(
|
|
tenant_id=app.tenant_id,
|
|
tenant_id=app.tenant_id,
|
|
- agent_mode=app_model_config.agent_mode_dict,
|
|
+ app_model_config=app_model_config
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback])
|
|
|
|
+ if sensitive_word_avoidance_chain:
|
|
|
|
+ query = sensitive_word_avoidance_chain.run(query)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ agent_executor = orchestrator_rule_parser.to_agent_executor(
|
|
|
|
+ conversation_message_task=conversation_message_task,
|
|
|
|
+ memory=memory,
|
|
rest_tokens=rest_tokens_for_context_and_memory,
|
|
rest_tokens=rest_tokens_for_context_and_memory,
|
|
- memory=ReadOnlyConversationTokenDBStringBufferSharedMemory(memory=memory) if memory else None,
|
|
+ chain_callback=chain_callback
|
|
- conversation_message_task=conversation_message_task
|
|
|
|
)
|
|
)
|
|
|
|
|
|
- chain_output = ''
|
|
+
|
|
- if main_chain:
|
|
+ agent_execute_result = None
|
|
- chain_output = main_chain.run(query)
|
|
+ if agent_executor:
|
|
|
|
+ should_use_agent = agent_executor.should_use_agent(query)
|
|
|
|
+ if should_use_agent:
|
|
|
|
+ agent_execute_result = agent_executor.run(query)
|
|
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -90,7 +107,7 @@ class Completion:
|
|
app_model_config=app_model_config,
|
|
app_model_config=app_model_config,
|
|
query=query,
|
|
query=query,
|
|
inputs=inputs,
|
|
inputs=inputs,
|
|
- chain_output=chain_output,
|
|
+ agent_execute_result=agent_execute_result,
|
|
conversation_message_task=conversation_message_task,
|
|
conversation_message_task=conversation_message_task,
|
|
memory=memory,
|
|
memory=memory,
|
|
streaming=streaming
|
|
streaming=streaming
|
|
@@ -105,9 +122,20 @@ class Completion:
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
|
def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict,
|
|
- chain_output: str,
|
|
+ agent_execute_result: Optional[AgentExecuteResult],
|
|
conversation_message_task: ConversationMessageTask,
|
|
conversation_message_task: ConversationMessageTask,
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool):
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \
|
|
|
|
+ and agent_execute_result.strategy != PlanningStrategy.ROUTER:
|
|
|
|
+ final_llm = FakeLLM(response=agent_execute_result.output,
|
|
|
|
+ origin_llm=agent_execute_result.configuration.llm,
|
|
|
|
+ streaming=streaming)
|
|
|
|
+ final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task)
|
|
|
|
+ response = final_llm.generate([[HumanMessage(content=query)]])
|
|
|
|
+ return response
|
|
|
|
+
|
|
final_llm = LLMBuilder.to_llm_from_model(
|
|
final_llm = LLMBuilder.to_llm_from_model(
|
|
tenant_id=tenant_id,
|
|
tenant_id=tenant_id,
|
|
model=app_model_config.model_dict,
|
|
model=app_model_config.model_dict,
|
|
@@ -122,7 +150,7 @@ class Completion:
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
query=query,
|
|
query=query,
|
|
inputs=inputs,
|
|
inputs=inputs,
|
|
- chain_output=chain_output,
|
|
+ agent_execute_result=agent_execute_result,
|
|
memory=memory
|
|
memory=memory
|
|
)
|
|
)
|
|
|
|
|
|
@@ -142,16 +170,9 @@ class Completion:
|
|
@classmethod
|
|
@classmethod
|
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
|
pre_prompt: str, query: str, inputs: dict,
|
|
pre_prompt: str, query: str, inputs: dict,
|
|
- chain_output: Optional[str],
|
|
+ agent_execute_result: Optional[AgentExecuteResult],
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
|
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
|
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
if mode == 'completion':
|
|
if mode == 'completion':
|
|
prompt_template = JinjaPromptTemplate.from_template(
|
|
prompt_template = JinjaPromptTemplate.from_template(
|
|
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
@@ -165,18 +186,13 @@ When answer to user:
|
|
- If you don't know when you are not sure, ask for clarification.
|
|
- If you don't know when you are not sure, ask for clarification.
|
|
Avoid mentioning that you obtained the information from the context.
|
|
Avoid mentioning that you obtained the information from the context.
|
|
And answer according to the language of the user's question.
|
|
And answer according to the language of the user's question.
|
|
-""" if chain_output else "")
|
|
+""" if agent_execute_result else "")
|
|
+ (pre_prompt + "\n" if pre_prompt else "")
|
|
+ (pre_prompt + "\n" if pre_prompt else "")
|
|
+ "{{query}}\n"
|
|
+ "{{query}}\n"
|
|
)
|
|
)
|
|
|
|
|
|
- if chain_output:
|
|
+ if agent_execute_result:
|
|
- inputs['context'] = chain_output
|
|
+ inputs['context'] = agent_execute_result.output
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
|
|
|
|
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
|
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
|
|
prompt_content = prompt_template.format(
|
|
prompt_content = prompt_template.format(
|
|
@@ -206,8 +222,8 @@ And answer according to the language of the user's question.
|
|
if pre_prompt_inputs:
|
|
if pre_prompt_inputs:
|
|
human_inputs.update(pre_prompt_inputs)
|
|
human_inputs.update(pre_prompt_inputs)
|
|
|
|
|
|
- if chain_output:
|
|
+ if agent_execute_result:
|
|
- human_inputs['context'] = chain_output
|
|
+ human_inputs['context'] = agent_execute_result.output
|
|
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
|
|
|
|
|
<context>
|
|
<context>
|
|
@@ -240,18 +256,10 @@ And answer according to the language of the user's question.
|
|
- max_tokens - curr_message_tokens
|
|
- max_tokens - curr_message_tokens
|
|
rest_tokens = max(rest_tokens, 0)
|
|
rest_tokens = max(rest_tokens, 0)
|
|
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
|
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-
|
|
|
|
human_message_prompt += "\n\n" if human_message_prompt else ""
|
|
human_message_prompt += "\n\n" if human_message_prompt else ""
|
|
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
|
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
|
- "inside <histories></histories> XML tags.\n\n<histories>"
|
|
+ "inside <histories></histories> XML tags.\n\n<histories>\n"
|
|
- human_message_prompt += histories + "</histories>"
|
|
+ human_message_prompt += histories + "\n</histories>"
|
|
|
|
|
|
human_message_prompt += query_prompt
|
|
human_message_prompt += query_prompt
|
|
|
|
|
|
@@ -263,10 +271,13 @@ And answer according to the language of the user's question.
|
|
|
|
|
|
messages.append(human_message)
|
|
messages.append(human_message)
|
|
|
|
|
|
- return messages, ['\nHuman:']
|
|
+ for message in messages:
|
|
|
|
+ message.content = re.sub(r'<\|.*?\|>', '', message.content)
|
|
|
|
+
|
|
|
|
+ return messages, ['\nHuman:', '</histories>']
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def get_llm_callbacks(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
|
+ def get_llm_callbacks(cls, llm: BaseLanguageModel,
|
|
streaming: bool,
|
|
streaming: bool,
|
|
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
|
conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]:
|
|
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
|
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
|
@@ -277,8 +288,7 @@ And answer according to the language of the user's question.
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
|
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
|
- max_token_limit: int) -> \
|
|
+ max_token_limit: int) -> str:
|
|
- str:
|
|
|
|
"""Get memory messages."""
|
|
"""Get memory messages."""
|
|
memory.max_token_limit = max_token_limit
|
|
memory.max_token_limit = max_token_limit
|
|
memory_key = memory.memory_variables[0]
|
|
memory_key = memory.memory_variables[0]
|
|
@@ -329,7 +339,7 @@ And answer according to the language of the user's question.
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
pre_prompt=app_model_config.pre_prompt,
|
|
query=query,
|
|
query=query,
|
|
inputs=inputs,
|
|
inputs=inputs,
|
|
- chain_output=None,
|
|
+ agent_execute_result=None,
|
|
memory=None
|
|
memory=None
|
|
)
|
|
)
|
|
|
|
|
|
@@ -379,6 +389,7 @@ And answer according to the language of the user's question.
|
|
query=message.query,
|
|
query=message.query,
|
|
inputs=message.inputs,
|
|
inputs=message.inputs,
|
|
chain_output=None,
|
|
chain_output=None,
|
|
|
|
+ agent_execute_result=None,
|
|
memory=None
|
|
memory=None
|
|
)
|
|
)
|
|
|
|
|