spark.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import re
  2. import string
  3. import threading
  4. from _decimal import Decimal, ROUND_HALF_UP
  5. from typing import Dict, List, Optional, Any, Mapping
  6. from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
  7. from langchain.chat_models.base import BaseChatModel
  8. from langchain.llms.utils import enforce_stop_tokens
  9. from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
  10. ChatGeneration
  11. from langchain.utils import get_from_dict_or_env
  12. from pydantic import root_validator
  13. from core.third_party.spark.spark_llm import SparkLLMClient
  14. class ChatSpark(BaseChatModel):
  15. r"""Wrapper around Spark's large language model.
  16. To use, you should pass `app_id`, `api_key`, `api_secret`
  17. as a named parameter to the constructor.
  18. Example:
  19. .. code-block:: python
  20. client = SparkLLMClient(
  21. model_name="<model_name>",
  22. app_id="<app_id>",
  23. api_key="<api_key>",
  24. api_secret="<api_secret>"
  25. )
  26. """
  27. client: Any = None #: :meta private:
  28. model_name: str = "spark"
  29. """The Spark model name."""
  30. max_tokens: int = 256
  31. """Denotes the number of tokens to predict per generation."""
  32. temperature: Optional[float] = None
  33. """A non-negative float that tunes the degree of randomness in generation."""
  34. top_k: Optional[int] = None
  35. """Number of most likely tokens to consider at each step."""
  36. user_id: Optional[str] = None
  37. """User ID to use for the model."""
  38. streaming: bool = False
  39. """Whether to stream the results."""
  40. app_id: Optional[str] = None
  41. api_key: Optional[str] = None
  42. api_secret: Optional[str] = None
  43. api_domain: Optional[str] = None
  44. @root_validator()
  45. def validate_environment(cls, values: Dict) -> Dict:
  46. """Validate that api key and python package exists in environment."""
  47. values["app_id"] = get_from_dict_or_env(
  48. values, "app_id", "SPARK_APP_ID"
  49. )
  50. values["api_key"] = get_from_dict_or_env(
  51. values, "api_key", "SPARK_API_KEY"
  52. )
  53. values["api_secret"] = get_from_dict_or_env(
  54. values, "api_secret", "SPARK_API_SECRET"
  55. )
  56. values["client"] = SparkLLMClient(
  57. model_name=values["model_name"],
  58. app_id=values["app_id"],
  59. api_key=values["api_key"],
  60. api_secret=values["api_secret"],
  61. api_domain=values.get('api_domain')
  62. )
  63. return values
  64. @property
  65. def _default_params(self) -> Mapping[str, Any]:
  66. """Get the default parameters for calling Anthropic API."""
  67. d = {
  68. "max_tokens": self.max_tokens
  69. }
  70. if self.temperature is not None:
  71. d["temperature"] = self.temperature
  72. if self.top_k is not None:
  73. d["top_k"] = self.top_k
  74. return d
  75. @property
  76. def _identifying_params(self) -> Mapping[str, Any]:
  77. """Get the identifying parameters."""
  78. return {**{}, **self._default_params}
  79. @property
  80. def lc_secrets(self) -> Dict[str, str]:
  81. return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
  82. @property
  83. def _llm_type(self) -> str:
  84. """Return type of chat model."""
  85. return "spark-chat"
  86. @property
  87. def lc_serializable(self) -> bool:
  88. return True
  89. def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
  90. """Format a list of messages into a full dict list.
  91. Args:
  92. messages (List[BaseMessage]): List of BaseMessage to combine.
  93. Returns:
  94. list[dict]
  95. """
  96. messages = messages.copy() # don't mutate the original list
  97. new_messages = []
  98. for message in messages:
  99. if isinstance(message, ChatMessage):
  100. new_messages.append({'role': 'user', 'content': message.content})
  101. elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
  102. new_messages.append({'role': 'user', 'content': message.content})
  103. elif isinstance(message, AIMessage):
  104. new_messages.append({'role': 'assistant', 'content': message.content})
  105. else:
  106. raise ValueError(f"Got unknown type {message}")
  107. return new_messages
  108. def _generate(
  109. self,
  110. messages: List[BaseMessage],
  111. stop: Optional[List[str]] = None,
  112. run_manager: Optional[CallbackManagerForLLMRun] = None,
  113. **kwargs: Any,
  114. ) -> ChatResult:
  115. messages = self._convert_messages_to_dicts(messages)
  116. thread = threading.Thread(target=self.client.run, args=(
  117. messages,
  118. self.user_id,
  119. self._default_params,
  120. self.streaming
  121. ))
  122. thread.start()
  123. completion = ""
  124. for content in self.client.subscribe():
  125. if isinstance(content, dict):
  126. delta = content['data']
  127. else:
  128. delta = content
  129. completion += delta
  130. if self.streaming and run_manager:
  131. run_manager.on_llm_new_token(
  132. delta,
  133. )
  134. thread.join()
  135. if stop is not None:
  136. completion = enforce_stop_tokens(completion, stop)
  137. message = AIMessage(content=completion)
  138. return ChatResult(generations=[ChatGeneration(message=message)])
  139. async def _agenerate(
  140. self,
  141. messages: List[BaseMessage],
  142. stop: Optional[List[str]] = None,
  143. run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
  144. **kwargs: Any,
  145. ) -> ChatResult:
  146. message = AIMessage(content='')
  147. return ChatResult(generations=[ChatGeneration(message=message)])
  148. def get_num_tokens(self, text: str) -> float:
  149. """Calculate number of tokens."""
  150. total = Decimal(0)
  151. words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
  152. for word in words:
  153. if word:
  154. if '\u4e00' <= word <= '\u9fff': # if chinese
  155. total += Decimal('1.5')
  156. else:
  157. total += Decimal('0.8')
  158. return int(total)