dataset.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import logging
  6. import os
  7. import pickle
  8. import re
  9. import time
  10. from json import JSONDecodeError
  11. from sqlalchemy import func
  12. from sqlalchemy.dialects.postgresql import JSONB
  13. from configs import dify_config
  14. from core.rag.retrieval.retrival_methods import RetrievalMethod
  15. from extensions.ext_database import db
  16. from extensions.ext_storage import storage
  17. from models import StringUUID
  18. from models.account import Account
  19. from models.model import App, Tag, TagBinding, UploadFile
  20. class Dataset(db.Model):
  21. __tablename__ = 'datasets'
  22. __table_args__ = (
  23. db.PrimaryKeyConstraint('id', name='dataset_pkey'),
  24. db.Index('dataset_tenant_idx', 'tenant_id'),
  25. db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
  26. )
  27. INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
  28. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
  29. tenant_id = db.Column(StringUUID, nullable=False)
  30. name = db.Column(db.String(255), nullable=False)
  31. description = db.Column(db.Text, nullable=True)
  32. provider = db.Column(db.String(255), nullable=False,
  33. server_default=db.text("'vendor'::character varying"))
  34. permission = db.Column(db.String(255), nullable=False,
  35. server_default=db.text("'only_me'::character varying"))
  36. data_source_type = db.Column(db.String(255))
  37. indexing_technique = db.Column(db.String(255), nullable=True)
  38. index_struct = db.Column(db.Text, nullable=True)
  39. created_by = db.Column(StringUUID, nullable=False)
  40. created_at = db.Column(db.DateTime, nullable=False,
  41. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  42. updated_by = db.Column(StringUUID, nullable=True)
  43. updated_at = db.Column(db.DateTime, nullable=False,
  44. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  45. embedding_model = db.Column(db.String(255), nullable=True)
  46. embedding_model_provider = db.Column(db.String(255), nullable=True)
  47. collection_binding_id = db.Column(StringUUID, nullable=True)
  48. retrieval_model = db.Column(JSONB, nullable=True)
  49. @property
  50. def dataset_keyword_table(self):
  51. dataset_keyword_table = db.session.query(DatasetKeywordTable).filter(
  52. DatasetKeywordTable.dataset_id == self.id).first()
  53. if dataset_keyword_table:
  54. return dataset_keyword_table
  55. return None
  56. @property
  57. def index_struct_dict(self):
  58. return json.loads(self.index_struct) if self.index_struct else None
  59. @property
  60. def created_by_account(self):
  61. return db.session.get(Account, self.created_by)
  62. @property
  63. def latest_process_rule(self):
  64. return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \
  65. .order_by(DatasetProcessRule.created_at.desc()).first()
  66. @property
  67. def app_count(self):
  68. return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id,
  69. App.id == AppDatasetJoin.app_id).scalar()
  70. @property
  71. def document_count(self):
  72. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  73. @property
  74. def available_document_count(self):
  75. return db.session.query(func.count(Document.id)).filter(
  76. Document.dataset_id == self.id,
  77. Document.indexing_status == 'completed',
  78. Document.enabled == True,
  79. Document.archived == False
  80. ).scalar()
  81. @property
  82. def available_segment_count(self):
  83. return db.session.query(func.count(DocumentSegment.id)).filter(
  84. DocumentSegment.dataset_id == self.id,
  85. DocumentSegment.status == 'completed',
  86. DocumentSegment.enabled == True
  87. ).scalar()
  88. @property
  89. def word_count(self):
  90. return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
  91. .filter(Document.dataset_id == self.id).scalar()
  92. @property
  93. def doc_form(self):
  94. document = db.session.query(Document).filter(
  95. Document.dataset_id == self.id).first()
  96. if document:
  97. return document.doc_form
  98. return None
  99. @property
  100. def retrieval_model_dict(self):
  101. default_retrieval_model = {
  102. 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
  103. 'reranking_enable': False,
  104. 'reranking_model': {
  105. 'reranking_provider_name': '',
  106. 'reranking_model_name': ''
  107. },
  108. 'top_k': 2,
  109. 'score_threshold_enabled': False
  110. }
  111. return self.retrieval_model if self.retrieval_model else default_retrieval_model
  112. @property
  113. def tags(self):
  114. tags = db.session.query(Tag).join(
  115. TagBinding,
  116. Tag.id == TagBinding.tag_id
  117. ).filter(
  118. TagBinding.target_id == self.id,
  119. TagBinding.tenant_id == self.tenant_id,
  120. Tag.tenant_id == self.tenant_id,
  121. Tag.type == 'knowledge'
  122. ).all()
  123. return tags if tags else []
  124. @staticmethod
  125. def gen_collection_name_by_id(dataset_id: str) -> str:
  126. normalized_dataset_id = dataset_id.replace("-", "_")
  127. return f'Vector_index_{normalized_dataset_id}_Node'
  128. class DatasetProcessRule(db.Model):
  129. __tablename__ = 'dataset_process_rules'
  130. __table_args__ = (
  131. db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'),
  132. db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
  133. )
  134. id = db.Column(StringUUID, nullable=False,
  135. server_default=db.text('uuid_generate_v4()'))
  136. dataset_id = db.Column(StringUUID, nullable=False)
  137. mode = db.Column(db.String(255), nullable=False,
  138. server_default=db.text("'automatic'::character varying"))
  139. rules = db.Column(db.Text, nullable=True)
  140. created_by = db.Column(StringUUID, nullable=False)
  141. created_at = db.Column(db.DateTime, nullable=False,
  142. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  143. MODES = ['automatic', 'custom']
  144. PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails']
  145. AUTOMATIC_RULES = {
  146. 'pre_processing_rules': [
  147. {'id': 'remove_extra_spaces', 'enabled': True},
  148. {'id': 'remove_urls_emails', 'enabled': False}
  149. ],
  150. 'segmentation': {
  151. 'delimiter': '\n',
  152. 'max_tokens': 500,
  153. 'chunk_overlap': 50
  154. }
  155. }
  156. def to_dict(self):
  157. return {
  158. 'id': self.id,
  159. 'dataset_id': self.dataset_id,
  160. 'mode': self.mode,
  161. 'rules': self.rules_dict,
  162. 'created_by': self.created_by,
  163. 'created_at': self.created_at,
  164. }
  165. @property
  166. def rules_dict(self):
  167. try:
  168. return json.loads(self.rules) if self.rules else None
  169. except JSONDecodeError:
  170. return None
  171. class Document(db.Model):
  172. __tablename__ = 'documents'
  173. __table_args__ = (
  174. db.PrimaryKeyConstraint('id', name='document_pkey'),
  175. db.Index('document_dataset_id_idx', 'dataset_id'),
  176. db.Index('document_is_paused_idx', 'is_paused'),
  177. db.Index('document_tenant_idx', 'tenant_id'),
  178. )
  179. # initial fields
  180. id = db.Column(StringUUID, nullable=False,
  181. server_default=db.text('uuid_generate_v4()'))
  182. tenant_id = db.Column(StringUUID, nullable=False)
  183. dataset_id = db.Column(StringUUID, nullable=False)
  184. position = db.Column(db.Integer, nullable=False)
  185. data_source_type = db.Column(db.String(255), nullable=False)
  186. data_source_info = db.Column(db.Text, nullable=True)
  187. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  188. batch = db.Column(db.String(255), nullable=False)
  189. name = db.Column(db.String(255), nullable=False)
  190. created_from = db.Column(db.String(255), nullable=False)
  191. created_by = db.Column(StringUUID, nullable=False)
  192. created_api_request_id = db.Column(StringUUID, nullable=True)
  193. created_at = db.Column(db.DateTime, nullable=False,
  194. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  195. # start processing
  196. processing_started_at = db.Column(db.DateTime, nullable=True)
  197. # parsing
  198. file_id = db.Column(db.Text, nullable=True)
  199. word_count = db.Column(db.Integer, nullable=True)
  200. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  201. # cleaning
  202. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  203. # split
  204. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  205. # indexing
  206. tokens = db.Column(db.Integer, nullable=True)
  207. indexing_latency = db.Column(db.Float, nullable=True)
  208. completed_at = db.Column(db.DateTime, nullable=True)
  209. # pause
  210. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
  211. paused_by = db.Column(StringUUID, nullable=True)
  212. paused_at = db.Column(db.DateTime, nullable=True)
  213. # error
  214. error = db.Column(db.Text, nullable=True)
  215. stopped_at = db.Column(db.DateTime, nullable=True)
  216. # basic fields
  217. indexing_status = db.Column(db.String(
  218. 255), nullable=False, server_default=db.text("'waiting'::character varying"))
  219. enabled = db.Column(db.Boolean, nullable=False,
  220. server_default=db.text('true'))
  221. disabled_at = db.Column(db.DateTime, nullable=True)
  222. disabled_by = db.Column(StringUUID, nullable=True)
  223. archived = db.Column(db.Boolean, nullable=False,
  224. server_default=db.text('false'))
  225. archived_reason = db.Column(db.String(255), nullable=True)
  226. archived_by = db.Column(StringUUID, nullable=True)
  227. archived_at = db.Column(db.DateTime, nullable=True)
  228. updated_at = db.Column(db.DateTime, nullable=False,
  229. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  230. doc_type = db.Column(db.String(40), nullable=True)
  231. doc_metadata = db.Column(db.JSON, nullable=True)
  232. doc_form = db.Column(db.String(
  233. 255), nullable=False, server_default=db.text("'text_model'::character varying"))
  234. doc_language = db.Column(db.String(255), nullable=True)
  235. DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
  236. @property
  237. def display_status(self):
  238. status = None
  239. if self.indexing_status == 'waiting':
  240. status = 'queuing'
  241. elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused:
  242. status = 'paused'
  243. elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']:
  244. status = 'indexing'
  245. elif self.indexing_status == 'error':
  246. status = 'error'
  247. elif self.indexing_status == 'completed' and not self.archived and self.enabled:
  248. status = 'available'
  249. elif self.indexing_status == 'completed' and not self.archived and not self.enabled:
  250. status = 'disabled'
  251. elif self.indexing_status == 'completed' and self.archived:
  252. status = 'archived'
  253. return status
  254. @property
  255. def data_source_info_dict(self):
  256. if self.data_source_info:
  257. try:
  258. data_source_info_dict = json.loads(self.data_source_info)
  259. except JSONDecodeError:
  260. data_source_info_dict = {}
  261. return data_source_info_dict
  262. return None
  263. @property
  264. def data_source_detail_dict(self):
  265. if self.data_source_info:
  266. if self.data_source_type == 'upload_file':
  267. data_source_info_dict = json.loads(self.data_source_info)
  268. file_detail = db.session.query(UploadFile). \
  269. filter(UploadFile.id == data_source_info_dict['upload_file_id']). \
  270. one_or_none()
  271. if file_detail:
  272. return {
  273. 'upload_file': {
  274. 'id': file_detail.id,
  275. 'name': file_detail.name,
  276. 'size': file_detail.size,
  277. 'extension': file_detail.extension,
  278. 'mime_type': file_detail.mime_type,
  279. 'created_by': file_detail.created_by,
  280. 'created_at': file_detail.created_at.timestamp()
  281. }
  282. }
  283. elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
  284. return json.loads(self.data_source_info)
  285. return {}
  286. @property
  287. def average_segment_length(self):
  288. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  289. return self.word_count // self.segment_count
  290. return 0
  291. @property
  292. def dataset_process_rule(self):
  293. if self.dataset_process_rule_id:
  294. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  295. return None
  296. @property
  297. def dataset(self):
  298. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  299. @property
  300. def segment_count(self):
  301. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  302. @property
  303. def hit_count(self):
  304. return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
  305. .filter(DocumentSegment.document_id == self.id).scalar()
  306. def to_dict(self):
  307. return {
  308. 'id': self.id,
  309. 'tenant_id': self.tenant_id,
  310. 'dataset_id': self.dataset_id,
  311. 'position': self.position,
  312. 'data_source_type': self.data_source_type,
  313. 'data_source_info': self.data_source_info,
  314. 'dataset_process_rule_id': self.dataset_process_rule_id,
  315. 'batch': self.batch,
  316. 'name': self.name,
  317. 'created_from': self.created_from,
  318. 'created_by': self.created_by,
  319. 'created_api_request_id': self.created_api_request_id,
  320. 'created_at': self.created_at,
  321. 'processing_started_at': self.processing_started_at,
  322. 'file_id': self.file_id,
  323. 'word_count': self.word_count,
  324. 'parsing_completed_at': self.parsing_completed_at,
  325. 'cleaning_completed_at': self.cleaning_completed_at,
  326. 'splitting_completed_at': self.splitting_completed_at,
  327. 'tokens': self.tokens,
  328. 'indexing_latency': self.indexing_latency,
  329. 'completed_at': self.completed_at,
  330. 'is_paused': self.is_paused,
  331. 'paused_by': self.paused_by,
  332. 'paused_at': self.paused_at,
  333. 'error': self.error,
  334. 'stopped_at': self.stopped_at,
  335. 'indexing_status': self.indexing_status,
  336. 'enabled': self.enabled,
  337. 'disabled_at': self.disabled_at,
  338. 'disabled_by': self.disabled_by,
  339. 'archived': self.archived,
  340. 'archived_reason': self.archived_reason,
  341. 'archived_by': self.archived_by,
  342. 'archived_at': self.archived_at,
  343. 'updated_at': self.updated_at,
  344. 'doc_type': self.doc_type,
  345. 'doc_metadata': self.doc_metadata,
  346. 'doc_form': self.doc_form,
  347. 'doc_language': self.doc_language,
  348. 'display_status': self.display_status,
  349. 'data_source_info_dict': self.data_source_info_dict,
  350. 'average_segment_length': self.average_segment_length,
  351. 'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  352. 'dataset': self.dataset.to_dict() if self.dataset else None,
  353. 'segment_count': self.segment_count,
  354. 'hit_count': self.hit_count
  355. }
  356. @classmethod
  357. def from_dict(cls, data: dict):
  358. return cls(
  359. id=data.get('id'),
  360. tenant_id=data.get('tenant_id'),
  361. dataset_id=data.get('dataset_id'),
  362. position=data.get('position'),
  363. data_source_type=data.get('data_source_type'),
  364. data_source_info=data.get('data_source_info'),
  365. dataset_process_rule_id=data.get('dataset_process_rule_id'),
  366. batch=data.get('batch'),
  367. name=data.get('name'),
  368. created_from=data.get('created_from'),
  369. created_by=data.get('created_by'),
  370. created_api_request_id=data.get('created_api_request_id'),
  371. created_at=data.get('created_at'),
  372. processing_started_at=data.get('processing_started_at'),
  373. file_id=data.get('file_id'),
  374. word_count=data.get('word_count'),
  375. parsing_completed_at=data.get('parsing_completed_at'),
  376. cleaning_completed_at=data.get('cleaning_completed_at'),
  377. splitting_completed_at=data.get('splitting_completed_at'),
  378. tokens=data.get('tokens'),
  379. indexing_latency=data.get('indexing_latency'),
  380. completed_at=data.get('completed_at'),
  381. is_paused=data.get('is_paused'),
  382. paused_by=data.get('paused_by'),
  383. paused_at=data.get('paused_at'),
  384. error=data.get('error'),
  385. stopped_at=data.get('stopped_at'),
  386. indexing_status=data.get('indexing_status'),
  387. enabled=data.get('enabled'),
  388. disabled_at=data.get('disabled_at'),
  389. disabled_by=data.get('disabled_by'),
  390. archived=data.get('archived'),
  391. archived_reason=data.get('archived_reason'),
  392. archived_by=data.get('archived_by'),
  393. archived_at=data.get('archived_at'),
  394. updated_at=data.get('updated_at'),
  395. doc_type=data.get('doc_type'),
  396. doc_metadata=data.get('doc_metadata'),
  397. doc_form=data.get('doc_form'),
  398. doc_language=data.get('doc_language')
  399. )
  400. class DocumentSegment(db.Model):
  401. __tablename__ = 'document_segments'
  402. __table_args__ = (
  403. db.PrimaryKeyConstraint('id', name='document_segment_pkey'),
  404. db.Index('document_segment_dataset_id_idx', 'dataset_id'),
  405. db.Index('document_segment_document_id_idx', 'document_id'),
  406. db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
  407. db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
  408. db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
  409. db.Index('document_segment_tenant_idx', 'tenant_id'),
  410. )
  411. # initial fields
  412. id = db.Column(StringUUID, nullable=False,
  413. server_default=db.text('uuid_generate_v4()'))
  414. tenant_id = db.Column(StringUUID, nullable=False)
  415. dataset_id = db.Column(StringUUID, nullable=False)
  416. document_id = db.Column(StringUUID, nullable=False)
  417. position = db.Column(db.Integer, nullable=False)
  418. content = db.Column(db.Text, nullable=False)
  419. answer = db.Column(db.Text, nullable=True)
  420. word_count = db.Column(db.Integer, nullable=False)
  421. tokens = db.Column(db.Integer, nullable=False)
  422. # indexing fields
  423. keywords = db.Column(db.JSON, nullable=True)
  424. index_node_id = db.Column(db.String(255), nullable=True)
  425. index_node_hash = db.Column(db.String(255), nullable=True)
  426. # basic fields
  427. hit_count = db.Column(db.Integer, nullable=False, default=0)
  428. enabled = db.Column(db.Boolean, nullable=False,
  429. server_default=db.text('true'))
  430. disabled_at = db.Column(db.DateTime, nullable=True)
  431. disabled_by = db.Column(StringUUID, nullable=True)
  432. status = db.Column(db.String(255), nullable=False,
  433. server_default=db.text("'waiting'::character varying"))
  434. created_by = db.Column(StringUUID, nullable=False)
  435. created_at = db.Column(db.DateTime, nullable=False,
  436. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  437. updated_by = db.Column(StringUUID, nullable=True)
  438. updated_at = db.Column(db.DateTime, nullable=False,
  439. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  440. indexing_at = db.Column(db.DateTime, nullable=True)
  441. completed_at = db.Column(db.DateTime, nullable=True)
  442. error = db.Column(db.Text, nullable=True)
  443. stopped_at = db.Column(db.DateTime, nullable=True)
  444. @property
  445. def dataset(self):
  446. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  447. @property
  448. def document(self):
  449. return db.session.query(Document).filter(Document.id == self.document_id).first()
  450. @property
  451. def previous_segment(self):
  452. return db.session.query(DocumentSegment).filter(
  453. DocumentSegment.document_id == self.document_id,
  454. DocumentSegment.position == self.position - 1
  455. ).first()
  456. @property
  457. def next_segment(self):
  458. return db.session.query(DocumentSegment).filter(
  459. DocumentSegment.document_id == self.document_id,
  460. DocumentSegment.position == self.position + 1
  461. ).first()
  462. def get_sign_content(self):
  463. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  464. text = self.content
  465. matches = re.finditer(pattern, text)
  466. signed_urls = []
  467. for match in matches:
  468. upload_file_id = match.group(1)
  469. nonce = os.urandom(16).hex()
  470. timestamp = str(int(time.time()))
  471. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  472. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b''
  473. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  474. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  475. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  476. signed_url = f"{match.group(0)}?{params}"
  477. signed_urls.append((match.start(), match.end(), signed_url))
  478. # Reconstruct the text with signed URLs
  479. offset = 0
  480. for start, end, signed_url in signed_urls:
  481. text = text[:start + offset] + signed_url + text[end + offset:]
  482. offset += len(signed_url) - (end - start)
  483. return text
  484. class AppDatasetJoin(db.Model):
  485. __tablename__ = 'app_dataset_joins'
  486. __table_args__ = (
  487. db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'),
  488. db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
  489. )
  490. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  491. app_id = db.Column(StringUUID, nullable=False)
  492. dataset_id = db.Column(StringUUID, nullable=False)
  493. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  494. @property
  495. def app(self):
  496. return db.session.get(App, self.app_id)
  497. class DatasetQuery(db.Model):
  498. __tablename__ = 'dataset_queries'
  499. __table_args__ = (
  500. db.PrimaryKeyConstraint('id', name='dataset_query_pkey'),
  501. db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
  502. )
  503. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  504. dataset_id = db.Column(StringUUID, nullable=False)
  505. content = db.Column(db.Text, nullable=False)
  506. source = db.Column(db.String(255), nullable=False)
  507. source_app_id = db.Column(StringUUID, nullable=True)
  508. created_by_role = db.Column(db.String, nullable=False)
  509. created_by = db.Column(StringUUID, nullable=False)
  510. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  511. class DatasetKeywordTable(db.Model):
  512. __tablename__ = 'dataset_keyword_tables'
  513. __table_args__ = (
  514. db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'),
  515. db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
  516. )
  517. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  518. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  519. keyword_table = db.Column(db.Text, nullable=False)
  520. data_source_type = db.Column(db.String(255), nullable=False,
  521. server_default=db.text("'database'::character varying"))
  522. @property
  523. def keyword_table_dict(self):
  524. class SetDecoder(json.JSONDecoder):
  525. def __init__(self, *args, **kwargs):
  526. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  527. def object_hook(self, dct):
  528. if isinstance(dct, dict):
  529. for keyword, node_idxs in dct.items():
  530. if isinstance(node_idxs, list):
  531. dct[keyword] = set(node_idxs)
  532. return dct
  533. # get dataset
  534. dataset = Dataset.query.filter_by(
  535. id=self.dataset_id
  536. ).first()
  537. if not dataset:
  538. return None
  539. if self.data_source_type == 'database':
  540. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  541. else:
  542. file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt'
  543. try:
  544. keyword_table_text = storage.load_once(file_key)
  545. if keyword_table_text:
  546. return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder)
  547. return None
  548. except Exception as e:
  549. logging.exception(str(e))
  550. return None
  551. class Embedding(db.Model):
  552. __tablename__ = 'embeddings'
  553. __table_args__ = (
  554. db.PrimaryKeyConstraint('id', name='embedding_pkey'),
  555. db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
  556. )
  557. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  558. model_name = db.Column(db.String(40), nullable=False,
  559. server_default=db.text("'text-embedding-ada-002'::character varying"))
  560. hash = db.Column(db.String(64), nullable=False)
  561. embedding = db.Column(db.LargeBinary, nullable=False)
  562. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  563. provider_name = db.Column(db.String(40), nullable=False,
  564. server_default=db.text("''::character varying"))
  565. def set_embedding(self, embedding_data: list[float]):
  566. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  567. def get_embedding(self) -> list[float]:
  568. return pickle.loads(self.embedding)
  569. class DatasetCollectionBinding(db.Model):
  570. __tablename__ = 'dataset_collection_bindings'
  571. __table_args__ = (
  572. db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
  573. db.Index('provider_model_name_idx', 'provider_name', 'model_name')
  574. )
  575. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  576. provider_name = db.Column(db.String(40), nullable=False)
  577. model_name = db.Column(db.String(40), nullable=False)
  578. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  579. collection_name = db.Column(db.String(64), nullable=False)
  580. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  581. class DatasetPermission(db.Model):
  582. __tablename__ = 'dataset_permissions'
  583. __table_args__ = (
  584. db.PrimaryKeyConstraint('id', name='dataset_permission_pkey'),
  585. db.Index('idx_dataset_permissions_dataset_id', 'dataset_id'),
  586. db.Index('idx_dataset_permissions_account_id', 'account_id'),
  587. db.Index('idx_dataset_permissions_tenant_id', 'tenant_id')
  588. )
  589. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'), primary_key=True)
  590. dataset_id = db.Column(StringUUID, nullable=False)
  591. account_id = db.Column(StringUUID, nullable=False)
  592. tenant_id = db.Column(StringUUID, nullable=False)
  593. has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text('true'))
  594. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))