123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import json
- import logging
- from typing import List, Optional
- from llama_index.data_structs import Node
- from requests import ReadTimeout
- from sqlalchemy.exc import IntegrityError
- from tenacity import retry, stop_after_attempt, retry_if_exception_type
- from core.index.index_builder import IndexBuilder
- from core.vector_store.base import BaseGPTVectorStoreIndex
- from extensions.ext_vector_store import vector_store
- from extensions.ext_database import db
- from models.dataset import Dataset, Embedding
- class VectorIndex:
- def __init__(self, dataset: Dataset):
- self._dataset = dataset
- def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
- if not self._dataset.index_struct_dict:
- index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
- self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
- db.session.commit()
- service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
- index = vector_store.get_index(
- service_context=service_context,
- index_struct=self._dataset.index_struct_dict
- )
- if duplicate_check:
- nodes = self._filter_duplicate_nodes(index, nodes)
- embedding_queue_nodes = []
- embedded_nodes = []
- for node in nodes:
- node_hash = node.doc_hash
- # if node hash in cached embedding tables, use cached embedding
- embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
- if embedding:
- node.embedding = embedding.get_embedding()
- embedded_nodes.append(node)
- else:
- embedding_queue_nodes.append(node)
- if embedding_queue_nodes:
- embedding_results = index._get_node_embedding_results(
- embedding_queue_nodes,
- set(),
- )
- # pre embed nodes for cached embedding
- for embedding_result in embedding_results:
- node = embedding_result.node
- node.embedding = embedding_result.embedding
- try:
- embedding = Embedding(hash=node.doc_hash)
- embedding.set_embedding(node.embedding)
- db.session.add(embedding)
- db.session.commit()
- except IntegrityError:
- db.session.rollback()
- continue
- except:
- logging.exception('Failed to add embedding to db')
- continue
- embedded_nodes.append(node)
- self.index_insert_nodes(index, embedded_nodes)
- @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
- def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
- index.insert_nodes(nodes)
- def del_nodes(self, node_ids: List[str]):
- if not self._dataset.index_struct_dict:
- return
- service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
- index = vector_store.get_index(
- service_context=service_context,
- index_struct=self._dataset.index_struct_dict
- )
- for node_id in node_ids:
- self.index_delete_node(index, node_id)
- @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
- def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
- index.delete_node(node_id)
- def del_doc(self, doc_id: str):
- if not self._dataset.index_struct_dict:
- return
- service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
- index = vector_store.get_index(
- service_context=service_context,
- index_struct=self._dataset.index_struct_dict
- )
- self.index_delete_doc(index, doc_id)
- @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
- def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
- index.delete(doc_id)
- @property
- def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
- if not self._dataset.index_struct_dict:
- return None
- service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
- return vector_store.get_index(
- service_context=service_context,
- index_struct=self._dataset.index_struct_dict
- )
- def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
- for node in nodes:
- node_id = node.doc_id
- exists_duplicate_node = index.exists_by_node_id(node_id)
- if exists_duplicate_node:
- nodes.remove(node)
- return nodes
|