deal_dataset_vector_index_task.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import logging
  2. import time
  3. import click
  4. from celery import shared_task
  5. from llama_index.data_structs.node_v2 import DocumentRelationship, Node
  6. from core.index.vector_index import VectorIndex
  7. from extensions.ext_database import db
  8. from models.dataset import DocumentSegment, Document, Dataset
  9. @shared_task
  10. def deal_dataset_vector_index_task(dataset_id: str, action: str):
  11. """
  12. Async deal dataset from index
  13. :param dataset_id: dataset_id
  14. :param action: action
  15. Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
  16. """
  17. logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green'))
  18. start_at = time.perf_counter()
  19. try:
  20. dataset = Dataset.query.filter_by(
  21. id=dataset_id
  22. ).first()
  23. if not dataset:
  24. raise Exception('Dataset not found')
  25. documents = Document.query.filter_by(dataset_id=dataset_id).all()
  26. if documents:
  27. vector_index = VectorIndex(dataset=dataset)
  28. for document in documents:
  29. # delete from vector index
  30. if action == "remove":
  31. vector_index.del_doc(document.id)
  32. elif action == "add":
  33. segments = db.session.query(DocumentSegment).filter(
  34. DocumentSegment.document_id == document.id,
  35. DocumentSegment.enabled == True
  36. ) .order_by(DocumentSegment.position.asc()).all()
  37. nodes = []
  38. previous_node = None
  39. for segment in segments:
  40. relationships = {
  41. DocumentRelationship.SOURCE: document.id
  42. }
  43. if previous_node:
  44. relationships[DocumentRelationship.PREVIOUS] = previous_node.doc_id
  45. previous_node.relationships[DocumentRelationship.NEXT] = segment.index_node_id
  46. node = Node(
  47. doc_id=segment.index_node_id,
  48. doc_hash=segment.index_node_hash,
  49. text=segment.content,
  50. extra_info=None,
  51. node_info=None,
  52. relationships=relationships
  53. )
  54. previous_node = node
  55. nodes.append(node)
  56. # save vector index
  57. vector_index.add_nodes(
  58. nodes=nodes,
  59. duplicate_check=True
  60. )
  61. end_at = time.perf_counter()
  62. logging.info(
  63. click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
  64. except Exception:
  65. logging.exception("Deal dataset vector index failed")