notion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import json
  2. import logging
  3. from typing import Any, Dict, List, Optional
  4. import requests
  5. from flask import current_app
  6. from langchain.document_loaders.base import BaseLoader
  7. from langchain.schema import Document
  8. from extensions.ext_database import db
  9. from models.dataset import Document as DocumentModel
  10. from models.source import DataSourceBinding
  11. logger = logging.getLogger(__name__)
  12. BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
  13. DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
  14. SEARCH_URL = "https://api.notion.com/v1/search"
  15. RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
  16. RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
  17. HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
  18. class NotionLoader(BaseLoader):
  19. def __init__(
  20. self,
  21. notion_access_token: str,
  22. notion_workspace_id: str,
  23. notion_obj_id: str,
  24. notion_page_type: str,
  25. document_model: Optional[DocumentModel] = None
  26. ):
  27. self._document_model = document_model
  28. self._notion_workspace_id = notion_workspace_id
  29. self._notion_obj_id = notion_obj_id
  30. self._notion_page_type = notion_page_type
  31. self._notion_access_token = notion_access_token
  32. if not self._notion_access_token:
  33. integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
  34. if integration_token is None:
  35. raise ValueError(
  36. "Must specify `integration_token` or set environment "
  37. "variable `NOTION_INTEGRATION_TOKEN`."
  38. )
  39. self._notion_access_token = integration_token
  40. @classmethod
  41. def from_document(cls, document_model: DocumentModel):
  42. data_source_info = document_model.data_source_info_dict
  43. if not data_source_info or 'notion_page_id' not in data_source_info \
  44. or 'notion_workspace_id' not in data_source_info:
  45. raise ValueError("no notion page found")
  46. notion_workspace_id = data_source_info['notion_workspace_id']
  47. notion_obj_id = data_source_info['notion_page_id']
  48. notion_page_type = data_source_info['type']
  49. notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
  50. return cls(
  51. notion_access_token=notion_access_token,
  52. notion_workspace_id=notion_workspace_id,
  53. notion_obj_id=notion_obj_id,
  54. notion_page_type=notion_page_type,
  55. document_model=document_model
  56. )
  57. def load(self) -> List[Document]:
  58. self.update_last_edited_time(
  59. self._document_model
  60. )
  61. text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
  62. return text_docs
  63. def _load_data_as_documents(
  64. self, notion_obj_id: str, notion_page_type: str
  65. ) -> List[Document]:
  66. docs = []
  67. if notion_page_type == 'database':
  68. # get all the pages in the database
  69. page_text_documents = self._get_notion_database_data(notion_obj_id)
  70. docs.extend(page_text_documents)
  71. elif notion_page_type == 'page':
  72. page_text_list = self._get_notion_block_data(notion_obj_id)
  73. for page_text in page_text_list:
  74. docs.append(Document(page_content=page_text))
  75. else:
  76. raise ValueError("notion page type not supported")
  77. return docs
  78. def _get_notion_database_data(
  79. self, database_id: str, query_dict: Dict[str, Any] = {}
  80. ) -> List[Document]:
  81. """Get all the pages from a Notion database."""
  82. res = requests.post(
  83. DATABASE_URL_TMPL.format(database_id=database_id),
  84. headers={
  85. "Authorization": "Bearer " + self._notion_access_token,
  86. "Content-Type": "application/json",
  87. "Notion-Version": "2022-06-28",
  88. },
  89. json=query_dict,
  90. )
  91. data = res.json()
  92. database_content_list = []
  93. if 'results' not in data or data["results"] is None:
  94. return []
  95. for result in data["results"]:
  96. properties = result['properties']
  97. data = {}
  98. for property_name, property_value in properties.items():
  99. type = property_value['type']
  100. if type == 'multi_select':
  101. value = []
  102. multi_select_list = property_value[type]
  103. for multi_select in multi_select_list:
  104. value.append(multi_select['name'])
  105. elif type == 'rich_text' or type == 'title':
  106. if len(property_value[type]) > 0:
  107. value = property_value[type][0]['plain_text']
  108. else:
  109. value = ''
  110. elif type == 'select' or type == 'status':
  111. if property_value[type]:
  112. value = property_value[type]['name']
  113. else:
  114. value = ''
  115. else:
  116. value = property_value[type]
  117. data[property_name] = value
  118. row_dict = {k: v for k, v in data.items() if v}
  119. row_content = ''
  120. for key, value in row_dict.items():
  121. if isinstance(value, dict):
  122. value_dict = {k: v for k, v in value.items() if v}
  123. value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
  124. row_content = row_content + f'{key}:{value_content}\n'
  125. else:
  126. row_content = row_content + f'{key}:{value}\n'
  127. document = Document(page_content=row_content)
  128. database_content_list.append(document)
  129. return database_content_list
  130. def _get_notion_block_data(self, page_id: str) -> List[str]:
  131. result_lines_arr = []
  132. cur_block_id = page_id
  133. while True:
  134. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  135. query_dict: Dict[str, Any] = {}
  136. res = requests.request(
  137. "GET",
  138. block_url,
  139. headers={
  140. "Authorization": "Bearer " + self._notion_access_token,
  141. "Content-Type": "application/json",
  142. "Notion-Version": "2022-06-28",
  143. },
  144. json=query_dict
  145. )
  146. data = res.json()
  147. # current block's heading
  148. heading = ''
  149. for result in data["results"]:
  150. result_type = result["type"]
  151. result_obj = result[result_type]
  152. cur_result_text_arr = []
  153. if result_type == 'table':
  154. result_block_id = result["id"]
  155. text = self._read_table_rows(result_block_id)
  156. text += "\n\n"
  157. result_lines_arr.append(text)
  158. else:
  159. if "rich_text" in result_obj:
  160. for rich_text in result_obj["rich_text"]:
  161. # skip if doesn't have text object
  162. if "text" in rich_text:
  163. text = rich_text["text"]["content"]
  164. cur_result_text_arr.append(text)
  165. if result_type in HEADING_TYPE:
  166. heading = text
  167. result_block_id = result["id"]
  168. has_children = result["has_children"]
  169. block_type = result["type"]
  170. if has_children and block_type != 'child_page':
  171. children_text = self._read_block(
  172. result_block_id, num_tabs=1
  173. )
  174. cur_result_text_arr.append(children_text)
  175. cur_result_text = "\n".join(cur_result_text_arr)
  176. cur_result_text += "\n\n"
  177. if result_type in HEADING_TYPE:
  178. result_lines_arr.append(cur_result_text)
  179. else:
  180. result_lines_arr.append(f'{heading}\n{cur_result_text}')
  181. if data["next_cursor"] is None:
  182. break
  183. else:
  184. cur_block_id = data["next_cursor"]
  185. return result_lines_arr
  186. def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
  187. """Read a block."""
  188. result_lines_arr = []
  189. cur_block_id = block_id
  190. while True:
  191. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  192. query_dict: Dict[str, Any] = {}
  193. res = requests.request(
  194. "GET",
  195. block_url,
  196. headers={
  197. "Authorization": "Bearer " + self._notion_access_token,
  198. "Content-Type": "application/json",
  199. "Notion-Version": "2022-06-28",
  200. },
  201. json=query_dict
  202. )
  203. data = res.json()
  204. if 'results' not in data or data["results"] is None:
  205. break
  206. heading = ''
  207. for result in data["results"]:
  208. result_type = result["type"]
  209. result_obj = result[result_type]
  210. cur_result_text_arr = []
  211. if result_type == 'table':
  212. result_block_id = result["id"]
  213. text = self._read_table_rows(result_block_id)
  214. result_lines_arr.append(text)
  215. else:
  216. if "rich_text" in result_obj:
  217. for rich_text in result_obj["rich_text"]:
  218. # skip if doesn't have text object
  219. if "text" in rich_text:
  220. text = rich_text["text"]["content"]
  221. prefix = "\t" * num_tabs
  222. cur_result_text_arr.append(prefix + text)
  223. if result_type in HEADING_TYPE:
  224. heading = text
  225. result_block_id = result["id"]
  226. has_children = result["has_children"]
  227. block_type = result["type"]
  228. if has_children and block_type != 'child_page':
  229. children_text = self._read_block(
  230. result_block_id, num_tabs=num_tabs + 1
  231. )
  232. cur_result_text_arr.append(children_text)
  233. cur_result_text = "\n".join(cur_result_text_arr)
  234. if result_type in HEADING_TYPE:
  235. result_lines_arr.append(cur_result_text)
  236. else:
  237. result_lines_arr.append(f'{heading}\n{cur_result_text}')
  238. if data["next_cursor"] is None:
  239. break
  240. else:
  241. cur_block_id = data["next_cursor"]
  242. result_lines = "\n".join(result_lines_arr)
  243. return result_lines
  244. def _read_table_rows(self, block_id: str) -> str:
  245. """Read table rows."""
  246. done = False
  247. result_lines_arr = []
  248. cur_block_id = block_id
  249. while not done:
  250. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  251. query_dict: Dict[str, Any] = {}
  252. res = requests.request(
  253. "GET",
  254. block_url,
  255. headers={
  256. "Authorization": "Bearer " + self._notion_access_token,
  257. "Content-Type": "application/json",
  258. "Notion-Version": "2022-06-28",
  259. },
  260. json=query_dict
  261. )
  262. data = res.json()
  263. # get table headers text
  264. table_header_cell_texts = []
  265. tabel_header_cells = data["results"][0]['table_row']['cells']
  266. for tabel_header_cell in tabel_header_cells:
  267. if tabel_header_cell:
  268. for table_header_cell_text in tabel_header_cell:
  269. text = table_header_cell_text["text"]["content"]
  270. table_header_cell_texts.append(text)
  271. # get table columns text and format
  272. results = data["results"]
  273. for i in range(len(results) - 1):
  274. column_texts = []
  275. tabel_column_cells = data["results"][i + 1]['table_row']['cells']
  276. for j in range(len(tabel_column_cells)):
  277. if tabel_column_cells[j]:
  278. for table_column_cell_text in tabel_column_cells[j]:
  279. column_text = table_column_cell_text["text"]["content"]
  280. column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
  281. cur_result_text = "\n".join(column_texts)
  282. result_lines_arr.append(cur_result_text)
  283. if data["next_cursor"] is None:
  284. done = True
  285. break
  286. else:
  287. cur_block_id = data["next_cursor"]
  288. result_lines = "\n".join(result_lines_arr)
  289. return result_lines
  290. def update_last_edited_time(self, document_model: DocumentModel):
  291. if not document_model:
  292. return
  293. last_edited_time = self.get_notion_last_edited_time()
  294. data_source_info = document_model.data_source_info_dict
  295. data_source_info['last_edited_time'] = last_edited_time
  296. update_params = {
  297. DocumentModel.data_source_info: json.dumps(data_source_info)
  298. }
  299. DocumentModel.query.filter_by(id=document_model.id).update(update_params)
  300. db.session.commit()
  301. def get_notion_last_edited_time(self) -> str:
  302. obj_id = self._notion_obj_id
  303. page_type = self._notion_page_type
  304. if page_type == 'database':
  305. retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
  306. else:
  307. retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
  308. query_dict: Dict[str, Any] = {}
  309. res = requests.request(
  310. "GET",
  311. retrieve_page_url,
  312. headers={
  313. "Authorization": "Bearer " + self._notion_access_token,
  314. "Content-Type": "application/json",
  315. "Notion-Version": "2022-06-28",
  316. },
  317. json=query_dict
  318. )
  319. data = res.json()
  320. return data["last_edited_time"]
  321. @classmethod
  322. def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
  323. data_source_binding = DataSourceBinding.query.filter(
  324. db.and_(
  325. DataSourceBinding.tenant_id == tenant_id,
  326. DataSourceBinding.provider == 'notion',
  327. DataSourceBinding.disabled == False,
  328. DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
  329. )
  330. ).first()
  331. if not data_source_binding:
  332. raise Exception(f'No notion data source binding found for tenant {tenant_id} '
  333. f'and notion workspace {notion_workspace_id}')
  334. return data_source_binding.access_token