app.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import os
  2. from configs.app_config import DifyConfig
  3. if not os.environ.get("DEBUG") or os.environ.get("DEBUG", "false").lower() != 'true':
  4. from gevent import monkey
  5. monkey.patch_all()
  6. import grpc.experimental.gevent
  7. grpc.experimental.gevent.init_gevent()
  8. import json
  9. import logging
  10. import sys
  11. import threading
  12. import time
  13. import warnings
  14. from logging.handlers import RotatingFileHandler
  15. from flask import Flask, Response, request
  16. from flask_cors import CORS
  17. from werkzeug.exceptions import Unauthorized
  18. from commands import register_commands
  19. # DO NOT REMOVE BELOW
  20. from extensions import (
  21. ext_celery,
  22. ext_code_based_extension,
  23. ext_compress,
  24. ext_database,
  25. ext_hosting_provider,
  26. ext_login,
  27. ext_mail,
  28. ext_migrate,
  29. ext_redis,
  30. ext_sentry,
  31. ext_storage,
  32. )
  33. from extensions.ext_database import db
  34. from extensions.ext_login import login_manager
  35. from libs.passport import PassportService
  36. from services.account_service import AccountService
  37. # DO NOT REMOVE ABOVE
  38. warnings.simplefilter("ignore", ResourceWarning)
  39. # fix windows platform
  40. if os.name == "nt":
  41. os.system('tzutil /s "UTC"')
  42. else:
  43. os.environ['TZ'] = 'UTC'
  44. time.tzset()
  45. class DifyApp(Flask):
  46. pass
  47. # -------------
  48. # Configuration
  49. # -------------
  50. config_type = os.getenv('EDITION', default='SELF_HOSTED') # ce edition first
  51. # ----------------------------
  52. # Application Factory Function
  53. # ----------------------------
  54. def create_flask_app_with_configs() -> Flask:
  55. """
  56. create a raw flask app
  57. with configs loaded from .env file
  58. """
  59. dify_app = DifyApp(__name__)
  60. dify_app.config.from_mapping(DifyConfig().model_dump())
  61. # populate configs into system environment variables
  62. for key, value in dify_app.config.items():
  63. if isinstance(value, str):
  64. os.environ[key] = value
  65. elif isinstance(value, int | float | bool):
  66. os.environ[key] = str(value)
  67. elif value is None:
  68. os.environ[key] = ''
  69. return dify_app
  70. def create_app() -> Flask:
  71. app = create_flask_app_with_configs()
  72. app.secret_key = app.config['SECRET_KEY']
  73. log_handlers = None
  74. log_file = app.config.get('LOG_FILE')
  75. if log_file:
  76. log_dir = os.path.dirname(log_file)
  77. os.makedirs(log_dir, exist_ok=True)
  78. log_handlers = [
  79. RotatingFileHandler(
  80. filename=log_file,
  81. maxBytes=1024 * 1024 * 1024,
  82. backupCount=5
  83. ),
  84. logging.StreamHandler(sys.stdout)
  85. ]
  86. logging.basicConfig(
  87. level=app.config.get('LOG_LEVEL'),
  88. format=app.config.get('LOG_FORMAT'),
  89. datefmt=app.config.get('LOG_DATEFORMAT'),
  90. handlers=log_handlers,
  91. force=True
  92. )
  93. log_tz = app.config.get('LOG_TZ')
  94. if log_tz:
  95. from datetime import datetime
  96. import pytz
  97. timezone = pytz.timezone(log_tz)
  98. def time_converter(seconds):
  99. return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
  100. for handler in logging.root.handlers:
  101. handler.formatter.converter = time_converter
  102. initialize_extensions(app)
  103. register_blueprints(app)
  104. register_commands(app)
  105. return app
  106. def initialize_extensions(app):
  107. # Since the application instance is now created, pass it to each Flask
  108. # extension instance to bind it to the Flask application instance (app)
  109. ext_compress.init_app(app)
  110. ext_code_based_extension.init()
  111. ext_database.init_app(app)
  112. ext_migrate.init(app, db)
  113. ext_redis.init_app(app)
  114. ext_storage.init_app(app)
  115. ext_celery.init_app(app)
  116. ext_login.init_app(app)
  117. ext_mail.init_app(app)
  118. ext_hosting_provider.init_app(app)
  119. ext_sentry.init_app(app)
  120. # Flask-Login configuration
  121. @login_manager.request_loader
  122. def load_user_from_request(request_from_flask_login):
  123. """Load user based on the request."""
  124. if request.blueprint not in ['console', 'inner_api']:
  125. return None
  126. # Check if the user_id contains a dot, indicating the old format
  127. auth_header = request.headers.get('Authorization', '')
  128. if not auth_header:
  129. auth_token = request.args.get('_token')
  130. if not auth_token:
  131. raise Unauthorized('Invalid Authorization token.')
  132. else:
  133. if ' ' not in auth_header:
  134. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  135. auth_scheme, auth_token = auth_header.split(None, 1)
  136. auth_scheme = auth_scheme.lower()
  137. if auth_scheme != 'bearer':
  138. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  139. decoded = PassportService().verify(auth_token)
  140. user_id = decoded.get('user_id')
  141. return AccountService.load_logged_in_account(account_id=user_id, token=auth_token)
  142. @login_manager.unauthorized_handler
  143. def unauthorized_handler():
  144. """Handle unauthorized requests."""
  145. return Response(json.dumps({
  146. 'code': 'unauthorized',
  147. 'message': "Unauthorized."
  148. }), status=401, content_type="application/json")
  149. # register blueprint routers
  150. def register_blueprints(app):
  151. from controllers.console import bp as console_app_bp
  152. from controllers.files import bp as files_bp
  153. from controllers.inner_api import bp as inner_api_bp
  154. from controllers.service_api import bp as service_api_bp
  155. from controllers.web import bp as web_bp
  156. CORS(service_api_bp,
  157. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  158. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  159. )
  160. app.register_blueprint(service_api_bp)
  161. CORS(web_bp,
  162. resources={
  163. r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
  164. supports_credentials=True,
  165. allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
  166. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  167. expose_headers=['X-Version', 'X-Env']
  168. )
  169. app.register_blueprint(web_bp)
  170. CORS(console_app_bp,
  171. resources={
  172. r"/*": {"origins": app.config['CONSOLE_CORS_ALLOW_ORIGINS']}},
  173. supports_credentials=True,
  174. allow_headers=['Content-Type', 'Authorization'],
  175. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
  176. expose_headers=['X-Version', 'X-Env']
  177. )
  178. app.register_blueprint(console_app_bp)
  179. CORS(files_bp,
  180. allow_headers=['Content-Type'],
  181. methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
  182. )
  183. app.register_blueprint(files_bp)
  184. app.register_blueprint(inner_api_bp)
  185. # create app
  186. app = create_app()
  187. celery = app.extensions["celery"]
  188. if app.config.get('TESTING'):
  189. print("App is running in TESTING mode")
  190. @app.after_request
  191. def after_request(response):
  192. """Add Version headers to the response."""
  193. response.set_cookie('remember_token', '', expires=0)
  194. response.headers.add('X-Version', app.config['CURRENT_VERSION'])
  195. response.headers.add('X-Env', app.config['DEPLOY_ENV'])
  196. return response
  197. @app.route('/health')
  198. def health():
  199. return Response(json.dumps({
  200. 'status': 'ok',
  201. 'version': app.config['CURRENT_VERSION']
  202. }), status=200, content_type="application/json")
  203. @app.route('/threads')
  204. def threads():
  205. num_threads = threading.active_count()
  206. threads = threading.enumerate()
  207. thread_list = []
  208. for thread in threads:
  209. thread_name = thread.name
  210. thread_id = thread.ident
  211. is_alive = thread.is_alive()
  212. thread_list.append({
  213. 'name': thread_name,
  214. 'id': thread_id,
  215. 'is_alive': is_alive
  216. })
  217. return {
  218. 'thread_num': num_threads,
  219. 'threads': thread_list
  220. }
  221. @app.route('/db-pool-stat')
  222. def pool_stat():
  223. engine = db.engine
  224. return {
  225. 'pool_size': engine.pool.size(),
  226. 'checked_in_connections': engine.pool.checkedin(),
  227. 'checked_out_connections': engine.pool.checkedout(),
  228. 'overflow_connections': engine.pool.overflow(),
  229. 'connection_timeout': engine.pool.timeout(),
  230. 'recycle_time': db.engine.pool._recycle
  231. }
  232. if __name__ == '__main__':
  233. app.run(host='0.0.0.0', port=5001)