import json
import logging
from typing import List, Optional

from llama_index.data_structs import Node
from requests import ReadTimeout
from sqlalchemy.exc import IntegrityError
from tenacity import retry, stop_after_attempt, retry_if_exception_type

from core.index.index_builder import IndexBuilder
from core.vector_store.base import BaseGPTVectorStoreIndex
from extensions.ext_vector_store import vector_store
from extensions.ext_database import db
from models.dataset import Dataset, Embedding


class VectorIndex:

    def __init__(self, dataset: Dataset):
        self._dataset = dataset

    def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
        if not self._dataset.index_struct_dict:
            index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
            self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
            db.session.commit()

        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)

        index = vector_store.get_index(
            service_context=service_context,
            index_struct=self._dataset.index_struct_dict
        )

        if duplicate_check:
            nodes = self._filter_duplicate_nodes(index, nodes)

        embedding_queue_nodes = []
        embedded_nodes = []
        for node in nodes:
            node_hash = node.doc_hash

            # if node hash in cached embedding tables, use cached embedding
            embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
            if embedding:
                node.embedding = embedding.get_embedding()
                embedded_nodes.append(node)
            else:
                embedding_queue_nodes.append(node)

        if embedding_queue_nodes:
            embedding_results = index._get_node_embedding_results(
                embedding_queue_nodes,
                set(),
            )

            # pre embed nodes for cached embedding
            for embedding_result in embedding_results:
                node = embedding_result.node
                node.embedding = embedding_result.embedding

                try:
                    embedding = Embedding(hash=node.doc_hash)
                    embedding.set_embedding(node.embedding)
                    db.session.add(embedding)
                    db.session.commit()
                except IntegrityError:
                    db.session.rollback()
                    continue
                except:
                    logging.exception('Failed to add embedding to db')
                    continue

                embedded_nodes.append(node)

        self.index_insert_nodes(index, embedded_nodes)

    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
    def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
        index.insert_nodes(nodes)

    def del_nodes(self, node_ids: List[str]):
        if not self._dataset.index_struct_dict:
            return

        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)

        index = vector_store.get_index(
            service_context=service_context,
            index_struct=self._dataset.index_struct_dict
        )

        for node_id in node_ids:
            self.index_delete_node(index, node_id)

    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
    def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
        index.delete_node(node_id)

    def del_doc(self, doc_id: str):
        if not self._dataset.index_struct_dict:
            return

        service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)

        index = vector_store.get_index(
            service_context=service_context,
            index_struct=self._dataset.index_struct_dict
        )

        self.index_delete_doc(index, doc_id)

    @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
    def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
        index.delete(doc_id)

    @property
    def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
        if not self._dataset.index_struct_dict:
            return None

        service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)

        return vector_store.get_index(
            service_context=service_context,
            index_struct=self._dataset.index_struct_dict
        )

    def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
        for node in nodes:
            node_id = node.doc_id
            exists_duplicate_node = index.exists_by_node_id(node_id)
            if exists_duplicate_node:
                nodes.remove(node)

        return nodes