| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913 | import datetimeimport jsonimport loggingimport reimport threadingimport timeimport uuidfrom typing import AbstractSet, Any, Collection, List, Literal, Optional, Type, Union, castfrom core.data_loader.file_extractor import FileExtractorfrom core.data_loader.loader.notion import NotionLoaderfrom core.docstore.dataset_docstore import DatasetDocumentStorefrom core.errors.error import ProviderTokenNotInitErrorfrom core.generator.llm_generator import LLMGeneratorfrom core.index.index import IndexBuilderfrom core.model_manager import ModelManager, ModelInstancefrom core.model_runtime.entities.model_entities import ModelType, PriceTypefrom core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModelfrom core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModelfrom core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizerfrom core.spiltter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitterfrom extensions.ext_database import dbfrom extensions.ext_redis import redis_clientfrom extensions.ext_storage import storagefrom flask import Flask, current_appfrom flask_login import current_userfrom langchain.schema import Documentfrom langchain.text_splitter import TS, TextSplitter, TokenTextSplitterfrom libs import helperfrom models.dataset import Dataset, DatasetProcessRulefrom models.dataset import Document as DatasetDocumentfrom models.dataset import DocumentSegmentfrom models.model import UploadFilefrom models.source import DataSourceBindingfrom sqlalchemy.orm.exc import ObjectDeletedErrorclass IndexingRunner:    def __init__(self):        self.storage = storage        self.model_manager = ModelManager()    def run(self, dataset_documents: List[DatasetDocument]):        """Run the indexing process."""        for dataset_document in dataset_documents:            try:                                dataset = Dataset.query.filter_by(                    id=dataset_document.dataset_id                ).first()                if not dataset:                    raise ValueError("no dataset found")                                processing_rule = db.session.query(DatasetProcessRule). \                    filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \                    first()                                text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')                                embedding_model_instance = None                if dataset.indexing_technique == 'high_quality':                    if dataset.embedding_model_provider:                        embedding_model_instance = self.model_manager.get_model_instance(                            tenant_id=dataset.tenant_id,                            provider=dataset.embedding_model_provider,                            model_type=ModelType.TEXT_EMBEDDING,                            model=dataset.embedding_model                        )                    else:                        embedding_model_instance = self.model_manager.get_default_model_instance(                            tenant_id=dataset.tenant_id,                            model_type=ModelType.TEXT_EMBEDDING,                        )                                splitter = self._get_splitter(processing_rule, embedding_model_instance)                                documents = self._step_split(                    text_docs=text_docs,                    splitter=splitter,                    dataset=dataset,                    dataset_document=dataset_document,                    processing_rule=processing_rule                )                self._build_index(                    dataset=dataset,                    dataset_document=dataset_document,                    documents=documents                )            except DocumentIsPausedException:                raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))            except ProviderTokenNotInitError as e:                dataset_document.indexing_status = 'error'                dataset_document.error = str(e.description)                dataset_document.stopped_at = datetime.datetime.utcnow()                db.session.commit()            except ObjectDeletedError:                logging.warning('Document deleted, document id: {}'.format(dataset_document.id))            except Exception as e:                logging.exception("consume document failed")                dataset_document.indexing_status = 'error'                dataset_document.error = str(e)                dataset_document.stopped_at = datetime.datetime.utcnow()                db.session.commit()    def run_in_splitting_status(self, dataset_document: DatasetDocument):        """Run the indexing process when the index_status is splitting."""        try:                        dataset = Dataset.query.filter_by(                id=dataset_document.dataset_id            ).first()            if not dataset:                raise ValueError("no dataset found")                        document_segments = DocumentSegment.query.filter_by(                dataset_id=dataset.id,                document_id=dataset_document.id            ).all()            for document_segment in document_segments:                db.session.delete(document_segment)            db.session.commit()                        processing_rule = db.session.query(DatasetProcessRule). \                filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \                first()                        text_docs = self._load_data(dataset_document, processing_rule.mode == 'automatic')                        embedding_model_instance = None            if dataset.indexing_technique == 'high_quality':                if dataset.embedding_model_provider:                    embedding_model_instance = self.model_manager.get_model_instance(                        tenant_id=dataset.tenant_id,                        provider=dataset.embedding_model_provider,                        model_type=ModelType.TEXT_EMBEDDING,                        model=dataset.embedding_model                    )                else:                    embedding_model_instance = self.model_manager.get_default_model_instance(                        tenant_id=dataset.tenant_id,                        model_type=ModelType.TEXT_EMBEDDING,                    )                        splitter = self._get_splitter(processing_rule, embedding_model_instance)                        documents = self._step_split(                text_docs=text_docs,                splitter=splitter,                dataset=dataset,                dataset_document=dataset_document,                processing_rule=processing_rule            )                        self._build_index(                dataset=dataset,                dataset_document=dataset_document,                documents=documents            )        except DocumentIsPausedException:            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))        except ProviderTokenNotInitError as e:            dataset_document.indexing_status = 'error'            dataset_document.error = str(e.description)            dataset_document.stopped_at = datetime.datetime.utcnow()            db.session.commit()        except Exception as e:            logging.exception("consume document failed")            dataset_document.indexing_status = 'error'            dataset_document.error = str(e)            dataset_document.stopped_at = datetime.datetime.utcnow()            db.session.commit()    def run_in_indexing_status(self, dataset_document: DatasetDocument):        """Run the indexing process when the index_status is indexing."""        try:                        dataset = Dataset.query.filter_by(                id=dataset_document.dataset_id            ).first()            if not dataset:                raise ValueError("no dataset found")                        document_segments = DocumentSegment.query.filter_by(                dataset_id=dataset.id,                document_id=dataset_document.id            ).all()            documents = []            if document_segments:                for document_segment in document_segments:                                        if document_segment.status != "completed":                        document = Document(                            page_content=document_segment.content,                            metadata={                                "doc_id": document_segment.index_node_id,                                "doc_hash": document_segment.index_node_hash,                                "document_id": document_segment.document_id,                                "dataset_id": document_segment.dataset_id,                            }                        )                        documents.append(document)                        self._build_index(                dataset=dataset,                dataset_document=dataset_document,                documents=documents            )        except DocumentIsPausedException:            raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))        except ProviderTokenNotInitError as e:            dataset_document.indexing_status = 'error'            dataset_document.error = str(e.description)            dataset_document.stopped_at = datetime.datetime.utcnow()            db.session.commit()        except Exception as e:            logging.exception("consume document failed")            dataset_document.indexing_status = 'error'            dataset_document.error = str(e)            dataset_document.stopped_at = datetime.datetime.utcnow()            db.session.commit()    def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,                               doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,                               indexing_technique: str = 'economy') -> dict:        """        Estimate the indexing for the document.        """        embedding_model_instance = None        if dataset_id:            dataset = Dataset.query.filter_by(                id=dataset_id            ).first()            if not dataset:                raise ValueError('Dataset not found.')            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':                if dataset.embedding_model_provider:                    embedding_model_instance = self.model_manager.get_model_instance(                        tenant_id=tenant_id,                        provider=dataset.embedding_model_provider,                        model_type=ModelType.TEXT_EMBEDDING,                        model=dataset.embedding_model                    )                else:                    embedding_model_instance = self.model_manager.get_default_model_instance(                        tenant_id=tenant_id,                        model_type=ModelType.TEXT_EMBEDDING,                    )        else:            if indexing_technique == 'high_quality':                embedding_model_instance = self.model_manager.get_default_model_instance(                    tenant_id=tenant_id,                    model_type=ModelType.TEXT_EMBEDDING,                )        tokens = 0        preview_texts = []        total_segments = 0        total_price = 0        currency = 'USD'        for file_detail in file_details:            processing_rule = DatasetProcessRule(                mode=tmp_processing_rule["mode"],                rules=json.dumps(tmp_processing_rule["rules"])            )                        text_docs = FileExtractor.load(file_detail, is_automatic=processing_rule.mode == 'automatic')                        splitter = self._get_splitter(processing_rule, embedding_model_instance)                        documents = self._split_to_documents_for_estimate(                text_docs=text_docs,                splitter=splitter,                processing_rule=processing_rule            )            total_segments += len(documents)            for document in documents:                if len(preview_texts) < 5:                    preview_texts.append(document.page_content)                if indexing_technique == 'high_quality' or embedding_model_instance:                    embedding_model_type_instance = embedding_model_instance.model_type_instance                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)                    tokens += embedding_model_type_instance.get_num_tokens(                        model=embedding_model_instance.model,                        credentials=embedding_model_instance.credentials,                        texts=[self.filter_string(document.page_content)]                    )        if doc_form and doc_form == 'qa_model':            model_instance = self.model_manager.get_default_model_instance(                tenant_id=tenant_id,                model_type=ModelType.LLM            )            model_type_instance = model_instance.model_type_instance            model_type_instance = cast(LargeLanguageModel, model_type_instance)            if len(preview_texts) > 0:                                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],                                                             doc_language)                document_qa_list = self.format_split_text(response)                price_info = model_type_instance.get_price(                    model=model_instance.model,                    credentials=model_instance.credentials,                    price_type=PriceType.INPUT,                    tokens=total_segments * 2000,                )                return {                    "total_segments": total_segments * 20,                    "tokens": total_segments * 2000,                    "total_price": '{:f}'.format(price_info.total_amount),                    "currency": price_info.currency,                    "qa_preview": document_qa_list,                    "preview": preview_texts                }        if embedding_model_instance:            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)            embedding_price_info = embedding_model_type_instance.get_price(                model=embedding_model_instance.model,                credentials=embedding_model_instance.credentials,                price_type=PriceType.INPUT,                tokens=tokens            )            total_price = '{:f}'.format(embedding_price_info.total_amount)            currency = embedding_price_info.currency        return {            "total_segments": total_segments,            "tokens": tokens,            "total_price": total_price,            "currency": currency,            "preview": preview_texts        }    def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,                                 doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,                                 indexing_technique: str = 'economy') -> dict:        """        Estimate the indexing for the document.        """        embedding_model_instance = None        if dataset_id:            dataset = Dataset.query.filter_by(                id=dataset_id            ).first()            if not dataset:                raise ValueError('Dataset not found.')            if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':                if dataset.embedding_model_provider:                    embedding_model_instance = self.model_manager.get_model_instance(                        tenant_id=tenant_id,                        provider=dataset.embedding_model_provider,                        model_type=ModelType.TEXT_EMBEDDING,                        model=dataset.embedding_model                    )                else:                    embedding_model_instance = self.model_manager.get_default_model_instance(                        tenant_id=tenant_id,                        model_type=ModelType.TEXT_EMBEDDING,                    )        else:            if indexing_technique == 'high_quality':                embedding_model_instance = self.model_manager.get_default_model_instance(                    tenant_id=tenant_id,                    model_type=ModelType.TEXT_EMBEDDING                )                tokens = 0        preview_texts = []        total_segments = 0        total_price = 0        currency = 'USD'        for notion_info in notion_info_list:            workspace_id = notion_info['workspace_id']            data_source_binding = DataSourceBinding.query.filter(                db.and_(                    DataSourceBinding.tenant_id == current_user.current_tenant_id,                    DataSourceBinding.provider == 'notion',                    DataSourceBinding.disabled == False,                    DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'                )            ).first()            if not data_source_binding:                raise ValueError('Data source binding not found.')            for page in notion_info['pages']:                loader = NotionLoader(                    notion_access_token=data_source_binding.access_token,                    notion_workspace_id=workspace_id,                    notion_obj_id=page['page_id'],                    notion_page_type=page['type']                )                documents = loader.load()                processing_rule = DatasetProcessRule(                    mode=tmp_processing_rule["mode"],                    rules=json.dumps(tmp_processing_rule["rules"])                )                                splitter = self._get_splitter(processing_rule, embedding_model_instance)                                documents = self._split_to_documents_for_estimate(                    text_docs=documents,                    splitter=splitter,                    processing_rule=processing_rule                )                total_segments += len(documents)                embedding_model_type_instance = None                if embedding_model_instance:                    embedding_model_type_instance = embedding_model_instance.model_type_instance                    embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)                for document in documents:                    if len(preview_texts) < 5:                        preview_texts.append(document.page_content)                    if indexing_technique == 'high_quality' and embedding_model_type_instance:                        tokens += embedding_model_type_instance.get_num_tokens(                            model=embedding_model_instance.model,                            credentials=embedding_model_instance.credentials,                            texts=[document.page_content]                        )        if doc_form and doc_form == 'qa_model':            model_instance = self.model_manager.get_default_model_instance(                tenant_id=tenant_id,                model_type=ModelType.LLM            )            model_type_instance = model_instance.model_type_instance            model_type_instance = cast(LargeLanguageModel, model_type_instance)            if len(preview_texts) > 0:                                response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],                                                             doc_language)                document_qa_list = self.format_split_text(response)                price_info = model_type_instance.get_price(                    model=model_instance.model,                    credentials=model_instance.credentials,                    price_type=PriceType.INPUT,                    tokens=total_segments * 2000,                )                return {                    "total_segments": total_segments * 20,                    "tokens": total_segments * 2000,                    "total_price": '{:f}'.format(price_info.total_amount),                    "currency": price_info.currency,                    "qa_preview": document_qa_list,                    "preview": preview_texts                }        if embedding_model_instance:            embedding_model_type_instance = embedding_model_instance.model_type_instance            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)            embedding_price_info = embedding_model_type_instance.get_price(                model=embedding_model_instance.model,                credentials=embedding_model_instance.credentials,                price_type=PriceType.INPUT,                tokens=tokens            )            total_price = '{:f}'.format(embedding_price_info.total_amount)            currency = embedding_price_info.currency        return {            "total_segments": total_segments,            "tokens": tokens,            "total_price": total_price,            "currency": currency,            "preview": preview_texts        }    def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]:                if dataset_document.data_source_type not in ["upload_file", "notion_import"]:            return []        data_source_info = dataset_document.data_source_info_dict        text_docs = []        if dataset_document.data_source_type == 'upload_file':            if not data_source_info or 'upload_file_id' not in data_source_info:                raise ValueError("no upload file found")            file_detail = db.session.query(UploadFile). \                filter(UploadFile.id == data_source_info['upload_file_id']). \                one_or_none()            if file_detail:                text_docs = FileExtractor.load(file_detail, is_automatic=automatic)        elif dataset_document.data_source_type == 'notion_import':            loader = NotionLoader.from_document(dataset_document)            text_docs = loader.load()                self._update_document_index_status(            document_id=dataset_document.id,            after_indexing_status="splitting",            extra_update_params={                DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),                DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()            }        )                text_docs = cast(List[Document], text_docs)        for text_doc in text_docs:                        text_doc.page_content = self.filter_string(text_doc.page_content)            text_doc.metadata['document_id'] = dataset_document.id            text_doc.metadata['dataset_id'] = dataset_document.dataset_id        return text_docs    def filter_string(self, text):        text = re.sub(r'<\|', '<', text)        text = re.sub(r'\|>', '>', text)        text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)                text = re.sub(u'\uFFFE', '', text)        return text    def _get_splitter(self, processing_rule: DatasetProcessRule,                      embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:        """        Get the NodeParser object according to the processing rule.        """        if processing_rule.mode == "custom":                        rules = json.loads(processing_rule.rules)            segmentation = rules["segmentation"]            if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:                raise ValueError("Custom segment length should be between 50 and 1000.")            separator = segmentation["separator"]            if separator:                separator = separator.replace('\\n', '\n')            character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(                chunk_size=segmentation["max_tokens"],                chunk_overlap=segmentation.get('chunk_overlap', 0),                fixed_separator=separator,                separators=["\n\n", "。", ".", " ", ""],                embedding_model_instance=embedding_model_instance            )        else:                        character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(                chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],                chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],                separators=["\n\n", "。", ".", " ", ""],                embedding_model_instance=embedding_model_instance            )        return character_splitter    def _step_split(self, text_docs: List[Document], splitter: TextSplitter,                    dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \            -> List[Document]:        """        Split the text documents into documents and save them to the document segment.        """        documents = self._split_to_documents(            text_docs=text_docs,            splitter=splitter,            processing_rule=processing_rule,            tenant_id=dataset.tenant_id,            document_form=dataset_document.doc_form,            document_language=dataset_document.doc_language        )                doc_store = DatasetDocumentStore(            dataset=dataset,            user_id=dataset_document.created_by,            document_id=dataset_document.id        )                doc_store.add_documents(documents)                cur_time = datetime.datetime.utcnow()        self._update_document_index_status(            document_id=dataset_document.id,            after_indexing_status="indexing",            extra_update_params={                DatasetDocument.cleaning_completed_at: cur_time,                DatasetDocument.splitting_completed_at: cur_time,            }        )                self._update_segments_by_document(            dataset_document_id=dataset_document.id,            update_params={                DocumentSegment.status: "indexing",                DocumentSegment.indexing_at: datetime.datetime.utcnow()            }        )        return documents    def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter,                            processing_rule: DatasetProcessRule, tenant_id: str,                            document_form: str, document_language: str) -> List[Document]:        """        Split the text documents into nodes.        """        all_documents = []        all_qa_documents = []        for text_doc in text_docs:                        document_text = self._document_clean(text_doc.page_content, processing_rule)            text_doc.page_content = document_text                        documents = splitter.split_documents([text_doc])            split_documents = []            for document_node in documents:                if document_node.page_content.strip():                    doc_id = str(uuid.uuid4())                    hash = helper.generate_text_hash(document_node.page_content)                    document_node.metadata['doc_id'] = doc_id                    document_node.metadata['doc_hash'] = hash                                        page_content = document_node.page_content                    if page_content.startswith(".") or page_content.startswith("。"):                        page_content = page_content[1:]                    else:                        page_content = page_content                    document_node.page_content = page_content                    if document_node.page_content:                        split_documents.append(document_node)            all_documents.extend(split_documents)                if document_form == 'qa_model':            for i in range(0, len(all_documents), 10):                threads = []                sub_documents = all_documents[i:i + 10]                for doc in sub_documents:                    document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={                        'flask_app': current_app._get_current_object(),                        'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,                        'document_language': document_language})                    threads.append(document_format_thread)                    document_format_thread.start()                for thread in threads:                    thread.join()            return all_qa_documents        return all_documents    def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):        format_documents = []        if document_node.page_content is None or not document_node.page_content.strip():            return        with flask_app.app_context():            try:                                response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)                document_qa_list = self.format_split_text(response)                qa_documents = []                for result in document_qa_list:                    qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())                    doc_id = str(uuid.uuid4())                    hash = helper.generate_text_hash(result['question'])                    qa_document.metadata['answer'] = result['answer']                    qa_document.metadata['doc_id'] = doc_id                    qa_document.metadata['doc_hash'] = hash                    qa_documents.append(qa_document)                format_documents.extend(qa_documents)            except Exception as e:                logging.exception(e)            all_qa_documents.extend(format_documents)    def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,                                         processing_rule: DatasetProcessRule) -> List[Document]:        """        Split the text documents into nodes.        """        all_documents = []        for text_doc in text_docs:                        document_text = self._document_clean(text_doc.page_content, processing_rule)            text_doc.page_content = document_text                        documents = splitter.split_documents([text_doc])            split_documents = []            for document in documents:                if document.page_content is None or not document.page_content.strip():                    continue                doc_id = str(uuid.uuid4())                hash = helper.generate_text_hash(document.page_content)                document.metadata['doc_id'] = doc_id                document.metadata['doc_hash'] = hash                split_documents.append(document)            all_documents.extend(split_documents)        return all_documents    def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:        """        Clean the document text according to the processing rules.        """        if processing_rule.mode == "automatic":            rules = DatasetProcessRule.AUTOMATIC_RULES        else:            rules = json.loads(processing_rule.rules) if processing_rule.rules else {}        if 'pre_processing_rules' in rules:            pre_processing_rules = rules["pre_processing_rules"]            for pre_processing_rule in pre_processing_rules:                if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:                                        pattern = r'\n{3,}'                    text = re.sub(pattern, '\n\n', text)                    pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'                    text = re.sub(pattern, ' ', text)                elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:                                        pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'                    text = re.sub(pattern, '', text)                                        pattern = r'https?://[^\s]+'                    text = re.sub(pattern, '', text)        return text    def format_split_text(self, text):        regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"        matches = re.findall(regex, text, re.UNICODE)        return [            {                "question": q,                "answer": re.sub(r"\n\s*", "\n", a.strip())            }            for q, a in matches if q and a        ]    def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:        """        Build the index for the document.        """        vector_index = IndexBuilder.get_index(dataset, 'high_quality')        keyword_table_index = IndexBuilder.get_index(dataset, 'economy')        embedding_model_instance = None        if dataset.indexing_technique == 'high_quality':            embedding_model_instance = self.model_manager.get_model_instance(                tenant_id=dataset.tenant_id,                provider=dataset.embedding_model_provider,                model_type=ModelType.TEXT_EMBEDDING,                model=dataset.embedding_model            )                indexing_start_at = time.perf_counter()        tokens = 0        chunk_size = 100        embedding_model_type_instance = None        if embedding_model_instance:            embedding_model_type_instance = embedding_model_instance.model_type_instance            embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)        for i in range(0, len(documents), chunk_size):                        self._check_document_paused_status(dataset_document.id)            chunk_documents = documents[i:i + chunk_size]            if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:                tokens += sum(                    embedding_model_type_instance.get_num_tokens(                        embedding_model_instance.model,                        embedding_model_instance.credentials,                        [document.page_content]                    )                    for document in chunk_documents                )                        if vector_index:                vector_index.add_texts(chunk_documents)                        keyword_table_index.add_texts(chunk_documents)            document_ids = [document.metadata['doc_id'] for document in chunk_documents]            db.session.query(DocumentSegment).filter(                DocumentSegment.document_id == dataset_document.id,                DocumentSegment.index_node_id.in_(document_ids),                DocumentSegment.status == "indexing"            ).update({                DocumentSegment.status: "completed",                DocumentSegment.enabled: True,                DocumentSegment.completed_at: datetime.datetime.utcnow()            })            db.session.commit()        indexing_end_at = time.perf_counter()                self._update_document_index_status(            document_id=dataset_document.id,            after_indexing_status="completed",            extra_update_params={                DatasetDocument.tokens: tokens,                DatasetDocument.completed_at: datetime.datetime.utcnow(),                DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,            }        )    def _check_document_paused_status(self, document_id: str):        indexing_cache_key = 'document_{}_is_paused'.format(document_id)        result = redis_client.get(indexing_cache_key)        if result:            raise DocumentIsPausedException()    def _update_document_index_status(self, document_id: str, after_indexing_status: str,                                      extra_update_params: Optional[dict] = None) -> None:        """        Update the document indexing status.        """        count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()        if count > 0:            raise DocumentIsPausedException()        document = DatasetDocument.query.filter_by(id=document_id).first()        if not document:            raise DocumentIsDeletedPausedException()        update_params = {            DatasetDocument.indexing_status: after_indexing_status        }        if extra_update_params:            update_params.update(extra_update_params)        DatasetDocument.query.filter_by(id=document_id).update(update_params)        db.session.commit()    def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:        """        Update the document segment by document id.        """        DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)        db.session.commit()    def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset):        """        Batch add segments index processing        """        documents = []        for segment in segments:            document = Document(                page_content=segment.content,                metadata={                    "doc_id": segment.index_node_id,                    "doc_hash": segment.index_node_hash,                    "document_id": segment.document_id,                    "dataset_id": segment.dataset_id,                }            )            documents.append(document)                index = IndexBuilder.get_index(dataset, 'high_quality')        if index:            index.add_texts(documents, duplicate_check=True)                index = IndexBuilder.get_index(dataset, 'economy')        if index:            index.add_texts(documents)class DocumentIsPausedException(Exception):    passclass DocumentIsDeletedPausedException(Exception):    pass
 |