message_file_parser.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import re
  2. from collections.abc import Mapping, Sequence
  3. from typing import Any, Union
  4. from urllib.parse import parse_qs, urlparse
  5. import requests
  6. from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
  7. from extensions.ext_database import db
  8. from models.account import Account
  9. from models.model import EndUser, MessageFile, UploadFile
  10. from services.file_service import IMAGE_EXTENSIONS
  11. class MessageFileParser:
  12. def __init__(self, tenant_id: str, app_id: str) -> None:
  13. self.tenant_id = tenant_id
  14. self.app_id = app_id
  15. def validate_and_transform_files_arg(
  16. self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
  17. ) -> list[FileVar]:
  18. """
  19. validate and transform files arg
  20. :param files:
  21. :param file_extra_config:
  22. :param user:
  23. :return:
  24. """
  25. for file in files:
  26. if not isinstance(file, dict):
  27. raise ValueError("Invalid file format, must be dict")
  28. if not file.get("type"):
  29. raise ValueError("Missing file type")
  30. FileType.value_of(file.get("type"))
  31. if not file.get("transfer_method"):
  32. raise ValueError("Missing file transfer method")
  33. FileTransferMethod.value_of(file.get("transfer_method"))
  34. if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
  35. if not file.get("url"):
  36. raise ValueError("Missing file url")
  37. if not file.get("url").startswith("http"):
  38. raise ValueError("Invalid file url")
  39. if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
  40. raise ValueError("Missing file upload_file_id")
  41. if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
  42. raise ValueError("Missing file tool_file_id")
  43. # transform files to file objs
  44. type_file_objs = self._to_file_objs(files, file_extra_config)
  45. # validate files
  46. new_files = []
  47. for file_type, file_objs in type_file_objs.items():
  48. if file_type == FileType.IMAGE:
  49. # parse and validate files
  50. image_config = file_extra_config.image_config
  51. # check if image file feature is enabled
  52. if not image_config:
  53. continue
  54. # Validate number of files
  55. if len(files) > image_config["number_limits"]:
  56. raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
  57. for file_obj in file_objs:
  58. # Validate transfer method
  59. if file_obj.transfer_method.value not in image_config["transfer_methods"]:
  60. raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
  61. # Validate file type
  62. if file_obj.type != FileType.IMAGE:
  63. raise ValueError(f"Invalid file type: {file_obj.type}")
  64. if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
  65. # check remote url valid and is image
  66. result, error = self._check_image_remote_url(file_obj.url)
  67. if result is False:
  68. raise ValueError(error)
  69. elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
  70. # get upload file from upload_file_id
  71. upload_file = (
  72. db.session.query(UploadFile)
  73. .filter(
  74. UploadFile.id == file_obj.related_id,
  75. UploadFile.tenant_id == self.tenant_id,
  76. UploadFile.created_by == user.id,
  77. UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
  78. UploadFile.extension.in_(IMAGE_EXTENSIONS),
  79. )
  80. .first()
  81. )
  82. # check upload file is belong to tenant and user
  83. if not upload_file:
  84. raise ValueError("Invalid upload file")
  85. new_files.append(file_obj)
  86. # return all file objs
  87. return new_files
  88. def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
  89. """
  90. transform message files
  91. :param files:
  92. :param file_extra_config:
  93. :return:
  94. """
  95. # transform files to file objs
  96. type_file_objs = self._to_file_objs(files, file_extra_config)
  97. # return all file objs
  98. return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
  99. def _to_file_objs(
  100. self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
  101. ) -> dict[FileType, list[FileVar]]:
  102. """
  103. transform files to file objs
  104. :param files:
  105. :param file_extra_config:
  106. :return:
  107. """
  108. type_file_objs: dict[FileType, list[FileVar]] = {
  109. # Currently only support image
  110. FileType.IMAGE: []
  111. }
  112. if not files:
  113. return type_file_objs
  114. # group by file type and convert file args or message files to FileObj
  115. for file in files:
  116. if isinstance(file, MessageFile):
  117. if file.belongs_to == FileBelongsTo.ASSISTANT.value:
  118. continue
  119. file_obj = self._to_file_obj(file, file_extra_config)
  120. if file_obj.type not in type_file_objs:
  121. continue
  122. type_file_objs[file_obj.type].append(file_obj)
  123. return type_file_objs
  124. def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
  125. """
  126. transform file to file obj
  127. :param file:
  128. :return:
  129. """
  130. if isinstance(file, dict):
  131. transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
  132. if transfer_method != FileTransferMethod.TOOL_FILE:
  133. return FileVar(
  134. tenant_id=self.tenant_id,
  135. type=FileType.value_of(file.get("type")),
  136. transfer_method=transfer_method,
  137. url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
  138. related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
  139. extra_config=file_extra_config,
  140. )
  141. return FileVar(
  142. tenant_id=self.tenant_id,
  143. type=FileType.value_of(file.get("type")),
  144. transfer_method=transfer_method,
  145. url=None,
  146. related_id=file.get("tool_file_id"),
  147. extra_config=file_extra_config,
  148. )
  149. else:
  150. return FileVar(
  151. id=file.id,
  152. tenant_id=self.tenant_id,
  153. type=FileType.value_of(file.type),
  154. transfer_method=FileTransferMethod.value_of(file.transfer_method),
  155. url=file.url,
  156. related_id=file.upload_file_id or None,
  157. extra_config=file_extra_config,
  158. )
  159. def _check_image_remote_url(self, url):
  160. try:
  161. headers = {
  162. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
  163. " Chrome/91.0.4472.124 Safari/537.36"
  164. }
  165. def is_s3_presigned_url(url):
  166. try:
  167. parsed_url = urlparse(url)
  168. if "amazonaws.com" not in parsed_url.netloc:
  169. return False
  170. query_params = parse_qs(parsed_url.query)
  171. required_params = ["Signature", "Expires"]
  172. for param in required_params:
  173. if param not in query_params:
  174. return False
  175. if not query_params["Expires"][0].isdigit():
  176. return False
  177. signature = query_params["Signature"][0]
  178. if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
  179. return False
  180. return True
  181. except Exception:
  182. return False
  183. if is_s3_presigned_url(url):
  184. response = requests.get(url, headers=headers, allow_redirects=True)
  185. if response.status_code in {200, 304}:
  186. return True, ""
  187. response = requests.head(url, headers=headers, allow_redirects=True)
  188. if response.status_code in {200, 304}:
  189. return True, ""
  190. else:
  191. return False, "URL does not exist."
  192. except requests.RequestException as e:
  193. return False, f"Error checking URL: {e}"