123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- import math
- import re
- from typing import Mapping, List, Dict, Any, Optional
- from langchain import PromptTemplate
- from langchain.callbacks.manager import CallbackManagerForChainRun
- from langchain.chains.base import Chain
- from pydantic import Extra
- from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
- from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
- from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
- from core.conversation_message_task import ConversationMessageTask
- from core.llm.llm_builder import LLMBuilder
- from core.tool.dataset_index_tool import DatasetTool
- from models.dataset import Dataset, DatasetProcessRule
- DEFAULT_K = 2
- CONTEXT_TOKENS_PERCENT = 0.3
- MULTI_PROMPT_ROUTER_TEMPLATE = """
- Given a raw text input to a language model select the model prompt best suited for \
- the input. You will be given the names of the available prompts and a description of \
- what the prompt is best suited for. You may also revise the original input if you \
- think that revising it will ultimately lead to a better response from the language \
- model.
- << FORMATTING >>
- Return a markdown code snippet with a JSON object formatted to look like, \
- no any other string out of markdown code snippet:
- ```json
- {{{{
- "destination": string \\ name of the prompt to use or "DEFAULT"
- "next_inputs": string \\ a potentially modified version of the original input
- }}}}
- ```
- REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
- it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
- REMEMBER: "next_inputs" can just be the original input if you don't think any \
- modifications are needed.
- << CANDIDATE PROMPTS >>
- {destinations}
- << INPUT >>
- {{input}}
- << OUTPUT >>
- """
- class MultiDatasetRouterChain(Chain):
- """Use a single chain to route an input to one of multiple candidate chains."""
- router_chain: LLMRouterChain
- """Chain for deciding a destination chain and the input to it."""
- dataset_tools: Mapping[str, DatasetTool]
- """Map of name to candidate chains that inputs can be routed to."""
- class Config:
- """Configuration for this pydantic object."""
- extra = Extra.forbid
- arbitrary_types_allowed = True
- @property
- def input_keys(self) -> List[str]:
- """Will be whatever keys the router chain prompt expects.
- :meta private:
- """
- return self.router_chain.input_keys
- @property
- def output_keys(self) -> List[str]:
- return ["text"]
- @classmethod
- def from_datasets(
- cls,
- tenant_id: str,
- datasets: List[Dataset],
- conversation_message_task: ConversationMessageTask,
- rest_tokens: int,
- **kwargs: Any,
- ):
- """Convenience constructor for instantiating from destination prompts."""
- llm = LLMBuilder.to_llm(
- tenant_id=tenant_id,
- model_name='gpt-3.5-turbo',
- temperature=0,
- max_tokens=1024,
- callbacks=[DifyStdOutCallbackHandler()]
- )
- destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
- else ('useful for when you want to answer queries about the ' + d.name))
- for d in datasets]
- destinations_str = "\n".join(destinations)
- router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
- destinations=destinations_str
- )
- router_prompt = PromptTemplate(
- template=router_template,
- input_variables=["input"],
- output_parser=RouterOutputParser(),
- )
- router_chain = LLMRouterChain.from_llm(llm, router_prompt)
- dataset_tools = {}
- for dataset in datasets:
- # fulfill description when it is empty
- if dataset.available_document_count == 0 or dataset.available_document_count == 0:
- continue
- description = dataset.description
- if not description:
- description = 'useful for when you want to answer queries about the ' + dataset.name
- k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
- if k == 0:
- continue
- dataset_tool = DatasetTool(
- name=f"dataset-{dataset.id}",
- description=description,
- k=k,
- dataset=dataset,
- callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
- )
- dataset_tools[str(dataset.id)] = dataset_tool
- return cls(
- router_chain=router_chain,
- dataset_tools=dataset_tools,
- **kwargs,
- )
- @classmethod
- def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
- processing_rule = dataset.latest_process_rule
- if not processing_rule:
- return DEFAULT_K
- if processing_rule.mode == "custom":
- rules = processing_rule.rules_dict
- if not rules:
- return DEFAULT_K
- segmentation = rules["segmentation"]
- segment_max_tokens = segmentation["max_tokens"]
- else:
- segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
- # when rest_tokens is less than default context tokens
- if rest_tokens < segment_max_tokens * DEFAULT_K:
- return rest_tokens // segment_max_tokens
- context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
- # when context_limit_tokens is less than default context tokens, use default_k
- if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
- return DEFAULT_K
- # Expand the k value when there's still some room left in the 30% rest tokens space
- return context_limit_tokens // segment_max_tokens
- def _call(
- self,
- inputs: Dict[str, Any],
- run_manager: Optional[CallbackManagerForChainRun] = None,
- ) -> Dict[str, Any]:
- if len(self.dataset_tools) == 0:
- return {"text": ''}
- elif len(self.dataset_tools) == 1:
- return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
- route = self.router_chain.route(inputs)
- destination = ''
- if route.destination:
- pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
- match = re.search(pattern, route.destination, re.IGNORECASE)
- if match:
- destination = match.group()
- if not destination:
- return {"text": ''}
- elif destination in self.dataset_tools:
- return {"text": self.dataset_tools[destination].run(
- route.next_inputs['input']
- )}
- else:
- raise ValueError(
- f"Received invalid destination chain name '{destination}'"
- )
|