oauth_data_source.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import urllib.parse
  2. import requests
  3. from flask_login import current_user
  4. from extensions.ext_database import db
  5. from models.source import DataSourceBinding
  6. class OAuthDataSource:
  7. def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
  8. self.client_id = client_id
  9. self.client_secret = client_secret
  10. self.redirect_uri = redirect_uri
  11. def get_authorization_url(self):
  12. raise NotImplementedError()
  13. def get_access_token(self, code: str):
  14. raise NotImplementedError()
  15. class NotionOAuth(OAuthDataSource):
  16. _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize'
  17. _TOKEN_URL = 'https://api.notion.com/v1/oauth/token'
  18. _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search"
  19. _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
  20. _NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
  21. def get_authorization_url(self):
  22. params = {
  23. 'client_id': self.client_id,
  24. 'response_type': 'code',
  25. 'redirect_uri': self.redirect_uri,
  26. 'owner': 'user'
  27. }
  28. return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
  29. def get_access_token(self, code: str):
  30. data = {
  31. 'code': code,
  32. 'grant_type': 'authorization_code',
  33. 'redirect_uri': self.redirect_uri
  34. }
  35. headers = {'Accept': 'application/json'}
  36. auth = (self.client_id, self.client_secret)
  37. response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
  38. response_json = response.json()
  39. access_token = response_json.get('access_token')
  40. if not access_token:
  41. raise ValueError(f"Error in Notion OAuth: {response_json}")
  42. workspace_name = response_json.get('workspace_name')
  43. workspace_icon = response_json.get('workspace_icon')
  44. workspace_id = response_json.get('workspace_id')
  45. # get all authorized pages
  46. pages = self.get_authorized_pages(access_token)
  47. source_info = {
  48. 'workspace_name': workspace_name,
  49. 'workspace_icon': workspace_icon,
  50. 'workspace_id': workspace_id,
  51. 'pages': pages,
  52. 'total': len(pages)
  53. }
  54. # save data source binding
  55. data_source_binding = DataSourceBinding.query.filter(
  56. db.and_(
  57. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  58. DataSourceBinding.provider == 'notion',
  59. DataSourceBinding.access_token == access_token
  60. )
  61. ).first()
  62. if data_source_binding:
  63. data_source_binding.source_info = source_info
  64. data_source_binding.disabled = False
  65. db.session.commit()
  66. else:
  67. new_data_source_binding = DataSourceBinding(
  68. tenant_id=current_user.current_tenant_id,
  69. access_token=access_token,
  70. source_info=source_info,
  71. provider='notion'
  72. )
  73. db.session.add(new_data_source_binding)
  74. db.session.commit()
  75. def save_internal_access_token(self, access_token: str):
  76. workspace_name = self.notion_workspace_name(access_token)
  77. workspace_icon = None
  78. workspace_id = current_user.current_tenant_id
  79. # get all authorized pages
  80. pages = self.get_authorized_pages(access_token)
  81. source_info = {
  82. 'workspace_name': workspace_name,
  83. 'workspace_icon': workspace_icon,
  84. 'workspace_id': workspace_id,
  85. 'pages': pages,
  86. 'total': len(pages)
  87. }
  88. # save data source binding
  89. data_source_binding = DataSourceBinding.query.filter(
  90. db.and_(
  91. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  92. DataSourceBinding.provider == 'notion',
  93. DataSourceBinding.access_token == access_token
  94. )
  95. ).first()
  96. if data_source_binding:
  97. data_source_binding.source_info = source_info
  98. data_source_binding.disabled = False
  99. db.session.commit()
  100. else:
  101. new_data_source_binding = DataSourceBinding(
  102. tenant_id=current_user.current_tenant_id,
  103. access_token=access_token,
  104. source_info=source_info,
  105. provider='notion'
  106. )
  107. db.session.add(new_data_source_binding)
  108. db.session.commit()
  109. def sync_data_source(self, binding_id: str):
  110. # save data source binding
  111. data_source_binding = DataSourceBinding.query.filter(
  112. db.and_(
  113. DataSourceBinding.tenant_id == current_user.current_tenant_id,
  114. DataSourceBinding.provider == 'notion',
  115. DataSourceBinding.id == binding_id,
  116. DataSourceBinding.disabled == False
  117. )
  118. ).first()
  119. if data_source_binding:
  120. # get all authorized pages
  121. pages = self.get_authorized_pages(data_source_binding.access_token)
  122. source_info = data_source_binding.source_info
  123. new_source_info = {
  124. 'workspace_name': source_info['workspace_name'],
  125. 'workspace_icon': source_info['workspace_icon'],
  126. 'workspace_id': source_info['workspace_id'],
  127. 'pages': pages,
  128. 'total': len(pages)
  129. }
  130. data_source_binding.source_info = new_source_info
  131. data_source_binding.disabled = False
  132. db.session.commit()
  133. else:
  134. raise ValueError('Data source binding not found')
  135. def get_authorized_pages(self, access_token: str):
  136. pages = []
  137. page_results = self.notion_page_search(access_token)
  138. database_results = self.notion_database_search(access_token)
  139. # get page detail
  140. for page_result in page_results:
  141. page_id = page_result['id']
  142. if 'Name' in page_result['properties']:
  143. if len(page_result['properties']['Name']['title']) > 0:
  144. page_name = page_result['properties']['Name']['title'][0]['plain_text']
  145. else:
  146. page_name = 'Untitled'
  147. elif 'title' in page_result['properties']:
  148. if len(page_result['properties']['title']['title']) > 0:
  149. page_name = page_result['properties']['title']['title'][0]['plain_text']
  150. else:
  151. page_name = 'Untitled'
  152. elif 'Title' in page_result['properties']:
  153. if len(page_result['properties']['Title']['title']) > 0:
  154. page_name = page_result['properties']['Title']['title'][0]['plain_text']
  155. else:
  156. page_name = 'Untitled'
  157. else:
  158. page_name = 'Untitled'
  159. page_icon = page_result['icon']
  160. if page_icon:
  161. icon_type = page_icon['type']
  162. if icon_type == 'external' or icon_type == 'file':
  163. url = page_icon[icon_type]['url']
  164. icon = {
  165. 'type': 'url',
  166. 'url': url if url.startswith('http') else f'https://www.notion.so{url}'
  167. }
  168. else:
  169. icon = {
  170. 'type': 'emoji',
  171. 'emoji': page_icon[icon_type]
  172. }
  173. else:
  174. icon = None
  175. parent = page_result['parent']
  176. parent_type = parent['type']
  177. if parent_type == 'block_id':
  178. parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
  179. elif parent_type == 'workspace':
  180. parent_id = 'root'
  181. else:
  182. parent_id = parent[parent_type]
  183. page = {
  184. 'page_id': page_id,
  185. 'page_name': page_name,
  186. 'page_icon': icon,
  187. 'parent_id': parent_id,
  188. 'type': 'page'
  189. }
  190. pages.append(page)
  191. # get database detail
  192. for database_result in database_results:
  193. page_id = database_result['id']
  194. if len(database_result['title']) > 0:
  195. page_name = database_result['title'][0]['plain_text']
  196. else:
  197. page_name = 'Untitled'
  198. page_icon = database_result['icon']
  199. if page_icon:
  200. icon_type = page_icon['type']
  201. if icon_type == 'external' or icon_type == 'file':
  202. url = page_icon[icon_type]['url']
  203. icon = {
  204. 'type': 'url',
  205. 'url': url if url.startswith('http') else f'https://www.notion.so{url}'
  206. }
  207. else:
  208. icon = {
  209. 'type': icon_type,
  210. icon_type: page_icon[icon_type]
  211. }
  212. else:
  213. icon = None
  214. parent = database_result['parent']
  215. parent_type = parent['type']
  216. if parent_type == 'block_id':
  217. parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type])
  218. elif parent_type == 'workspace':
  219. parent_id = 'root'
  220. else:
  221. parent_id = parent[parent_type]
  222. page = {
  223. 'page_id': page_id,
  224. 'page_name': page_name,
  225. 'page_icon': icon,
  226. 'parent_id': parent_id,
  227. 'type': 'database'
  228. }
  229. pages.append(page)
  230. return pages
  231. def notion_page_search(self, access_token: str):
  232. data = {
  233. 'filter': {
  234. "value": "page",
  235. "property": "object"
  236. }
  237. }
  238. headers = {
  239. 'Content-Type': 'application/json',
  240. 'Authorization': f"Bearer {access_token}",
  241. 'Notion-Version': '2022-06-28',
  242. }
  243. response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
  244. response_json = response.json()
  245. if 'results' in response_json:
  246. results = response_json['results']
  247. else:
  248. results = []
  249. return results
  250. def notion_block_parent_page_id(self, access_token: str, block_id: str):
  251. headers = {
  252. 'Authorization': f"Bearer {access_token}",
  253. 'Notion-Version': '2022-06-28',
  254. }
  255. response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers)
  256. response_json = response.json()
  257. parent = response_json['parent']
  258. parent_type = parent['type']
  259. if parent_type == 'block_id':
  260. return self.notion_block_parent_page_id(access_token, parent[parent_type])
  261. return parent[parent_type]
  262. def notion_workspace_name(self, access_token: str):
  263. headers = {
  264. 'Authorization': f"Bearer {access_token}",
  265. 'Notion-Version': '2022-06-28',
  266. }
  267. response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
  268. response_json = response.json()
  269. if 'object' in response_json and response_json['object'] == 'user':
  270. user_type = response_json['type']
  271. user_info = response_json[user_type]
  272. if 'workspace_name' in user_info:
  273. return user_info['workspace_name']
  274. return 'workspace'
  275. def notion_database_search(self, access_token: str):
  276. data = {
  277. 'filter': {
  278. "value": "database",
  279. "property": "object"
  280. }
  281. }
  282. headers = {
  283. 'Content-Type': 'application/json',
  284. 'Authorization': f"Bearer {access_token}",
  285. 'Notion-Version': '2022-06-28',
  286. }
  287. response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
  288. response_json = response.json()
  289. if 'results' in response_json:
  290. results = response_json['results']
  291. else:
  292. results = []
  293. return results