vector_index.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. import logging
  3. from typing import List, Optional
  4. from llama_index.data_structs import Node
  5. from requests import ReadTimeout
  6. from sqlalchemy.exc import IntegrityError
  7. from tenacity import retry, stop_after_attempt, retry_if_exception_type
  8. from core.index.index_builder import IndexBuilder
  9. from core.vector_store.base import BaseGPTVectorStoreIndex
  10. from extensions.ext_vector_store import vector_store
  11. from extensions.ext_database import db
  12. from models.dataset import Dataset, Embedding
  13. class VectorIndex:
  14. def __init__(self, dataset: Dataset):
  15. self._dataset = dataset
  16. def add_nodes(self, nodes: List[Node], duplicate_check: bool = False):
  17. if not self._dataset.index_struct_dict:
  18. index_id = "Vector_index_" + self._dataset.id.replace("-", "_")
  19. self._dataset.index_struct = json.dumps(vector_store.to_index_struct(index_id))
  20. db.session.commit()
  21. service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
  22. index = vector_store.get_index(
  23. service_context=service_context,
  24. index_struct=self._dataset.index_struct_dict
  25. )
  26. if duplicate_check:
  27. nodes = self._filter_duplicate_nodes(index, nodes)
  28. embedding_queue_nodes = []
  29. embedded_nodes = []
  30. for node in nodes:
  31. node_hash = node.doc_hash
  32. # if node hash in cached embedding tables, use cached embedding
  33. embedding = db.session.query(Embedding).filter_by(hash=node_hash).first()
  34. if embedding:
  35. node.embedding = embedding.get_embedding()
  36. embedded_nodes.append(node)
  37. else:
  38. embedding_queue_nodes.append(node)
  39. if embedding_queue_nodes:
  40. embedding_results = index._get_node_embedding_results(
  41. embedding_queue_nodes,
  42. set(),
  43. )
  44. # pre embed nodes for cached embedding
  45. for embedding_result in embedding_results:
  46. node = embedding_result.node
  47. node.embedding = embedding_result.embedding
  48. try:
  49. embedding = Embedding(hash=node.doc_hash)
  50. embedding.set_embedding(node.embedding)
  51. db.session.add(embedding)
  52. db.session.commit()
  53. except IntegrityError:
  54. db.session.rollback()
  55. continue
  56. except:
  57. logging.exception('Failed to add embedding to db')
  58. continue
  59. embedded_nodes.append(node)
  60. self.index_insert_nodes(index, embedded_nodes)
  61. @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
  62. def index_insert_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]):
  63. index.insert_nodes(nodes)
  64. def del_nodes(self, node_ids: List[str]):
  65. if not self._dataset.index_struct_dict:
  66. return
  67. service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
  68. index = vector_store.get_index(
  69. service_context=service_context,
  70. index_struct=self._dataset.index_struct_dict
  71. )
  72. for node_id in node_ids:
  73. self.index_delete_node(index, node_id)
  74. @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
  75. def index_delete_node(self, index: BaseGPTVectorStoreIndex, node_id: str):
  76. index.delete_node(node_id)
  77. def del_doc(self, doc_id: str):
  78. if not self._dataset.index_struct_dict:
  79. return
  80. service_context = IndexBuilder.get_fake_llm_service_context(tenant_id=self._dataset.tenant_id)
  81. index = vector_store.get_index(
  82. service_context=service_context,
  83. index_struct=self._dataset.index_struct_dict
  84. )
  85. self.index_delete_doc(index, doc_id)
  86. @retry(reraise=True, retry=retry_if_exception_type(ReadTimeout), stop=stop_after_attempt(3))
  87. def index_delete_doc(self, index: BaseGPTVectorStoreIndex, doc_id: str):
  88. index.delete(doc_id)
  89. @property
  90. def query_index(self) -> Optional[BaseGPTVectorStoreIndex]:
  91. if not self._dataset.index_struct_dict:
  92. return None
  93. service_context = IndexBuilder.get_default_service_context(tenant_id=self._dataset.tenant_id)
  94. return vector_store.get_index(
  95. service_context=service_context,
  96. index_struct=self._dataset.index_struct_dict
  97. )
  98. def _filter_duplicate_nodes(self, index: BaseGPTVectorStoreIndex, nodes: List[Node]) -> List[Node]:
  99. for node in nodes:
  100. node_id = node.doc_id
  101. exists_duplicate_node = index.exists_by_node_id(node_id)
  102. if exists_duplicate_node:
  103. nodes.remove(node)
  104. return nodes