message.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # -*- coding:utf-8 -*-
  2. import json
  3. import logging
  4. from typing import Generator, Union
  5. import services
  6. from controllers.web import api
  7. from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError,
  8. CompletionRequestError, NotChatAppError, NotCompletionAppError,
  9. ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError,
  10. ProviderQuotaExceededError)
  11. from controllers.web.wraps import WebApiResource
  12. from core.entities.application_entities import InvokeFrom
  13. from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
  14. from core.model_runtime.errors.invoke import InvokeError
  15. from fields.conversation_fields import message_file_fields
  16. from flask import Response, stream_with_context
  17. from flask_restful import fields, marshal_with, reqparse
  18. from flask_restful.inputs import int_range
  19. from libs.helper import TimestampField, uuid_value
  20. from services.completion_service import CompletionService
  21. from services.errors.app import MoreLikeThisDisabledError
  22. from services.errors.conversation import ConversationNotExistsError
  23. from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
  24. from services.message_service import MessageService
  25. from werkzeug.exceptions import InternalServerError, NotFound
  26. class MessageListApi(WebApiResource):
  27. feedback_fields = {
  28. 'rating': fields.String
  29. }
  30. retriever_resource_fields = {
  31. 'id': fields.String,
  32. 'message_id': fields.String,
  33. 'position': fields.Integer,
  34. 'dataset_id': fields.String,
  35. 'dataset_name': fields.String,
  36. 'document_id': fields.String,
  37. 'document_name': fields.String,
  38. 'data_source_type': fields.String,
  39. 'segment_id': fields.String,
  40. 'score': fields.Float,
  41. 'hit_count': fields.Integer,
  42. 'word_count': fields.Integer,
  43. 'segment_position': fields.Integer,
  44. 'index_node_hash': fields.String,
  45. 'content': fields.String,
  46. 'created_at': TimestampField
  47. }
  48. message_fields = {
  49. 'id': fields.String,
  50. 'conversation_id': fields.String,
  51. 'inputs': fields.Raw,
  52. 'query': fields.String,
  53. 'answer': fields.String,
  54. 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
  55. 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
  56. 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
  57. 'created_at': TimestampField
  58. }
  59. message_infinite_scroll_pagination_fields = {
  60. 'limit': fields.Integer,
  61. 'has_more': fields.Boolean,
  62. 'data': fields.List(fields.Nested(message_fields))
  63. }
  64. @marshal_with(message_infinite_scroll_pagination_fields)
  65. def get(self, app_model, end_user):
  66. if app_model.mode != 'chat':
  67. raise NotChatAppError()
  68. parser = reqparse.RequestParser()
  69. parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
  70. parser.add_argument('first_id', type=uuid_value, location='args')
  71. parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
  72. args = parser.parse_args()
  73. try:
  74. return MessageService.pagination_by_first_id(app_model, end_user,
  75. args['conversation_id'], args['first_id'], args['limit'])
  76. except services.errors.conversation.ConversationNotExistsError:
  77. raise NotFound("Conversation Not Exists.")
  78. except services.errors.message.FirstMessageNotExistsError:
  79. raise NotFound("First Message Not Exists.")
  80. class MessageFeedbackApi(WebApiResource):
  81. def post(self, app_model, end_user, message_id):
  82. message_id = str(message_id)
  83. parser = reqparse.RequestParser()
  84. parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
  85. args = parser.parse_args()
  86. try:
  87. MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
  88. except services.errors.message.MessageNotExistsError:
  89. raise NotFound("Message Not Exists.")
  90. return {'result': 'success'}
  91. class MessageMoreLikeThisApi(WebApiResource):
  92. def get(self, app_model, end_user, message_id):
  93. if app_model.mode != 'completion':
  94. raise NotCompletionAppError()
  95. message_id = str(message_id)
  96. parser = reqparse.RequestParser()
  97. parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args')
  98. args = parser.parse_args()
  99. streaming = args['response_mode'] == 'streaming'
  100. try:
  101. response = CompletionService.generate_more_like_this(
  102. app_model=app_model,
  103. user=end_user,
  104. message_id=message_id,
  105. invoke_from=InvokeFrom.WEB_APP,
  106. streaming=streaming
  107. )
  108. return compact_response(response)
  109. except MessageNotExistsError:
  110. raise NotFound("Message Not Exists.")
  111. except MoreLikeThisDisabledError:
  112. raise AppMoreLikeThisDisabledError()
  113. except ProviderTokenNotInitError as ex:
  114. raise ProviderNotInitializeError(ex.description)
  115. except QuotaExceededError:
  116. raise ProviderQuotaExceededError()
  117. except ModelCurrentlyNotSupportError:
  118. raise ProviderModelCurrentlyNotSupportError()
  119. except InvokeError as e:
  120. raise CompletionRequestError(e.description)
  121. except ValueError as e:
  122. raise e
  123. except Exception:
  124. logging.exception("internal server error.")
  125. raise InternalServerError()
  126. def compact_response(response: Union[dict, Generator]) -> Response:
  127. if isinstance(response, dict):
  128. return Response(response=json.dumps(response), status=200, mimetype='application/json')
  129. else:
  130. def generate() -> Generator:
  131. try:
  132. for chunk in response:
  133. yield chunk
  134. except MessageNotExistsError:
  135. yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
  136. except MoreLikeThisDisabledError:
  137. yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
  138. except ProviderTokenNotInitError as ex:
  139. yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
  140. except QuotaExceededError:
  141. yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
  142. except ModelCurrentlyNotSupportError:
  143. yield "data: " + json.dumps(api.handle_error(ProviderModelCurrentlyNotSupportError()).get_json()) + "\n\n"
  144. except InvokeError as e:
  145. yield "data: " + json.dumps(api.handle_error(CompletionRequestError(e.description)).get_json()) + "\n\n"
  146. except ValueError as e:
  147. yield "data: " + json.dumps(api.handle_error(e).get_json()) + "\n\n"
  148. except Exception:
  149. logging.exception("internal server error.")
  150. yield "data: " + json.dumps(api.handle_error(InternalServerError()).get_json()) + "\n\n"
  151. return Response(stream_with_context(generate()), status=200,
  152. mimetype='text/event-stream')
  153. class MessageSuggestedQuestionApi(WebApiResource):
  154. def get(self, app_model, end_user, message_id):
  155. if app_model.mode != 'chat':
  156. raise NotCompletionAppError()
  157. message_id = str(message_id)
  158. try:
  159. questions = MessageService.get_suggested_questions_after_answer(
  160. app_model=app_model,
  161. user=end_user,
  162. message_id=message_id
  163. )
  164. except MessageNotExistsError:
  165. raise NotFound("Message not found")
  166. except ConversationNotExistsError:
  167. raise NotFound("Conversation not found")
  168. except SuggestedQuestionsAfterAnswerDisabledError:
  169. raise AppSuggestedQuestionsAfterAnswerDisabledError()
  170. except ProviderTokenNotInitError as ex:
  171. raise ProviderNotInitializeError(ex.description)
  172. except QuotaExceededError:
  173. raise ProviderQuotaExceededError()
  174. except ModelCurrentlyNotSupportError:
  175. raise ProviderModelCurrentlyNotSupportError()
  176. except InvokeError as e:
  177. raise CompletionRequestError(e.description)
  178. except Exception:
  179. logging.exception("internal server error.")
  180. raise InternalServerError()
  181. return {'data': questions}
  182. api.add_resource(MessageListApi, '/messages')
  183. api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
  184. api.add_resource(MessageMoreLikeThisApi, '/messages/<uuid:message_id>/more-like-this')
  185. api.add_resource(MessageSuggestedQuestionApi, '/messages/<uuid:message_id>/suggested-questions')