|
@@ -9,38 +9,40 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI
|
|
|
|
|
|
class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|
|
|
|
|
- def _update_endpoint_url(self, credentials: dict):
|
|
|
+ def _update_credential(self, model: str, credentials: dict):
|
|
|
credentials['endpoint_url'] = "https://openrouter.ai/api/v1"
|
|
|
- return credentials
|
|
|
+ credentials['mode'] = self.get_model_mode(model).value
|
|
|
+ credentials['function_calling_type'] = 'tool_call'
|
|
|
+ return
|
|
|
|
|
|
def _invoke(self, model: str, credentials: dict,
|
|
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
stream: bool = True, user: Optional[str] = None) \
|
|
|
-> Union[LLMResult, Generator]:
|
|
|
- cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
+ self._update_credential(model, credentials)
|
|
|
|
|
|
- return super()._invoke(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
- cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
+ self._update_credential(model, credentials)
|
|
|
|
|
|
- return super().validate_credentials(model, cred_with_endpoint)
|
|
|
+ return super().validate_credentials(model, credentials)
|
|
|
|
|
|
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
|
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
|
|
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
|
|
- cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
+ self._update_credential(model, credentials)
|
|
|
|
|
|
- return super()._generate(model, cred_with_endpoint, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
+ return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
|
|
|
|
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
|
|
- cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
+ self._update_credential(model, credentials)
|
|
|
|
|
|
- return super().get_customizable_model_schema(model, cred_with_endpoint)
|
|
|
+ return super().get_customizable_model_schema(model, credentials)
|
|
|
|
|
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
|
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
|
|
- cred_with_endpoint = self._update_endpoint_url(credentials=credentials)
|
|
|
+ self._update_credential(model, credentials)
|
|
|
|
|
|
- return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools)
|
|
|
+ return super().get_num_tokens(model, credentials, prompt_messages, tools)
|