test_llm.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import os
  2. from collections.abc import Generator
  3. import pytest
  4. from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
  5. from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
  6. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  7. from core.model_runtime.model_providers.spark.llm.llm import SparkLargeLanguageModel
  8. def test_validate_credentials():
  9. model = SparkLargeLanguageModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='spark-1.5',
  13. credentials={
  14. 'app_id': 'invalid_key'
  15. }
  16. )
  17. model.validate_credentials(
  18. model='spark-1.5',
  19. credentials={
  20. 'app_id': os.environ.get('SPARK_APP_ID'),
  21. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  22. 'api_key': os.environ.get('SPARK_API_KEY')
  23. }
  24. )
  25. def test_invoke_model():
  26. model = SparkLargeLanguageModel()
  27. response = model.invoke(
  28. model='spark-1.5',
  29. credentials={
  30. 'app_id': os.environ.get('SPARK_APP_ID'),
  31. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  32. 'api_key': os.environ.get('SPARK_API_KEY')
  33. },
  34. prompt_messages=[
  35. UserPromptMessage(
  36. content='Who are you?'
  37. )
  38. ],
  39. model_parameters={
  40. 'temperature': 0.5,
  41. 'max_tokens': 10
  42. },
  43. stop=['How'],
  44. stream=False,
  45. user="abc-123"
  46. )
  47. assert isinstance(response, LLMResult)
  48. assert len(response.message.content) > 0
  49. def test_invoke_stream_model():
  50. model = SparkLargeLanguageModel()
  51. response = model.invoke(
  52. model='spark-1.5',
  53. credentials={
  54. 'app_id': os.environ.get('SPARK_APP_ID'),
  55. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  56. 'api_key': os.environ.get('SPARK_API_KEY')
  57. },
  58. prompt_messages=[
  59. UserPromptMessage(
  60. content='Hello World!'
  61. )
  62. ],
  63. model_parameters={
  64. 'temperature': 0.5,
  65. 'max_tokens': 100
  66. },
  67. stream=True,
  68. user="abc-123"
  69. )
  70. assert isinstance(response, Generator)
  71. for chunk in response:
  72. assert isinstance(chunk, LLMResultChunk)
  73. assert isinstance(chunk.delta, LLMResultChunkDelta)
  74. assert isinstance(chunk.delta.message, AssistantPromptMessage)
  75. assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
  76. def test_get_num_tokens():
  77. model = SparkLargeLanguageModel()
  78. num_tokens = model.get_num_tokens(
  79. model='spark-1.5',
  80. credentials={
  81. 'app_id': os.environ.get('SPARK_APP_ID'),
  82. 'api_secret': os.environ.get('SPARK_API_SECRET'),
  83. 'api_key': os.environ.get('SPARK_API_KEY')
  84. },
  85. prompt_messages=[
  86. SystemPromptMessage(
  87. content='You are a helpful AI assistant.',
  88. ),
  89. UserPromptMessage(
  90. content='Hello World!'
  91. )
  92. ]
  93. )
  94. assert num_tokens == 14