__init__.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. from typing import Optional
  3. import langchain
  4. from flask import Flask
  5. from jieba.analyse import default_tfidf
  6. from langchain import set_handler
  7. from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
  8. from llama_index import IndexStructType, QueryMode
  9. from llama_index.indices.registry import INDEX_STRUT_TYPE_TO_QUERY_MAP
  10. from pydantic import BaseModel
  11. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  12. from core.index.keyword_table.jieba_keyword_table import GPTJIEBAKeywordTableIndex
  13. from core.index.keyword_table.stopwords import STOPWORDS
  14. from core.prompt.prompt_template import OneLineFormatter
  15. from core.vector_store.vector_store import VectorStore
  16. from core.vector_store.vector_store_index_query import EnhanceGPTVectorStoreIndexQuery
  17. class HostedOpenAICredential(BaseModel):
  18. api_key: str
  19. class HostedLLMCredentials(BaseModel):
  20. openai: Optional[HostedOpenAICredential] = None
  21. hosted_llm_credentials = HostedLLMCredentials()
  22. def init_app(app: Flask):
  23. formatter = OneLineFormatter()
  24. DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
  25. INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.KEYWORD_TABLE] = GPTJIEBAKeywordTableIndex.get_query_map()
  26. INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.WEAVIATE] = {
  27. QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
  28. QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
  29. }
  30. INDEX_STRUT_TYPE_TO_QUERY_MAP[IndexStructType.QDRANT] = {
  31. QueryMode.DEFAULT: EnhanceGPTVectorStoreIndexQuery,
  32. QueryMode.EMBEDDING: EnhanceGPTVectorStoreIndexQuery,
  33. }
  34. default_tfidf.stop_words = STOPWORDS
  35. if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
  36. langchain.verbose = True
  37. set_handler(DifyStdOutCallbackHandler())
  38. if app.config.get("OPENAI_API_KEY"):
  39. hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))