Ver Fonte

Revert "update batch evaluate (#154)" (#171)

This reverts commit b5c716c8c872eaf5ff3f882b36b5b91e4dd75d10.
Lin Manhui há 1 ano atrás
pai
commit
42ac0a912c

+ 16 - 15
paddlers/tasks/change_detector.py

@@ -29,7 +29,7 @@ import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms import Resize, decode_image, construct_sample
-from paddlers.utils import to_data_parallel
+from paddlers.utils import get_single_card_bs
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
@@ -447,22 +447,25 @@ class BaseChangeDetector(BaseModel):
         """
         """
 
 
         self._check_transforms(eval_dataset.transforms)
         self._check_transforms(eval_dataset.transforms)
-        net = self.net
-        net.eval()
 
 
-        # XXX: Hard-coding
+        self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
-            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
-            ):
+            if not (paddle.distributed.parallel.parallel_helper.
+                    _is_parallel_ctx_initialized()):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-            else:
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
 
 
+        batch_size_each_card = get_single_card_bs(batch_size)
+        if batch_size_each_card > 1:
+            batch_size_each_card = 1
+            batch_size = batch_size_each_card * paddlers.env_info['num']
+            logging.warning(
+                "ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
+                "during evaluation, so batch_size " \
+                "is forcibly set to {}.".format(batch_size)
+            )
         self.eval_data_loader = self.build_data_loader(
         self.eval_data_loader = self.build_data_loader(
             eval_dataset, batch_size=batch_size, mode='eval')
             eval_dataset, batch_size=batch_size, mode='eval')
 
 
@@ -482,9 +485,9 @@ class BaseChangeDetector(BaseModel):
                             enable=True,
                             enable=True,
                             custom_white_list=self.custom_white_list,
                             custom_white_list=self.custom_white_list,
                             custom_black_list=self.custom_black_list):
                             custom_black_list=self.custom_black_list):
-                        outputs = self.run(net, data, 'eval')
+                        outputs = self.run(self.net, data, 'eval')
                 else:
                 else:
-                    outputs = self.run(net, data, 'eval')
+                    outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']
                 intersect_area = outputs['intersect_area']
@@ -691,8 +694,6 @@ class BaseChangeDetector(BaseModel):
                 else:
                 else:
                     raise RuntimeError
                     raise RuntimeError
             results.append(pred)
             results.append(pred)
-        if len(results) > 1:
-            results = [paddle.concat(results, axis=0)]
         return results
         return results
 
 
     def _infer_postprocess(self, batch_label_map, batch_score_map,
     def _infer_postprocess(self, batch_label_map, batch_score_map,

+ 40 - 54
paddlers/tasks/classifier.py

@@ -25,7 +25,6 @@ import paddlers.models.ppcls as ppcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models.ppcls.metric import build_metrics
 from paddlers.models.ppcls.metric import build_metrics
-from paddlers.utils import to_data_parallel
 from paddlers.models import clas_losses
 from paddlers.models import clas_losses
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
@@ -403,67 +402,54 @@ class BaseClassifier(BaseModel):
         """
         """
 
 
         self._check_transforms(eval_dataset.transforms)
         self._check_transforms(eval_dataset.transforms)
-        net = self.net
-        net.eval()
 
 
-        # XXX: Hard-coding
+        self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-            else:
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-
-        self.eval_data_loader = self.build_data_loader(
-            eval_dataset, batch_size=batch_size, mode='eval')
-        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, eval_dataset.num_samples))
-
-        top1s = []
-        top5s = []
-        with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
-                if self.precision == 'fp16':
-                    with paddle.amp.auto_cast(
-                            level=self.amp_level,
-                            enable=True,
-                            custom_white_list=self.custom_white_list,
-                            custom_black_list=self.custom_black_list):
-                        outputs = self.run(net, data, 'eval')
-                else:
-                    outputs = self.run(net, data, 'eval')
-                if nranks > 1:
-                    t1 = outputs["top1"]
-                    t5 = outputs["top5"]
-                    t1s = []
-                    t5s = []
-                    paddle.distributed.all_gather(t1s, t1)
-                    paddle.distributed.all_gather(t5s, t5)
-                    for rank_id in range(nranks):
-                        top1 = t1s[rank_id]
-                        top5 = t5s[rank_id]
-                        for i in range(data['image'].shape[0]):
-                            top1s.append(top1)
-                            top5s.append(top5)
-                else:
-                    for i in range(data['image'].shape[0]):
-                        top1s.append(outputs["top1"])
-                        top5s.append(outputs["top5"])
-
-        top1 = np.mean(top1s)
-        top5 = np.mean(top5s)
-        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
-
-        if return_details:
-            # TODO: Add details
-            return eval_metrics, None
 
 
-        return eval_metrics
+        if batch_size > 1:
+            logging.warning(
+                "Classifier only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            logging.info(
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
+
+            top1s = []
+            top5s = []
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
+                    top1s.append(outputs["top1"])
+                    top5s.append(outputs["top5"])
+
+            top1 = np.mean(top1s)
+            top5 = np.mean(top5s)
+            eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
+
+            if return_details:
+                # TODO: Add details
+                return eval_metrics, None
+
+            return eval_metrics
 
 
     @paddle.no_grad()
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):
     def predict(self, img_file, transforms=None):

