file_obj.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import enum
  2. from typing import Optional
  3. from pydantic import BaseModel
  4. from core.file.upload_file_parser import UploadFileParser
  5. from core.model_runtime.entities.message_entities import ImagePromptMessageContent
  6. from extensions.ext_database import db
  7. from models.model import UploadFile
  8. class FileType(enum.Enum):
  9. IMAGE = 'image'
  10. @staticmethod
  11. def value_of(value):
  12. for member in FileType:
  13. if member.value == value:
  14. return member
  15. raise ValueError(f"No matching enum found for value '{value}'")
  16. class FileTransferMethod(enum.Enum):
  17. REMOTE_URL = 'remote_url'
  18. LOCAL_FILE = 'local_file'
  19. TOOL_FILE = 'tool_file'
  20. @staticmethod
  21. def value_of(value):
  22. for member in FileTransferMethod:
  23. if member.value == value:
  24. return member
  25. raise ValueError(f"No matching enum found for value '{value}'")
  26. class FileBelongsTo(enum.Enum):
  27. USER = 'user'
  28. ASSISTANT = 'assistant'
  29. @staticmethod
  30. def value_of(value):
  31. for member in FileBelongsTo:
  32. if member.value == value:
  33. return member
  34. raise ValueError(f"No matching enum found for value '{value}'")
  35. class FileObj(BaseModel):
  36. id: Optional[str]
  37. tenant_id: str
  38. type: FileType
  39. transfer_method: FileTransferMethod
  40. url: Optional[str]
  41. upload_file_id: Optional[str]
  42. file_config: dict
  43. @property
  44. def data(self) -> Optional[str]:
  45. return self._get_data()
  46. @property
  47. def preview_url(self) -> Optional[str]:
  48. return self._get_data(force_url=True)
  49. @property
  50. def prompt_message_content(self) -> ImagePromptMessageContent:
  51. if self.type == FileType.IMAGE:
  52. image_config = self.file_config.get('image')
  53. return ImagePromptMessageContent(
  54. data=self.data,
  55. detail=ImagePromptMessageContent.DETAIL.HIGH
  56. if image_config.get("detail") == "high" else ImagePromptMessageContent.DETAIL.LOW
  57. )
  58. def _get_data(self, force_url: bool = False) -> Optional[str]:
  59. if self.type == FileType.IMAGE:
  60. if self.transfer_method == FileTransferMethod.REMOTE_URL:
  61. return self.url
  62. elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
  63. upload_file = (db.session.query(UploadFile)
  64. .filter(
  65. UploadFile.id == self.upload_file_id,
  66. UploadFile.tenant_id == self.tenant_id
  67. ).first())
  68. return UploadFileParser.get_image_data(
  69. upload_file=upload_file,
  70. force_url=force_url
  71. )
  72. return None