file_factory.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. import mimetypes
  2. from collections.abc import Callable, Mapping, Sequence
  3. from typing import Any
  4. import httpx
  5. from sqlalchemy import select
  6. from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
  7. from core.helper import ssrf_proxy
  8. from extensions.ext_database import db
  9. from models import MessageFile, ToolFile, UploadFile
  10. def build_from_message_files(
  11. *,
  12. message_files: Sequence["MessageFile"],
  13. tenant_id: str,
  14. config: FileUploadConfig,
  15. ) -> Sequence[File]:
  16. results = [
  17. build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
  18. for file in message_files
  19. if file.belongs_to != FileBelongsTo.ASSISTANT
  20. ]
  21. return results
  22. def build_from_message_file(
  23. *,
  24. message_file: "MessageFile",
  25. tenant_id: str,
  26. config: FileUploadConfig,
  27. ):
  28. mapping = {
  29. "transfer_method": message_file.transfer_method,
  30. "url": message_file.url,
  31. "id": message_file.id,
  32. "type": message_file.type,
  33. "upload_file_id": message_file.upload_file_id,
  34. }
  35. return build_from_mapping(
  36. mapping=mapping,
  37. tenant_id=tenant_id,
  38. config=config,
  39. )
  40. def build_from_mapping(
  41. *,
  42. mapping: Mapping[str, Any],
  43. tenant_id: str,
  44. config: FileUploadConfig | None = None,
  45. ) -> File:
  46. config = config or FileUploadConfig()
  47. transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
  48. build_functions: dict[FileTransferMethod, Callable] = {
  49. FileTransferMethod.LOCAL_FILE: _build_from_local_file,
  50. FileTransferMethod.REMOTE_URL: _build_from_remote_url,
  51. FileTransferMethod.TOOL_FILE: _build_from_tool_file,
  52. }
  53. build_func = build_functions.get(transfer_method)
  54. if not build_func:
  55. raise ValueError(f"Invalid file transfer method: {transfer_method}")
  56. file = build_func(
  57. mapping=mapping,
  58. tenant_id=tenant_id,
  59. transfer_method=transfer_method,
  60. )
  61. if not _is_file_valid_with_config(file=file, config=config):
  62. raise ValueError(f"File validation failed for file: {file.filename}")
  63. return file
  64. def build_from_mappings(
  65. *,
  66. mappings: Sequence[Mapping[str, Any]],
  67. config: FileUploadConfig | None,
  68. tenant_id: str,
  69. ) -> Sequence[File]:
  70. if not config:
  71. return []
  72. files = [
  73. build_from_mapping(
  74. mapping=mapping,
  75. tenant_id=tenant_id,
  76. config=config,
  77. )
  78. for mapping in mappings
  79. ]
  80. if (
  81. # If image config is set.
  82. config.image_config
  83. # And the number of image files exceeds the maximum limit
  84. and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
  85. ):
  86. raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
  87. if config.number_limits and len(files) > config.number_limits:
  88. raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
  89. return files
  90. def _build_from_local_file(
  91. *,
  92. mapping: Mapping[str, Any],
  93. tenant_id: str,
  94. transfer_method: FileTransferMethod,
  95. ) -> File:
  96. file_type = FileType.value_of(mapping.get("type"))
  97. stmt = select(UploadFile).where(
  98. UploadFile.id == mapping.get("upload_file_id"),
  99. UploadFile.tenant_id == tenant_id,
  100. )
  101. row = db.session.scalar(stmt)
  102. if row is None:
  103. raise ValueError("Invalid upload file")
  104. return File(
  105. id=mapping.get("id"),
  106. filename=row.name,
  107. extension="." + row.extension,
  108. mime_type=row.mime_type,
  109. tenant_id=tenant_id,
  110. type=file_type,
  111. transfer_method=transfer_method,
  112. remote_url=row.source_url,
  113. related_id=mapping.get("upload_file_id"),
  114. size=row.size,
  115. )
  116. def _build_from_remote_url(
  117. *,
  118. mapping: Mapping[str, Any],
  119. tenant_id: str,
  120. transfer_method: FileTransferMethod,
  121. ) -> File:
  122. url = mapping.get("url")
  123. if not url:
  124. raise ValueError("Invalid file url")
  125. mime_type, filename, file_size = _get_remote_file_info(url)
  126. extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
  127. return File(
  128. id=mapping.get("id"),
  129. filename=filename,
  130. tenant_id=tenant_id,
  131. type=FileType.value_of(mapping.get("type")),
  132. transfer_method=transfer_method,
  133. remote_url=url,
  134. mime_type=mime_type,
  135. extension=extension,
  136. size=file_size,
  137. )
  138. def _get_remote_file_info(url: str):
  139. mime_type = mimetypes.guess_type(url)[0] or ""
  140. file_size = -1
  141. filename = url.split("/")[-1].split("?")[0] or "unknown_file"
  142. resp = ssrf_proxy.head(url, follow_redirects=True)
  143. if resp.status_code == httpx.codes.OK:
  144. if content_disposition := resp.headers.get("Content-Disposition"):
  145. filename = str(content_disposition.split("filename=")[-1].strip('"'))
  146. file_size = int(resp.headers.get("Content-Length", file_size))
  147. mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
  148. return mime_type, filename, file_size
  149. def _build_from_tool_file(
  150. *,
  151. mapping: Mapping[str, Any],
  152. tenant_id: str,
  153. transfer_method: FileTransferMethod,
  154. ) -> File:
  155. tool_file = (
  156. db.session.query(ToolFile)
  157. .filter(
  158. ToolFile.id == mapping.get("tool_file_id"),
  159. ToolFile.tenant_id == tenant_id,
  160. )
  161. .first()
  162. )
  163. if tool_file is None:
  164. raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
  165. extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
  166. return File(
  167. id=mapping.get("id"),
  168. tenant_id=tenant_id,
  169. filename=tool_file.name,
  170. type=FileType.value_of(mapping.get("type")),
  171. transfer_method=transfer_method,
  172. remote_url=tool_file.original_url,
  173. related_id=tool_file.id,
  174. extension=extension,
  175. mime_type=tool_file.mimetype,
  176. size=tool_file.size,
  177. )
  178. def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
  179. if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
  180. return False
  181. if config.allowed_extensions and file.extension not in config.allowed_extensions:
  182. return False
  183. if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
  184. return False
  185. if file.type == FileType.IMAGE and config.image_config:
  186. if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
  187. return False
  188. return True