spark.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import re
  2. import string
  3. import threading
  4. from _decimal import Decimal, ROUND_HALF_UP
  5. from typing import Dict, List, Optional, Any, Mapping
  6. from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
  7. from langchain.chat_models.base import BaseChatModel
  8. from langchain.llms.utils import enforce_stop_tokens
  9. from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
  10. ChatGeneration
  11. from langchain.utils import get_from_dict_or_env
  12. from pydantic import root_validator
  13. from core.third_party.spark.spark_llm import SparkLLMClient
  14. class ChatSpark(BaseChatModel):
  15. r"""Wrapper around Spark's large language model.
  16. To use, you should pass `app_id`, `api_key`, `api_secret`
  17. as a named parameter to the constructor.
  18. Example:
  19. .. code-block:: python
  20. client = SparkLLMClient(
  21. app_id="<app_id>",
  22. api_key="<api_key>",
  23. api_secret="<api_secret>"
  24. )
  25. """
  26. client: Any = None #: :meta private:
  27. max_tokens: int = 256
  28. """Denotes the number of tokens to predict per generation."""
  29. temperature: Optional[float] = None
  30. """A non-negative float that tunes the degree of randomness in generation."""
  31. top_k: Optional[int] = None
  32. """Number of most likely tokens to consider at each step."""
  33. user_id: Optional[str] = None
  34. """User ID to use for the model."""
  35. streaming: bool = False
  36. """Whether to stream the results."""
  37. app_id: Optional[str] = None
  38. api_key: Optional[str] = None
  39. api_secret: Optional[str] = None
  40. @root_validator()
  41. def validate_environment(cls, values: Dict) -> Dict:
  42. """Validate that api key and python package exists in environment."""
  43. values["app_id"] = get_from_dict_or_env(
  44. values, "app_id", "SPARK_APP_ID"
  45. )
  46. values["api_key"] = get_from_dict_or_env(
  47. values, "api_key", "SPARK_API_KEY"
  48. )
  49. values["api_secret"] = get_from_dict_or_env(
  50. values, "api_secret", "SPARK_API_SECRET"
  51. )
  52. values["client"] = SparkLLMClient(
  53. app_id=values["app_id"],
  54. api_key=values["api_key"],
  55. api_secret=values["api_secret"],
  56. )
  57. return values
  58. @property
  59. def _default_params(self) -> Mapping[str, Any]:
  60. """Get the default parameters for calling Anthropic API."""
  61. d = {
  62. "max_tokens": self.max_tokens
  63. }
  64. if self.temperature is not None:
  65. d["temperature"] = self.temperature
  66. if self.top_k is not None:
  67. d["top_k"] = self.top_k
  68. return d
  69. @property
  70. def _identifying_params(self) -> Mapping[str, Any]:
  71. """Get the identifying parameters."""
  72. return {**{}, **self._default_params}
  73. @property
  74. def lc_secrets(self) -> Dict[str, str]:
  75. return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
  76. @property
  77. def _llm_type(self) -> str:
  78. """Return type of chat model."""
  79. return "spark-chat"
  80. @property
  81. def lc_serializable(self) -> bool:
  82. return True
  83. def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
  84. """Format a list of messages into a full dict list.
  85. Args:
  86. messages (List[BaseMessage]): List of BaseMessage to combine.
  87. Returns:
  88. list[dict]
  89. """
  90. messages = messages.copy() # don't mutate the original list
  91. new_messages = []
  92. for message in messages:
  93. if isinstance(message, ChatMessage):
  94. new_messages.append({'role': 'user', 'content': message.content})
  95. elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
  96. new_messages.append({'role': 'user', 'content': message.content})
  97. elif isinstance(message, AIMessage):
  98. new_messages.append({'role': 'assistant', 'content': message.content})
  99. else:
  100. raise ValueError(f"Got unknown type {message}")
  101. return new_messages
  102. def _generate(
  103. self,
  104. messages: List[BaseMessage],
  105. stop: Optional[List[str]] = None,
  106. run_manager: Optional[CallbackManagerForLLMRun] = None,
  107. **kwargs: Any,
  108. ) -> ChatResult:
  109. messages = self._convert_messages_to_dicts(messages)
  110. thread = threading.Thread(target=self.client.run, args=(
  111. messages,
  112. self.user_id,
  113. self._default_params,
  114. self.streaming
  115. ))
  116. thread.start()
  117. completion = ""
  118. for content in self.client.subscribe():
  119. if isinstance(content, dict):
  120. delta = content['data']
  121. else:
  122. delta = content
  123. completion += delta
  124. if self.streaming and run_manager:
  125. run_manager.on_llm_new_token(
  126. delta,
  127. )
  128. thread.join()
  129. if stop is not None:
  130. completion = enforce_stop_tokens(completion, stop)
  131. message = AIMessage(content=completion)
  132. return ChatResult(generations=[ChatGeneration(message=message)])
  133. async def _agenerate(
  134. self,
  135. messages: List[BaseMessage],
  136. stop: Optional[List[str]] = None,
  137. run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
  138. **kwargs: Any,
  139. ) -> ChatResult:
  140. message = AIMessage(content='')
  141. return ChatResult(generations=[ChatGeneration(message=message)])
  142. def get_num_tokens(self, text: str) -> float:
  143. """Calculate number of tokens."""
  144. total = Decimal(0)
  145. words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
  146. for word in words:
  147. if word:
  148. if '\u4e00' <= word <= '\u9fff': # if chinese
  149. total += Decimal('1.5')
  150. else:
  151. total += Decimal('0.8')
  152. return int(total)