12345678910111213141516171819202122232425262728293031323334353637383940414243 |
- from typing import Dict
- from langchain.tools import BaseTool
- from llama_index.indices.base import BaseGPTIndex
- from llama_index.langchain_helpers.agents import IndexToolConfig
- from pydantic import Field
- from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler
- class EnhanceLlamaIndexTool(BaseTool):
- """Tool for querying a LlamaIndex."""
- # NOTE: name/description still needs to be set
- index: BaseGPTIndex
- query_kwargs: Dict = Field(default_factory=dict)
- return_sources: bool = False
- callback_handler: IndexToolCallbackHandler
- @classmethod
- def from_tool_config(cls, tool_config: IndexToolConfig,
- callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
- """Create a tool from a tool config."""
- return_sources = tool_config.tool_kwargs.pop("return_sources", False)
- return cls(
- index=tool_config.index,
- callback_handler=callback_handler,
- name=tool_config.name,
- description=tool_config.description,
- return_sources=return_sources,
- query_kwargs=tool_config.index_query_kwargs,
- **tool_config.tool_kwargs,
- )
- def _run(self, tool_input: str) -> str:
- response = self.index.query(tool_input, **self.query_kwargs)
- self.callback_handler.on_tool_end(response)
- return str(response)
- async def _arun(self, tool_input: str) -> str:
- response = await self.index.aquery(tool_input, **self.query_kwargs)
- self.callback_handler.on_tool_end(response)
- return str(response)
|