oauth_data_source.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import json
  2. import urllib.parse
  3. import requests
  4. from flask_login import current_user
  5. from extensions.ext_database import db
  6. from models.source import DataSourceBinding
  7. class OAuthDataSource:
  8. def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
  9. self.client_id = client_id
  10. self.client_secret = client_secret
  11. self.redirect_uri = redirect_uri
  12. def get_authorization_url(self):
  13. raise NotImplementedError()
  14. def get_access_token(self, code: str):
  15. raise NotImplementedError()
  16. class NotionOAuth(OAuthDataSource):
  17. _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
  18. _TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
  19. _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
  20. _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
  21. _NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
  22. def get_authorization_url(self):
  23. params = {
  24. 'client_id': self.client_id,
  25. 'response_type': 'code',
  26. 'redirect_uri': self.redirect_uri,
  27. 'owner': 'user'
  28. }
  29. return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
  30. def get_access_token(self, code: str):
  31. data = {
  32. 'code': code,
  33. 'grant_type': 'authorization_code',
  34. 'redirect_uri': self.redirect_uri
  35. }
  36. headers = {'Accept': 'application/json'}
  37. auth = (self.client_id, self.client_secret)
  38. response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
  39. response_json = response.json()
  40. access_token = response_json.get('access_token')
  41. if not access_token:
  42. raise ValueError(f"Error in Notion OAuth: {response_json}")
  43. workspace_name = response_json.get('workspace_name')
  44. workspace_icon = response_json.get('workspace_icon')
  45. workspace_id = response_json.get('workspace_id')
  46. # get all authorized pages
  47. pages = self.get_authorized_pages(access_token)
  48. source_info = {
  49. 'workspace_name': workspace_name,
  50. 'workspace_icon': workspace_icon,
  51. 'workspace_id': workspace_id,
  52. 'pages': pages,
  53. 'total': len(pages)
  54. }
  55. # save data source binding
  56. data_source_binding = DataSourceBinding.query.filter(
  57. db.and_(
  58. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  59. DataSourceBinding.provider == 'notion',
  60. DataSourceBinding.access_token == access_token
  61. )
  62. ).first()
  63. if data_source_binding:
  64. data_source_binding.source_info = source_info
  65. data_source_binding.disabled = False
  66. db.session.commit()
  67. else:
  68. new_data_source_binding = DataSourceBinding(
  69. tenant_id=current_user.current_tenant_id,
  70. access_token=access_token,
  71. source_info=source_info,
  72. provider='notion'
  73. )
  74. db.session.add(new_data_source_binding)
  75. db.session.commit()
  76. def save_internal_access_token(self, access_token: str):
  77. workspace_name = self.notion_workspace_name(access_token)
  78. workspace_icon = None
  79. workspace_id = current_user.current_tenant_id
  80. # get all authorized pages
  81. pages = self.get_authorized_pages(access_token)
  82. source_info = {
  83. 'workspace_name': workspace_name,
  84. 'workspace_icon': workspace_icon,
  85. 'workspace_id': workspace_id,
  86. 'pages': pages,
  87. 'total': len(pages)
  88. }
  89. # save data source binding
  90. data_source_binding = DataSourceBinding.query.filter(
  91. db.and_(
  92. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  93. DataSourceBinding.provider == 'notion',
  94. DataSourceBinding.access_token == access_token
  95. )
  96. ).first()
  97. if data_source_binding:
  98. data_source_binding.source_info = source_info
  99. data_source_binding.disabled = False
  100. db.session.commit()
  101. else:
  102. new_data_source_binding = DataSourceBinding(
  103. tenant_id=current_user.current_tenant_id,
  104. access_token=access_token,
  105. source_info=source_info,
  106. provider='notion'
  107. )
  108. db.session.add(new_data_source_binding)
  109. db.session.commit()
  110. def sync_data_source(self, binding_id: str):
  111. # save data source binding
  112. data_source_binding = DataSourceBinding.query.filter(
  113. db.and_(
  114. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  115. DataSourceBinding.provider == 'notion',
  116. DataSourceBinding.id == binding_id,
  117. DataSourceBinding.disabled == False
  118. )
  119. ).first()
  120. if data_source_binding:
  121. # get all authorized pages
  122. pages = self.get_authorized_pages(data_source_binding.access_token)
  123. source_info = data_source_binding.source_info
  124. new_source_info = {
  125. 'workspace_name': source_info['workspace_name'],
  126. 'workspace_icon': source_info['workspace_icon'],
  127. 'workspace_id': source_info['workspace_id'],
  128. 'pages': pages,
  129. 'total': len(pages)
  130. }
  131. data_source_binding.source_info = new_source_info
  132. data_source_binding.disabled = False
  133. db.session.commit()
  134. else:
  135. raise ValueError('Data source binding not found')
  136. def get_authorized_pages(self, access_token: str):
  137. pages = []
  138. page_results = self.notion_page_search(access_token)
  139. database_results = self.notion_database_search(access_token)
  140. # get page detail
  141. for page_result in page_results:
  142. page_id = page_result['id']
  143. if 'Name' in page_result['properties']:
  144. if len(page_result['properties']['Name']['title']) > 0:
  145. page_name = page_result['properties']['Name']['title'][0]['plain_text']
  146. else:
  147. page_name = 'Untitled'
  148. elif 'title' in page_result['properties']:
  149. if len(page_result['properties']['title']['title']) > 0:
  150. page_name = page_result['properties']['title']['title'][0]['plain_text']
  151. else:
  152. page_name = 'Untitled'
  153. elif 'Title' in page_result['properties']:
  154. if len(page_result['properties']['Title']['title']) > 0:
  155. page_name = page_result['properties']['Title']['title'][0]['plain_text']
  156. else:
  157. page_name = 'Untitled'
  158. else:
  159. page_name = 'Untitled'
  160. page_icon = page_result['icon']
  161. if page_icon:
  162. icon_type = page_icon['type']
  163. if icon_type == 'external' or icon_type == 'file':
  164. url = page_icon[icon_type]['url']
  165. icon = {
  166. 'type': 'url',
  167. 'url': url if url.startswith('http') else f'https://www.notion.so{url}'
  168. }
  169. else:
  170. icon = {
  171. 'type': 'emoji',
  172. 'emoji': page_icon[icon_type]
  173. }
  174. else:
  175. icon = None
  176. parent = page_result['parent']
  177. parent_type = parent['type']
  178. if parent_type == 'block_id':
  179. parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
  180. elif parent_type == 'workspace':
  181. parent_id = 'root'
  182. else:
  183. parent_id = parent[parent_type]
  184. page = {
  185. 'page_id': page_id,
  186. 'page_name': page_name,
  187. 'page_icon': icon,
  188. 'parent_id': parent_id,
  189. 'type': 'page'
  190. }
  191. pages.append(page)
  192. # get database detail
  193. for database_result in database_results:
  194. page_id = database_result['id']
  195. if len(database_result['title']) > 0:
  196. page_name = database_result['title'][0]['plain_text']
  197. else:
  198. page_name = 'Untitled'
  199. page_icon = database_result['icon']
  200. if page_icon:
  201. icon_type = page_icon['type']
  202. if icon_type == 'external' or icon_type == 'file':
  203. url = page_icon[icon_type]['url']
  204. icon = {
  205. 'type': 'url',
  206. 'url': url if url.startswith('http') else f'https://www.notion.so{url}'
  207. }
  208. else:
  209. icon = {
  210. 'type': icon_type,
  211. icon_type: page_icon[icon_type]
  212. }
  213. else:
  214. icon = None
  215. parent = database_result['parent']
  216. parent_type = parent['type']
  217. if parent_type == 'block_id':
  218. parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
  219. elif parent_type == 'workspace':
  220. parent_id = 'root'
  221. else:
  222. parent_id = parent[parent_type]
  223. page = {
  224. 'page_id': page_id,
  225. 'page_name': page_name,
  226. 'page_icon': icon,
  227. 'parent_id': parent_id,
  228. 'type': 'database'
  229. }
  230. pages.append(page)
  231. return pages
  232. def notion_page_search(self, access_token: str):
  233. data = {
  234. 'filter': {
  235. "value": "page",
  236. "property": "object"
  237. }
  238. }
  239. headers = {
  240. 'Content-Type': 'application/json',
  241. 'Authorization': f"Bearer {access_token}",
  242. 'Notion-Version': '2022-06-28',
  243. }
  244. response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
  245. response_json = response.json()
  246. if 'results' in response_json:
  247. results = response_json['results']
  248. else:
  249. results = []
  250. return results
  251. def notion_block_parent_page_id(self, access_token: str, block_id: str):
  252. headers = {
  253. 'Authorization': f"Bearer {access_token}",
  254. 'Notion-Version': '2022-06-28',
  255. }
  256. response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers)
  257. response_json = response.json()
  258. parent = response_json['parent']
  259. parent_type = parent['type']
  260. if parent_type == 'block_id':
  261. return self.notion_block_parent_page_id(access_token, parent[parent_type])
  262. return parent[parent_type]
  263. def notion_workspace_name(self, access_token: str):
  264. headers = {
  265. 'Authorization': f"Bearer {access_token}",
  266. 'Notion-Version': '2022-06-28',
  267. }
  268. response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
  269. response_json = response.json()
  270. if 'object' in response_json and response_json['object'] == 'user':
  271. user_type = response_json['type']
  272. user_info = response_json[user_type]
  273. if 'workspace_name' in user_info:
  274. return user_info['workspace_name']
  275. return 'workspace'
  276. def notion_database_search(self, access_token: str):
  277. data = {
  278. 'filter': {
  279. "value": "database",
  280. "property": "object"
  281. }
  282. }
  283. headers = {
  284. 'Content-Type': 'application/json',
  285. 'Authorization': f"Bearer {access_token}",
  286. 'Notion-Version': '2022-06-28',
  287. }
  288. response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
  289. response_json = response.json()
  290. if 'results' in response_json:
  291. results = response_json['results']
  292. else:
  293. results = []
  294. return results