openai_embedding.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from typing import Optional, Any, List
  2. import openai
  3. from llama_index.embeddings.base import BaseEmbedding
  4. from llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \
  5. _TEXT_MODE_MODEL_DICT
  6. from tenacity import wait_random_exponential, retry, stop_after_attempt
  7. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  8. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  9. def get_embedding(
  10. text: str,
  11. engine: Optional[str] = None,
  12. api_key: Optional[str] = None,
  13. **kwargs
  14. ) -> List[float]:
  15. """Get embedding.
  16. NOTE: Copied from OpenAI's embedding utils:
  17. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  18. Copied here to avoid importing unnecessary dependencies
  19. like matplotlib, plotly, scipy, sklearn.
  20. """
  21. text = text.replace("\n", " ")
  22. return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]
  23. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  24. async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[
  25. float]:
  26. """Asynchronously get embedding.
  27. NOTE: Copied from OpenAI's embedding utils:
  28. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  29. Copied here to avoid importing unnecessary dependencies
  30. like matplotlib, plotly, scipy, sklearn.
  31. """
  32. # replace newlines, which can negatively affect performance.
  33. text = text.replace("\n", " ")
  34. return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][
  35. "embedding"
  36. ]
  37. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  38. def get_embeddings(
  39. list_of_text: List[str],
  40. engine: Optional[str] = None,
  41. api_key: Optional[str] = None,
  42. **kwargs
  43. ) -> List[List[float]]:
  44. """Get embeddings.
  45. NOTE: Copied from OpenAI's embedding utils:
  46. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  47. Copied here to avoid importing unnecessary dependencies
  48. like matplotlib, plotly, scipy, sklearn.
  49. """
  50. assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
  51. # replace newlines, which can negatively affect performance.
  52. list_of_text = [text.replace("\n", " ") for text in list_of_text]
  53. data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data
  54. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
  55. return [d["embedding"] for d in data]
  56. @retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
  57. async def aget_embeddings(
  58. list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs
  59. ) -> List[List[float]]:
  60. """Asynchronously get embeddings.
  61. NOTE: Copied from OpenAI's embedding utils:
  62. https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
  63. Copied here to avoid importing unnecessary dependencies
  64. like matplotlib, plotly, scipy, sklearn.
  65. """
  66. assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
  67. # replace newlines, which can negatively affect performance.
  68. list_of_text = [text.replace("\n", " ") for text in list_of_text]
  69. data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data
  70. data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input.
  71. return [d["embedding"] for d in data]
  72. class OpenAIEmbedding(BaseEmbedding):
  73. def __init__(
  74. self,
  75. mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,
  76. model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,
  77. deployment_name: Optional[str] = None,
  78. openai_api_key: Optional[str] = None,
  79. **kwargs: Any,
  80. ) -> None:
  81. """Init params."""
  82. new_kwargs = {}
  83. if 'embed_batch_size' in kwargs:
  84. new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']
  85. if 'tokenizer' in kwargs:
  86. new_kwargs['tokenizer'] = kwargs['tokenizer']
  87. super().__init__(**new_kwargs)
  88. self.mode = OpenAIEmbeddingMode(mode)
  89. self.model = OpenAIEmbeddingModelType(model)
  90. self.deployment_name = deployment_name
  91. self.openai_api_key = openai_api_key
  92. self.openai_api_type = kwargs.get('openai_api_type')
  93. self.openai_api_version = kwargs.get('openai_api_version')
  94. self.openai_api_base = kwargs.get('openai_api_base')
  95. @handle_llm_exceptions
  96. def _get_query_embedding(self, query: str) -> List[float]:
  97. """Get query embedding."""
  98. if self.deployment_name is not None:
  99. engine = self.deployment_name
  100. else:
  101. key = (self.mode, self.model)
  102. if key not in _QUERY_MODE_MODEL_DICT:
  103. raise ValueError(f"Invalid mode, model combination: {key}")
  104. engine = _QUERY_MODE_MODEL_DICT[key]
  105. return get_embedding(query, engine=engine, api_key=self.openai_api_key,
  106. api_type=self.openai_api_type, api_version=self.openai_api_version,
  107. api_base=self.openai_api_base)
  108. def _get_text_embedding(self, text: str) -> List[float]:
  109. """Get text embedding."""
  110. if self.deployment_name is not None:
  111. engine = self.deployment_name
  112. else:
  113. key = (self.mode, self.model)
  114. if key not in _TEXT_MODE_MODEL_DICT:
  115. raise ValueError(f"Invalid mode, model combination: {key}")
  116. engine = _TEXT_MODE_MODEL_DICT[key]
  117. return get_embedding(text, engine=engine, api_key=self.openai_api_key,
  118. api_type=self.openai_api_type, api_version=self.openai_api_version,
  119. api_base=self.openai_api_base)
  120. async def _aget_text_embedding(self, text: str) -> List[float]:
  121. """Asynchronously get text embedding."""
  122. if self.deployment_name is not None:
  123. engine = self.deployment_name
  124. else:
  125. key = (self.mode, self.model)
  126. if key not in _TEXT_MODE_MODEL_DICT:
  127. raise ValueError(f"Invalid mode, model combination: {key}")
  128. engine = _TEXT_MODE_MODEL_DICT[key]
  129. return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,
  130. api_type=self.openai_api_type, api_version=self.openai_api_version,
  131. api_base=self.openai_api_base)
  132. def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
  133. """Get text embeddings.
  134. By default, this is a wrapper around _get_text_embedding.
  135. Can be overriden for batch queries.
  136. """
  137. if self.openai_api_type and self.openai_api_type == 'azure':
  138. embeddings = []
  139. for text in texts:
  140. embeddings.append(self._get_text_embedding(text))
  141. return embeddings
  142. if self.deployment_name is not None:
  143. engine = self.deployment_name
  144. else:
  145. key = (self.mode, self.model)
  146. if key not in _TEXT_MODE_MODEL_DICT:
  147. raise ValueError(f"Invalid mode, model combination: {key}")
  148. engine = _TEXT_MODE_MODEL_DICT[key]
  149. embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,
  150. api_type=self.openai_api_type, api_version=self.openai_api_version,
  151. api_base=self.openai_api_base)
  152. return embeddings
  153. async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
  154. """Asynchronously get text embeddings."""
  155. if self.openai_api_type and self.openai_api_type == 'azure':
  156. embeddings = []
  157. for text in texts:
  158. embeddings.append(await self._aget_text_embedding(text))
  159. return embeddings
  160. if self.deployment_name is not None:
  161. engine = self.deployment_name
  162. else:
  163. key = (self.mode, self.model)
  164. if key not in _TEXT_MODE_MODEL_DICT:
  165. raise ValueError(f"Invalid mode, model combination: {key}")
  166. engine = _TEXT_MODE_MODEL_DICT[key]
  167. embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,
  168. api_type=self.openai_api_type, api_version=self.openai_api_version,
  169. api_base=self.openai_api_base)
  170. return embeddings