synthesizer.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import (
  2. Any,
  3. Dict,
  4. Optional, Sequence,
  5. )
  6. from llama_index.indices.response.response_synthesis import ResponseSynthesizer
  7. from llama_index.indices.response.response_builder import ResponseMode, BaseResponseBuilder, get_response_builder
  8. from llama_index.indices.service_context import ServiceContext
  9. from llama_index.optimization.optimizer import BaseTokenUsageOptimizer
  10. from llama_index.prompts.prompts import (
  11. QuestionAnswerPrompt,
  12. RefinePrompt,
  13. SimpleInputPrompt,
  14. )
  15. from llama_index.types import RESPONSE_TEXT_TYPE
  16. class EnhanceResponseSynthesizer(ResponseSynthesizer):
  17. @classmethod
  18. def from_args(
  19. cls,
  20. service_context: ServiceContext,
  21. streaming: bool = False,
  22. use_async: bool = False,
  23. text_qa_template: Optional[QuestionAnswerPrompt] = None,
  24. refine_template: Optional[RefinePrompt] = None,
  25. simple_template: Optional[SimpleInputPrompt] = None,
  26. response_mode: ResponseMode = ResponseMode.DEFAULT,
  27. response_kwargs: Optional[Dict] = None,
  28. optimizer: Optional[BaseTokenUsageOptimizer] = None,
  29. ) -> "ResponseSynthesizer":
  30. response_builder: Optional[BaseResponseBuilder] = None
  31. if response_mode != ResponseMode.NO_TEXT:
  32. if response_mode == 'no_synthesizer':
  33. response_builder = NoSynthesizer(
  34. service_context=service_context,
  35. simple_template=simple_template,
  36. streaming=streaming,
  37. )
  38. else:
  39. response_builder = get_response_builder(
  40. service_context,
  41. text_qa_template,
  42. refine_template,
  43. simple_template,
  44. response_mode,
  45. use_async=use_async,
  46. streaming=streaming,
  47. )
  48. return cls(response_builder, response_mode, response_kwargs, optimizer)
  49. class NoSynthesizer(BaseResponseBuilder):
  50. def __init__(
  51. self,
  52. service_context: ServiceContext,
  53. simple_template: Optional[SimpleInputPrompt] = None,
  54. streaming: bool = False,
  55. ) -> None:
  56. super().__init__(service_context, streaming)
  57. async def aget_response(
  58. self,
  59. query_str: str,
  60. text_chunks: Sequence[str],
  61. prev_response: Optional[str] = None,
  62. **response_kwargs: Any,
  63. ) -> RESPONSE_TEXT_TYPE:
  64. return "\n".join(text_chunks)
  65. def get_response(
  66. self,
  67. query_str: str,
  68. text_chunks: Sequence[str],
  69. prev_response: Optional[str] = None,
  70. **response_kwargs: Any,
  71. ) -> RESPONSE_TEXT_TYPE:
  72. return "\n".join(text_chunks)