message.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import logging
  4. from typing import Generator, Union
  5. from flask import Response, stream_with_context
  6. from flask_restful import fields, marshal_with, reqparse
  7. from flask_restful.inputs import int_range
  8. from werkzeug.exceptions import InternalServerError, NotFound
  9. import services
  10. from controllers.web import api
  11. from controllers.web.error import (
  12. AppMoreLikeThisDisabledError,
  13. AppSuggestedQuestionsAfterAnswerDisabledError,
  14. CompletionRequestError,
  15. NotChatAppError,
  16. NotCompletionAppError,
  17. ProviderModelCurrentlyNotSupportError,
  18. ProviderNotInitializeError,
  19. ProviderQuotaExceededError,
  20. )
  21. from controllers.web.wraps import WebApiResource
  22. from core.entities.application_entities import InvokeFrom
  23. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  24. from core.model_runtime.errors.invoke import InvokeError
  25. from fields.conversation_fields import message_file_fields
  26. from fields.message_fields import agent_thought_fields
  27. from libs.helper import TimestampField, uuid_value
  28. from services.completion_service import CompletionService
  29. from services.errors.app import MoreLikeThisDisabledError
  30. from services.errors.conversation import ConversationNotExistsError
  31. from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
  32. from services.message_service import MessageService
  33. class MessageListApi(WebApiResource):
  34. feedback_fields = {
  35. 'rating': fields.String
  36. }
  37. retriever_resource_fields = {
  38. 'id': fields.String,
  39. 'message_id': fields.String,
  40. 'position': fields.Integer,
  41. 'dataset_id': fields.String,
  42. 'dataset_name': fields.String,
  43. 'document_id': fields.String,
  44. 'document_name': fields.String,
  45. 'data_source_type': fields.String,
  46. 'segment_id': fields.String,
  47. 'score': fields.Float,
  48. 'hit_count': fields.Integer,
  49. 'word_count': fields.Integer,
  50. 'segment_position': fields.Integer,
  51. 'index_node_hash': fields.String,
  52. 'content': fields.String,
  53. 'created_at': TimestampField
  54. }
  55. message_fields = {
  56. 'id': fields.String,
  57. 'conversation_id': fields.String,
  58. 'inputs': fields.Raw,
  59. 'query': fields.String,
  60. 'answer': fields.String,
  61. 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
  62. 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
  63. 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
  64. 'created_at': TimestampField,
  65. 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields))
  66. }
  67. message_infinite_scroll_pagination_fields = {
  68. 'limit': fields.Integer,
  69. 'has_more': fields.Boolean,
  70. 'data': fields.List(fields.Nested(message_fields))
  71. }
  72. @marshal_with(message_infinite_scroll_pagination_fields)
  73. def get(self, app_model, end_user):
  74. if app_model.mode != 'chat':
  75. raise NotChatAppError()
  76. parser = reqparse.RequestParser()
  77. parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
  78. parser.add_argument('first_id', type=uuid_value, location='args')
  79. parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
  80. args = parser.parse_args()
  81. try:
  82. return MessageService.pagination_by_first_id(app_model, end_user,
  83. args['conversation_id'], args['first_id'], args['limit'])
  84. except services.errors.conversation.ConversationNotExistsError:
  85. raise NotFound("Conversation Not Exists.")
  86. except services.errors.message.FirstMessageNotExistsError:
  87. raise NotFound("First Message Not Exists.")
  88. class MessageFeedbackApi(WebApiResource):
  89. def post(self, app_model, end_user, message_id):
  90. message_id = str(message_id)
  91. parser = reqparse.RequestParser()
  92. parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
  93. args = parser.parse_args()
  94. try:
  95. MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
  96. except services.errors.message.MessageNotExistsError:
  97. raise NotFound("Message Not Exists.")
  98. return {'result': 'success'}
  99. class MessageMoreLikeThisApi(WebApiResource):
  100. def get(self, app_model, end_user, message_id):
  101. if app_model.mode != 'completion':
  102. raise NotCompletionAppError()
  103. message_id = str(message_id)
  104. parser = reqparse.RequestParser()
  105. parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
  106. args = parser.parse_args()
  107. streaming = args['response_mode'] == 'streaming'
  108. try:
  109. response = CompletionService.generate_more_like_this(
  110. app_model=app_model,
  111. user=end_user,
  112. message_id=message_id,
  113. invoke_from=InvokeFrom.WEB_APP,
  114. streaming=streaming
  115. )
  116. return compact_response(response)
  117. except MessageNotExistsError:
  118. raise NotFound("Message Not Exists.")
  119. except MoreLikeThisDisabledError:
  120. raise AppMoreLikeThisDisabledError()
  121. except ProviderTokenNotInitError as ex:
  122. raise ProviderNotInitializeError(ex.description)
  123. except QuotaExceededError:
  124. raise ProviderQuotaExceededError()
  125. except ModelCurrentlyNotSupportError:
  126. raise ProviderModelCurrentlyNotSupportError()
  127. except InvokeError as e:
  128. raise CompletionRequestError(e.description)
  129. except ValueError as e:
  130. raise e
  131. except Exception:
  132. logging.exception("internal server error.")
  133. raise InternalServerError()
  134. def compact_response(response: Union[dict, Generator]) -> Response:
  135. if isinstance(response, dict):
  136. return Response(response=json.dumps(response), status=200, mimetype='application/json')
  137. else:
  138. def generate() -> Generator:
  139. for chunk in response:
  140. yield chunk
  141. return Response(stream_with_context(generate()), status=200,
  142. mimetype='text/event-stream')
  143. class MessageSuggestedQuestionApi(WebApiResource):
  144. def get(self, app_model, end_user, message_id):
  145. if app_model.mode != 'chat':
  146. raise NotCompletionAppError()
  147. message_id = str(message_id)
  148. try:
  149. questions = MessageService.get_suggested_questions_after_answer(
  150. app_model=app_model,
  151. user=end_user,
  152. message_id=message_id
  153. )
  154. except MessageNotExistsError:
  155. raise NotFound("Message not found")
  156. except ConversationNotExistsError:
  157. raise NotFound("Conversation not found")
  158. except SuggestedQuestionsAfterAnswerDisabledError:
  159. raise AppSuggestedQuestionsAfterAnswerDisabledError()
  160. except ProviderTokenNotInitError as ex:
  161. raise ProviderNotInitializeError(ex.description)
  162. except QuotaExceededError:
  163. raise ProviderQuotaExceededError()
  164. except ModelCurrentlyNotSupportError:
  165. raise ProviderModelCurrentlyNotSupportError()
  166. except InvokeError as e:
  167. raise CompletionRequestError(e.description)
  168. except Exception:
  169. logging.exception("internal server error.")
  170. raise InternalServerError()
  171. return {'data': questions}
  172. api.add_resource(MessageListApi, '/messages')
  173. api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
  174. api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
  175. api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')