Browse Source

chore: fix indention violations by applying E111 to E117 ruff rules (#4925)

Bowen Liang 10 months ago
parent
commit
f32b440c4a

+ 1 - 1
.github/workflows/style.yml

@@ -36,7 +36,7 @@ jobs:
 
       - name: Ruff check
         if: steps.changed-files.outputs.any_changed == 'true'
-        run: ruff check ./api
+        run: ruff check --preview ./api
 
       - name: Dotenv check
         if: steps.changed-files.outputs.any_changed == 'true'

+ 0 - 1
api/core/agent/base_agent_runner.py

@@ -528,4 +528,3 @@ class BaseAgentRunner(AppRunner):
                 return UserPromptMessage(content=prompt_message_contents)
         else:
             return UserPromptMessage(content=message.query)
-         

+ 16 - 16
api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py

@@ -57,23 +57,23 @@ class BaichuanModel:
         }[model]
 
     def _handle_chat_generate_response(self, response) -> BaichuanMessage:
-            resp = response.json()
-            choices = resp.get('choices', [])
-            message = BaichuanMessage(content='', role='assistant')
-            for choice in choices:
-                message.content += choice['message']['content']
-                message.role = choice['message']['role']
-                if choice['finish_reason']:
-                    message.stop_reason = choice['finish_reason']
+        resp = response.json()
+        choices = resp.get('choices', [])
+        message = BaichuanMessage(content='', role='assistant')
+        for choice in choices:
+            message.content += choice['message']['content']
+            message.role = choice['message']['role']
+            if choice['finish_reason']:
+                message.stop_reason = choice['finish_reason']
+
+        if 'usage' in resp:
+            message.usage = {
+                'prompt_tokens': resp['usage']['prompt_tokens'],
+                'completion_tokens': resp['usage']['completion_tokens'],
+                'total_tokens': resp['usage']['total_tokens'],
+            }
 
-            if 'usage' in resp:
-                message.usage = {
-                    'prompt_tokens': resp['usage']['prompt_tokens'],
-                    'completion_tokens': resp['usage']['completion_tokens'],
-                    'total_tokens': resp['usage']['total_tokens'],
-                }
-            
-            return message
+        return message
     
     def _handle_chat_stream_generate_response(self, response) -> Generator:
         for line in response.iter_lines():

+ 26 - 26
api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py

@@ -59,15 +59,15 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         model_prefix = model.split('.')[0]
          
         if model_prefix == "amazon" :
-           for text in texts:
-              body = {
+            for text in texts:
+                body = {
                  "inputText": text,
-              }
-              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
-              embeddings.extend([response_body.get('embedding')])
-              token_usage += response_body.get('inputTextTokenCount')
-           logger.warning(f'Total Tokens: {token_usage}')
-           result = TextEmbeddingResult(
+                }
+                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+                embeddings.extend([response_body.get('embedding')])
+                token_usage += response_body.get('inputTextTokenCount')
+            logger.warning(f'Total Tokens: {token_usage}')
+            result = TextEmbeddingResult(
                 model=model,
                 embeddings=embeddings,
                 usage=self._calc_response_usage(
@@ -75,20 +75,20 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
                     credentials=credentials,
                     tokens=token_usage
                 )
-           )
-           return result
-           
+            )
+            return result
+
         if model_prefix == "cohere" :
-           input_type = 'search_document' if len(texts) > 1 else 'search_query'
-           for text in texts:
-              body = {
+            input_type = 'search_document' if len(texts) > 1 else 'search_query'
+            for text in texts:
+                body = {
                  "texts": [text],
                  "input_type": input_type,
-              }
-              response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
-              embeddings.extend(response_body.get('embeddings'))
-              token_usage += len(text)
-           result = TextEmbeddingResult(
+                }
+                response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
+                embeddings.extend(response_body.get('embeddings'))
+                token_usage += len(text)
+            result = TextEmbeddingResult(
                 model=model,
                 embeddings=embeddings,
                 usage=self._calc_response_usage(
@@ -96,9 +96,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
                     credentials=credentials,
                     tokens=token_usage
                 )
-           )
-           return result
-        
+            )
+            return result
+
         #others
         raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
 
@@ -183,7 +183,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         )
 
         return usage
