| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235 | import datetimeimport jsonimport loggingimport randomimport timeimport uuidfrom typing import List, Optional, castfrom core.errors.error import LLMBadRequestError, ProviderTokenNotInitErrorfrom core.index.index import IndexBuilderfrom core.model_manager import ModelManagerfrom core.model_runtime.entities.model_entities import ModelTypefrom core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModelfrom events.dataset_event import dataset_was_deletedfrom events.document_event import document_was_deletedfrom extensions.ext_database import dbfrom extensions.ext_redis import redis_clientfrom flask import current_appfrom flask_login import current_userfrom libs import helperfrom models.account import Accountfrom models.dataset import (AppDatasetJoin, Dataset, DatasetCollectionBinding, DatasetProcessRule, DatasetQuery,                            Document, DocumentSegment)from models.model import UploadFilefrom models.source import DataSourceBindingfrom services.errors.account import NoPermissionErrorfrom services.errors.dataset import DatasetNameDuplicateErrorfrom services.errors.document import DocumentIndexingErrorfrom services.errors.file import FileNotExistsErrorfrom services.vector_service import VectorServicefrom sqlalchemy import funcfrom tasks.clean_notion_document_task import clean_notion_document_taskfrom tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_taskfrom tasks.delete_segment_from_index_task import delete_segment_from_index_taskfrom tasks.document_indexing_task import document_indexing_taskfrom tasks.document_indexing_update_task import document_indexing_update_taskfrom tasks.recover_document_indexing_task import recover_document_indexing_taskclass DatasetService:    @staticmethod    def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None):        if user:            permission_filter = db.or_(Dataset.created_by == user.id,                                       Dataset.permission == 'all_team_members')        else:            permission_filter = Dataset.permission == 'all_team_members'        datasets = Dataset.query.filter(            db.and_(Dataset.provider == provider, Dataset.tenant_id == tenant_id, permission_filter)) \            .order_by(Dataset.created_at.desc()) \            .paginate(            page=page,            per_page=per_page,            max_per_page=100,            error_out=False        )        return datasets.items, datasets.total    @staticmethod    def get_process_rules(dataset_id):        # get the latest process rule        dataset_process_rule = db.session.query(DatasetProcessRule). \            filter(DatasetProcessRule.dataset_id == dataset_id). \            order_by(DatasetProcessRule.created_at.desc()). \            limit(1). \            one_or_none()        if dataset_process_rule:            mode = dataset_process_rule.mode            rules = dataset_process_rule.rules_dict        else:            mode = DocumentService.DEFAULT_RULES['mode']            rules = DocumentService.DEFAULT_RULES['rules']        return {            'mode': mode,            'rules': rules        }    @staticmethod    def get_datasets_by_ids(ids, tenant_id):        datasets = Dataset.query.filter(Dataset.id.in_(ids),                                        Dataset.tenant_id == tenant_id).paginate(            page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)        return datasets.items, datasets.total    @staticmethod    def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account):        # check if dataset name already exists        if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():            raise DatasetNameDuplicateError(                f'Dataset with name {name} already exists.')        embedding_model = None        if indexing_technique == 'high_quality':            model_manager = ModelManager()            embedding_model = model_manager.get_default_model_instance(                tenant_id=tenant_id,                model_type=ModelType.TEXT_EMBEDDING            )        dataset = Dataset(name=name, indexing_technique=indexing_technique)        # dataset = Dataset(name=name, provider=provider, config=config)        dataset.created_by = account.id        dataset.updated_by = account.id        dataset.tenant_id = tenant_id        dataset.embedding_model_provider = embedding_model.provider if embedding_model else None        dataset.embedding_model = embedding_model.model if embedding_model else None        db.session.add(dataset)        db.session.commit()        return dataset    @staticmethod    def get_dataset(dataset_id):        dataset = Dataset.query.filter_by(            id=dataset_id        ).first()        if dataset is None:            return None        else:            return dataset    @staticmethod    def check_dataset_model_setting(dataset):        if dataset.indexing_technique == 'high_quality':            try:                model_manager = ModelManager()                model_manager.get_model_instance(                    tenant_id=dataset.tenant_id,                    provider=dataset.embedding_model_provider,                    model_type=ModelType.TEXT_EMBEDDING,                    model=dataset.embedding_model                )            except LLMBadRequestError:                raise ValueError(                    f"No Embedding Model available. Please configure a valid provider "                    f"in the Settings -> Model Provider.")            except ProviderTokenNotInitError as ex:                raise ValueError(f"The dataset in unavailable, due to: "                                 f"{ex.description}")    @staticmethod    def update_dataset(dataset_id, data, user):        filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}        dataset = DatasetService.get_dataset(dataset_id)        DatasetService.check_dataset_permission(dataset, user)        action = None        if dataset.indexing_technique != data['indexing_technique']:            # if update indexing_technique            if data['indexing_technique'] == 'economy':                action = 'remove'                filtered_data['embedding_model'] = None                filtered_data['embedding_model_provider'] = None                filtered_data['collection_binding_id'] = None            elif data['indexing_technique'] == 'high_quality':                action = 'add'                # get embedding model setting                try:                    model_manager = ModelManager()                    embedding_model = model_manager.get_default_model_instance(                        tenant_id=current_user.current_tenant_id,                        model_type=ModelType.TEXT_EMBEDDING                    )                    filtered_data['embedding_model'] = embedding_model.model                    filtered_data['embedding_model_provider'] = embedding_model.provider                    dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(                        embedding_model.provider,                        embedding_model.model                    )                    filtered_data['collection_binding_id'] = dataset_collection_binding.id                except LLMBadRequestError:                    raise ValueError(                        f"No Embedding Model available. Please configure a valid provider "                        f"in the Settings -> Model Provider.")                except ProviderTokenNotInitError as ex:                    raise ValueError(ex.description)        filtered_data['updated_by'] = user.id        filtered_data['updated_at'] = datetime.datetime.now()        # update Retrieval model        filtered_data['retrieval_model'] = data['retrieval_model']        dataset.query.filter_by(id=dataset_id).update(filtered_data)        db.session.commit()        if action:            deal_dataset_vector_index_task.delay(dataset_id, action)        return dataset    @staticmethod    def delete_dataset(dataset_id, user):        # todo: cannot delete dataset if it is being processed        dataset = DatasetService.get_dataset(dataset_id)        if dataset is None:            return False        DatasetService.check_dataset_permission(dataset, user)        dataset_was_deleted.send(dataset)        db.session.delete(dataset)        db.session.commit()        return True    @staticmethod    def check_dataset_permission(dataset, user):        if dataset.tenant_id != user.current_tenant_id:            logging.debug(                f'User {user.id} does not have permission to access dataset {dataset.id}')            raise NoPermissionError(                'You do not have permission to access this dataset.')        if dataset.permission == 'only_me' and dataset.created_by != user.id:            logging.debug(                f'User {user.id} does not have permission to access dataset {dataset.id}')            raise NoPermissionError(                'You do not have permission to access this dataset.')    @staticmethod    def get_dataset_queries(dataset_id: str, page: int, per_page: int):        dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \            .order_by(db.desc(DatasetQuery.created_at)) \            .paginate(            page=page, per_page=per_page, max_per_page=100, error_out=False        )        return dataset_queries.items, dataset_queries.total    @staticmethod    def get_related_apps(dataset_id: str):        return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \            .order_by(db.desc(AppDatasetJoin.created_at)).all()class DocumentService:    DEFAULT_RULES = {        'mode': 'custom',        'rules': {            'pre_processing_rules': [                {'id': 'remove_extra_spaces', 'enabled': True},                {'id': 'remove_urls_emails', 'enabled': False}            ],            'segmentation': {                'delimiter': '\n',                'max_tokens': 500            }        }    }    DOCUMENT_METADATA_SCHEMA = {        "book": {            "title": str,            "language": str,            "author": str,            "publisher": str,            "publication_date": str,            "isbn": str,            "category": str,        },        "web_page": {            "title": str,            "url": str,            "language": str,            "publish_date": str,            "author/publisher": str,            "topic/keywords": str,            "description": str,        },        "paper": {            "title": str,            "language": str,            "author": str,            "publish_date": str,            "journal/conference_name": str,            "volume/issue/page_numbers": str,            "doi": str,            "topic/keywords": str,            "abstract": str,        },        "social_media_post": {            "platform": str,            "author/username": str,            "publish_date": str,            "post_url": str,            "topic/tags": str,        },        "wikipedia_entry": {            "title": str,            "language": str,            "web_page_url": str,            "last_edit_date": str,            "editor/contributor": str,            "summary/introduction": str,        },        "personal_document": {            "title": str,            "author": str,            "creation_date": str,            "last_modified_date": str,            "document_type": str,            "tags/category": str,        },        "business_document": {            "title": str,            "author": str,            "creation_date": str,            "last_modified_date": str,            "document_type": str,            "department/team": str,        },        "im_chat_log": {            "chat_platform": str,            "chat_participants/group_name": str,            "start_date": str,            "end_date": str,            "summary": str,        },        "synced_from_notion": {            "title": str,            "language": str,            "author/creator": str,            "creation_date": str,            "last_modified_date": str,            "notion_page_link": str,            "category/tags": str,            "description": str,        },        "synced_from_github": {            "repository_name": str,            "repository_description": str,            "repository_owner/organization": str,            "code_filename": str,            "code_file_path": str,            "programming_language": str,            "github_link": str,            "open_source_license": str,            "commit_date": str,            "commit_author": str,        },        "others": dict    }    @staticmethod    def get_document(dataset_id: str, document_id: str) -> Optional[Document]:        document = db.session.query(Document).filter(            Document.id == document_id,            Document.dataset_id == dataset_id        ).first()        return document    @staticmethod    def get_document_by_id(document_id: str) -> Optional[Document]:        document = db.session.query(Document).filter(            Document.id == document_id        ).first()        return document    @staticmethod    def get_document_by_dataset_id(dataset_id: str) -> List[Document]:        documents = db.session.query(Document).filter(            Document.dataset_id == dataset_id,            Document.enabled == True        ).all()        return documents    @staticmethod    def get_batch_documents(dataset_id: str, batch: str) -> List[Document]:        documents = db.session.query(Document).filter(            Document.batch == batch,            Document.dataset_id == dataset_id,            Document.tenant_id == current_user.current_tenant_id        ).all()        return documents    @staticmethod    def get_document_file_detail(file_id: str):        file_detail = db.session.query(UploadFile). \            filter(UploadFile.id == file_id). \            one_or_none()        return file_detail    @staticmethod    def check_archived(document):        if document.archived:            return True        else:            return False    @staticmethod    def delete_document(document):        # trigger document_was_deleted signal        document_was_deleted.send(document.id, dataset_id=document.dataset_id)        db.session.delete(document)        db.session.commit()    @staticmethod    def pause_document(document):        if document.indexing_status not in ["waiting", "parsing", "cleaning", "splitting", "indexing"]:            raise DocumentIndexingError()        # update document to be paused        document.is_paused = True        document.paused_by = current_user.id        document.paused_at = datetime.datetime.utcnow()        db.session.add(document)        db.session.commit()        # set document paused flag        indexing_cache_key = 'document_{}_is_paused'.format(document.id)        redis_client.setnx(indexing_cache_key, "True")    @staticmethod    def recover_document(document):        if not document.is_paused:            raise DocumentIndexingError()        # update document to be recover        document.is_paused = False        document.paused_by = None        document.paused_at = None        db.session.add(document)        db.session.commit()        # delete paused flag        indexing_cache_key = 'document_{}_is_paused'.format(document.id)        redis_client.delete(indexing_cache_key)        # trigger async task        recover_document_indexing_task.delay(document.dataset_id, document.id)    @staticmethod    def get_documents_position(dataset_id):        document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()        if document:            return document.position + 1        else:            return 1    @staticmethod    def save_document_with_dataset_id(dataset: Dataset, document_data: dict,                                      account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,                                      created_from: str = 'web'):        # check document limit        if current_app.config['EDITION'] == 'CLOUD':            if 'original_document_id' not in document_data or not document_data['original_document_id']:                count = 0                if document_data["data_source"]["type"] == "upload_file":                    upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']                    count = len(upload_file_list)                elif document_data["data_source"]["type"] == "notion_import":                    notion_info_list = document_data["data_source"]['info_list']['notion_info_list']                    for notion_info in notion_info_list:                        count = count + len(notion_info['pages'])        # if dataset is empty, update dataset data_source_type        if not dataset.data_source_type:            dataset.data_source_type = document_data["data_source"]["type"]        if not dataset.indexing_technique:            if 'indexing_technique' not in document_data \                    or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST:                raise ValueError("Indexing technique is required")            dataset.indexing_technique = document_data["indexing_technique"]            if document_data["indexing_technique"] == 'high_quality':                model_manager = ModelManager()                embedding_model = model_manager.get_default_model_instance(                    tenant_id=current_user.current_tenant_id,                    model_type=ModelType.TEXT_EMBEDDING                )                dataset.embedding_model = embedding_model.model                dataset.embedding_model_provider = embedding_model.provider                dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(                    embedding_model.provider,                    embedding_model.model                )                dataset.collection_binding_id = dataset_collection_binding.id                if not dataset.retrieval_model:                    default_retrieval_model = {                        'search_method': 'semantic_search',                        'reranking_enable': False,                        'reranking_model': {                            'reranking_provider_name': '',                            'reranking_model_name': ''                        },                        'top_k': 2,                        'score_threshold_enabled': False                    }                    dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get(                        'retrieval_model') else default_retrieval_model        documents = []        batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))        if 'original_document_id' in document_data and document_data["original_document_id"]:            document = DocumentService.update_document_with_dataset_id(dataset, document_data, account)            documents.append(document)        else:            # save process rule            if not dataset_process_rule:                process_rule = document_data["process_rule"]                if process_rule["mode"] == "custom":                    dataset_process_rule = DatasetProcessRule(                        dataset_id=dataset.id,                        mode=process_rule["mode"],                        rules=json.dumps(process_rule["rules"]),                        created_by=account.id                    )                elif process_rule["mode"] == "automatic":                    dataset_process_rule = DatasetProcessRule(                        dataset_id=dataset.id,                        mode=process_rule["mode"],                        rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),                        created_by=account.id                    )                db.session.add(dataset_process_rule)                db.session.commit()            position = DocumentService.get_documents_position(dataset.id)            document_ids = []            if document_data["data_source"]["type"] == "upload_file":                upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']                for file_id in upload_file_list:                    file = db.session.query(UploadFile).filter(                        UploadFile.tenant_id == dataset.tenant_id,                        UploadFile.id == file_id                    ).first()                    # raise error if file not found                    if not file:                        raise FileNotExistsError()                    file_name = file.name                    data_source_info = {                        "upload_file_id": file_id,                    }                    document = DocumentService.build_document(dataset, dataset_process_rule.id,                                                              document_data["data_source"]["type"],                                                              document_data["doc_form"],                                                              document_data["doc_language"],                                                              data_source_info, created_from, position,                                                              account, file_name, batch)                    db.session.add(document)                    db.session.flush()                    document_ids.append(document.id)                    documents.append(document)                    position += 1            elif document_data["data_source"]["type"] == "notion_import":                notion_info_list = document_data["data_source"]['info_list']['notion_info_list']                exist_page_ids = []                exist_document = dict()                documents = Document.query.filter_by(                    dataset_id=dataset.id,                    tenant_id=current_user.current_tenant_id,                    data_source_type='notion_import',                    enabled=True                ).all()                if documents:                    for document in documents:                        data_source_info = json.loads(document.data_source_info)                        exist_page_ids.append(data_source_info['notion_page_id'])                        exist_document[data_source_info['notion_page_id']] = document.id                for notion_info in notion_info_list:                    workspace_id = notion_info['workspace_id']                    data_source_binding = DataSourceBinding.query.filter(                        db.and_(                            DataSourceBinding.tenant_id == current_user.current_tenant_id,                            DataSourceBinding.provider == 'notion',                            DataSourceBinding.disabled == False,                            DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'                        )                    ).first()                    if not data_source_binding:                        raise ValueError('Data source binding not found.')                    for page in notion_info['pages']:                        if page['page_id'] not in exist_page_ids:                            data_source_info = {                                "notion_workspace_id": workspace_id,                                "notion_page_id": page['page_id'],                                "notion_page_icon": page['page_icon'],                                "type": page['type']                            }                            document = DocumentService.build_document(dataset, dataset_process_rule.id,                                                                      document_data["data_source"]["type"],                                                                      document_data["doc_form"],                                                                      document_data["doc_language"],                                                                      data_source_info, created_from, position,                                                                      account, page['page_name'], batch)                            db.session.add(document)                            db.session.flush()                            document_ids.append(document.id)                            documents.append(document)                            position += 1                        else:                            exist_document.pop(page['page_id'])                # delete not selected documents                if len(exist_document) > 0:                    clean_notion_document_task.delay(list(exist_document.values()), dataset.id)            db.session.commit()            # trigger async task            document_indexing_task.delay(dataset.id, document_ids)        return documents, batch    @staticmethod    def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,                       document_language: str, data_source_info: dict, created_from: str, position: int,                       account: Account,                       name: str, batch: str):        document = Document(            tenant_id=dataset.tenant_id,            dataset_id=dataset.id,            position=position,            data_source_type=data_source_type,            data_source_info=json.dumps(data_source_info),            dataset_process_rule_id=process_rule_id,            batch=batch,            name=name,            created_from=created_from,            created_by=account.id,            doc_form=document_form,            doc_language=document_language        )        return document    @staticmethod    def get_tenant_documents_count():        documents_count = Document.query.filter(Document.completed_at.isnot(None),                                                Document.enabled == True,                                                Document.archived == False,                                                Document.tenant_id == current_user.current_tenant_id).count()        return documents_count    @staticmethod    def update_document_with_dataset_id(dataset: Dataset, document_data: dict,                                        account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,                                        created_from: str = 'web'):        DatasetService.check_dataset_model_setting(dataset)        document = DocumentService.get_document(dataset.id, document_data["original_document_id"])        if document.display_status != 'available':            raise ValueError("Document is not available")        # update document name        if 'name' in document_data and document_data['name']:            document.name = document_data['name']        # save process rule        if 'process_rule' in document_data and document_data['process_rule']:            process_rule = document_data["process_rule"]            if process_rule["mode"] == "custom":                dataset_process_rule = DatasetProcessRule(                    dataset_id=dataset.id,                    mode=process_rule["mode"],                    rules=json.dumps(process_rule["rules"]),                    created_by=account.id                )            elif process_rule["mode"] == "automatic":                dataset_process_rule = DatasetProcessRule(                    dataset_id=dataset.id,                    mode=process_rule["mode"],                    rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),                    created_by=account.id                )            db.session.add(dataset_process_rule)            db.session.commit()            document.dataset_process_rule_id = dataset_process_rule.id        # update document data source        if 'data_source' in document_data and document_data['data_source']:            file_name = ''            data_source_info = {}            if document_data["data_source"]["type"] == "upload_file":                upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']                for file_id in upload_file_list:                    file = db.session.query(UploadFile).filter(                        UploadFile.tenant_id == dataset.tenant_id,                        UploadFile.id == file_id                    ).first()                    # raise error if file not found                    if not file:                        raise FileNotExistsError()                    file_name = file.name                    data_source_info = {                        "upload_file_id": file_id,                    }            elif document_data["data_source"]["type"] == "notion_import":                notion_info_list = document_data["data_source"]['info_list']['notion_info_list']                for notion_info in notion_info_list:                    workspace_id = notion_info['workspace_id']                    data_source_binding = DataSourceBinding.query.filter(                        db.and_(                            DataSourceBinding.tenant_id == current_user.current_tenant_id,                            DataSourceBinding.provider == 'notion',                            DataSourceBinding.disabled == False,                            DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'                        )                    ).first()                    if not data_source_binding:                        raise ValueError('Data source binding not found.')                    for page in notion_info['pages']:                        data_source_info = {                            "notion_workspace_id": workspace_id,                            "notion_page_id": page['page_id'],                            "notion_page_icon": page['page_icon'],                            "type": page['type']                        }            document.data_source_type = document_data["data_source"]["type"]            document.data_source_info = json.dumps(data_source_info)            document.name = file_name        # update document to be waiting        document.indexing_status = 'waiting'        document.completed_at = None        document.processing_started_at = None        document.parsing_completed_at = None        document.cleaning_completed_at = None        document.splitting_completed_at = None        document.updated_at = datetime.datetime.utcnow()        document.created_from = created_from        document.doc_form = document_data['doc_form']        db.session.add(document)        db.session.commit()        # update document segment        update_params = {            DocumentSegment.status: 're_segment'        }        DocumentSegment.query.filter_by(document_id=document.id).update(update_params)        db.session.commit()        # trigger async task        document_indexing_update_task.delay(document.dataset_id, document.id)        return document    @staticmethod    def save_document_without_dataset_id(tenant_id: str, document_data: dict, account: Account):        count = 0        if document_data["data_source"]["type"] == "upload_file":            upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']            count = len(upload_file_list)        elif document_data["data_source"]["type"] == "notion_import":            notion_info_list = document_data["data_source"]['info_list']['notion_info_list']            for notion_info in notion_info_list:                count = count + len(notion_info['pages'])        embedding_model = None        dataset_collection_binding_id = None        retrieval_model = None        if document_data['indexing_technique'] == 'high_quality':            model_manager = ModelManager()            embedding_model = model_manager.get_default_model_instance(                tenant_id=current_user.current_tenant_id,                model_type=ModelType.TEXT_EMBEDDING            )            dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(                embedding_model.provider,                embedding_model.model            )            dataset_collection_binding_id = dataset_collection_binding.id            if 'retrieval_model' in document_data and document_data['retrieval_model']:                retrieval_model = document_data['retrieval_model']            else:                default_retrieval_model = {                    'search_method': 'semantic_search',                    'reranking_enable': False,                    'reranking_model': {                        'reranking_provider_name': '',                        'reranking_model_name': ''                    },                    'top_k': 2,                    'score_threshold_enabled': False                }                retrieval_model = default_retrieval_model        # save dataset        dataset = Dataset(            tenant_id=tenant_id,            name='',            data_source_type=document_data["data_source"]["type"],            indexing_technique=document_data["indexing_technique"],            created_by=account.id,            embedding_model=embedding_model.model if embedding_model else None,            embedding_model_provider=embedding_model.provider if embedding_model else None,            collection_binding_id=dataset_collection_binding_id,            retrieval_model=retrieval_model        )        db.session.add(dataset)        db.session.flush()        documents, batch = DocumentService.save_document_with_dataset_id(dataset, document_data, account)        cut_length = 18        cut_name = documents[0].name[:cut_length]        dataset.name = cut_name + '...'        dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name        db.session.commit()        return dataset, documents, batch    @classmethod    def document_create_args_validate(cls, args: dict):        if 'original_document_id' not in args or not args['original_document_id']:            DocumentService.data_source_args_validate(args)            DocumentService.process_rule_args_validate(args)        else:            if ('data_source' not in args and not args['data_source']) \                    and ('process_rule' not in args and not args['process_rule']):                raise ValueError("Data source or Process rule is required")            else:                if 'data_source' in args and args['data_source']:                    DocumentService.data_source_args_validate(args)                if 'process_rule' in args and args['process_rule']:                    DocumentService.process_rule_args_validate(args)    @classmethod    def data_source_args_validate(cls, args: dict):        if 'data_source' not in args or not args['data_source']:            raise ValueError("Data source is required")        if not isinstance(args['data_source'], dict):            raise ValueError("Data source is invalid")        if 'type' not in args['data_source'] or not args['data_source']['type']:            raise ValueError("Data source type is required")        if args['data_source']['type'] not in Document.DATA_SOURCES:            raise ValueError("Data source type is invalid")        if 'info_list' not in args['data_source'] or not args['data_source']['info_list']:            raise ValueError("Data source info is required")        if args['data_source']['type'] == 'upload_file':            if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][                'file_info_list']:                raise ValueError("File source info is required")        if args['data_source']['type'] == 'notion_import':            if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][                'notion_info_list']:                raise ValueError("Notion source info is required")    @classmethod    def process_rule_args_validate(cls, args: dict):        if 'process_rule' not in args or not args['process_rule']:            raise ValueError("Process rule is required")        if not isinstance(args['process_rule'], dict):            raise ValueError("Process rule is invalid")        if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:            raise ValueError("Process rule mode is required")        if args['process_rule']['mode'] not in DatasetProcessRule.MODES:            raise ValueError("Process rule mode is invalid")        if args['process_rule']['mode'] == 'automatic':            args['process_rule']['rules'] = {}        else:            if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:                raise ValueError("Process rule rules is required")            if not isinstance(args['process_rule']['rules'], dict):                raise ValueError("Process rule rules is invalid")            if 'pre_processing_rules' not in args['process_rule']['rules'] \                    or args['process_rule']['rules']['pre_processing_rules'] is None:                raise ValueError("Process rule pre_processing_rules is required")            if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):                raise ValueError("Process rule pre_processing_rules is invalid")            unique_pre_processing_rule_dicts = {}            for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:                if 'id' not in pre_processing_rule or not pre_processing_rule['id']:                    raise ValueError("Process rule pre_processing_rules id is required")                if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:                    raise ValueError("Process rule pre_processing_rules id is invalid")                if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:                    raise ValueError("Process rule pre_processing_rules enabled is required")                if not isinstance(pre_processing_rule['enabled'], bool):                    raise ValueError("Process rule pre_processing_rules enabled is invalid")                unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule            args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())            if 'segmentation' not in args['process_rule']['rules'] \                    or args['process_rule']['rules']['segmentation'] is None:                raise ValueError("Process rule segmentation is required")            if not isinstance(args['process_rule']['rules']['segmentation'], dict):                raise ValueError("Process rule segmentation is invalid")            if 'separator' not in args['process_rule']['rules']['segmentation'] \                    or not args['process_rule']['rules']['segmentation']['separator']:                raise ValueError("Process rule segmentation separator is required")            if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):                raise ValueError("Process rule segmentation separator is invalid")            if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \                    or not args['process_rule']['rules']['segmentation']['max_tokens']:                raise ValueError("Process rule segmentation max_tokens is required")            if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):                raise ValueError("Process rule segmentation max_tokens is invalid")    @classmethod    def estimate_args_validate(cls, args: dict):        if 'info_list' not in args or not args['info_list']:            raise ValueError("Data source info is required")        if not isinstance(args['info_list'], dict):            raise ValueError("Data info is invalid")        if 'process_rule' not in args or not args['process_rule']:            raise ValueError("Process rule is required")        if not isinstance(args['process_rule'], dict):            raise ValueError("Process rule is invalid")        if 'mode' not in args['process_rule'] or not args['process_rule']['mode']:            raise ValueError("Process rule mode is required")        if args['process_rule']['mode'] not in DatasetProcessRule.MODES:            raise ValueError("Process rule mode is invalid")        if args['process_rule']['mode'] == 'automatic':            args['process_rule']['rules'] = {}        else:            if 'rules' not in args['process_rule'] or not args['process_rule']['rules']:                raise ValueError("Process rule rules is required")            if not isinstance(args['process_rule']['rules'], dict):                raise ValueError("Process rule rules is invalid")            if 'pre_processing_rules' not in args['process_rule']['rules'] \                    or args['process_rule']['rules']['pre_processing_rules'] is None:                raise ValueError("Process rule pre_processing_rules is required")            if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list):                raise ValueError("Process rule pre_processing_rules is invalid")            unique_pre_processing_rule_dicts = {}            for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']:                if 'id' not in pre_processing_rule or not pre_processing_rule['id']:                    raise ValueError("Process rule pre_processing_rules id is required")                if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES:                    raise ValueError("Process rule pre_processing_rules id is invalid")                if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None:                    raise ValueError("Process rule pre_processing_rules enabled is required")                if not isinstance(pre_processing_rule['enabled'], bool):                    raise ValueError("Process rule pre_processing_rules enabled is invalid")                unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule            args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values())            if 'segmentation' not in args['process_rule']['rules'] \                    or args['process_rule']['rules']['segmentation'] is None:                raise ValueError("Process rule segmentation is required")            if not isinstance(args['process_rule']['rules']['segmentation'], dict):                raise ValueError("Process rule segmentation is invalid")            if 'separator' not in args['process_rule']['rules']['segmentation'] \                    or not args['process_rule']['rules']['segmentation']['separator']:                raise ValueError("Process rule segmentation separator is required")            if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str):                raise ValueError("Process rule segmentation separator is invalid")            if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \                    or not args['process_rule']['rules']['segmentation']['max_tokens']:                raise ValueError("Process rule segmentation max_tokens is required")            if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int):                raise ValueError("Process rule segmentation max_tokens is invalid")class SegmentService:    @classmethod    def segment_create_args_validate(cls, args: dict, document: Document):        if document.doc_form == 'qa_model':            if 'answer' not in args or not args['answer']:                raise ValueError("Answer is required")            if not args['answer'].strip():                raise ValueError("Answer is empty")        if 'content' not in args or not args['content'] or not args['content'].strip():            raise ValueError("Content is empty")    @classmethod    def create_segment(cls, args: dict, document: Document, dataset: Dataset):        content = args['content']        doc_id = str(uuid.uuid4())        segment_hash = helper.generate_text_hash(content)        tokens = 0        if dataset.indexing_technique == 'high_quality':            model_manager = ModelManager()            embedding_model = model_manager.get_model_instance(                tenant_id=current_user.current_tenant_id,                provider=dataset.embedding_model_provider,                model_type=ModelType.TEXT_EMBEDDING,                model=dataset.embedding_model            )            # calc embedding use tokens            model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)            tokens = model_type_instance.get_num_tokens(                model=embedding_model.model,                credentials=embedding_model.credentials,                texts=[content]            )        max_position = db.session.query(func.max(DocumentSegment.position)).filter(            DocumentSegment.document_id == document.id        ).scalar()        segment_document = DocumentSegment(            tenant_id=current_user.current_tenant_id,            dataset_id=document.dataset_id,            document_id=document.id,            index_node_id=doc_id,            index_node_hash=segment_hash,            position=max_position + 1 if max_position else 1,            content=content,            word_count=len(content),            tokens=tokens,            status='completed',            indexing_at=datetime.datetime.utcnow(),            completed_at=datetime.datetime.utcnow(),            created_by=current_user.id        )        if document.doc_form == 'qa_model':            segment_document.answer = args['answer']        db.session.add(segment_document)        db.session.commit()        # save vector index        try:            VectorService.create_segment_vector(args['keywords'], segment_document, dataset)        except Exception as e:            logging.exception("create segment index failed")            segment_document.enabled = False            segment_document.disabled_at = datetime.datetime.utcnow()            segment_document.status = 'error'            segment_document.error = str(e)            db.session.commit()        segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()        return segment    @classmethod    def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset):        embedding_model = None        if dataset.indexing_technique == 'high_quality':            model_manager = ModelManager()            embedding_model = model_manager.get_model_instance(                tenant_id=current_user.current_tenant_id,                provider=dataset.embedding_model_provider,                model_type=ModelType.TEXT_EMBEDDING,                model=dataset.embedding_model            )        max_position = db.session.query(func.max(DocumentSegment.position)).filter(            DocumentSegment.document_id == document.id        ).scalar()        pre_segment_data_list = []        segment_data_list = []        for segment_item in segments:            content = segment_item['content']            doc_id = str(uuid.uuid4())            segment_hash = helper.generate_text_hash(content)            tokens = 0            if dataset.indexing_technique == 'high_quality' and embedding_model:                # calc embedding use tokens                model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)                tokens = model_type_instance.get_num_tokens(                    model=embedding_model.model,                    credentials=embedding_model.credentials,                    texts=[content]                )            segment_document = DocumentSegment(                tenant_id=current_user.current_tenant_id,                dataset_id=document.dataset_id,                document_id=document.id,                index_node_id=doc_id,                index_node_hash=segment_hash,                position=max_position + 1 if max_position else 1,                content=content,                word_count=len(content),                tokens=tokens,                status='completed',                indexing_at=datetime.datetime.utcnow(),                completed_at=datetime.datetime.utcnow(),                created_by=current_user.id            )            if document.doc_form == 'qa_model':                segment_document.answer = segment_item['answer']            db.session.add(segment_document)            segment_data_list.append(segment_document)            pre_segment_data = {                'segment': segment_document,                'keywords': segment_item['keywords']            }            pre_segment_data_list.append(pre_segment_data)        try:            # save vector index            VectorService.multi_create_segment_vector(pre_segment_data_list, dataset)        except Exception as e:            logging.exception("create segment index failed")            for segment_document in segment_data_list:                segment_document.enabled = False                segment_document.disabled_at = datetime.datetime.utcnow()                segment_document.status = 'error'                segment_document.error = str(e)        db.session.commit()        return segment_data_list    @classmethod    def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset):        indexing_cache_key = 'segment_{}_indexing'.format(segment.id)        cache_result = redis_client.get(indexing_cache_key)        if cache_result is not None:            raise ValueError("Segment is indexing, please try again later")        try:            content = args['content']            if segment.content == content:                if document.doc_form == 'qa_model':                    segment.answer = args['answer']                if 'keywords' in args and args['keywords']:                    segment.keywords = args['keywords']                if'enabled' in args and args['enabled'] is not None:                    segment.enabled = args['enabled']                db.session.add(segment)                db.session.commit()                # update segment index task                if args['keywords']:                    kw_index = IndexBuilder.get_index(dataset, 'economy')                    # delete from keyword index                    kw_index.delete_by_ids([segment.index_node_id])                    # save keyword index                    kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)            else:                segment_hash = helper.generate_text_hash(content)                tokens = 0                if dataset.indexing_technique == 'high_quality':                    model_manager = ModelManager()                    embedding_model = model_manager.get_model_instance(                        tenant_id=current_user.current_tenant_id,                        provider=dataset.embedding_model_provider,                        model_type=ModelType.TEXT_EMBEDDING,                        model=dataset.embedding_model                    )                    # calc embedding use tokens                    model_type_instance = cast(TextEmbeddingModel, embedding_model.model_type_instance)                    tokens = model_type_instance.get_num_tokens(                        model=embedding_model.model,                        credentials=embedding_model.credentials,                        texts=[content]                    )                segment.content = content                segment.index_node_hash = segment_hash                segment.word_count = len(content)                segment.tokens = tokens                segment.status = 'completed'                segment.indexing_at = datetime.datetime.utcnow()                segment.completed_at = datetime.datetime.utcnow()                segment.updated_by = current_user.id                segment.updated_at = datetime.datetime.utcnow()                if document.doc_form == 'qa_model':                    segment.answer = args['answer']                db.session.add(segment)                db.session.commit()                # update segment vector index                VectorService.update_segment_vector(args['keywords'], segment, dataset)        except Exception as e:            logging.exception("update segment index failed")            segment.enabled = False            segment.disabled_at = datetime.datetime.utcnow()            segment.status = 'error'            segment.error = str(e)            db.session.commit()        segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()        return segment    @classmethod    def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset):        indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id)        cache_result = redis_client.get(indexing_cache_key)        if cache_result is not None:            raise ValueError("Segment is deleting.")        # enabled segment need to delete index        if segment.enabled:            # send delete segment index task            redis_client.setex(indexing_cache_key, 600, 1)            delete_segment_from_index_task.delay(segment.id, segment.index_node_id, dataset.id, document.id)        db.session.delete(segment)        db.session.commit()class DatasetCollectionBindingService:    @classmethod    def get_dataset_collection_binding(cls, provider_name: str, model_name: str,                                       collection_type: str = 'dataset') -> DatasetCollectionBinding:        dataset_collection_binding = db.session.query(DatasetCollectionBinding). \            filter(DatasetCollectionBinding.provider_name == provider_name,                   DatasetCollectionBinding.model_name == model_name,                   DatasetCollectionBinding.type == collection_type). \            order_by(DatasetCollectionBinding.created_at). \            first()        if not dataset_collection_binding:            dataset_collection_binding = DatasetCollectionBinding(                provider_name=provider_name,                model_name=model_name,                collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node',                type=collection_type            )            db.session.add(dataset_collection_binding)            db.session.commit()        return dataset_collection_binding    @classmethod    def get_dataset_collection_binding_by_id_and_type(cls, collection_binding_id: str,                                                      collection_type: str = 'dataset') -> DatasetCollectionBinding:        dataset_collection_binding = db.session.query(DatasetCollectionBinding). \            filter(DatasetCollectionBinding.id == collection_binding_id,                   DatasetCollectionBinding.type == collection_type). \            order_by(DatasetCollectionBinding.created_at). \            first()        return dataset_collection_binding
 |