Browse Source

update document and segment word count (#10449)

Jyong 5 months ago
parent
commit
4f1a56f0f0
2 changed files with 41 additions and 4 deletions
  1. 35 3
      api/services/dataset_service.py
  2. 6 1
      api/tasks/batch_create_segment_to_index_task.py

+ 35 - 3
api/services/dataset_service.py

@@ -1414,9 +1414,13 @@ class SegmentService:
                 created_by=current_user.id,
             )
             if document.doc_form == "qa_model":
+                segment_document.word_count += len(args["answer"])
                 segment_document.answer = args["answer"]
 
             db.session.add(segment_document)
+            # update document word count
+            document.word_count += segment_document.word_count
+            db.session.add(document)
             db.session.commit()
 
             # save vector index
@@ -1435,6 +1439,7 @@ class SegmentService:
     @classmethod
     def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):
         lock_name = "multi_add_segment_lock_document_id_{}".format(document.id)
+        increment_word_count = 0
         with redis_client.lock(lock_name, timeout=600):
             embedding_model = None
             if dataset.indexing_technique == "high_quality":
@@ -1460,7 +1465,10 @@ class SegmentService:
                 tokens = 0
                 if dataset.indexing_technique == "high_quality" and embedding_model:
                     # calc embedding use tokens
-                    tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
+                    if document.doc_form == "qa_model":
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
+                    else:
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
                 segment_document = DocumentSegment(
                     tenant_id=current_user.current_tenant_id,
                     dataset_id=document.dataset_id,
@@ -1478,6 +1486,8 @@ class SegmentService:
                 )
                 if document.doc_form == "qa_model":
                     segment_document.answer = segment_item["answer"]
+                    segment_document.word_count += len(segment_item["answer"])
+                increment_word_count += segment_document.word_count
                 db.session.add(segment_document)
                 segment_data_list.append(segment_document)
 
@@ -1486,7 +1496,9 @@ class SegmentService:
                     keywords_list.append(segment_item["keywords"])
                 else:
                     keywords_list.append(None)
-
+            # update document word count
+            document.word_count += increment_word_count
+            db.session.add(document)
             try:
                 # save vector index
                 VectorService.create_segments_vector(keywords_list, pre_segment_data_list, dataset)
@@ -1527,10 +1539,14 @@ class SegmentService:
             else:
                 raise ValueError("Can't update disabled segment")
         try:
+            word_count_change = segment.word_count
             content = segment_update_entity.content
             if segment.content == content:
+                segment.word_count = len(content)
                 if document.doc_form == "qa_model":
                     segment.answer = segment_update_entity.answer
+                    segment.word_count += len(segment_update_entity.answer)
+                word_count_change = segment.word_count - word_count_change
                 if segment_update_entity.keywords:
                     segment.keywords = segment_update_entity.keywords
                 segment.enabled = True
@@ -1538,6 +1554,10 @@ class SegmentService:
                 segment.disabled_by = None
                 db.session.add(segment)
                 db.session.commit()
+                # update document word count
+                if word_count_change != 0:
+                    document.word_count = max(0, document.word_count + word_count_change)
+                    db.session.add(document)
                 # update segment index task
                 if segment_update_entity.enabled:
                     VectorService.create_segments_vector([segment_update_entity.keywords], [segment], dataset)
@@ -1554,7 +1574,10 @@ class SegmentService:
                     )
 
                     # calc embedding use tokens
-                    tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
+                    if document.doc_form == "qa_model":
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
+                    else:
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
                 segment.content = content
                 segment.index_node_hash = segment_hash
                 segment.word_count = len(content)
@@ -1569,6 +1592,12 @@ class SegmentService:
                 segment.disabled_by = None
                 if document.doc_form == "qa_model":
                     segment.answer = segment_update_entity.answer
+                    segment.word_count += len(segment_update_entity.answer)
+                word_count_change = segment.word_count - word_count_change
+                # update document word count
+                if word_count_change != 0:
+                    document.word_count = max(0, document.word_count + word_count_change)
+                    db.session.add(document)
                 db.session.add(segment)
                 db.session.commit()
                 # update segment vector index
@@ -1597,6 +1626,9 @@ class SegmentService:
             redis_client.setex(indexing_cache_key, 600, 1)
             delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)
         db.session.delete(segment)
+        # update document word count
+        document.word_count -= segment.word_count
+        db.session.add(document)
         db.session.commit()
 
 

+ 6 - 1
api/tasks/batch_create_segment_to_index_task.py

@@ -57,7 +57,7 @@ def batch_create_segment_to_index_task(
                 model_type=ModelType.TEXT_EMBEDDING,
                 model=dataset.embedding_model,
             )
-
+        word_count_change = 0
         for segment in content:
             content = segment["content"]
             doc_id = str(uuid.uuid4())
@@ -86,8 +86,13 @@ def batch_create_segment_to_index_task(
             )
             if dataset_document.doc_form == "qa_model":
                 segment_document.answer = segment["answer"]
+                segment_document.word_count += len(segment["answer"])
+            word_count_change += segment_document.word_count
             db.session.add(segment_document)
             document_segments.append(segment_document)
+        # update document word count
+        dataset_document.word_count += word_count_change
+        db.session.add(dataset_document)
         # add index to db
         indexing_runner = IndexingRunner()
         indexing_runner.batch_add_segments(document_segments, dataset)