+ 57 - 210
paddlers/tasks/object_detector.py

@@ -31,7 +31,6 @@ from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
 from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget, BatchPadRGT, BatchNormalizeImage
 from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget, BatchPadRGT, BatchNormalizeImage
 from paddlers.models.ppdet.optimizer import ModelEMA
 from paddlers.models.ppdet.optimizer import ModelEMA
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
-from paddlers.utils import to_data_parallel
 from paddlers.utils.checkpoint import det_pretrain_weights_dict
 from paddlers.utils.checkpoint import det_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric, RBoxMetric
 from .utils.det_metrics import VOCMetric, COCOMetric, RBoxMetric
@@ -630,223 +629,71 @@ class BaseDetector(BaseModel):
                                       self._default_collate_fn)
                                       self._default_collate_fn)
 
 
         self._check_transforms(eval_dataset.transforms)
         self._check_transforms(eval_dataset.transforms)
-        net = self.net
-        net.eval()
 
 
-        # XXX: Hard-coding
+        self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
+
+        if batch_size > 1:
+            logging.warning(
+                "Detector only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset,
+                batch_size=batch_size,
+                mode='eval',
+                collate_fn=eval_dataset.collate_fn)
+            is_bbox_normalized = False
+            if hasattr(eval_dataset, 'batch_transforms'):
+                is_bbox_normalized = any(
+                    isinstance(t, _NormalizeBox)
+                    for t in eval_dataset.batch_transforms.batch_transforms)
+            if self.metric == 'voc':
+                eval_metric = VOCMetric(
+                    labels=eval_dataset.labels,
+                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
+                    is_bbox_normalized=is_bbox_normalized,
+                    classwise=False)
+            elif self.metric == 'coco':
+                eval_metric = COCOMetric(
+                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
+                    classwise=False)
             else:
             else:
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-
-        self.eval_data_loader = self.build_data_loader(
-            eval_dataset,
-            batch_size=batch_size,
-            mode='eval',
-            collate_fn=eval_dataset.collate_fn)
-        is_bbox_normalized = False
-        if hasattr(eval_dataset, 'batch_transforms'):
-            is_bbox_normalized = any(
-                isinstance(t, _NormalizeBox)
-                for t in eval_dataset.batch_transforms.batch_transforms)
-        if self.metric == 'voc':
-            eval_metric = VOCMetric(
-                labels=eval_dataset.labels,
-                coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                is_bbox_normalized=is_bbox_normalized,
-                classwise=False)
-        elif self.metric == 'coco':
-            eval_metric = COCOMetric(
-                coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False)
-        else:
-            assert hasattr(eval_dataset, 'get_anno_path')
-            eval_metric = RBoxMetric(
-                anno_file=eval_dataset.get_anno_path(), classwise=False)
-        scores = collections.OrderedDict()
-        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, eval_dataset.num_samples))
-        with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
-                if self.precision == 'fp16':
-                    with paddle.amp.auto_cast(
-                            level=self.amp_level,
-                            enable=True,
-                            custom_white_list=self.custom_white_list,
-                            custom_black_list=self.custom_black_list):
-                        outputs = self.run(net, data, 'eval')
-                else:
-                    outputs = self.run(net, data, 'eval')
-
-                sum_num = 0
-                if nranks > 1:
-                    for i in range(outputs['bbox_num'].shape[0]):
-                        output_bbox_num = outputs['bbox_num'][i:i + 1].cuda()
-
-                        start_id = sum_num
-                        sum_num += int(output_bbox_num)
-                        end_id = sum_num
-                        output_bbox = outputs['bbox'][start_id:end_id + 1].cuda(
-                        )
-
-                        data_single_im_id = data['im_id'][i].unsqueeze(0)
-                        data_single_image = data['image'][i].unsqueeze(0)
-                        data_single_image_shape = data['image_shape'][
-                            i].unsqueeze(0)
-                        data_single_im_shape = data['im_shape'][i].unsqueeze(0)
-                        data_single_scale_factor = data['scale_factor'][
-                            i].unsqueeze(0)
-                        data_single_permuted = data['permuted'][i]
-                        data_single_gt_bbox = data['gt_bbox'][i]
-                        data_single_gt_bbox_num = paddle.to_tensor(
-                            data_single_gt_bbox.shape[0])
-                        data_single_difficult = data['difficult'][i]
-                        data_single_gt_class = data['gt_class'][i]
-
-                        output_bbox_num_list = []
-                        paddle.distributed.all_gather(output_bbox_num_list,
-                                                      output_bbox_num)
-                        max_num = paddle.max(
-                            paddle.concat(output_bbox_num_list))
-                        if len(output_bbox) < max_num:
-                            tp_box = output_bbox[0:1].clone()
-                            pad_box = tp_box.tile(
-                                (max_num - len(output_bbox), 1))
-                            output_bbox_pad = paddle.concat(
-                                [output_bbox, pad_box], axis=0)
-                        else:
-                            output_bbox_pad = output_bbox
-                        output_bbox_list = []
-                        paddle.distributed.all_gather(output_bbox_list,
-                                                      output_bbox_pad)
-                        data_single_im_id_list = []
-                        paddle.distributed.all_gather(data_single_im_id_list,
-                                                      data_single_im_id)
-                        data_single_image_list = []
-                        paddle.distributed.all_gather(data_single_image_list,
-                                                      data_single_image)
-                        data_single_image_shape_list = []
-                        paddle.distributed.all_gather(
-                            data_single_image_shape_list,
-                            data_single_image_shape)
-                        data_single_im_shape_list = []
-                        paddle.distributed.all_gather(data_single_im_shape_list,
-                                                      data_single_im_shape)
-                        data_single_scale_factor_list = []
-                        paddle.distributed.all_gather(
-                            data_single_scale_factor_list,
-                            data_single_scale_factor)
-                        data_single_permuted_list = []
-                        paddle.distributed.all_gather(data_single_permuted_list,
-                                                      data_single_permuted)
-                        data_single_gt_bbox_num_list = []
-                        paddle.distributed.all_gather(
-                            data_single_gt_bbox_num_list,
-                            data_single_gt_bbox_num)
-                        max_num = paddle.max(
-                            paddle.concat(output_bbox_num_list))
-                        if data_single_gt_bbox.shape[0] < max_num:
-                            tp_box = data_single_gt_bbox[0:1].clone()
-                            pad_box = tp_box.tile(
-                                (max_num - data_single_gt_bbox.shape[0], 1))
-                            data_single_gt_bbox_pad = paddle.concat(
-                                [data_single_gt_bbox, pad_box], axis=0)
-                            tp_diff = data_single_difficult[0:1].clone()
-                            pad_diff = tp_diff.tile(
-                                (max_num - data_single_gt_bbox.shape[0], 1))
-                            data_single_difficult_pad = paddle.concat(
-                                [data_single_difficult, pad_diff], axis=0)
-                            tp_glass = data_single_gt_class[0:1].clone()
-                            pad_glass = tp_glass.tile(
-                                (max_num - data_single_gt_bbox.shape[0], 1))
-                            data_single_gt_class_pad = paddle.concat(
-                                [data_single_gt_class, pad_glass], axis=0)
-                        else:
-                            data_single_gt_bbox_pad = data_single_gt_bbox_pad
-                        data_single_gt_bbox_list = []
-                        paddle.distributed.all_gather(data_single_gt_bbox_list,
-                                                      data_single_gt_bbox_pad)
-                        data_single_difficult_list = []
-                        paddle.distributed.all_gather(
-                            data_single_difficult_list,
-                            data_single_difficult_pad)
-                        data_single_gt_class_list = []
-                        paddle.distributed.all_gather(data_single_gt_class_list,
-                                                      data_single_gt_class_pad)
-
-                        for rank_id in range(nranks):
-                            output = {}
-                            data_single = {}
-                            output['bbox_num'] = output_bbox_num_list[rank_id]
-                            output['bbox'] = output_bbox_list[
-                                rank_id][:output_bbox_num_list[rank_id]]
-
-                            data_single['im_id'] = data_single_im_id_list[
-                                rank_id]
-                            data_single['image'] = data_single_image_list[
-                                rank_id]
-                            data_single[
-                                'image_shape'] = data_single_image_shape_list[
-                                    rank_id]
-                            data_single['im_shape'] = data_single_im_shape_list[
-                                rank_id]
-                            data_single[
-                                'scale_factor'] = data_single_scale_factor_list[
-                                    rank_id]
-                            data_single['permuted'] = data_single_permuted_list[
-                                rank_id]
-                            box_num = data_single_gt_bbox_num_list[rank_id]
-                            data_single['gt_bbox'] = [
-                                data_single_gt_bbox_list[rank_id][:box_num]
-                            ]
-                            data_single['difficult'] = [
-                                data_single_difficult_list[rank_id][:box_num]
-                            ]
-                            data_single['gt_class'] = [
-                                data_single_gt_class_list[rank_id][:box_num]
-                            ]
-
-                            eval_metric.update(data_single, output)
-                else:
-                    for i in range(outputs['bbox_num'].shape[0]):
-                        output = {}
-                        output['bbox_num'] = outputs['bbox_num'][i:i + 1]
-
-                        start_id = sum_num
-                        sum_num += int(output['bbox_num'])
-                        end_id = sum_num
-                        output['bbox'] = outputs['bbox'][start_id:end_id + 1]
-
-                        data_single = {}
-                        data_single['im_id'] = data['im_id'][i].unsqueeze(0)
-                        data_single['image'] = data['image'][i].unsqueeze(0)
-                        data_single['image_shape'] = data['image_shape'][
-                            i].unsqueeze(0)
-                        data_single['im_shape'] = data['im_shape'][i].unsqueeze(
-                            0)
-                        data_single['scale_factor'] = data['scale_factor'][
-                            i].unsqueeze(0)
-                        data_single['permuted'] = data['permuted'][i]
-                        data_single['gt_bbox'] = [data['gt_bbox'][i]]
-                        data_single['difficult'] = [data['difficult'][i]]
-                        data_single['gt_class'] = [data['gt_class'][i]]
-
-                        eval_metric.update(data_single, output)
-
-            eval_metric.accumulate()
-            self.eval_details = eval_metric.details
-            scores.update(eval_metric.get())
-            eval_metric.reset()
-
-        if return_details:
-            return scores, self.eval_details
-        return scores
+                assert hasattr(eval_dataset, 'get_anno_path')
+                eval_metric = RBoxMetric(
+                    anno_file=eval_dataset.get_anno_path(), classwise=False)
+            scores = collections.OrderedDict()
+            logging.info(
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
+                    eval_metric.update(data, outputs)
+                eval_metric.accumulate()
+                self.eval_details = eval_metric.details
+                scores.update(eval_metric.get())
+                eval_metric.reset()
+
+            if return_details:
+                return scores, self.eval_details
+            return scores
 
 
     @paddle.no_grad()
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):
     def predict(self, img_file, transforms=None):

