|  | @@ -4,19 +4,15 @@ from typing import Any, Union
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from novita_client import (
 | 
	
		
			
				|  |  |      NovitaClient,
 | 
	
		
			
				|  |  | -    Txt2ImgV3Embedding,
 | 
	
		
			
				|  |  | -    Txt2ImgV3HiresFix,
 | 
	
		
			
				|  |  | -    Txt2ImgV3LoRA,
 | 
	
		
			
				|  |  | -    Txt2ImgV3Refiner,
 | 
	
		
			
				|  |  | -    V3TaskImage,
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  from core.tools.entities.tool_entities import ToolInvokeMessage
 | 
	
		
			
				|  |  |  from core.tools.errors import ToolProviderCredentialValidationError
 | 
	
		
			
				|  |  | +from core.tools.provider.builtin.novitaai._novita_tool_base import NovitaAiToolBase
 | 
	
		
			
				|  |  |  from core.tools.tool.builtin_tool import BuiltinTool
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -class NovitaAiTxt2ImgTool(BuiltinTool):
 | 
	
		
			
				|  |  | +class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
 | 
	
		
			
				|  |  |      def _invoke(self,
 | 
	
		
			
				|  |  |                  user_id: str,
 | 
	
		
			
				|  |  |                  tool_parameters: dict[str, Any],
 | 
	
	
		
			
				|  | @@ -73,65 +69,19 @@ class NovitaAiTxt2ImgTool(BuiltinTool):
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # process loras
 | 
	
		
			
				|  |  |          if 'loras' in res_parameters:
 | 
	
		
			
				|  |  | -            loras_ori_list = res_parameters.get('loras').strip().split(';')
 | 
	
		
			
				|  |  | -            locals_list = []
 | 
	
		
			
				|  |  | -            for lora_str in loras_ori_list:
 | 
	
		
			
				|  |  | -                lora_info = lora_str.strip().split(',')
 | 
	
		
			
				|  |  | -                lora = Txt2ImgV3LoRA(
 | 
	
		
			
				|  |  | -                    model_name=lora_info[0].strip(),
 | 
	
		
			
				|  |  | -                    strength=float(lora_info[1]),
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -                locals_list.append(lora)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            res_parameters['loras'] = locals_list
 | 
	
		
			
				|  |  | +            res_parameters['loras'] = self._extract_loras(res_parameters.get('loras'))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # process embeddings
 | 
	
		
			
				|  |  |          if 'embeddings' in res_parameters:
 | 
	
		
			
				|  |  | -            embeddings_ori_list = res_parameters.get('embeddings').strip().split(';')
 | 
	
		
			
				|  |  | -            locals_list = []
 | 
	
		
			
				|  |  | -            for embedding_str in embeddings_ori_list:
 | 
	
		
			
				|  |  | -                embedding = Txt2ImgV3Embedding(
 | 
	
		
			
				|  |  | -                    model_name=embedding_str.strip()
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -                locals_list.append(embedding)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            res_parameters['embeddings'] = locals_list
 | 
	
		
			
				|  |  | +            res_parameters['embeddings'] = self._extract_embeddings(res_parameters.get('embeddings'))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          # process hires_fix
 | 
	
		
			
				|  |  |          if 'hires_fix' in res_parameters:
 | 
	
		
			
				|  |  | -            hires_fix_ori = res_parameters.get('hires_fix')
 | 
	
		
			
				|  |  | -            hires_fix_info = hires_fix_ori.strip().split(',')
 | 
	
		
			
				|  |  | -            if 'upscaler' in hires_fix_info:
 | 
	
		
			
				|  |  | -                hires_fix = Txt2ImgV3HiresFix(
 | 
	
		
			
				|  |  | -                    target_width=int(hires_fix_info[0]),
 | 
	
		
			
				|  |  | -                    target_height=int(hires_fix_info[1]),
 | 
	
		
			
				|  |  | -                    strength=float(hires_fix_info[2]),
 | 
	
		
			
				|  |  | -                    upscaler=hires_fix_info[3].strip()
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                hires_fix = Txt2ImgV3HiresFix(
 | 
	
		
			
				|  |  | -                    target_width=int(hires_fix_info[0]),
 | 
	
		
			
				|  |  | -                    target_height=int(hires_fix_info[1]),
 | 
	
		
			
				|  |  | -                    strength=float(hires_fix_info[2])
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            res_parameters['hires_fix'] = hires_fix
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -            if 'refiner_switch_at' in res_parameters:
 | 
	
		
			
				|  |  | -                refiner = Txt2ImgV3Refiner(
 | 
	
		
			
				|  |  | -                    switch_at=float(res_parameters.get('refiner_switch_at'))
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -                del res_parameters['refiner_switch_at']
 | 
	
		
			
				|  |  | -                res_parameters['refiner'] = refiner
 | 
	
		
			
				|  |  | +            res_parameters['hires_fix'] = self._extract_hires_fix(res_parameters.get('hires_fix'))
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        return res_parameters
 | 
	
		
			
				|  |  | +        # process refiner
 | 
	
		
			
				|  |  | +        if 'refiner_switch_at' in res_parameters:
 | 
	
		
			
				|  |  | +            res_parameters['refiner'] = self._extract_refiner(res_parameters.get('refiner_switch_at'))
 | 
	
		
			
				|  |  | +            del res_parameters['refiner_switch_at']
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    def _is_hit_nsfw_detection(self, image: V3TaskImage, confidence_threshold: float) -> bool:
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | -            is hit nsfw
 | 
	
		
			
				|  |  | -        """
 | 
	
		
			
				|  |  | -        if image.nsfw_detection_result is None:
 | 
	
		
			
				|  |  | -            return False
 | 
	
		
			
				|  |  | -        if image.nsfw_detection_result.valid and image.nsfw_detection_result.confidence >= confidence_threshold:
 | 
	
		
			
				|  |  | -            return True
 | 
	
		
			
				|  |  | -        return False
 | 
	
		
			
				|  |  | +        return res_parameters
 |