Browse Source

Update and optimize functions

Bobholamovic 2 years ago
parent
commit
a4f46bee92

+ 3 - 2
paddlers/tasks/base.py

@@ -230,8 +230,9 @@ class BaseModel(metaclass=ModelMeta):
         model_info['status'] = self.status
 
         paddle.save(self.net.state_dict(), osp.join(save_dir, 'model.pdparams'))
-        paddle.save(self.optimizer.state_dict(),
-                    osp.join(save_dir, 'model.pdopt'))
+        if self.optimizer is not None:
+            paddle.save(self.optimizer.state_dict(),
+                        osp.join(save_dir, 'model.pdopt'))
 
         with open(
                 osp.join(save_dir, 'model.yml'), encoding='utf-8',

+ 3 - 3
paddlers/tasks/change_detector.py

@@ -139,7 +139,7 @@ class BaseChangeDetector(BaseModel):
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
-            label = inputs['mask']
+            label = inputs['mask'].astype('int64')
             if label.ndim == 3:
                 paddle.unsqueeze_(label, axis=1)
             if label.ndim != 4:
@@ -160,7 +160,7 @@ class BaseChangeDetector(BaseModel):
                 if 'aux_masks' not in inputs:
                     raise ValueError("Auxiliary masks not found.")
                 labels_list = [
-                    inputs['aux_masks'][idx]
+                    inputs['aux_masks'][idx].astype('int64')
                     for idx in map(attrgetter('value'), net.OUT_TYPES)
                 ]
                 loss_list = metrics.multitask_loss_computation(
@@ -170,7 +170,7 @@ class BaseChangeDetector(BaseModel):
             else:
                 loss_list = metrics.loss_computation(
                     logits_list=net_out,
-                    labels=inputs['mask'],
+                    labels=inputs['mask'].astype('int64'),
                     losses=self.losses)
             loss = sum(loss_list)
             outputs['loss'] = loss

+ 7 - 4
paddlers/tasks/segmenter.py

@@ -142,7 +142,7 @@ class BaseSegmenter(BaseModel):
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
-            label = inputs['mask']
+            label = inputs['mask'].astype('int64')
             if label.ndim == 3:
                 paddle.unsqueeze_(label, axis=1)
             if label.ndim != 4:
@@ -158,7 +158,9 @@ class BaseSegmenter(BaseModel):
                                                            self.num_classes)
         if mode == 'train':
             loss_list = metrics.loss_computation(
-                logits_list=net_out, labels=inputs['mask'], losses=self.losses)
+                logits_list=net_out,
+                labels=inputs['mask'].astype('int64'),
+                losses=self.losses)
             loss = sum(loss_list)
             outputs['loss'] = loss
         return outputs
@@ -947,7 +949,7 @@ class C2FNet(BaseSegmenter):
                 pred = paddle.unsqueeze(net_out[0], axis=1)  # NCHW
             else:
                 pred = paddle.argmax(logit, axis=1, keepdim=True, dtype='int32')
-            label = inputs['mask']
+            label = inputs['mask'].astype('int64')
             if label.ndim == 3:
                 paddle.unsqueeze_(label, axis=1)
             if label.ndim != 4:
@@ -962,7 +964,8 @@ class C2FNet(BaseSegmenter):
             outputs['conf_mat'] = metrics.confusion_matrix(pred, label,
                                                            self.num_classes)
         if mode == 'train':
-            net_out = net(inputs['image'], heatmaps, inputs['mask'])
+            net_out = net(inputs['image'], heatmaps,
+                          inputs['mask'].astype('int64'))
             logit = [net_out[0], ]
             labels = net_out[1]
             outputs = OrderedDict()

+ 3 - 2
paddlers/utils/checkpoint.py

@@ -563,5 +563,6 @@ def load_checkpoint(model,
         pretrain_weights=osp.join(checkpoint, 'model.pdparams'),
         model_name=model_name)
     if load_optim_state:
-        load_optimizer(
-            optimizer, state_dict_path=osp.join(checkpoint, "model.pdopt"))
+        optim_path = osp.join(checkpoint, 'model.pdopt')
+        if osp.exists(optim_path):
+            load_optimizer(optimizer, state_dict_path=optim_path)