123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845 |
- import base64
- import enum
- import hashlib
- import hmac
- import json
- import logging
- import os
- import pickle
- import re
- import time
- from json import JSONDecodeError
- from sqlalchemy import func
- from sqlalchemy.dialects.postgresql import JSONB
- from configs import dify_config
- from core.rag.retrieval.retrieval_methods import RetrievalMethod
- from extensions.ext_database import db
- from extensions.ext_storage import storage
- from .account import Account
- from .model import App, Tag, TagBinding, UploadFile
- from .types import StringUUID
- class DatasetPermissionEnum(str, enum.Enum):
- ONLY_ME = "only_me"
- ALL_TEAM = "all_team_members"
- PARTIAL_TEAM = "partial_members"
- class Dataset(db.Model):
- __tablename__ = "datasets"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_pkey"),
- db.Index("dataset_tenant_idx", "tenant_id"),
- db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
- )
- INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
- PROVIDER_LIST = ["vendor", "external", None]
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.Text, nullable=True)
- provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
- permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
- data_source_type = db.Column(db.String(255))
- indexing_technique = db.Column(db.String(255), nullable=True)
- index_struct = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- embedding_model = db.Column(db.String(255), nullable=True)
- embedding_model_provider = db.Column(db.String(255), nullable=True)
- collection_binding_id = db.Column(StringUUID, nullable=True)
- retrieval_model = db.Column(JSONB, nullable=True)
- @property
- def dataset_keyword_table(self):
- dataset_keyword_table = (
- db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
- )
- if dataset_keyword_table:
- return dataset_keyword_table
- return None
- @property
- def index_struct_dict(self):
- return json.loads(self.index_struct) if self.index_struct else None
- @property
- def external_retrieval_model(self):
- default_retrieval_model = {
- "top_k": 2,
- "score_threshold": 0.0,
- }
- return self.retrieval_model or default_retrieval_model
- @property
- def created_by_account(self):
- return db.session.get(Account, self.created_by)
- @property
- def latest_process_rule(self):
- return (
- DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
- .order_by(DatasetProcessRule.created_at.desc())
- .first()
- )
- @property
- def app_count(self):
- return (
- db.session.query(func.count(AppDatasetJoin.id))
- .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
- .scalar()
- )
- @property
- def document_count(self):
- return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
- @property
- def available_document_count(self):
- return (
- db.session.query(func.count(Document.id))
- .filter(
- Document.dataset_id == self.id,
- Document.indexing_status == "completed",
- Document.enabled == True,
- Document.archived == False,
- )
- .scalar()
- )
- @property
- def available_segment_count(self):
- return (
- db.session.query(func.count(DocumentSegment.id))
- .filter(
- DocumentSegment.dataset_id == self.id,
- DocumentSegment.status == "completed",
- DocumentSegment.enabled == True,
- )
- .scalar()
- )
- @property
- def word_count(self):
- return (
- Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
- .filter(Document.dataset_id == self.id)
- .scalar()
- )
- @property
- def doc_form(self):
- document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
- if document:
- return document.doc_form
- return None
- @property
- def retrieval_model_dict(self):
- default_retrieval_model = {
- "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
- "reranking_enable": False,
- "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
- "score_threshold_enabled": False,
- }
- return self.retrieval_model or default_retrieval_model
- @property
- def tags(self):
- tags = (
- db.session.query(Tag)
- .join(TagBinding, Tag.id == TagBinding.tag_id)
- .filter(
- TagBinding.target_id == self.id,
- TagBinding.tenant_id == self.tenant_id,
- Tag.tenant_id == self.tenant_id,
- Tag.type == "knowledge",
- )
- .all()
- )
- return tags or []
- @property
- def external_knowledge_info(self):
- if self.provider != "external":
- return None
- external_knowledge_binding = (
- db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
- )
- if not external_knowledge_binding:
- return None
- external_knowledge_api = (
- db.session.query(ExternalKnowledgeApis)
- .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
- .first()
- )
- if not external_knowledge_api:
- return None
- return {
- "external_knowledge_id": external_knowledge_binding.external_knowledge_id,
- "external_knowledge_api_id": external_knowledge_api.id,
- "external_knowledge_api_name": external_knowledge_api.name,
- "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
- }
- @staticmethod
- def gen_collection_name_by_id(dataset_id: str) -> str:
- normalized_dataset_id = dataset_id.replace("-", "_")
- return f"Vector_index_{normalized_dataset_id}_Node"
- class DatasetProcessRule(db.Model):
- __tablename__ = "dataset_process_rules"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
- db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
- )
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False)
- mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
- rules = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- MODES = ["automatic", "custom"]
- PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
- AUTOMATIC_RULES = {
- "pre_processing_rules": [
- {"id": "remove_extra_spaces", "enabled": True},
- {"id": "remove_urls_emails", "enabled": False},
- ],
- "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
- }
- def to_dict(self):
- return {
- "id": self.id,
- "dataset_id": self.dataset_id,
- "mode": self.mode,
- "rules": self.rules_dict,
- "created_by": self.created_by,
- "created_at": self.created_at,
- }
- @property
- def rules_dict(self):
- try:
- return json.loads(self.rules) if self.rules else None
- except JSONDecodeError:
- return None
- class Document(db.Model):
- __tablename__ = "documents"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="document_pkey"),
- db.Index("document_dataset_id_idx", "dataset_id"),
- db.Index("document_is_paused_idx", "is_paused"),
- db.Index("document_tenant_idx", "tenant_id"),
- )
- # initial fields
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False)
- data_source_type = db.Column(db.String(255), nullable=False)
- data_source_info = db.Column(db.Text, nullable=True)
- dataset_process_rule_id = db.Column(StringUUID, nullable=True)
- batch = db.Column(db.String(255), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- created_from = db.Column(db.String(255), nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_api_request_id = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- # start processing
- processing_started_at = db.Column(db.DateTime, nullable=True)
- # parsing
- file_id = db.Column(db.Text, nullable=True)
- word_count = db.Column(db.Integer, nullable=True)
- parsing_completed_at = db.Column(db.DateTime, nullable=True)
- # cleaning
- cleaning_completed_at = db.Column(db.DateTime, nullable=True)
- # split
- splitting_completed_at = db.Column(db.DateTime, nullable=True)
- # indexing
- tokens = db.Column(db.Integer, nullable=True)
- indexing_latency = db.Column(db.Float, nullable=True)
- completed_at = db.Column(db.DateTime, nullable=True)
- # pause
- is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
- paused_by = db.Column(StringUUID, nullable=True)
- paused_at = db.Column(db.DateTime, nullable=True)
- # error
- error = db.Column(db.Text, nullable=True)
- stopped_at = db.Column(db.DateTime, nullable=True)
- # basic fields
- indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- disabled_at = db.Column(db.DateTime, nullable=True)
- disabled_by = db.Column(StringUUID, nullable=True)
- archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- archived_reason = db.Column(db.String(255), nullable=True)
- archived_by = db.Column(StringUUID, nullable=True)
- archived_at = db.Column(db.DateTime, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- doc_type = db.Column(db.String(40), nullable=True)
- doc_metadata = db.Column(db.JSON, nullable=True)
- doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
- doc_language = db.Column(db.String(255), nullable=True)
- DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
- @property
- def display_status(self):
- status = None
- if self.indexing_status == "waiting":
- status = "queuing"
- elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
- status = "paused"
- elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
- status = "indexing"
- elif self.indexing_status == "error":
- status = "error"
- elif self.indexing_status == "completed" and not self.archived and self.enabled:
- status = "available"
- elif self.indexing_status == "completed" and not self.archived and not self.enabled:
- status = "disabled"
- elif self.indexing_status == "completed" and self.archived:
- status = "archived"
- return status
- @property
- def data_source_info_dict(self):
- if self.data_source_info:
- try:
- data_source_info_dict = json.loads(self.data_source_info)
- except JSONDecodeError:
- data_source_info_dict = {}
- return data_source_info_dict
- return None
- @property
- def data_source_detail_dict(self):
- if self.data_source_info:
- if self.data_source_type == "upload_file":
- data_source_info_dict = json.loads(self.data_source_info)
- file_detail = (
- db.session.query(UploadFile)
- .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
- .one_or_none()
- )
- if file_detail:
- return {
- "upload_file": {
- "id": file_detail.id,
- "name": file_detail.name,
- "size": file_detail.size,
- "extension": file_detail.extension,
- "mime_type": file_detail.mime_type,
- "created_by": file_detail.created_by,
- "created_at": file_detail.created_at.timestamp(),
- }
- }
- elif self.data_source_type in {"notion_import", "website_crawl"}:
- return json.loads(self.data_source_info)
- return {}
- @property
- def average_segment_length(self):
- if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
- return self.word_count // self.segment_count
- return 0
- @property
- def dataset_process_rule(self):
- if self.dataset_process_rule_id:
- return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
- return None
- @property
- def dataset(self):
- return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
- @property
- def segment_count(self):
- return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
- @property
- def hit_count(self):
- return (
- DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
- .filter(DocumentSegment.document_id == self.id)
- .scalar()
- )
- def to_dict(self):
- return {
- "id": self.id,
- "tenant_id": self.tenant_id,
- "dataset_id": self.dataset_id,
- "position": self.position,
- "data_source_type": self.data_source_type,
- "data_source_info": self.data_source_info,
- "dataset_process_rule_id": self.dataset_process_rule_id,
- "batch": self.batch,
- "name": self.name,
- "created_from": self.created_from,
- "created_by": self.created_by,
- "created_api_request_id": self.created_api_request_id,
- "created_at": self.created_at,
- "processing_started_at": self.processing_started_at,
- "file_id": self.file_id,
- "word_count": self.word_count,
- "parsing_completed_at": self.parsing_completed_at,
- "cleaning_completed_at": self.cleaning_completed_at,
- "splitting_completed_at": self.splitting_completed_at,
- "tokens": self.tokens,
- "indexing_latency": self.indexing_latency,
- "completed_at": self.completed_at,
- "is_paused": self.is_paused,
- "paused_by": self.paused_by,
- "paused_at": self.paused_at,
- "error": self.error,
- "stopped_at": self.stopped_at,
- "indexing_status": self.indexing_status,
- "enabled": self.enabled,
- "disabled_at": self.disabled_at,
- "disabled_by": self.disabled_by,
- "archived": self.archived,
- "archived_reason": self.archived_reason,
- "archived_by": self.archived_by,
- "archived_at": self.archived_at,
- "updated_at": self.updated_at,
- "doc_type": self.doc_type,
- "doc_metadata": self.doc_metadata,
- "doc_form": self.doc_form,
- "doc_language": self.doc_language,
- "display_status": self.display_status,
- "data_source_info_dict": self.data_source_info_dict,
- "average_segment_length": self.average_segment_length,
- "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
- "dataset": self.dataset.to_dict() if self.dataset else None,
- "segment_count": self.segment_count,
- "hit_count": self.hit_count,
- }
- @classmethod
- def from_dict(cls, data: dict):
- return cls(
- id=data.get("id"),
- tenant_id=data.get("tenant_id"),
- dataset_id=data.get("dataset_id"),
- position=data.get("position"),
- data_source_type=data.get("data_source_type"),
- data_source_info=data.get("data_source_info"),
- dataset_process_rule_id=data.get("dataset_process_rule_id"),
- batch=data.get("batch"),
- name=data.get("name"),
- created_from=data.get("created_from"),
- created_by=data.get("created_by"),
- created_api_request_id=data.get("created_api_request_id"),
- created_at=data.get("created_at"),
- processing_started_at=data.get("processing_started_at"),
- file_id=data.get("file_id"),
- word_count=data.get("word_count"),
- parsing_completed_at=data.get("parsing_completed_at"),
- cleaning_completed_at=data.get("cleaning_completed_at"),
- splitting_completed_at=data.get("splitting_completed_at"),
- tokens=data.get("tokens"),
- indexing_latency=data.get("indexing_latency"),
- completed_at=data.get("completed_at"),
- is_paused=data.get("is_paused"),
- paused_by=data.get("paused_by"),
- paused_at=data.get("paused_at"),
- error=data.get("error"),
- stopped_at=data.get("stopped_at"),
- indexing_status=data.get("indexing_status"),
- enabled=data.get("enabled"),
- disabled_at=data.get("disabled_at"),
- disabled_by=data.get("disabled_by"),
- archived=data.get("archived"),
- archived_reason=data.get("archived_reason"),
- archived_by=data.get("archived_by"),
- archived_at=data.get("archived_at"),
- updated_at=data.get("updated_at"),
- doc_type=data.get("doc_type"),
- doc_metadata=data.get("doc_metadata"),
- doc_form=data.get("doc_form"),
- doc_language=data.get("doc_language"),
- )
- class DocumentSegment(db.Model):
- __tablename__ = "document_segments"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
- db.Index("document_segment_dataset_id_idx", "dataset_id"),
- db.Index("document_segment_document_id_idx", "document_id"),
- db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
- db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
- db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),
- db.Index("document_segment_tenant_idx", "tenant_id"),
- )
- # initial fields
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- document_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False)
- content = db.Column(db.Text, nullable=False)
- answer = db.Column(db.Text, nullable=True)
- word_count = db.Column(db.Integer, nullable=False)
- tokens = db.Column(db.Integer, nullable=False)
- # indexing fields
- keywords = db.Column(db.JSON, nullable=True)
- index_node_id = db.Column(db.String(255), nullable=True)
- index_node_hash = db.Column(db.String(255), nullable=True)
- # basic fields
- hit_count = db.Column(db.Integer, nullable=False, default=0)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- disabled_at = db.Column(db.DateTime, nullable=True)
- disabled_by = db.Column(StringUUID, nullable=True)
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- indexing_at = db.Column(db.DateTime, nullable=True)
- completed_at = db.Column(db.DateTime, nullable=True)
- error = db.Column(db.Text, nullable=True)
- stopped_at = db.Column(db.DateTime, nullable=True)
- @property
- def dataset(self):
- return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
- @property
- def document(self):
- return db.session.query(Document).filter(Document.id == self.document_id).first()
- @property
- def previous_segment(self):
- return (
- db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
- .first()
- )
- @property
- def next_segment(self):
- return (
- db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
- .first()
- )
- def get_sign_content(self):
- signed_urls = []
- text = self.content
- # For data before v0.10.0
- pattern = r"/files/([a-f0-9\-]+)/image-preview"
- matches = re.finditer(pattern, text)
- for match in matches:
- upload_file_id = match.group(1)
- nonce = os.urandom(16).hex()
- timestamp = str(int(time.time()))
- data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
- secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
- sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
- encoded_sign = base64.urlsafe_b64encode(sign).decode()
- params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
- signed_url = f"{match.group(0)}?{params}"
- signed_urls.append((match.start(), match.end(), signed_url))
- # For data after v0.10.0
- pattern = r"/files/([a-f0-9\-]+)/file-preview"
- matches = re.finditer(pattern, text)
- for match in matches:
- upload_file_id = match.group(1)
- nonce = os.urandom(16).hex()
- timestamp = str(int(time.time()))
- data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
- secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
- sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
- encoded_sign = base64.urlsafe_b64encode(sign).decode()
- params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
- signed_url = f"{match.group(0)}?{params}"
- signed_urls.append((match.start(), match.end(), signed_url))
- # Reconstruct the text with signed URLs
- offset = 0
- for start, end, signed_url in signed_urls:
- text = text[: start + offset] + signed_url + text[end + offset :]
- offset += len(signed_url) - (end - start)
- return text
- class AppDatasetJoin(db.Model):
- __tablename__ = "app_dataset_joins"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
- db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
- )
- id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
- @property
- def app(self):
- return db.session.get(App, self.app_id)
- class DatasetQuery(db.Model):
- __tablename__ = "dataset_queries"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
- db.Index("dataset_query_dataset_id_idx", "dataset_id"),
- )
- id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False)
- content = db.Column(db.Text, nullable=False)
- source = db.Column(db.String(255), nullable=False)
- source_app_id = db.Column(StringUUID, nullable=True)
- created_by_role = db.Column(db.String, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
- class DatasetKeywordTable(db.Model):
- __tablename__ = "dataset_keyword_tables"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
- db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
- )
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False, unique=True)
- keyword_table = db.Column(db.Text, nullable=False)
- data_source_type = db.Column(
- db.String(255), nullable=False, server_default=db.text("'database'::character varying")
- )
- @property
- def keyword_table_dict(self):
- class SetDecoder(json.JSONDecoder):
- def __init__(self, *args, **kwargs):
- super().__init__(object_hook=self.object_hook, *args, **kwargs)
- def object_hook(self, dct):
- if isinstance(dct, dict):
- for keyword, node_idxs in dct.items():
- if isinstance(node_idxs, list):
- dct[keyword] = set(node_idxs)
- return dct
- # get dataset
- dataset = Dataset.query.filter_by(id=self.dataset_id).first()
- if not dataset:
- return None
- if self.data_source_type == "database":
- return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
- else:
- file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"
- try:
- keyword_table_text = storage.load_once(file_key)
- if keyword_table_text:
- return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
- return None
- except Exception as e:
- logging.exception(str(e))
- return None
- class Embedding(db.Model):
- __tablename__ = "embeddings"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="embedding_pkey"),
- db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
- db.Index("created_at_idx", "created_at"),
- )
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- model_name = db.Column(
- db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
- )
- hash = db.Column(db.String(64), nullable=False)
- embedding = db.Column(db.LargeBinary, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
- def set_embedding(self, embedding_data: list[float]):
- self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
- def get_embedding(self) -> list[float]:
- return pickle.loads(self.embedding)
- class DatasetCollectionBinding(db.Model):
- __tablename__ = "dataset_collection_bindings"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
- db.Index("provider_model_name_idx", "provider_name", "model_name"),
- )
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- provider_name = db.Column(db.String(40), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
- collection_name = db.Column(db.String(64), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- class TidbAuthBinding(db.Model):
- __tablename__ = "tidb_auth_bindings"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
- db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
- db.Index("tidb_auth_bindings_active_idx", "active"),
- db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
- db.Index("tidb_auth_bindings_status_idx", "status"),
- )
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- cluster_id = db.Column(db.String(255), nullable=False)
- cluster_name = db.Column(db.String(255), nullable=False)
- active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
- account = db.Column(db.String(255), nullable=False)
- password = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- class Whitelist(db.Model):
- __tablename__ = "whitelists"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
- db.Index("whitelists_tenant_idx", "tenant_id"),
- )
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- category = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- class DatasetPermission(db.Model):
- __tablename__ = "dataset_permissions"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
- db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
- db.Index("idx_dataset_permissions_account_id", "account_id"),
- db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
- )
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
- dataset_id = db.Column(StringUUID, nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- tenant_id = db.Column(StringUUID, nullable=False)
- has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- class ExternalKnowledgeApis(db.Model):
- __tablename__ = "external_knowledge_apis"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
- db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
- db.Index("external_knowledge_apis_name_idx", "name"),
- )
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.String(255), nullable=False)
- tenant_id = db.Column(StringUUID, nullable=False)
- settings = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- def to_dict(self):
- return {
- "id": self.id,
- "tenant_id": self.tenant_id,
- "name": self.name,
- "description": self.description,
- "settings": self.settings_dict,
- "dataset_bindings": self.dataset_bindings,
- "created_by": self.created_by,
- "created_at": self.created_at.isoformat(),
- }
- @property
- def settings_dict(self):
- try:
- return json.loads(self.settings) if self.settings else None
- except JSONDecodeError:
- return None
- @property
- def dataset_bindings(self):
- external_knowledge_bindings = (
- db.session.query(ExternalKnowledgeBindings)
- .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
- .all()
- )
- dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
- datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
- dataset_bindings = []
- for dataset in datasets:
- dataset_bindings.append({"id": dataset.id, "name": dataset.name})
- return dataset_bindings
- class ExternalKnowledgeBindings(db.Model):
- __tablename__ = "external_knowledge_bindings"
- __table_args__ = (
- db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
- db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
- db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
- db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
- db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
- )
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- external_knowledge_api_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- external_knowledge_id = db.Column(db.Text, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|