+ 41 - 45
paddlers/tasks/restorer.py

@@ -30,7 +30,6 @@ from paddlers.models import res_losses
 from paddlers.models.ppgan.modules.init import init_weights
 from paddlers.models.ppgan.modules.init import init_weights
 from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms.functions import calc_hr_shape
 from paddlers.transforms.functions import calc_hr_shape
-from paddlers.utils import to_data_parallel
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils.res_adapters import GANAdapter, OptimizerAdapter
 from .utils.res_adapters import GANAdapter, OptimizerAdapter
@@ -415,59 +414,58 @@ class BaseRestorer(BaseModel):
         """
         """
 
 
         self._check_transforms(eval_dataset.transforms)
         self._check_transforms(eval_dataset.transforms)
-        net = self.net
-        net.eval()
 
 
-        # XXX: Hard-coding
+        self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-            else:
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-
-        self.eval_data_loader = self.build_data_loader(
-            eval_dataset, batch_size=batch_size, mode='eval')
-        # XXX: Hard-code crop_border and test_y_channel
-        psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
-        ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
-        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
-                     format(eval_dataset.num_samples, eval_dataset.num_samples))
-        with paddle.no_grad():
-            for step, data in enumerate(self.eval_data_loader):
-                if self.precision == 'fp16':
-                    with paddle.amp.auto_cast(
-                            level=self.amp_level,
-                            enable=True,
-                            custom_white_list=self.custom_white_list,
-                            custom_black_list=self.custom_black_list):
-                        outputs = self.run(net, data, 'eval')
-                else:
-                    outputs = self.run(net, data, 'eval')
-                if len(outputs['pred'].shape) > 3:
-                    for i in range(batch_size):
-                        psnr.update(outputs['pred'][i], outputs['tar'][i])
-                        ssim.update(outputs['pred'][i], outputs['tar'][i])
-                else:
+
+        # TODO: Distributed evaluation
+        if batch_size > 1:
+            logging.warning(
+                "Restorer only supports single card evaluation with batch_size=1 "
+                "during evaluation, so batch_size is forcibly set to 1.")
+            batch_size = 1
+
+        if nranks < 2 or local_rank == 0:
+            self.eval_data_loader = self.build_data_loader(
+                eval_dataset, batch_size=batch_size, mode='eval')
+            # XXX: Hard-code crop_border and test_y_channel
+            psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
+            ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
+            logging.info(
+                "Start to evaluate (total_samples={}, total_steps={})...".
+                format(eval_dataset.num_samples, eval_dataset.num_samples))
+            with paddle.no_grad():
+                for step, data in enumerate(self.eval_data_loader):
+                    if self.precision == 'fp16':
+                        with paddle.amp.auto_cast(
+                                level=self.amp_level,
+                                enable=True,
+                                custom_white_list=self.custom_white_list,
+                                custom_black_list=self.custom_black_list):
+                            outputs = self.run(self.net, data, 'eval')
+                    else:
+                        outputs = self.run(self.net, data, 'eval')
                     psnr.update(outputs['pred'], outputs['tar'])
                     psnr.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
