llm_router_chain.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. """Base classes for LLM-powered router chains."""
  2. from __future__ import annotations
  3. import json
  4. from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
  5. from langchain.chains.base import Chain
  6. from pydantic import root_validator
  7. from langchain.chains import LLMChain
  8. from langchain.prompts import BasePromptTemplate
  9. from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
  10. class Route(NamedTuple):
  11. destination: Optional[str]
  12. next_inputs: Dict[str, Any]
  13. class LLMRouterChain(Chain):
  14. """A router chain that uses an LLM chain to perform routing."""
  15. llm_chain: LLMChain
  16. """LLM chain used to perform routing"""
  17. @root_validator()
  18. def validate_prompt(cls, values: dict) -> dict:
  19. prompt = values["llm_chain"].prompt
  20. if prompt.output_parser is None:
  21. raise ValueError(
  22. "LLMRouterChain requires base llm_chain prompt to have an output"
  23. " parser that converts LLM text output to a dictionary with keys"
  24. " 'destination' and 'next_inputs'. Received a prompt with no output"
  25. " parser."
  26. )
  27. return values
  28. @property
  29. def input_keys(self) -> List[str]:
  30. """Will be whatever keys the LLM chain prompt expects.
  31. :meta private:
  32. """
  33. return self.llm_chain.input_keys
  34. def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
  35. super()._validate_outputs(outputs)
  36. if not isinstance(outputs["next_inputs"], dict):
  37. raise ValueError
  38. def _call(
  39. self,
  40. inputs: Dict[str, Any]
  41. ) -> Dict[str, Any]:
  42. output = cast(
  43. Dict[str, Any],
  44. self.llm_chain.predict_and_parse(**inputs),
  45. )
  46. return output
  47. @classmethod
  48. def from_llm(
  49. cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
  50. ) -> LLMRouterChain:
  51. """Convenience constructor."""
  52. llm_chain = LLMChain(llm=llm, prompt=prompt)
  53. return cls(llm_chain=llm_chain, **kwargs)
  54. @property
  55. def output_keys(self) -> List[str]:
  56. return ["destination", "next_inputs"]
  57. def route(self, inputs: Dict[str, Any]) -> Route:
  58. result = self(inputs)
  59. return Route(result["destination"], result["next_inputs"])
  60. class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
  61. """Parser for output of router chain int he multi-prompt chain."""
  62. default_destination: str = "DEFAULT"
  63. next_inputs_type: Type = str
  64. next_inputs_inner_key: str = "input"
  65. def parse_json_markdown(self, json_string: str) -> dict:
  66. # Remove the triple backticks if present
  67. json_string = json_string.strip()
  68. start_index = json_string.find("```json")
  69. end_index = json_string.find("```", start_index + len("```json"))
  70. if start_index != -1 and end_index != -1:
  71. extracted_content = json_string[start_index + len("```json"):end_index].strip()
  72. # Parse the JSON string into a Python dictionary
  73. parsed = json.loads(extracted_content)
  74. elif json_string.startswith("{"):
  75. # Parse the JSON string into a Python dictionary
  76. parsed = json.loads(json_string)
  77. else:
  78. raise Exception("Could not find JSON block in the output.")
  79. return parsed
  80. def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
  81. try:
  82. json_obj = self.parse_json_markdown(text)
  83. except json.JSONDecodeError as e:
  84. raise OutputParserException(f"Got invalid JSON object. Error: {e}")
  85. for key in expected_keys:
  86. if key not in json_obj:
  87. raise OutputParserException(
  88. f"Got invalid return object. Expected key `{key}` "
  89. f"to be present, but got {json_obj}"
  90. )
  91. return json_obj
  92. def parse(self, text: str) -> Dict[str, Any]:
  93. try:
  94. expected_keys = ["destination", "next_inputs"]
  95. parsed = self.parse_and_check_json_markdown(text, expected_keys)
  96. if not isinstance(parsed["destination"], str):
  97. raise ValueError("Expected 'destination' to be a string.")
  98. if not isinstance(parsed["next_inputs"], self.next_inputs_type):
  99. raise ValueError(
  100. f"Expected 'next_inputs' to be {self.next_inputs_type}."
  101. )
  102. parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
  103. if (
  104. parsed["destination"].strip().lower()
  105. == self.default_destination.lower()
  106. ):
  107. parsed["destination"] = None
  108. else:
  109. parsed["destination"] = parsed["destination"].strip()
  110. return parsed
  111. except Exception as e:
  112. raise OutputParserException(
  113. f"Parsing text\n{text}\n raised following error:\n{e}"
  114. )