deal_dataset_vector_index_task.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import logging
  2. import time
  3. import click
  4. from celery import shared_task
  5. from core.index.index import IndexBuilder
  6. from extensions.ext_database import db
  7. from langchain.schema import Document
  8. from models.dataset import Dataset
  9. from models.dataset import Document as DatasetDocument
  10. from models.dataset import DocumentSegment
  11. @shared_task(queue='dataset')
  12. def deal_dataset_vector_index_task(dataset_id: str, action: str):
  13. """
  14. Async deal dataset from index
  15. :param dataset_id: dataset_id
  16. :param action: action
  17. Usage: deal_dataset_vector_index_task.delay(dataset_id, action)
  18. """
  19. logging.info(click.style('Start deal dataset vector index: {}'.format(dataset_id), fg='green'))
  20. start_at = time.perf_counter()
  21. try:
  22. dataset = Dataset.query.filter_by(
  23. id=dataset_id
  24. ).first()
  25. if not dataset:
  26. raise Exception('Dataset not found')
  27. if action == "remove":
  28. index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
  29. index.delete_by_group_id(dataset.id)
  30. elif action == "add":
  31. dataset_documents = db.session.query(DatasetDocument).filter(
  32. DatasetDocument.dataset_id == dataset_id,
  33. DatasetDocument.indexing_status == 'completed',
  34. DatasetDocument.enabled == True,
  35. DatasetDocument.archived == False,
  36. ).all()
  37. if dataset_documents:
  38. # save vector index
  39. index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
  40. documents = []
  41. for dataset_document in dataset_documents:
  42. # delete from vector index
  43. segments = db.session.query(DocumentSegment).filter(
  44. DocumentSegment.document_id == dataset_document.id,
  45. DocumentSegment.enabled == True
  46. ) .order_by(DocumentSegment.position.asc()).all()
  47. for segment in segments:
  48. document = Document(
  49. page_content=segment.content,
  50. metadata={
  51. "doc_id": segment.index_node_id,
  52. "doc_hash": segment.index_node_hash,
  53. "document_id": segment.document_id,
  54. "dataset_id": segment.dataset_id,
  55. }
  56. )
  57. documents.append(document)
  58. # save vector index
  59. index.create(documents)
  60. end_at = time.perf_counter()
  61. logging.info(
  62. click.style('Deal dataset vector index: {} latency: {}'.format(dataset_id, end_at - start_at), fg='green'))
  63. except Exception:
  64. logging.exception("Deal dataset vector index failed")