notion.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import json
  2. import logging
  3. from typing import List, Dict, Any, Optional
  4. import requests
  5. from flask import current_app
  6. from langchain.document_loaders.base import BaseLoader
  7. from langchain.schema import Document
  8. from extensions.ext_database import db
  9. from models.dataset import Document as DocumentModel
  10. from models.source import DataSourceBinding
  11. logger = logging.getLogger(__name__)
  12. BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
  13. DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
  14. SEARCH_URL = "https://api.notion.com/v1/search"
  15. RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
  16. RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
  17. HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
  18. class NotionLoader(BaseLoader):
  19. def __init__(
  20. self,
  21. notion_access_token: str,
  22. notion_workspace_id: str,
  23. notion_obj_id: str,
  24. notion_page_type: str,
  25. document_model: Optional[DocumentModel] = None
  26. ):
  27. self._document_model = document_model
  28. self._notion_workspace_id = notion_workspace_id
  29. self._notion_obj_id = notion_obj_id
  30. self._notion_page_type = notion_page_type
  31. self._notion_access_token = notion_access_token
  32. if not self._notion_access_token:
  33. integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
  34. if integration_token is None:
  35. raise ValueError(
  36. "Must specify `integration_token` or set environment "
  37. "variable `NOTION_INTEGRATION_TOKEN`."
  38. )
  39. self._notion_access_token = integration_token
  40. @classmethod
  41. def from_document(cls, document_model: DocumentModel):
  42. data_source_info = document_model.data_source_info_dict
  43. if not data_source_info or 'notion_page_id' not in data_source_info \
  44. or 'notion_workspace_id' not in data_source_info:
  45. raise ValueError("no notion page found")
  46. notion_workspace_id = data_source_info['notion_workspace_id']
  47. notion_obj_id = data_source_info['notion_page_id']
  48. notion_page_type = data_source_info['type']
  49. notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
  50. return cls(
  51. notion_access_token=notion_access_token,
  52. notion_workspace_id=notion_workspace_id,
  53. notion_obj_id=notion_obj_id,
  54. notion_page_type=notion_page_type,
  55. document_model=document_model
  56. )
  57. def load(self) -> List[Document]:
  58. self.update_last_edited_time(
  59. self._document_model
  60. )
  61. text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
  62. return text_docs
  63. def _load_data_as_documents(
  64. self, notion_obj_id: str, notion_page_type: str
  65. ) -> List[Document]:
  66. docs = []
  67. if notion_page_type == 'database':
  68. # get all the pages in the database
  69. page_text = self._get_notion_database_data(notion_obj_id)
  70. docs.append(Document(page_content=page_text))
  71. elif notion_page_type == 'page':
  72. page_text_list = self._get_notion_block_data(notion_obj_id)
  73. for page_text in page_text_list:
  74. docs.append(Document(page_content=page_text))
  75. else:
  76. raise ValueError("notion page type not supported")
  77. return docs
  78. def _get_notion_database_data(
  79. self, database_id: str, query_dict: Dict[str, Any] = {}
  80. ) -> str:
  81. """Get all the pages from a Notion database."""
  82. res = requests.post(
  83. DATABASE_URL_TMPL.format(database_id=database_id),
  84. headers={
  85. "Authorization": "Bearer " + self._notion_access_token,
  86. "Content-Type": "application/json",
  87. "Notion-Version": "2022-06-28",
  88. },
  89. json=query_dict,
  90. )
  91. data = res.json()
  92. database_content_list = []
  93. if 'results' not in data or data["results"] is None:
  94. return ""
  95. for result in data["results"]:
  96. properties = result['properties']
  97. data = {}
  98. for property_name, property_value in properties.items():
  99. type = property_value['type']
  100. if type == 'multi_select':
  101. value = []
  102. multi_select_list = property_value[type]
  103. for multi_select in multi_select_list:
  104. value.append(multi_select['name'])
  105. elif type == 'rich_text' or type == 'title':
  106. if len(property_value[type]) > 0:
  107. value = property_value[type][0]['plain_text']
  108. else:
  109. value = ''
  110. elif type == 'select' or type == 'status':
  111. if property_value[type]:
  112. value = property_value[type]['name']
  113. else:
  114. value = ''
  115. else:
  116. value = property_value[type]
  117. data[property_name] = value
  118. database_content_list.append(json.dumps(data, ensure_ascii=False))
  119. return "\n\n".join(database_content_list)
  120. def _get_notion_block_data(self, page_id: str) -> List[str]:
  121. result_lines_arr = []
  122. cur_block_id = page_id
  123. while True:
  124. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  125. query_dict: Dict[str, Any] = {}
  126. res = requests.request(
  127. "GET",
  128. block_url,
  129. headers={
  130. "Authorization": "Bearer " + self._notion_access_token,
  131. "Content-Type": "application/json",
  132. "Notion-Version": "2022-06-28",
  133. },
  134. json=query_dict
  135. )
  136. data = res.json()
  137. # current block's heading
  138. heading = ''
  139. for result in data["results"]:
  140. result_type = result["type"]
  141. result_obj = result[result_type]
  142. cur_result_text_arr = []
  143. if result_type == 'table':
  144. result_block_id = result["id"]
  145. text = self._read_table_rows(result_block_id)
  146. text += "\n\n"
  147. result_lines_arr.append(text)
  148. else:
  149. if "rich_text" in result_obj:
  150. for rich_text in result_obj["rich_text"]:
  151. # skip if doesn't have text object
  152. if "text" in rich_text:
  153. text = rich_text["text"]["content"]
  154. cur_result_text_arr.append(text)
  155. if result_type in HEADING_TYPE:
  156. heading = text
  157. result_block_id = result["id"]
  158. has_children = result["has_children"]
  159. block_type = result["type"]
  160. if has_children and block_type != 'child_page':
  161. children_text = self._read_block(
  162. result_block_id, num_tabs=1
  163. )
  164. cur_result_text_arr.append(children_text)
  165. cur_result_text = "\n".join(cur_result_text_arr)
  166. cur_result_text += "\n\n"
  167. if result_type in HEADING_TYPE:
  168. result_lines_arr.append(cur_result_text)
  169. else:
  170. result_lines_arr.append(f'{heading}\n{cur_result_text}')
  171. if data["next_cursor"] is None:
  172. break
  173. else:
  174. cur_block_id = data["next_cursor"]
  175. return result_lines_arr
  176. def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
  177. """Read a block."""
  178. result_lines_arr = []
  179. cur_block_id = block_id
  180. while True:
  181. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  182. query_dict: Dict[str, Any] = {}
  183. res = requests.request(
  184. "GET",
  185. block_url,
  186. headers={
  187. "Authorization": "Bearer " + self._notion_access_token,
  188. "Content-Type": "application/json",
  189. "Notion-Version": "2022-06-28",
  190. },
  191. json=query_dict
  192. )
  193. data = res.json()
  194. if 'results' not in data or data["results"] is None:
  195. break
  196. heading = ''
  197. for result in data["results"]:
  198. result_type = result["type"]
  199. result_obj = result[result_type]
  200. cur_result_text_arr = []
  201. if result_type == 'table':
  202. result_block_id = result["id"]
  203. text = self._read_table_rows(result_block_id)
  204. result_lines_arr.append(text)
  205. else:
  206. if "rich_text" in result_obj:
  207. for rich_text in result_obj["rich_text"]:
  208. # skip if doesn't have text object
  209. if "text" in rich_text:
  210. text = rich_text["text"]["content"]
  211. prefix = "\t" * num_tabs
  212. cur_result_text_arr.append(prefix + text)
  213. if result_type in HEADING_TYPE:
  214. heading = text
  215. result_block_id = result["id"]
  216. has_children = result["has_children"]
  217. block_type = result["type"]
  218. if has_children and block_type != 'child_page':
  219. children_text = self._read_block(
  220. result_block_id, num_tabs=num_tabs + 1
  221. )
  222. cur_result_text_arr.append(children_text)
  223. cur_result_text = "\n".join(cur_result_text_arr)
  224. if result_type in HEADING_TYPE:
  225. result_lines_arr.append(cur_result_text)
  226. else:
  227. result_lines_arr.append(f'{heading}\n{cur_result_text}')
  228. if data["next_cursor"] is None:
  229. break
  230. else:
  231. cur_block_id = data["next_cursor"]
  232. result_lines = "\n".join(result_lines_arr)
  233. return result_lines
  234. def _read_table_rows(self, block_id: str) -> str:
  235. """Read table rows."""
  236. done = False
  237. result_lines_arr = []
  238. cur_block_id = block_id
  239. while not done:
  240. block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
  241. query_dict: Dict[str, Any] = {}
  242. res = requests.request(
  243. "GET",
  244. block_url,
  245. headers={
  246. "Authorization": "Bearer " + self._notion_access_token,
  247. "Content-Type": "application/json",
  248. "Notion-Version": "2022-06-28",
  249. },
  250. json=query_dict
  251. )
  252. data = res.json()
  253. # get table headers text
  254. table_header_cell_texts = []
  255. tabel_header_cells = data["results"][0]['table_row']['cells']
  256. for tabel_header_cell in tabel_header_cells:
  257. if tabel_header_cell:
  258. for table_header_cell_text in tabel_header_cell:
  259. text = table_header_cell_text["text"]["content"]
  260. table_header_cell_texts.append(text)
  261. # get table columns text and format
  262. results = data["results"]
  263. for i in range(len(results) - 1):
  264. column_texts = []
  265. tabel_column_cells = data["results"][i + 1]['table_row']['cells']
  266. for j in range(len(tabel_column_cells)):
  267. if tabel_column_cells[j]:
  268. for table_column_cell_text in tabel_column_cells[j]:
  269. column_text = table_column_cell_text["text"]["content"]
  270. column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
  271. cur_result_text = "\n".join(column_texts)
  272. result_lines_arr.append(cur_result_text)
  273. if data["next_cursor"] is None:
  274. done = True
  275. break
  276. else:
  277. cur_block_id = data["next_cursor"]
  278. result_lines = "\n".join(result_lines_arr)
  279. return result_lines
  280. def update_last_edited_time(self, document_model: DocumentModel):
  281. if not document_model:
  282. return
  283. last_edited_time = self.get_notion_last_edited_time()
  284. data_source_info = document_model.data_source_info_dict
  285. data_source_info['last_edited_time'] = last_edited_time
  286. update_params = {
  287. DocumentModel.data_source_info: json.dumps(data_source_info)
  288. }
  289. DocumentModel.query.filter_by(id=document_model.id).update(update_params)
  290. db.session.commit()
  291. def get_notion_last_edited_time(self) -> str:
  292. obj_id = self._notion_obj_id
  293. page_type = self._notion_page_type
  294. if page_type == 'database':
  295. retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
  296. else:
  297. retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
  298. query_dict: Dict[str, Any] = {}
  299. res = requests.request(
  300. "GET",
  301. retrieve_page_url,
  302. headers={
  303. "Authorization": "Bearer " + self._notion_access_token,
  304. "Content-Type": "application/json",
  305. "Notion-Version": "2022-06-28",
  306. },
  307. json=query_dict
  308. )
  309. data = res.json()
  310. return data["last_edited_time"]
  311. @classmethod
  312. def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
  313. data_source_binding = DataSourceBinding.query.filter(
  314. db.and_(
  315. DataSourceBinding.tenant_id == tenant_id,
  316. DataSourceBinding.provider == 'notion',
  317. DataSourceBinding.disabled == False,
  318. DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
  319. )
  320. ).first()
  321. if not data_source_binding:
  322. raise Exception(f'No notion data source binding found for tenant {tenant_id} '
  323. f'and notion workspace {notion_workspace_id}')
  324. return data_source_binding.access_token