llama_index_tool.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from typing import Dict
  2. from langchain.tools import BaseTool
  3. from llama_index.indices.base import BaseGPTIndex
  4. from llama_index.langchain_helpers.agents import IndexToolConfig
  5. from pydantic import Field
  6. from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
  7. class EnhanceLlamaIndexTool(BaseTool):
  8. """Tool for querying a LlamaIndex."""
  9. # NOTE: name/description still needs to be set
  10. index: BaseGPTIndex
  11. query_kwargs: Dict = Field(default_factory=dict)
  12. return_sources: bool = False
  13. callback_handler: IndexToolCallbackHandler
  14. @classmethod
  15. def from_tool_config(cls, tool_config: IndexToolConfig,
  16. callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
  17. """Create a tool from a tool config."""
  18. return_sources = tool_config.tool_kwargs.pop("return_sources", False)
  19. return cls(
  20. index=tool_config.index,
  21. callback_handler=callback_handler,
  22. name=tool_config.name,
  23. description=tool_config.description,
  24. return_sources=return_sources,
  25. query_kwargs=tool_config.index_query_kwargs,
  26. **tool_config.tool_kwargs,
  27. )
  28. def _run(self, tool_input: str) -> str:
  29. response = self.index.query(tool_input, **self.query_kwargs)
  30. self.callback_handler.on_tool_end(response)
  31. return str(response)
  32. async def _arun(self, tool_input: str) -> str:
  33. response = await self.index.aquery(tool_input, **self.query_kwargs)
  34. self.callback_handler.on_tool_end(response)
  35. return str(response)