-        # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
-        assert len(psnr.results) > 0
-        assert len(ssim.results) > 0
-        eval_metrics = OrderedDict(
-            zip(['psnr', 'ssim'],
-                [np.mean(psnr.results), np.mean(ssim.results)]))
 
 
-        if return_details:
-            # TODO: Add details
-            return eval_metrics, None
+            # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
+            assert len(psnr.results) > 0
+            assert len(ssim.results) > 0
+            eval_metrics = OrderedDict(
+                zip(['psnr', 'ssim'],
+                    [np.mean(psnr.results), np.mean(ssim.results)]))
+
+            if return_details:
+                # TODO: Add details
+                return eval_metrics, None
 
 
-        return eval_metrics
+            return eval_metrics
 
 
     @paddle.no_grad()
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):
     def predict(self, img_file, transforms=None):
@@ -555,8 +553,6 @@ class BaseRestorer(BaseModel):
                 else:
                 else:
                     pass
                     pass
             results.append(pred)
             results.append(pred)
-        if len(results) > 1:
-            results = [paddle.concat(results, axis=0)]
         return results
         return results
 
 
     def _infer_postprocess(self, batch_res_map, batch_restore_list):
     def _infer_postprocess(self, batch_res_map, batch_restore_list):

+ 13 - 14
paddlers/tasks/segmenter.py

