| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 | import reimport stringimport threadingfrom _decimal import Decimal, ROUND_HALF_UPfrom typing import Dict, List, Optional, Any, Mappingfrom langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRunfrom langchain.chat_models.base import BaseChatModelfrom langchain.llms.utils import enforce_stop_tokensfrom langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \    ChatGenerationfrom langchain.utils import get_from_dict_or_envfrom pydantic import root_validatorfrom core.third_party.spark.spark_llm import SparkLLMClientclass ChatSpark(BaseChatModel):    r"""Wrapper around Spark's large language model.    To use, you should pass `app_id`, `api_key`, `api_secret`    as a named parameter to the constructor.    Example:        .. code-block:: python        client = SparkLLMClient(            model_name="<model_name>",            app_id="<app_id>",            api_key="<api_key>",            api_secret="<api_secret>"        )    """    client: Any = None  #: :meta private:    model_name: str = "spark"    """The Spark model name."""    max_tokens: int = 256    """Denotes the number of tokens to predict per generation."""    temperature: Optional[float] = None    """A non-negative float that tunes the degree of randomness in generation."""    top_k: Optional[int] = None    """Number of most likely tokens to consider at each step."""    user_id: Optional[str] = None    """User ID to use for the model."""    streaming: bool = False    """Whether to stream the results."""    app_id: Optional[str] = None    api_key: Optional[str] = None    api_secret: Optional[str] = None    api_domain: Optional[str] = None    @root_validator()    def validate_environment(cls, values: Dict) -> Dict:        """Validate that api key and python package exists in environment."""        values["app_id"] = get_from_dict_or_env(            values, "app_id", "SPARK_APP_ID"        )        values["api_key"] = get_from_dict_or_env(            values, "api_key", "SPARK_API_KEY"        )        values["api_secret"] = get_from_dict_or_env(            values, "api_secret", "SPARK_API_SECRET"        )        values["client"] = SparkLLMClient(            model_name=values["model_name"],            app_id=values["app_id"],            api_key=values["api_key"],            api_secret=values["api_secret"],            api_domain=values.get('api_domain')        )        return values    @property    def _default_params(self) -> Mapping[str, Any]:        """Get the default parameters for calling Anthropic API."""        d = {            "max_tokens": self.max_tokens        }        if self.temperature is not None:            d["temperature"] = self.temperature        if self.top_k is not None:            d["top_k"] = self.top_k        return d    @property    def _identifying_params(self) -> Mapping[str, Any]:        """Get the identifying parameters."""        return {**{}, **self._default_params}    @property    def lc_secrets(self) -> Dict[str, str]:        return {"api_key": "API_KEY", "api_secret": "API_SECRET"}    @property    def _llm_type(self) -> str:        """Return type of chat model."""        return "spark-chat"    @property    def lc_serializable(self) -> bool:        return True    def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:        """Format a list of messages into a full dict list.        Args:            messages (List[BaseMessage]): List of BaseMessage to combine.        Returns:            list[dict]        """        messages = messages.copy()  # don't mutate the original list        new_messages = []        for message in messages:            if isinstance(message, ChatMessage):                new_messages.append({'role': 'user', 'content': message.content})            elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):                new_messages.append({'role': 'user', 'content': message.content})            elif isinstance(message, AIMessage):                new_messages.append({'role': 'assistant', 'content': message.content})            else:                raise ValueError(f"Got unknown type {message}")        return new_messages    def _generate(        self,        messages: List[BaseMessage],        stop: Optional[List[str]] = None,        run_manager: Optional[CallbackManagerForLLMRun] = None,        **kwargs: Any,    ) -> ChatResult:        messages = self._convert_messages_to_dicts(messages)        thread = threading.Thread(target=self.client.run, args=(            messages,            self.user_id,            self._default_params,            self.streaming        ))        thread.start()        completion = ""        for content in self.client.subscribe():            if isinstance(content, dict):                delta = content['data']            else:                delta = content            completion += delta            if self.streaming and run_manager:                run_manager.on_llm_new_token(                    delta,                )        thread.join()        if stop is not None:            completion = enforce_stop_tokens(completion, stop)        message = AIMessage(content=completion)        return ChatResult(generations=[ChatGeneration(message=message)])    async def _agenerate(        self,        messages: List[BaseMessage],        stop: Optional[List[str]] = None,        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,        **kwargs: Any,    ) -> ChatResult:        message = AIMessage(content='')        return ChatResult(generations=[ChatGeneration(message=message)])    def get_num_tokens(self, text: str) -> float:        """Calculate number of tokens."""        total = Decimal(0)        words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)        for word in words:            if word:                if '\u4e00' <= word <= '\u9fff':  # if chinese                    total += Decimal('1.5')                else:                    total += Decimal('0.8')        return int(total)
 |