add_segment_to_index_task.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import datetime
  2. import logging
  3. import time
  4. import click
  5. from celery import shared_task
  6. from llama_index.data_structs import Node
  7. from llama_index.data_structs.node_v2 import DocumentRelationship
  8. from werkzeug.exceptions import NotFound
  9. from core.index.keyword_table_index import KeywordTableIndex
  10. from core.index.vector_index import VectorIndex
  11. from extensions.ext_database import db
  12. from extensions.ext_redis import redis_client
  13. from models.dataset import DocumentSegment
  14. @shared_task
  15. def add_segment_to_index_task(segment_id: str):
  16. """
  17. Async Add segment to index
  18. :param segment_id:
  19. Usage: add_segment_to_index.delay(segment_id)
  20. """
  21. logging.info(click.style('Start add segment to index: {}'.format(segment_id), fg='green'))
  22. start_at = time.perf_counter()
  23. segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
  24. if not segment:
  25. raise NotFound('Segment not found')
  26. if segment.status != 'completed':
  27. return
  28. indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
  29. try:
  30. relationships = {
  31. DocumentRelationship.SOURCE: segment.document_id,
  32. }
  33. previous_segment = segment.previous_segment
  34. if previous_segment:
  35. relationships[DocumentRelationship.PREVIOUS] = previous_segment.index_node_id
  36. next_segment = segment.next_segment
  37. if next_segment:
  38. relationships[DocumentRelationship.NEXT] = next_segment.index_node_id
  39. node = Node(
  40. doc_id=segment.index_node_id,
  41. doc_hash=segment.index_node_hash,
  42. text=segment.content,
  43. extra_info=None,
  44. node_info=None,
  45. relationships=relationships
  46. )
  47. dataset = segment.dataset
  48. if not dataset:
  49. raise Exception('Segment has no dataset')
  50. vector_index = VectorIndex(dataset=dataset)
  51. keyword_table_index = KeywordTableIndex(dataset=dataset)
  52. # save vector index
  53. if dataset.indexing_technique == "high_quality":
  54. vector_index.add_nodes(
  55. nodes=[node],
  56. duplicate_check=True
  57. )
  58. # save keyword index
  59. keyword_table_index.add_nodes([node])
  60. end_at = time.perf_counter()
  61. logging.info(click.style('Segment added to index: {} latency: {}'.format(segment.id, end_at - start_at), fg='green'))
  62. except Exception as e:
  63. logging.exception("add segment to index failed")
  64. segment.enabled = False
  65. segment.disabled_at = datetime.datetime.utcnow()
  66. segment.status = 'error'
  67. segment.error = str(e)
  68. db.session.commit()
  69. finally:
  70. redis_client.delete(indexing_cache_key)