@@ -28,7 +28,7 @@ import paddlers.rs_models.seg as cmseg
 import paddlers.utils.logging as logging
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms import Resize, decode_image, construct_sample
-from paddlers.utils import DisablePrint, to_data_parallel
+from paddlers.utils import get_single_card_bs, DisablePrint
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .base import BaseModel
 from .utils import seg_metrics as metrics
 from .utils import seg_metrics as metrics
@@ -429,22 +429,23 @@ class BaseSegmenter(BaseModel):
         """
         """
 
 
         self._check_transforms(eval_dataset.transforms)
         self._check_transforms(eval_dataset.transforms)
-        net = self.net
-        net.eval()
-
-        # XXX: Hard-coding
+        self.net.eval()
         nranks = paddle.distributed.get_world_size()
         nranks = paddle.distributed.get_world_size()
+        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
         if nranks > 1:
             # Initialize parallel environment if not done.
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
             ):
                 paddle.distributed.init_parallel_env()
                 paddle.distributed.init_parallel_env()
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
-            else:
-                net = to_data_parallel(
-                    net, find_unused_parameters=self.find_unused_parameters)
 
 
+        batch_size_each_card = get_single_card_bs(batch_size)
+        if batch_size_each_card > 1:
+            batch_size_each_card = 1
+            batch_size = batch_size_each_card * paddlers.env_info['num']
+            logging.warning(
+                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
+                "during evaluation, so batch_size " \
+                "is forcibly set to {}.".format(batch_size))
         self.eval_data_loader = self.build_data_loader(
         self.eval_data_loader = self.build_data_loader(
             eval_dataset, batch_size=batch_size, mode='eval')
             eval_dataset, batch_size=batch_size, mode='eval')
 
 