-    
+
     def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[InvokeError]:
         """
         Map client error to invoke error
@@ -212,9 +212,9 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
         content_type = 'application/json'
         try:
             response = bedrock_runtime.invoke_model(
-                body=json.dumps(body), 
-                modelId=model, 
-                accept=accept, 
+                body=json.dumps(body),
+                modelId=model,
+                accept=accept,
                 contentType=content_type
             )
             response_body = json.loads(response.get('body').read().decode('utf-8'))

+ 1 - 1
api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py

@@ -54,7 +54,7 @@ class PGVectoRS(BaseVector):
 
         class _Table(CollectionORM):
             __tablename__ = collection_name
-            __table_args__ = {"extend_existing": True}  # noqa: RUF012
+            __table_args__ = {"extend_existing": True}
             id: Mapped[UUID] = mapped_column(
                 postgresql.UUID(as_uuid=True),
                 primary_key=True,

+ 1 - 1
api/core/rag/datasource/vdb/relyt/relyt_vector.py

@@ -190,7 +190,7 @@ class RelytVector(BaseVector):
                     conn.execute(chunks_table.delete().where(delete_condition))
                     return True
         except Exception as e:
-            print("Delete operation failed:", str(e))  # noqa: T201
+            print("Delete operation failed:", str(e))
             return False
 
     def delete_by_metadata_field(self, key: str, value: str):

+ 1 - 1
api/core/rag/models/document.py

@@ -50,7 +50,7 @@ class BaseDocumentTransformer(ABC):
                 ) -> Sequence[Document]:
                     raise NotImplementedError
 
-    """  # noqa: E501
+    """
 
     @abstractmethod
     def transform_documents(

+ 1 - 1
api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py

@@ -68,7 +68,7 @@ class ArxivAPIWrapper(BaseModel):
 
         Args:
             query: a plaintext search query
-        """  # noqa: E501
+        """
         try:
             results = self.arxiv_search(  # type: ignore
                 query[: self.ARXIV_MAX_QUERY_LENGTH], max_results=self.top_k_results

+ 2 - 1
api/core/tools/provider/builtin/searxng/tools/searxng_search.py

@@ -121,4 +121,5 @@ class SearXNGSearchTool(BuiltinTool):
             query=query, 
             search_type=search_type, 
             result_type=result_type, 
-            topK=num_results)
+            topK=num_results
+        )

+ 2 - 2
api/core/tools/provider/builtin/twilio/tools/send_message.py

@@ -30,7 +30,7 @@ class TwilioAPIWrapper(BaseModel):
         Twilio also work here. You cannot, for example, spoof messages from a private 
         cell phone number. If you are using `messaging_service_sid`, this parameter 
         must be empty.
-    """  # noqa: E501
+    """
 
     @validator("client", pre=True, always=True)
     def set_validator(cls, values: dict) -> dict:
@@ -60,7 +60,7 @@ class TwilioAPIWrapper(BaseModel):
                 SMS/MMS or
                 [Channel user address](https://www.twilio.com/docs/sms/channels#channel-addresses)
                 for other 3rd-party channels.
-        """  # noqa: E501
+        """
         message = self.client.messages.create(to, from_=self.from_number, body=body)
         return message.sid
 

+ 10 - 8
api/core/tools/tool/tool.py

@@ -332,10 +332,11 @@ class Tool(BaseModel, ABC):
             :param text: the text
             :return: the text message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, 
-                                 message=text,
-                                 save_as=save_as
-                                 )
+        return ToolInvokeMessage(
+            type=ToolInvokeMessage.MessageType.TEXT,
+            message=text,
+            save_as=save_as
+        )
     
     def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
         """
@@ -344,7 +345,8 @@ class Tool(BaseModel, ABC):
             :param blob: the blob
             :return: the blob message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, 
-                                 message=blob, meta=meta,
-                                 save_as=save_as
-                                 )
+        return ToolInvokeMessage(
+            type=ToolInvokeMessage.MessageType.BLOB,
+            message=blob, meta=meta,
+            save_as=save_as
+        )

+ 10 - 0
api/pyproject.toml

@@ -13,8 +13,18 @@ select = [
     "F", # pyflakes rules
     "I", # isort rules
     "UP",   # pyupgrade rules
+    "E101", # mixed-spaces-and-tabs
+    "E111", # indentation-with-invalid-multiple
+    "E112", # no-indented-block
+    "E113", # unexpected-indentation
+    "E115", # no-indented-block-comment
+    "E116", # unexpected-indentation-comment
+    "E117", # over-indented
     "RUF019", # unnecessary-key-check
+    "RUF100", # unused-noqa
+    "RUF101", # redirected-noqa
     "S506", # unsafe-yaml-load
+    "W191", # tab-indentation
     "W605", # invalid-escape-sequence
 ]
 ignore = [

+ 1 - 1
dev/reformat

@@ -9,7 +9,7 @@ if ! command -v ruff &> /dev/null; then
 fi
 
 # run ruff linter
-ruff check --fix ./api
+ruff check --fix --preview ./api
 
 # env files linting relies on `dotenv-linter` in path
 if ! command -v dotenv-linter &> /dev/null; then

+ 1 - 1
web/.husky/pre-commit

@@ -31,7 +31,7 @@ if $api_modified; then
         pip install ruff
     fi
 
-    ruff check ./api || status=$?
+    ruff check --preview ./api || status=$?
 
     status=${status:-0}