123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- """Base classes for LLM-powered router chains."""
- from __future__ import annotations
- import json
- from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
- from langchain.chains.base import Chain
- from pydantic import root_validator
- from langchain.chains import LLMChain
- from langchain.prompts import BasePromptTemplate
- from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
- class Route(NamedTuple):
- destination: Optional[str]
- next_inputs: Dict[str, Any]
- class LLMRouterChain(Chain):
- """A router chain that uses an LLM chain to perform routing."""
- llm_chain: LLMChain
- """LLM chain used to perform routing"""
- @root_validator()
- def validate_prompt(cls, values: dict) -> dict:
- prompt = values["llm_chain"].prompt
- if prompt.output_parser is None:
- raise ValueError(
- "LLMRouterChain requires base llm_chain prompt to have an output"
- " parser that converts LLM text output to a dictionary with keys"
- " 'destination' and 'next_inputs'. Received a prompt with no output"
- " parser."
- )
- return values
- @property
- def input_keys(self) -> List[str]:
- """Will be whatever keys the LLM chain prompt expects.
- :meta private:
- """
- return self.llm_chain.input_keys
- def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
- super()._validate_outputs(outputs)
- if not isinstance(outputs["next_inputs"], dict):
- raise ValueError
- def _call(
- self,
- inputs: Dict[str, Any]
- ) -> Dict[str, Any]:
- output = cast(
- Dict[str, Any],
- self.llm_chain.predict_and_parse(**inputs),
- )
- return output
- @classmethod
- def from_llm(
- cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
- ) -> LLMRouterChain:
- """Convenience constructor."""
- llm_chain = LLMChain(llm=llm, prompt=prompt)
- return cls(llm_chain=llm_chain, **kwargs)
- @property
- def output_keys(self) -> List[str]:
- return ["destination", "next_inputs"]
- def route(self, inputs: Dict[str, Any]) -> Route:
- result = self(inputs)
- return Route(result["destination"], result["next_inputs"])
- class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
- """Parser for output of router chain int he multi-prompt chain."""
- default_destination: str = "DEFAULT"
- next_inputs_type: Type = str
- next_inputs_inner_key: str = "input"
- def parse_json_markdown(self, json_string: str) -> dict:
- # Remove the triple backticks if present
- start_index = json_string.find("```json")
- end_index = json_string.find("```", start_index + len("```json"))
- if start_index != -1 and end_index != -1:
- extracted_content = json_string[start_index + len("```json"):end_index].strip()
- # Parse the JSON string into a Python dictionary
- parsed = json.loads(extracted_content)
- else:
- raise Exception("Could not find JSON block in the output.")
- return parsed
- def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
- try:
- json_obj = self.parse_json_markdown(text)
- except json.JSONDecodeError as e:
- raise OutputParserException(f"Got invalid JSON object. Error: {e}")
- for key in expected_keys:
- if key not in json_obj:
- raise OutputParserException(
- f"Got invalid return object. Expected key `{key}` "
- f"to be present, but got {json_obj}"
- )
- return json_obj
- def parse(self, text: str) -> Dict[str, Any]:
- try:
- expected_keys = ["destination", "next_inputs"]
- parsed = self.parse_and_check_json_markdown(text, expected_keys)
- if not isinstance(parsed["destination"], str):
- raise ValueError("Expected 'destination' to be a string.")
- if not isinstance(parsed["next_inputs"], self.next_inputs_type):
- raise ValueError(
- f"Expected 'next_inputs' to be {self.next_inputs_type}."
- )
- parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
- if (
- parsed["destination"].strip().lower()
- == self.default_destination.lower()
- ):
- parsed["destination"] = None
- else:
- parsed["destination"] = parsed["destination"].strip()
- return parsed
- except Exception as e:
- raise OutputParserException(
- f"Parsing text\n{text}\n raised following error:\n{e}"
- )
|