from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union

from typing_extensions import deprecated

from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariableKey

VariableValue = Union[str, int, float, dict, list, FileVar]


SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"


class VariablePool:
    def __init__(
        self,
        system_variables: Mapping[SystemVariableKey, Any],
        user_inputs: Mapping[str, Any],
        environment_variables: Sequence[Variable],
        conversation_variables: Sequence[Variable] | None = None,
    ) -> None:
        # system variables
        # for example:
        # {
        #     'query': 'abc',
        #     'files': []
        # }

        # Varaible dictionary is a dictionary for looking up variables by their selector.
        # The first element of the selector is the node id, it's the first-level key in the dictionary.
        # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
        # elements of the selector except the first one.
        self._variable_dictionary: dict[str, dict[int, Segment]] = defaultdict(dict)

        # TODO: This user inputs is not used for pool.
        self.user_inputs = user_inputs

        # Add system variables to the variable pool
        self.system_variables = system_variables
        for key, value in system_variables.items():
            self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)

        # Add environment variables to the variable pool
        for var in environment_variables:
            self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)

        # Add conversation variables to the variable pool
        for var in conversation_variables or []:
            self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)

    def add(self, selector: Sequence[str], value: Any, /) -> None:
        """
        Adds a variable to the variable pool.

        Args:
            selector (Sequence[str]): The selector for the variable.
            value (VariableValue): The value of the variable.

        Raises:
            ValueError: If the selector is invalid.

        Returns:
            None
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")

        if value is None:
            return

        if isinstance(value, Segment):
            v = value
        else:
            v = factory.build_segment(value)

        hash_key = hash(tuple(selector[1:]))
        self._variable_dictionary[selector[0]][hash_key] = v

    def get(self, selector: Sequence[str], /) -> Segment | None:
        """
        Retrieves the value from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): The selector used to identify the variable.

        Returns:
            Any: The value associated with the given selector.

        Raises:
            ValueError: If the selector is invalid.
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")
        hash_key = hash(tuple(selector[1:]))
        value = self._variable_dictionary[selector[0]].get(hash_key)

        return value

    @deprecated("This method is deprecated, use `get` instead.")
    def get_any(self, selector: Sequence[str], /) -> Any | None:
        """
        Retrieves the value from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): The selector used to identify the variable.

        Returns:
            Any: The value associated with the given selector.

        Raises:
            ValueError: If the selector is invalid.
        """
        if len(selector) < 2:
            raise ValueError("Invalid selector")
        hash_key = hash(tuple(selector[1:]))
        value = self._variable_dictionary[selector[0]].get(hash_key)
        return value.to_object() if value else None

    def remove(self, selector: Sequence[str], /):
        """
        Remove variables from the variable pool based on the given selector.

        Args:
            selector (Sequence[str]): A sequence of strings representing the selector.

        Returns:
            None
        """
        if not selector:
            return
        if len(selector) == 1:
            self._variable_dictionary[selector[0]] = {}
            return
        hash_key = hash(tuple(selector[1:]))
        self._variable_dictionary[selector[0]].pop(hash_key, None)