test_llm.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.volcengine_maas.llm.llm import VolcengineMaaSLargeLanguageModel
  8. def test_validate_credentials_for_chat_model():
  9. model = VolcengineMaaSLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='NOT IMPORTANT',
  13. credentials={
  14. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  15. 'volc_region': 'cn-beijing',
  16. 'volc_access_key_id': 'INVALID',
  17. 'volc_secret_access_key': 'INVALID',
  18. 'endpoint_id': 'INVALID',
  19. }
  20. )
  21. model.validate_credentials(
  22. model='NOT IMPORTANT',
  23. credentials={
  24. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  25. 'volc_region': 'cn-beijing',
  26. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  27. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  28. 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
  29. }
  30. )
  31. def test_invoke_model():
  32. model = VolcengineMaaSLargeLanguageModel()
  33. response = model.invoke(
  34. model='NOT IMPORTANT',
  35. credentials={
  36. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  37. 'volc_region': 'cn-beijing',
  38. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  39. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  40. 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
  41. 'base_model_name': 'Skylark2-pro-4k',
  42. },
  43. prompt_messages=[
  44. UserPromptMessage(
  45. content='Hello World!'
  46. )
  47. ],
  48. model_parameters={
  49. 'temperature': 0.7,
  50. 'top_p': 1.0,
  51. 'top_k': 1,
  52. },
  53. stop=['you'],
  54. user="abc-123",
  55. stream=False
  56. )
  57. assert isinstance(response, LLMResult)
  58. assert len(response.message.content) > 0
  59. assert response.usage.total_tokens > 0
  60. def test_invoke_stream_model():
  61. model = VolcengineMaaSLargeLanguageModel()
  62. response = model.invoke(
  63. model='NOT IMPORTANT',
  64. credentials={
  65. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  66. 'volc_region': 'cn-beijing',
  67. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  68. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  69. 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
  70. 'base_model_name': 'Skylark2-pro-4k',
  71. },
  72. prompt_messages=[
  73. UserPromptMessage(
  74. content='Hello World!'
  75. )
  76. ],
  77. model_parameters={
  78. 'temperature': 0.7,
  79. 'top_p': 1.0,
  80. 'top_k': 1,
  81. },
  82. stop=['you'],
  83. stream=True,
  84. user="abc-123"
  85. )
  86. assert isinstance(response, Generator)
  87. for chunk in response:
  88. assert isinstance(chunk, LLMResultChunk)
  89. assert isinstance(chunk.delta, LLMResultChunkDelta)
  90. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  91. assert len(
  92. chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  93. def test_get_num_tokens():
  94. model = VolcengineMaaSLargeLanguageModel()
  95. response = model.get_num_tokens(
  96. model='NOT IMPORTANT',
  97. credentials={
  98. 'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
  99. 'volc_region': 'cn-beijing',
  100. 'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
  101. 'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
  102. 'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
  103. 'base_model_name': 'Skylark2-pro-4k',
  104. },
  105. prompt_messages=[
  106. UserPromptMessage(
  107. content='Hello World!'
  108. )
  109. ],
  110. tools=[]
  111. )
  112. assert isinstance(response, int)
  113. assert response == 6