12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import os
- from langchain.callbacks.manager import Callbacks
- from langchain.schema import BaseMessage, LLMResult
- from langchain.chat_models import ChatOpenAI
- from typing import Optional, List, Dict, Any, Union, Tuple
- from pydantic import root_validator
- from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
- class StreamableChatOpenAI(ChatOpenAI):
- request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
- """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
- max_retries: int = 1
- """Maximum number of retries to make when generating."""
- @root_validator()
- def validate_environment(cls, values: Dict) -> Dict:
- """Validate that api key and python package exists in environment."""
- try:
- import openai
- except ImportError:
- raise ValueError(
- "Could not import openai python package. "
- "Please install it with `pip install openai`."
- )
- try:
- values["client"] = openai.ChatCompletion
- except AttributeError:
- raise ValueError(
- "`openai` has no `ChatCompletion` attribute, this is likely "
- "due to an old version of the openai package. Try upgrading it "
- "with `pip install --upgrade openai`."
- )
- if values["n"] < 1:
- raise ValueError("n must be at least 1.")
- if values["n"] > 1 and values["streaming"]:
- raise ValueError("n must be 1 when streaming.")
- return values
- @property
- def _default_params(self) -> Dict[str, Any]:
- """Get the default parameters for calling OpenAI API."""
- return {
- **super()._default_params,
- "api_type": 'openai',
- "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
- "api_version": None,
- "api_key": self.openai_api_key,
- "organization": self.openai_organization if self.openai_organization else None,
- }
- @handle_openai_exceptions
- def generate(
- self,
- messages: List[List[BaseMessage]],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return super().generate(messages, stop, callbacks, **kwargs)
- @classmethod
- def get_kwargs_from_model_params(cls, params: dict):
- model_kwargs = {
- 'top_p': params.get('top_p', 1),
- 'frequency_penalty': params.get('frequency_penalty', 0),
- 'presence_penalty': params.get('presence_penalty', 0),
- }
- del params['top_p']
- del params['frequency_penalty']
- del params['presence_penalty']
- params['model_kwargs'] = model_kwargs
- return params
|