file_factory.py 7.6 KB


  1. import mimetypes
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any
  4. import httpx
  5. from sqlalchemy import select
  6. from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
  7. from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
  8. from core.helper import ssrf_proxy
  9. from extensions.ext_database import db
  10. from models import MessageFile, ToolFile, UploadFile
  11. from models.enums import CreatedByRole
  12. def build_from_message_files(
  13. *,
  14. message_files: Sequence["MessageFile"],
  15. tenant_id: str,
  16. config: FileExtraConfig,
  17. ) -> Sequence[File]:
  18. results = [
  19. build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
  20. for file in message_files
  21. if file.belongs_to != FileBelongsTo.ASSISTANT
  22. ]
  23. return results
  24. def build_from_message_file(
  25. *,
  26. message_file: "MessageFile",
  27. tenant_id: str,
  28. config: FileExtraConfig,
  29. ):
  30. mapping = {
  31. "transfer_method": message_file.transfer_method,
  32. "url": message_file.url,
  33. "id": message_file.id,
  34. "type": message_file.type,
  35. "upload_file_id": message_file.upload_file_id,
  36. }
  37. return build_from_mapping(
  38. mapping=mapping,
  39. tenant_id=tenant_id,
  40. user_id=message_file.created_by,
  41. role=CreatedByRole(message_file.created_by_role),
  42. config=config,
  43. )
  44. def build_from_mapping(
  45. *,
  46. mapping: Mapping[str, Any],
  47. tenant_id: str,
  48. user_id: str,
  49. role: "CreatedByRole",
  50. config: FileExtraConfig,
  51. ):
  52. transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
  53. match transfer_method:
  54. case FileTransferMethod.REMOTE_URL:
  55. file = _build_from_remote_url(
  56. mapping=mapping,
  57. tenant_id=tenant_id,
  58. config=config,
  59. transfer_method=transfer_method,
  60. )
  61. case FileTransferMethod.LOCAL_FILE:
  62. file = _build_from_local_file(
  63. mapping=mapping,
  64. tenant_id=tenant_id,
  65. user_id=user_id,
  66. role=role,
  67. config=config,
  68. transfer_method=transfer_method,
  69. )
  70. case FileTransferMethod.TOOL_FILE:
  71. file = _build_from_tool_file(
  72. mapping=mapping,
  73. tenant_id=tenant_id,
  74. user_id=user_id,
  75. config=config,
  76. transfer_method=transfer_method,
  77. )
  78. case _:
  79. raise ValueError(f"Invalid file transfer method: {transfer_method}")
  80. return file
  81. def build_from_mappings(
  82. *,
  83. mappings: Sequence[Mapping[str, Any]],
  84. config: FileExtraConfig | None,
  85. tenant_id: str,
  86. user_id: str,
  87. role: "CreatedByRole",
  88. ) -> Sequence[File]:
  89. if not config:
  90. return []
  91. files = [
  92. build_from_mapping(
  93. mapping=mapping,
  94. tenant_id=tenant_id,
  95. user_id=user_id,
  96. role=role,
  97. config=config,
  98. )
  99. for mapping in mappings
  100. ]
  101. if (
  102. # If image config is set.
  103. config.image_config
  104. # And the number of image files exceeds the maximum limit
  105. and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits
  106. ):
  107. raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}")
  108. if config.number_limits and len(files) > config.number_limits:
  109. raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}")
  110. return files
  111. def _build_from_local_file(
  112. *,
  113. mapping: Mapping[str, Any],
  114. tenant_id: str,
  115. user_id: str,
  116. role: "CreatedByRole",
  117. config: FileExtraConfig,
  118. transfer_method: FileTransferMethod,
  119. ):
  120. # check if the upload file exists.
  121. file_type = FileType.value_of(mapping.get("type"))
  122. stmt = select(UploadFile).where(
  123. UploadFile.id == mapping.get("upload_file_id"),
  124. UploadFile.tenant_id == tenant_id,
  125. UploadFile.created_by == user_id,
  126. UploadFile.created_by_role == role,
  127. )
  128. if file_type == FileType.IMAGE:
  129. stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
  130. elif file_type == FileType.VIDEO:
  131. stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
  132. elif file_type == FileType.AUDIO:
  133. stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
  134. elif file_type == FileType.DOCUMENT:
  135. stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
  136. row = db.session.scalar(stmt)
  137. if row is None:
  138. raise ValueError("Invalid upload file")
  139. file = File(
  140. id=mapping.get("id"),
  141. filename=row.name,
  142. extension="." + row.extension,
  143. mime_type=row.mime_type,
  144. tenant_id=tenant_id,
  145. type=file_type,
  146. transfer_method=transfer_method,
  147. remote_url=None,
  148. related_id=mapping.get("upload_file_id"),
  149. _extra_config=config,
  150. size=row.size,
  151. )
  152. return file
  153. def _build_from_remote_url(
  154. *,
  155. mapping: Mapping[str, Any],
  156. tenant_id: str,
  157. config: FileExtraConfig,
  158. transfer_method: FileTransferMethod,
  159. ):
  160. url = mapping.get("url")
  161. if not url:
  162. raise ValueError("Invalid file url")
  163. resp = ssrf_proxy.head(url, follow_redirects=True)
  164. if resp.status_code == httpx.codes.OK:
  165. # Try to extract filename from response headers or URL
  166. content_disposition = resp.headers.get("Content-Disposition")
  167. if content_disposition:
  168. filename = content_disposition.split("filename=")[-1].strip('"')
  169. else:
  170. filename = url.split("/")[-1].split("?")[0]
  171. # Create the File object
  172. file_size = int(resp.headers.get("Content-Length", -1))
  173. mime_type = str(resp.headers.get("Content-Type", ""))
  174. else:
  175. filename = ""
  176. file_size = -1
  177. mime_type = ""
  178. # If filename is empty, set a default one
  179. if not filename:
  180. filename = "unknown_file"
  181. # Determine file extension
  182. extension = "." + filename.split(".")[-1] if "." in filename else ".bin"
  183. if not mime_type:
  184. mime_type, _ = mimetypes.guess_type(url)
  185. file = File(
  186. id=mapping.get("id"),
  187. filename=filename,
  188. tenant_id=tenant_id,
  189. type=FileType.value_of(mapping.get("type")),
  190. transfer_method=transfer_method,
  191. remote_url=url,
  192. _extra_config=config,
  193. mime_type=mime_type,
  194. extension=extension,
  195. size=file_size,
  196. )
  197. return file
  198. def _build_from_tool_file(
  199. *,
  200. mapping: Mapping[str, Any],
  201. tenant_id: str,
  202. user_id: str,
  203. config: FileExtraConfig,
  204. transfer_method: FileTransferMethod,
  205. ):
  206. tool_file = (
  207. db.session.query(ToolFile)
  208. .filter(
  209. ToolFile.id == mapping.get("tool_file_id"),
  210. ToolFile.tenant_id == tenant_id,
  211. ToolFile.user_id == user_id,
  212. )
  213. .first()
  214. )
  215. if tool_file is None:
  216. raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
  217. path = tool_file.file_key
  218. if "." in path:
  219. extension = "." + path.split("/")[-1].split(".")[-1]
  220. else:
  221. extension = ".bin"
  222. file = File(
  223. id=mapping.get("id"),
  224. tenant_id=tenant_id,
  225. filename=tool_file.name,
  226. type=FileType.value_of(mapping.get("type")),
  227. transfer_method=transfer_method,
  228. remote_url=tool_file.original_url,
  229. related_id=tool_file.id,
  230. extension=extension,
  231. mime_type=tool_file.mimetype,
  232. size=tool_file.size,
  233. _extra_config=config,
  234. )
  235. return file