from typing import Dict

from langchain.tools import BaseTool
from llama_index.indices.base import BaseGPTIndex
from llama_index.langchain_helpers.agents import IndexToolConfig
from pydantic import Field

from core.callback_handler.index_tool_callback_handler import IndexToolCallbackHandler


class EnhanceLlamaIndexTool(BaseTool):
    """Tool for querying a LlamaIndex."""

    # NOTE: name/description still needs to be set
    index: BaseGPTIndex
    query_kwargs: Dict = Field(default_factory=dict)
    return_sources: bool = False
    callback_handler: IndexToolCallbackHandler

    @classmethod
    def from_tool_config(cls, tool_config: IndexToolConfig,
                         callback_handler: IndexToolCallbackHandler) -> "EnhanceLlamaIndexTool":
        """Create a tool from a tool config."""
        return_sources = tool_config.tool_kwargs.pop("return_sources", False)
        return cls(
            index=tool_config.index,
            callback_handler=callback_handler,
            name=tool_config.name,
            description=tool_config.description,
            return_sources=return_sources,
            query_kwargs=tool_config.index_query_kwargs,
            **tool_config.tool_kwargs,
        )

    def _run(self, tool_input: str) -> str:
        response = self.index.query(tool_input, **self.query_kwargs)
        self.callback_handler.on_tool_end(response)
        return str(response)

    async def _arun(self, tool_input: str) -> str:
        response = await self.index.aquery(tool_input, **self.query_kwargs)
        self.callback_handler.on_tool_end(response)
        return str(response)