entities.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from collections.abc import Sequence
  2. from typing import Any, Optional
  3. from pydantic import BaseModel, Field, field_validator
  4. from core.model_runtime.entities import ImagePromptMessageContent
  5. from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
  6. from core.workflow.entities.variable_entities import VariableSelector
  7. from core.workflow.nodes.base import BaseNodeData
  8. class ModelConfig(BaseModel):
  9. provider: str
  10. name: str
  11. mode: str
  12. completion_params: dict[str, Any] = {}
  13. class ContextConfig(BaseModel):
  14. enabled: bool
  15. variable_selector: Optional[list[str]] = None
  16. class VisionConfigOptions(BaseModel):
  17. variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
  18. detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
  19. class VisionConfig(BaseModel):
  20. enabled: bool = False
  21. configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
  22. @field_validator("configs", mode="before")
  23. @classmethod
  24. def convert_none_configs(cls, v: Any):
  25. if v is None:
  26. return VisionConfigOptions()
  27. return v
  28. class PromptConfig(BaseModel):
  29. jinja2_variables: Optional[list[VariableSelector]] = None
  30. class LLMNodeChatModelMessage(ChatModelMessage):
  31. jinja2_text: Optional[str] = None
  32. class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
  33. jinja2_text: Optional[str] = None
  34. class LLMNodeData(BaseNodeData):
  35. model: ModelConfig
  36. prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
  37. prompt_config: Optional[PromptConfig] = None
  38. memory: Optional[MemoryConfig] = None
  39. context: ContextConfig
  40. vision: VisionConfig = Field(default_factory=VisionConfig)