dataset.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import logging
  6. import os
  7. import pickle
  8. import re
  9. import time
  10. from json import JSONDecodeError
  11. from flask import current_app
  12. from sqlalchemy import func
  13. from sqlalchemy.dialects.postgresql import JSONB
  14. from extensions.ext_database import db
  15. from extensions.ext_storage import storage
  16. from models import StringUUID
  17. from models.account import Account
  18. from models.model import App, Tag, TagBinding, UploadFile
  19. class Dataset(db.Model):
  20. __tablename__ = 'datasets'
  21. __table_args__ = (
  22. db.PrimaryKeyConstraint('id', name='dataset_pkey'),
  23. db.Index('dataset_tenant_idx', 'tenant_id'),
  24. db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
  25. )
  26. INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
  27. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
  28. tenant_id = db.Column(StringUUID, nullable=False)
  29. name = db.Column(db.String(255), nullable=False)
  30. description = db.Column(db.Text, nullable=True)
  31. provider = db.Column(db.String(255), nullable=False,
  32. server_default=db.text("'vendor'::character varying"))
  33. permission = db.Column(db.String(255), nullable=False,
  34. server_default=db.text("'only_me'::character varying"))
  35. data_source_type = db.Column(db.String(255))
  36. indexing_technique = db.Column(db.String(255), nullable=True)
  37. index_struct = db.Column(db.Text, nullable=True)
  38. created_by = db.Column(StringUUID, nullable=False)
  39. created_at = db.Column(db.DateTime, nullable=False,
  40. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  41. updated_by = db.Column(StringUUID, nullable=True)
  42. updated_at = db.Column(db.DateTime, nullable=False,
  43. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  44. embedding_model = db.Column(db.String(255), nullable=True)
  45. embedding_model_provider = db.Column(db.String(255), nullable=True)
  46. collection_binding_id = db.Column(StringUUID, nullable=True)
  47. retrieval_model = db.Column(JSONB, nullable=True)
  48. @property
  49. def dataset_keyword_table(self):
  50. dataset_keyword_table = db.session.query(DatasetKeywordTable).filter(
  51. DatasetKeywordTable.dataset_id == self.id).first()
  52. if dataset_keyword_table:
  53. return dataset_keyword_table
  54. return None
  55. @property
  56. def index_struct_dict(self):
  57. return json.loads(self.index_struct) if self.index_struct else None
  58. @property
  59. def created_by_account(self):
  60. return Account.query.get(self.created_by)
  61. @property
  62. def latest_process_rule(self):
  63. return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \
  64. .order_by(DatasetProcessRule.created_at.desc()).first()
  65. @property
  66. def app_count(self):
  67. return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id).scalar()
  68. @property
  69. def document_count(self):
  70. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  71. @property
  72. def available_document_count(self):
  73. return db.session.query(func.count(Document.id)).filter(
  74. Document.dataset_id == self.id,
  75. Document.indexing_status == 'completed',
  76. Document.enabled == True,
  77. Document.archived == False
  78. ).scalar()
  79. @property
  80. def available_segment_count(self):
  81. return db.session.query(func.count(DocumentSegment.id)).filter(
  82. DocumentSegment.dataset_id == self.id,
  83. DocumentSegment.status == 'completed',
  84. DocumentSegment.enabled == True
  85. ).scalar()
  86. @property
  87. def word_count(self):
  88. return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
  89. .filter(Document.dataset_id == self.id).scalar()
  90. @property
  91. def doc_form(self):
  92. document = db.session.query(Document).filter(
  93. Document.dataset_id == self.id).first()
  94. if document:
  95. return document.doc_form
  96. return None
  97. @property
  98. def retrieval_model_dict(self):
  99. default_retrieval_model = {
  100. 'search_method': 'semantic_search',
  101. 'reranking_enable': False,
  102. 'reranking_model': {
  103. 'reranking_provider_name': '',
  104. 'reranking_model_name': ''
  105. },
  106. 'top_k': 2,
  107. 'score_threshold_enabled': False
  108. }
  109. return self.retrieval_model if self.retrieval_model else default_retrieval_model
  110. @property
  111. def tags(self):
  112. tags = db.session.query(Tag).join(
  113. TagBinding,
  114. Tag.id == TagBinding.tag_id
  115. ).filter(
  116. TagBinding.target_id == self.id,
  117. TagBinding.tenant_id == self.tenant_id,
  118. Tag.tenant_id == self.tenant_id,
  119. Tag.type == 'knowledge'
  120. ).all()
  121. return tags if tags else []
  122. @staticmethod
  123. def gen_collection_name_by_id(dataset_id: str) -> str:
  124. normalized_dataset_id = dataset_id.replace("-", "_")
  125. return f'Vector_index_{normalized_dataset_id}_Node'
  126. class DatasetProcessRule(db.Model):
  127. __tablename__ = 'dataset_process_rules'
  128. __table_args__ = (
  129. db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'),
  130. db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
  131. )
  132. id = db.Column(StringUUID, nullable=False,
  133. server_default=db.text('uuid_generate_v4()'))
  134. dataset_id = db.Column(StringUUID, nullable=False)
  135. mode = db.Column(db.String(255), nullable=False,
  136. server_default=db.text("'automatic'::character varying"))
  137. rules = db.Column(db.Text, nullable=True)
  138. created_by = db.Column(StringUUID, nullable=False)
  139. created_at = db.Column(db.DateTime, nullable=False,
  140. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  141. MODES = ['automatic', 'custom']
  142. PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails']
  143. AUTOMATIC_RULES = {
  144. 'pre_processing_rules': [
  145. {'id': 'remove_extra_spaces', 'enabled': True},
  146. {'id': 'remove_urls_emails', 'enabled': False}
  147. ],
  148. 'segmentation': {
  149. 'delimiter': '\n',
  150. 'max_tokens': 500,
  151. 'chunk_overlap': 50
  152. }
  153. }
  154. def to_dict(self):
  155. return {
  156. 'id': self.id,
  157. 'dataset_id': self.dataset_id,
  158. 'mode': self.mode,
  159. 'rules': self.rules_dict,
  160. 'created_by': self.created_by,
  161. 'created_at': self.created_at,
  162. }
  163. @property
  164. def rules_dict(self):
  165. try:
  166. return json.loads(self.rules) if self.rules else None
  167. except JSONDecodeError:
  168. return None
  169. class Document(db.Model):
  170. __tablename__ = 'documents'
  171. __table_args__ = (
  172. db.PrimaryKeyConstraint('id', name='document_pkey'),
  173. db.Index('document_dataset_id_idx', 'dataset_id'),
  174. db.Index('document_is_paused_idx', 'is_paused'),
  175. db.Index('document_tenant_idx', 'tenant_id'),
  176. )
  177. # initial fields
  178. id = db.Column(StringUUID, nullable=False,
  179. server_default=db.text('uuid_generate_v4()'))
  180. tenant_id = db.Column(StringUUID, nullable=False)
  181. dataset_id = db.Column(StringUUID, nullable=False)
  182. position = db.Column(db.Integer, nullable=False)
  183. data_source_type = db.Column(db.String(255), nullable=False)
  184. data_source_info = db.Column(db.Text, nullable=True)
  185. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  186. batch = db.Column(db.String(255), nullable=False)
  187. name = db.Column(db.String(255), nullable=False)
  188. created_from = db.Column(db.String(255), nullable=False)
  189. created_by = db.Column(StringUUID, nullable=False)
  190. created_api_request_id = db.Column(StringUUID, nullable=True)
  191. created_at = db.Column(db.DateTime, nullable=False,
  192. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  193. # start processing
  194. processing_started_at = db.Column(db.DateTime, nullable=True)
  195. # parsing
  196. file_id = db.Column(db.Text, nullable=True)
  197. word_count = db.Column(db.Integer, nullable=True)
  198. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  199. # cleaning
  200. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  201. # split
  202. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  203. # indexing
  204. tokens = db.Column(db.Integer, nullable=True)
  205. indexing_latency = db.Column(db.Float, nullable=True)
  206. completed_at = db.Column(db.DateTime, nullable=True)
  207. # pause
  208. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
  209. paused_by = db.Column(StringUUID, nullable=True)
  210. paused_at = db.Column(db.DateTime, nullable=True)
  211. # error
  212. error = db.Column(db.Text, nullable=True)
  213. stopped_at = db.Column(db.DateTime, nullable=True)
  214. # basic fields
  215. indexing_status = db.Column(db.String(
  216. 255), nullable=False, server_default=db.text("'waiting'::character varying"))
  217. enabled = db.Column(db.Boolean, nullable=False,
  218. server_default=db.text('true'))
  219. disabled_at = db.Column(db.DateTime, nullable=True)
  220. disabled_by = db.Column(StringUUID, nullable=True)
  221. archived = db.Column(db.Boolean, nullable=False,
  222. server_default=db.text('false'))
  223. archived_reason = db.Column(db.String(255), nullable=True)
  224. archived_by = db.Column(StringUUID, nullable=True)
  225. archived_at = db.Column(db.DateTime, nullable=True)
  226. updated_at = db.Column(db.DateTime, nullable=False,
  227. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  228. doc_type = db.Column(db.String(40), nullable=True)
  229. doc_metadata = db.Column(db.JSON, nullable=True)
  230. doc_form = db.Column(db.String(
  231. 255), nullable=False, server_default=db.text("'text_model'::character varying"))
  232. doc_language = db.Column(db.String(255), nullable=True)
  233. DATA_SOURCES = ['upload_file', 'notion_import']
  234. @property
  235. def display_status(self):
  236. status = None
  237. if self.indexing_status == 'waiting':
  238. status = 'queuing'
  239. elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused:
  240. status = 'paused'
  241. elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']:
  242. status = 'indexing'
  243. elif self.indexing_status == 'error':
  244. status = 'error'
  245. elif self.indexing_status == 'completed' and not self.archived and self.enabled:
  246. status = 'available'
  247. elif self.indexing_status == 'completed' and not self.archived and not self.enabled:
  248. status = 'disabled'
  249. elif self.indexing_status == 'completed' and self.archived:
  250. status = 'archived'
  251. return status
  252. @property
  253. def data_source_info_dict(self):
  254. if self.data_source_info:
  255. try:
  256. data_source_info_dict = json.loads(self.data_source_info)
  257. except JSONDecodeError:
  258. data_source_info_dict = {}
  259. return data_source_info_dict
  260. return None
  261. @property
  262. def data_source_detail_dict(self):
  263. if self.data_source_info:
  264. if self.data_source_type == 'upload_file':
  265. data_source_info_dict = json.loads(self.data_source_info)
  266. file_detail = db.session.query(UploadFile). \
  267. filter(UploadFile.id == data_source_info_dict['upload_file_id']). \
  268. one_or_none()
  269. if file_detail:
  270. return {
  271. 'upload_file': {
  272. 'id': file_detail.id,
  273. 'name': file_detail.name,
  274. 'size': file_detail.size,
  275. 'extension': file_detail.extension,
  276. 'mime_type': file_detail.mime_type,
  277. 'created_by': file_detail.created_by,
  278. 'created_at': file_detail.created_at.timestamp()
  279. }
  280. }
  281. elif self.data_source_type == 'notion_import':
  282. return json.loads(self.data_source_info)
  283. return {}
  284. @property
  285. def average_segment_length(self):
  286. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  287. return self.word_count // self.segment_count
  288. return 0
  289. @property
  290. def dataset_process_rule(self):
  291. if self.dataset_process_rule_id:
  292. return DatasetProcessRule.query.get(self.dataset_process_rule_id)
  293. return None
  294. @property
  295. def dataset(self):
  296. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  297. @property
  298. def segment_count(self):
  299. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  300. @property
  301. def hit_count(self):
  302. return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
  303. .filter(DocumentSegment.document_id == self.id).scalar()
  304. class DocumentSegment(db.Model):
  305. __tablename__ = 'document_segments'
  306. __table_args__ = (
  307. db.PrimaryKeyConstraint('id', name='document_segment_pkey'),
  308. db.Index('document_segment_dataset_id_idx', 'dataset_id'),
  309. db.Index('document_segment_document_id_idx', 'document_id'),
  310. db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
  311. db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
  312. db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
  313. db.Index('document_segment_tenant_idx', 'tenant_id'),
  314. )
  315. # initial fields
  316. id = db.Column(StringUUID, nullable=False,
  317. server_default=db.text('uuid_generate_v4()'))
  318. tenant_id = db.Column(StringUUID, nullable=False)
  319. dataset_id = db.Column(StringUUID, nullable=False)
  320. document_id = db.Column(StringUUID, nullable=False)
  321. position = db.Column(db.Integer, nullable=False)
  322. content = db.Column(db.Text, nullable=False)
  323. answer = db.Column(db.Text, nullable=True)
  324. word_count = db.Column(db.Integer, nullable=False)
  325. tokens = db.Column(db.Integer, nullable=False)
  326. # indexing fields
  327. keywords = db.Column(db.JSON, nullable=True)
  328. index_node_id = db.Column(db.String(255), nullable=True)
  329. index_node_hash = db.Column(db.String(255), nullable=True)
  330. # basic fields
  331. hit_count = db.Column(db.Integer, nullable=False, default=0)
  332. enabled = db.Column(db.Boolean, nullable=False,
  333. server_default=db.text('true'))
  334. disabled_at = db.Column(db.DateTime, nullable=True)
  335. disabled_by = db.Column(StringUUID, nullable=True)
  336. status = db.Column(db.String(255), nullable=False,
  337. server_default=db.text("'waiting'::character varying"))
  338. created_by = db.Column(StringUUID, nullable=False)
  339. created_at = db.Column(db.DateTime, nullable=False,
  340. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  341. updated_by = db.Column(StringUUID, nullable=True)
  342. updated_at = db.Column(db.DateTime, nullable=False,
  343. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  344. indexing_at = db.Column(db.DateTime, nullable=True)
  345. completed_at = db.Column(db.DateTime, nullable=True)
  346. error = db.Column(db.Text, nullable=True)
  347. stopped_at = db.Column(db.DateTime, nullable=True)
  348. @property
  349. def dataset(self):
  350. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  351. @property
  352. def document(self):
  353. return db.session.query(Document).filter(Document.id == self.document_id).first()
  354. @property
  355. def previous_segment(self):
  356. return db.session.query(DocumentSegment).filter(
  357. DocumentSegment.document_id == self.document_id,
  358. DocumentSegment.position == self.position - 1
  359. ).first()
  360. @property
  361. def next_segment(self):
  362. return db.session.query(DocumentSegment).filter(
  363. DocumentSegment.document_id == self.document_id,
  364. DocumentSegment.position == self.position + 1
  365. ).first()
  366. def get_sign_content(self):
  367. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  368. text = self.content
  369. match = re.search(pattern, text)
  370. if match:
  371. upload_file_id = match.group(1)
  372. nonce = os.urandom(16).hex()
  373. timestamp = str(int(time.time()))
  374. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  375. secret_key = current_app.config['SECRET_KEY'].encode()
  376. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  377. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  378. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  379. replacement = r"\g<0>?{params}".format(params=params)
  380. text = re.sub(pattern, replacement, text)
  381. return text
  382. class AppDatasetJoin(db.Model):
  383. __tablename__ = 'app_dataset_joins'
  384. __table_args__ = (
  385. db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'),
  386. db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
  387. )
  388. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  389. app_id = db.Column(StringUUID, nullable=False)
  390. dataset_id = db.Column(StringUUID, nullable=False)
  391. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  392. @property
  393. def app(self):
  394. return App.query.get(self.app_id)
  395. class DatasetQuery(db.Model):
  396. __tablename__ = 'dataset_queries'
  397. __table_args__ = (
  398. db.PrimaryKeyConstraint('id', name='dataset_query_pkey'),
  399. db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
  400. )
  401. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  402. dataset_id = db.Column(StringUUID, nullable=False)
  403. content = db.Column(db.Text, nullable=False)
  404. source = db.Column(db.String(255), nullable=False)
  405. source_app_id = db.Column(StringUUID, nullable=True)
  406. created_by_role = db.Column(db.String, nullable=False)
  407. created_by = db.Column(StringUUID, nullable=False)
  408. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  409. class DatasetKeywordTable(db.Model):
  410. __tablename__ = 'dataset_keyword_tables'
  411. __table_args__ = (
  412. db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'),
  413. db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
  414. )
  415. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  416. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  417. keyword_table = db.Column(db.Text, nullable=False)
  418. data_source_type = db.Column(db.String(255), nullable=False,
  419. server_default=db.text("'database'::character varying"))
  420. @property
  421. def keyword_table_dict(self):
  422. class SetDecoder(json.JSONDecoder):
  423. def __init__(self, *args, **kwargs):
  424. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  425. def object_hook(self, dct):
  426. if isinstance(dct, dict):
  427. for keyword, node_idxs in dct.items():
  428. if isinstance(node_idxs, list):
  429. dct[keyword] = set(node_idxs)
  430. return dct
  431. # get dataset
  432. dataset = Dataset.query.filter_by(
  433. id=self.dataset_id
  434. ).first()
  435. if not dataset:
  436. return None
  437. if self.data_source_type == 'database':
  438. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  439. else:
  440. file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt'
  441. try:
  442. keyword_table_text = storage.load_once(file_key)
  443. if keyword_table_text:
  444. return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder)
  445. return None
  446. except Exception as e:
  447. logging.exception(str(e))
  448. return None
  449. class Embedding(db.Model):
  450. __tablename__ = 'embeddings'
  451. __table_args__ = (
  452. db.PrimaryKeyConstraint('id', name='embedding_pkey'),
  453. db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
  454. )
  455. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  456. model_name = db.Column(db.String(40), nullable=False,
  457. server_default=db.text("'text-embedding-ada-002'::character varying"))
  458. hash = db.Column(db.String(64), nullable=False)
  459. embedding = db.Column(db.LargeBinary, nullable=False)
  460. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  461. provider_name = db.Column(db.String(40), nullable=False,
  462. server_default=db.text("''::character varying"))
  463. def set_embedding(self, embedding_data: list[float]):
  464. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  465. def get_embedding(self) -> list[float]:
  466. return pickle.loads(self.embedding)
  467. class DatasetCollectionBinding(db.Model):
  468. __tablename__ = 'dataset_collection_bindings'
  469. __table_args__ = (
  470. db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
  471. db.Index('provider_model_name_idx', 'provider_name', 'model_name')
  472. )
  473. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  474. provider_name = db.Column(db.String(40), nullable=False)
  475. model_name = db.Column(db.String(40), nullable=False)
  476. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  477. collection_name = db.Column(db.String(64), nullable=False)
  478. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))