1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192 |
- import os
- from langchain.callbacks.manager import Callbacks
- from langchain.schema import BaseMessage, LLMResult
- from langchain.chat_models import ChatOpenAI
- from typing import Optional, List, Dict, Any
- from pydantic import root_validator
- from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
- class StreamableChatOpenAI(ChatOpenAI):
- @root_validator()
- def validate_environment(cls, values: Dict) -> Dict:
- """Validate that api key and python package exists in environment."""
- try:
- import openai
- except ImportError:
- raise ValueError(
- "Could not import openai python package. "
- "Please install it with `pip install openai`."
- )
- try:
- values["client"] = openai.ChatCompletion
- except AttributeError:
- raise ValueError(
- "`openai` has no `ChatCompletion` attribute, this is likely "
- "due to an old version of the openai package. Try upgrading it "
- "with `pip install --upgrade openai`."
- )
- if values["n"] < 1:
- raise ValueError("n must be at least 1.")
- if values["n"] > 1 and values["streaming"]:
- raise ValueError("n must be 1 when streaming.")
- return values
- @property
- def _default_params(self) -> Dict[str, Any]:
- """Get the default parameters for calling OpenAI API."""
- return {
- **super()._default_params,
- "api_type": 'openai',
- "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
- "api_version": None,
- "api_key": self.openai_api_key,
- "organization": self.openai_organization if self.openai_organization else None,
- }
- def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
- """Get the number of tokens in a list of messages.
- Args:
- messages: The messages to count the tokens of.
- Returns:
- The number of tokens in the messages.
- """
- tokens_per_message = 5
- tokens_per_request = 3
- message_tokens = tokens_per_request
- message_strs = ''
- for message in messages:
- message_strs += message.content
- message_tokens += tokens_per_message
- # calc once
- message_tokens += self.get_num_tokens(message_strs)
- return message_tokens
- @handle_llm_exceptions
- def generate(
- self,
- messages: List[List[BaseMessage]],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return super().generate(messages, stop, callbacks, **kwargs)
- @handle_llm_exceptions_async
- async def agenerate(
- self,
- messages: List[List[BaseMessage]],
- stop: Optional[List[str]] = None,
- callbacks: Callbacks = None,
- **kwargs: Any,
- ) -> LLMResult:
- return await super().agenerate(messages, stop, callbacks, **kwargs)
|