Sfoglia il codice sorgente

[Feature]: Update clas task opt and load pre_wight

geoyee 3 anni fa
parent
commit
167580337d

+ 15 - 15
paddlers/tasks/classifier.py

@@ -28,7 +28,7 @@ from .base import BaseModel
 from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models.ppcls.loss import build_loss
 from paddlers.models.ppcls.data.postprocess import build_postprocess
-from paddlers.utils.checkpoint import imagenet_weights
+from paddlers.utils.checkpoint import cls_pretrain_weights_dict
 from paddlers.transforms import Decode, Resize
 
 __all__ = ["ResNet50_vd", "MobileNetV3_small_x1_0", "HRNet_W18_C"]
@@ -60,9 +60,9 @@ class BaseClassifier(BaseModel):
         self.find_unused_parameters = True
 
     def build_net(self, **params):
-        # TODO: when using paddle.utils.unique_name.guard,
-        net = paddleclas.arch.backbone.__dict__[self.model_name](
-            class_num=self.num_classes, **params)
+        with paddle.utils.unique_name.guard():
+            net = paddleclas.arch.backbone.__dict__[self.model_name](
+                class_num=self.num_classes, **params)
         return net
 
     def _fix_transforms_shape(self, image_shape):
@@ -121,12 +121,11 @@ class BaseClassifier(BaseModel):
         return outputs
 
     def default_metric(self):
-        # TODO: other metrics
         default_config = [{"TopkAcc":{"topk": [1, 5]}}]
         return build_metrics(default_config) 
 
     def default_loss(self):
-        # TODO: mixed_loss
+        # TODO: use mixed loss and other loss
         default_config = [{"CELoss":{"weight": 1.0}}]
         return build_loss(default_config)
 
@@ -135,15 +134,16 @@ class BaseClassifier(BaseModel):
                           learning_rate,
                           num_epochs,
                           num_steps_each_epoch,
-                          lr_decay_power=0.9):
+                          last_epoch=-1,
+                          L2_coeff=0.00007):
         decay_step = num_epochs * num_steps_each_epoch
-        lr_scheduler = paddle.optimizer.lr.PolynomialDecay(
-            learning_rate, decay_step, end_lr=0, power=lr_decay_power)
+        lr_scheduler = paddle.optimizer.lr.CosineAnnealingDecay(
+            learning_rate, T_max=decay_step, eta_min=0, last_epoch=last_epoch)
         optimizer = paddle.optimizer.Momentum(
             learning_rate=lr_scheduler,
             parameters=parameters,
             momentum=0.9,
-            weight_decay=4e-5)
+            weight_decay=paddle.regularizer.L2Decay(L2_coeff))
         return optimizer
 
     def default_postprocess(self, class_id_map_file):
@@ -164,7 +164,7 @@ class BaseClassifier(BaseModel):
               log_interval_steps=2,
               save_dir='output',
               pretrain_weights='IMAGENET',
-              learning_rate=0.01,
+              learning_rate=0.1,
               lr_decay_power=0.9,
               early_stop=False,
               early_stop_patience=5,
@@ -219,7 +219,7 @@ class BaseClassifier(BaseModel):
             self.optimizer = optimizer
 
         if pretrain_weights is not None and not osp.exists(pretrain_weights):
-            if pretrain_weights not in imagenet_weights[
+            if pretrain_weights not in cls_pretrain_weights_dict[
                     self.model_name]:
                 logging.warning(
                     "Path of pretrain_weights('{}') does not exist!".format(
@@ -227,9 +227,9 @@ class BaseClassifier(BaseModel):
                 logging.warning("Pretrain_weights is forcibly set to '{}'. "
                                 "If don't want to use pretrain weights, "
                                 "set pretrain_weights to be None.".format(
-                                    imagenet_weights[self.model_name][
+                                    cls_pretrain_weights_dict[self.model_name][
                                         0]))
-                pretrain_weights = imagenet_weights[self.model_name][
+                pretrain_weights = cls_pretrain_weights_dict[self.model_name][
                     0]
         elif pretrain_weights is not None and osp.exists(pretrain_weights):
             if osp.splitext(pretrain_weights)[-1] != '.pdparams':
@@ -237,7 +237,7 @@ class BaseClassifier(BaseModel):
                     "Invalid pretrain weights. Please specify a '.pdparams' file.",
                     exit=True)
         pretrained_dir = osp.join(save_dir, 'pretrain')
-        is_backbone_weights = pretrain_weights == 'IMAGENET'
+        is_backbone_weights = False  # pretrain_weights == 'IMAGENET'  # TODO: this is backbone
         self.net_initialize(
             pretrain_weights=pretrain_weights,
             save_dir=pretrained_dir,

+ 6 - 0
paddlers/utils/checkpoint.py

@@ -19,6 +19,12 @@ import paddle
 from . import logging
 from .download import download_and_decompress
 
+cls_pretrain_weights_dict = {
+    'ResNet50_vd': ['IMAGENET'],
+    'MobileNetV3_small_x1_0': ['IMAGENET'],
+    'HRNet_W18_C': ['IMAGENET'],
+}
+
 seg_pretrain_weights_dict = {
     'UNet': ['CITYSCAPES'],
     'DeepLabV3P': ['CITYSCAPES', 'PascalVOC', 'IMAGENET'],

+ 2 - 3
tutorials/train/classification/resnet50_vd_rs.py

@@ -53,6 +53,5 @@ model.train(
     train_dataset=train_dataset,
     train_batch_size=4,
     eval_dataset=eval_dataset,
-    learning_rate=0.01,
-    pretrain_weights=None,
-    save_dir='output/resnet_vd')
+    learning_rate=0.1,
+    save_dir='output/resnet_vd')