tool_chain.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from typing import List, Dict
  2. from langchain.chains.base import Chain
  3. from langchain.tools import BaseTool
  4. class ToolChain(Chain):
  5. input_key: str = "input" #: :meta private:
  6. output_key: str = "output" #: :meta private:
  7. tool: BaseTool
  8. @property
  9. def _chain_type(self) -> str:
  10. return "tool_chain"
  11. @property
  12. def input_keys(self) -> List[str]:
  13. """Expect input key.
  14. :meta private:
  15. """
  16. return [self.input_key]
  17. @property
  18. def output_keys(self) -> List[str]:
  19. """Return output key.
  20. :meta private:
  21. """
  22. return [self.output_key]
  23. def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
  24. input = inputs[self.input_key]
  25. output = self.tool.run(input, self.verbose)
  26. return {self.output_key: output}
  27. async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
  28. """Run the logic of this chain and return the output."""
  29. input = inputs[self.input_key]
  30. output = await self.tool.arun(input, self.verbose)
  31. return {self.output_key: output}