浏览代码

fix: replace os.path.join with yarl (#2690)

Yeuoly 1 年之前
父节点
当前提交
95733796f0

+ 5 - 3
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -1,10 +1,10 @@
-from os import path
 from threading import Lock
 from time import time
 
 from requests.adapters import HTTPAdapter
 from requests.exceptions import ConnectionError, MissingSchema, Timeout
 from requests.sessions import Session
+from yarl import URL
 
 
 class XinferenceModelExtraParameter:
@@ -55,7 +55,10 @@ class XinferenceHelper:
             get xinference model extra parameter like model_format and model_handle_type
         """
 
-        url = path.join(server_url, 'v1/models', model_uid)
+        if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
+            raise RuntimeError('model_uid is empty')
+
+        url = str(URL(server_url) / 'v1' / 'models' / model_uid)
 
         # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
         session = Session()
@@ -66,7 +69,6 @@ class XinferenceHelper:
             response = session.get(url, timeout=10)
         except (MissingSchema, ConnectionError, Timeout) as e:
             raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
-
         if response.status_code != 200:
             raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
         

+ 2 - 1
api/requirements.txt

@@ -68,4 +68,5 @@ pydub~=0.25.1
 gmpy2~=2.1.5
 numexpr~=2.9.0
 duckduckgo-search==4.4.3
-arxiv==2.1.0
+arxiv==2.1.0
+yarl~=1.9.4

+ 41 - 39
api/tests/integration_tests/model_runtime/__mock/xinference.py

@@ -32,68 +32,70 @@ class MockXinferenceClass(object):
         response = Response()
         if 'v1/models/' in url:
             # get model uid
-            model_uid = url.split('/')[-1]
+            model_uid = url.split('/')[-1] or ''
             if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
                 model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
                 response.status_code = 404
+                response._content = b'{}'
                 return response
 
             # check if url is valid
             if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
                 response.status_code = 404
+                response._content = b'{}'
                 return response
             
             if model_uid in ['generate', 'chat']:
                 response.status_code = 200
                 response._content = b'''{
-        "model_type": "LLM",
-        "address": "127.0.0.1:43877",
-        "accelerators": [
-            "0",
-            "1"
-        ],
-        "model_name": "chatglm3-6b",
-        "model_lang": [
-            "en"
-        ],
-        "model_ability": [
-            "generate",
-            "chat"
-        ],
-        "model_description": "latest chatglm3",
-        "model_format": "pytorch",
-        "model_size_in_billions": 7,
-        "quantization": "none",
-        "model_hub": "huggingface",
-        "revision": null,
-        "context_length": 2048,
-        "replica": 1
-    }'''
+                    "model_type": "LLM",
+                    "address": "127.0.0.1:43877",
+                    "accelerators": [
+                        "0",
+                        "1"
+                    ],
+                    "model_name": "chatglm3-6b",
+                    "model_lang": [
+                        "en"
+                    ],
+                    "model_ability": [
+                        "generate",
+                        "chat"
+                    ],
+                    "model_description": "latest chatglm3",
+                    "model_format": "pytorch",
+                    "model_size_in_billions": 7,
+                    "quantization": "none",
+                    "model_hub": "huggingface",
+                    "revision": null,
+                    "context_length": 2048,
+                    "replica": 1
+                }'''
                 return response
             
             elif model_uid == 'embedding':
                 response.status_code = 200
                 response._content = b'''{
-        "model_type": "embedding",
-        "address": "127.0.0.1:43877",
-        "accelerators": [
-            "0",
-            "1"
-        ],
-        "model_name": "bge",
-        "model_lang": [
-            "en"
-        ],
-        "revision": null,
-        "max_tokens": 512
-}'''
+                    "model_type": "embedding",
+                    "address": "127.0.0.1:43877",
+                    "accelerators": [
+                        "0",
+                        "1"
+                    ],
+                    "model_name": "bge",
+                    "model_lang": [
+                        "en"
+                    ],
+                    "revision": null,
+                    "max_tokens": 512
+                }'''
                 return response
             
         elif 'v1/cluster/auth' in url:
             response.status_code = 200
             response._content = b'''{
-    "auth": true
-}'''
+                "auth": true
+            }'''
             return response
         
     def _check_cluster_authenticated(self):