account.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import datetime
  2. import pytz
  3. from flask import request
  4. from flask_login import current_user
  5. from flask_restful import Resource, fields, marshal_with, reqparse
  6. from configs import dify_config
  7. from constants.languages import supported_language
  8. from controllers.console import api
  9. from controllers.console.setup import setup_required
  10. from controllers.console.workspace.error import (
  11. AccountAlreadyInitedError,
  12. CurrentPasswordIncorrectError,
  13. InvalidInvitationCodeError,
  14. RepeatPasswordNotMatchError,
  15. )
  16. from controllers.console.wraps import account_initialization_required
  17. from extensions.ext_database import db
  18. from fields.member_fields import account_fields
  19. from libs.helper import TimestampField, timezone
  20. from libs.login import login_required
  21. from models.account import AccountIntegrate, InvitationCode
  22. from services.account_service import AccountService
  23. from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
  24. class AccountInitApi(Resource):
  25. @setup_required
  26. @login_required
  27. def post(self):
  28. account = current_user
  29. if account.status == 'active':
  30. raise AccountAlreadyInitedError()
  31. parser = reqparse.RequestParser()
  32. if dify_config.EDITION == 'CLOUD':
  33. parser.add_argument('invitation_code', type=str, location='json')
  34. parser.add_argument(
  35. 'interface_language', type=supported_language, required=True, location='json')
  36. parser.add_argument('timezone', type=timezone,
  37. required=True, location='json')
  38. args = parser.parse_args()
  39. if dify_config.EDITION == 'CLOUD':
  40. if not args['invitation_code']:
  41. raise ValueError('invitation_code is required')
  42. # check invitation code
  43. invitation_code = db.session.query(InvitationCode).filter(
  44. InvitationCode.code == args['invitation_code'],
  45. InvitationCode.status == 'unused',
  46. ).first()
  47. if not invitation_code:
  48. raise InvalidInvitationCodeError()
  49. invitation_code.status = 'used'
  50. invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  51. invitation_code.used_by_tenant_id = account.current_tenant_id
  52. invitation_code.used_by_account_id = account.id
  53. account.interface_language = args['interface_language']
  54. account.timezone = args['timezone']
  55. account.interface_theme = 'light'
  56. account.status = 'active'
  57. account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  58. db.session.commit()
  59. return {'result': 'success'}
  60. class AccountProfileApi(Resource):
  61. @setup_required
  62. @login_required
  63. @account_initialization_required
  64. @marshal_with(account_fields)
  65. def get(self):
  66. return current_user
  67. class AccountNameApi(Resource):
  68. @setup_required
  69. @login_required
  70. @account_initialization_required
  71. @marshal_with(account_fields)
  72. def post(self):
  73. parser = reqparse.RequestParser()
  74. parser.add_argument('name', type=str, required=True, location='json')
  75. args = parser.parse_args()
  76. # Validate account name length
  77. if len(args['name']) < 3 or len(args['name']) > 30:
  78. raise ValueError(
  79. "Account name must be between 3 and 30 characters.")
  80. updated_account = AccountService.update_account(current_user, name=args['name'])
  81. return updated_account
  82. class AccountAvatarApi(Resource):
  83. @setup_required
  84. @login_required
  85. @account_initialization_required
  86. @marshal_with(account_fields)
  87. def post(self):
  88. parser = reqparse.RequestParser()
  89. parser.add_argument('avatar', type=str, required=True, location='json')
  90. args = parser.parse_args()
  91. updated_account = AccountService.update_account(current_user, avatar=args['avatar'])
  92. return updated_account
  93. class AccountInterfaceLanguageApi(Resource):
  94. @setup_required
  95. @login_required
  96. @account_initialization_required
  97. @marshal_with(account_fields)
  98. def post(self):
  99. parser = reqparse.RequestParser()
  100. parser.add_argument(
  101. 'interface_language', type=supported_language, required=True, location='json')
  102. args = parser.parse_args()
  103. updated_account = AccountService.update_account(current_user, interface_language=args['interface_language'])
  104. return updated_account
  105. class AccountInterfaceThemeApi(Resource):
  106. @setup_required
  107. @login_required
  108. @account_initialization_required
  109. @marshal_with(account_fields)
  110. def post(self):
  111. parser = reqparse.RequestParser()
  112. parser.add_argument('interface_theme', type=str, choices=[
  113. 'light', 'dark'], required=True, location='json')
  114. args = parser.parse_args()
  115. updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme'])
  116. return updated_account
  117. class AccountTimezoneApi(Resource):
  118. @setup_required
  119. @login_required
  120. @account_initialization_required
  121. @marshal_with(account_fields)
  122. def post(self):
  123. parser = reqparse.RequestParser()
  124. parser.add_argument('timezone', type=str,
  125. required=True, location='json')
  126. args = parser.parse_args()
  127. # Validate timezone string, e.g. America/New_York, Asia/Shanghai
  128. if args['timezone'] not in pytz.all_timezones:
  129. raise ValueError("Invalid timezone string.")
  130. updated_account = AccountService.update_account(current_user, timezone=args['timezone'])
  131. return updated_account
  132. class AccountPasswordApi(Resource):
  133. @setup_required
  134. @login_required
  135. @account_initialization_required
  136. @marshal_with(account_fields)
  137. def post(self):
  138. parser = reqparse.RequestParser()
  139. parser.add_argument('password', type=str,
  140. required=False, location='json')
  141. parser.add_argument('new_password', type=str,
  142. required=True, location='json')
  143. parser.add_argument('repeat_new_password', type=str,
  144. required=True, location='json')
  145. args = parser.parse_args()
  146. if args['new_password'] != args['repeat_new_password']:
  147. raise RepeatPasswordNotMatchError()
  148. try:
  149. AccountService.update_account_password(
  150. current_user, args['password'], args['new_password'])
  151. except ServiceCurrentPasswordIncorrectError:
  152. raise CurrentPasswordIncorrectError()
  153. return {"result": "success"}
  154. class AccountIntegrateApi(Resource):
  155. integrate_fields = {
  156. 'provider': fields.String,
  157. 'created_at': TimestampField,
  158. 'is_bound': fields.Boolean,
  159. 'link': fields.String
  160. }
  161. integrate_list_fields = {
  162. 'data': fields.List(fields.Nested(integrate_fields)),
  163. }
  164. @setup_required
  165. @login_required
  166. @account_initialization_required
  167. @marshal_with(integrate_list_fields)
  168. def get(self):
  169. account = current_user
  170. account_integrates = db.session.query(AccountIntegrate).filter(
  171. AccountIntegrate.account_id == account.id).all()
  172. base_url = request.url_root.rstrip('/')
  173. oauth_base_path = "/console/api/oauth/login"
  174. providers = ["github", "google"]
  175. integrate_data = []
  176. for provider in providers:
  177. existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None)
  178. if existing_integrate:
  179. integrate_data.append({
  180. 'id': existing_integrate.id,
  181. 'provider': provider,
  182. 'created_at': existing_integrate.created_at,
  183. 'is_bound': True,
  184. 'link': None
  185. })
  186. else:
  187. integrate_data.append({
  188. 'id': None,
  189. 'provider': provider,
  190. 'created_at': None,
  191. 'is_bound': False,
  192. 'link': f'{base_url}{oauth_base_path}/{provider}'
  193. })
  194. return {'data': integrate_data}
  195. # Register API resources
  196. api.add_resource(AccountInitApi, '/account/init')
  197. api.add_resource(AccountProfileApi, '/account/profile')
  198. api.add_resource(AccountNameApi, '/account/name')
  199. api.add_resource(AccountAvatarApi, '/account/avatar')
  200. api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language')
  201. api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme')
  202. api.add_resource(AccountTimezoneApi, '/account/timezone')
  203. api.add_resource(AccountPasswordApi, '/account/password')
  204. api.add_resource(AccountIntegrateApi, '/account/integrates')
  205. # api.add_resource(AccountEmailApi, '/account/email')
  206. # api.add_resource(AccountEmailVerifyApi, '/account/email-verify')