123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- import re
- import string
- import threading
- from _decimal import Decimal, ROUND_HALF_UP
- from typing import Dict, List, Optional, Any, Mapping
- from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
- from langchain.chat_models.base import BaseChatModel
- from langchain.llms.utils import enforce_stop_tokens
- from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
- ChatGeneration
- from langchain.utils import get_from_dict_or_env
- from pydantic import root_validator
- from core.third_party.spark.spark_llm import SparkLLMClient
- class ChatSpark(BaseChatModel):
- r"""Wrapper around Spark's large language model.
- To use, you should pass `app_id`, `api_key`, `api_secret`
- as a named parameter to the constructor.
- Example:
- .. code-block:: python
- client = SparkLLMClient(
- model_name="<model_name>",
- app_id="<app_id>",
- api_key="<api_key>",
- api_secret="<api_secret>"
- )
- """
- client: Any = None #: :meta private:
- model_name: str = "spark"
- """The Spark model name."""
- max_tokens: int = 256
- """Denotes the number of tokens to predict per generation."""
- temperature: Optional[float] = None
- """A non-negative float that tunes the degree of randomness in generation."""
- top_k: Optional[int] = None
- """Number of most likely tokens to consider at each step."""
- user_id: Optional[str] = None
- """User ID to use for the model."""
- streaming: bool = False
- """Whether to stream the results."""
- app_id: Optional[str] = None
- api_key: Optional[str] = None
- api_secret: Optional[str] = None
- api_domain: Optional[str] = None
- @root_validator()
- def validate_environment(cls, values: Dict) -> Dict:
- """Validate that api key and python package exists in environment."""
- values["app_id"] = get_from_dict_or_env(
- values, "app_id", "SPARK_APP_ID"
- )
- values["api_key"] = get_from_dict_or_env(
- values, "api_key", "SPARK_API_KEY"
- )
- values["api_secret"] = get_from_dict_or_env(
- values, "api_secret", "SPARK_API_SECRET"
- )
- values["client"] = SparkLLMClient(
- model_name=values["model_name"],
- app_id=values["app_id"],
- api_key=values["api_key"],
- api_secret=values["api_secret"],
- api_domain=values.get('api_domain')
- )
- return values
- @property
- def _default_params(self) -> Mapping[str, Any]:
- """Get the default parameters for calling Anthropic API."""
- d = {
- "max_tokens": self.max_tokens
- }
- if self.temperature is not None:
- d["temperature"] = self.temperature
- if self.top_k is not None:
- d["top_k"] = self.top_k
- return d
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- """Get the identifying parameters."""
- return {**{}, **self._default_params}
- @property
- def lc_secrets(self) -> Dict[str, str]:
- return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
- @property
- def _llm_type(self) -> str:
- """Return type of chat model."""
- return "spark-chat"
- @property
- def lc_serializable(self) -> bool:
- return True
- def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
- """Format a list of messages into a full dict list.
- Args:
- messages (List[BaseMessage]): List of BaseMessage to combine.
- Returns:
- list[dict]
- """
- messages = messages.copy() # don't mutate the original list
- new_messages = []
- for message in messages:
- if isinstance(message, ChatMessage):
- new_messages.append({'role': 'user', 'content': message.content})
- elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
- new_messages.append({'role': 'user', 'content': message.content})
- elif isinstance(message, AIMessage):
- new_messages.append({'role': 'assistant', 'content': message.content})
- else:
- raise ValueError(f"Got unknown type {message}")
- return new_messages
- def _generate(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> ChatResult:
- messages = self._convert_messages_to_dicts(messages)
- thread = threading.Thread(target=self.client.run, args=(
- messages,
- self.user_id,
- self._default_params,
- self.streaming
- ))
- thread.start()
- completion = ""
- for content in self.client.subscribe():
- if isinstance(content, dict):
- delta = content['data']
- else:
- delta = content
- completion += delta
- if self.streaming and run_manager:
- run_manager.on_llm_new_token(
- delta,
- )
- thread.join()
- if stop is not None:
- completion = enforce_stop_tokens(completion, stop)
- message = AIMessage(content=completion)
- return ChatResult(generations=[ChatGeneration(message=message)])
- async def _agenerate(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> ChatResult:
- message = AIMessage(content='')
- return ChatResult(generations=[ChatGeneration(message=message)])
- def get_num_tokens(self, text: str) -> float:
- """Calculate number of tokens."""
- total = Decimal(0)
- words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
- for word in words:
- if word:
- if '\u4e00' <= word <= '\u9fff': # if chinese
- total += Decimal('1.5')
- else:
- total += Decimal('0.8')
- return int(total)
|