@@ -464,9 +465,9 @@ class BaseSegmenter(BaseModel):
                             enable=True,
                             enable=True,
                             custom_white_list=self.custom_white_list,
                             custom_white_list=self.custom_white_list,
                             custom_black_list=self.custom_black_list):
                             custom_black_list=self.custom_black_list):
-                        outputs = self.run(net, data, 'eval')
+                        outputs = self.run(self.net, data, 'eval')
                 else:
                 else:
-                    outputs = self.run(net, data, 'eval')
+                    outputs = self.run(self.net, data, 'eval')
                 pred_area = outputs['pred_area']
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']
                 intersect_area = outputs['intersect_area']
@@ -657,8 +658,6 @@ class BaseSegmenter(BaseModel):
                 else:
                 else:
                     raise RuntimeError
                     raise RuntimeError
             results.append(pred)
             results.append(pred)
-        if len(results) > 1:
-            results = [paddle.concat(results, axis=0)]
         return results
         return results
 
 
     def _infer_postprocess(self, batch_label_map, batch_score_map,
     def _infer_postprocess(self, batch_label_map, batch_score_map,

+ 4 - 15
paddlers/tasks/utils/slider_predict.py

@@ -512,21 +512,10 @@ def slider_predict(predict_func,
                 batch_out = predict_func(batch_data, transforms=transforms)
                 batch_out = predict_func(batch_data, transforms=transforms)
 
 
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
-                    if len(out['label_map'].shape) == 3:
-                        for i in range(out['label_map'].shape[0]):
-                            out_single = {}
-                            out_single['label_map'] = out['label_map'][i]
-                            out_single['score_map'] = out['score_map'][i]
-                            # Get processed result
-                            pred = overlap_processor.process_pred(out_single,
-                                                                  xoff_, yoff_)
-                            # Write to file
-                            band.WriteArray(pred, xoff_, yoff_)
-                    else:
-                        # Get processed result
-                        pred = overlap_processor.process_pred(out, xoff_, yoff_)
-                        # Write to file
-                        band.WriteArray(pred, xoff_, yoff_)
+                    # Get processed result
+                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
+                    # Write to file
+                    band.WriteArray(pred, xoff_, yoff_)
 
 
                 batch_data.clear()
                 batch_data.clear()
                 batch_offsets.clear()
                 batch_offsets.clear()