瀏覽代碼

update batch evaluate (#154)

* update batch evaluate

* update batch evaluate

* update batch evaluate for classifier and object_detector

* update style of slider_predict.py
huilin 1 年之前
父節點
當前提交
b5c716c8c8

+ 15 - 16
paddlers/tasks/change_detector.py

@@ -29,7 +29,7 @@ import paddlers.rs_models.cd as cmcd
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image, construct_sample
-from paddlers.utils import get_single_card_bs
+from paddlers.utils import to_data_parallel
 from paddlers.utils.checkpoint import cd_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
@@ -447,25 +447,22 @@ class BaseChangeDetector(BaseModel):
         """
 
         self._check_transforms(eval_dataset.transforms)
+        net = self.net
+        net.eval()
 
-        self.net.eval()
+        # XXX: Hard-coding
         nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
-            if not (paddle.distributed.parallel.parallel_helper.
-                    _is_parallel_ctx_initialized()):
+            if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
+            ):
                 paddle.distributed.init_parallel_env()
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+            else:
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
 
-        batch_size_each_card = get_single_card_bs(batch_size)
-        if batch_size_each_card > 1:
-            batch_size_each_card = 1
-            batch_size = batch_size_each_card * paddlers.env_info['num']
-            logging.warning(
-                "ChangeDetector only supports batch_size=1 for each gpu/cpu card " \
-                "during evaluation, so batch_size " \
-                "is forcibly set to {}.".format(batch_size)
-            )
         self.eval_data_loader = self.build_data_loader(
             eval_dataset, batch_size=batch_size, mode='eval')
 
@@ -485,9 +482,9 @@ class BaseChangeDetector(BaseModel):
                             enable=True,
                             custom_white_list=self.custom_white_list,
                             custom_black_list=self.custom_black_list):
-                        outputs = self.run(self.net, data, 'eval')
+                        outputs = self.run(net, data, 'eval')
                 else:
-                    outputs = self.run(self.net, data, 'eval')
+                    outputs = self.run(net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']
@@ -694,6 +691,8 @@ class BaseChangeDetector(BaseModel):
                 else:
                     raise RuntimeError
             results.append(pred)
+        if len(results) > 1:
+            results = [paddle.concat(results, axis=0)]
         return results
 
     def _infer_postprocess(self, batch_label_map, batch_score_map,

+ 54 - 40
paddlers/tasks/classifier.py

@@ -25,6 +25,7 @@ import paddlers.models.ppcls as ppcls
 import paddlers.rs_models.clas as cmcls
 import paddlers.utils.logging as logging
 from paddlers.models.ppcls.metric import build_metrics
+from paddlers.utils import to_data_parallel
 from paddlers.models import clas_losses
 from paddlers.models.ppcls.data.postprocess import build_postprocess
 from paddlers.utils.checkpoint import cls_pretrain_weights_dict
@@ -402,54 +403,67 @@ class BaseClassifier(BaseModel):
         """
 
         self._check_transforms(eval_dataset.transforms)
+        net = self.net
+        net.eval()
 
-        self.net.eval()
+        # XXX: Hard-coding
         nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+            else:
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+
+        self.eval_data_loader = self.build_data_loader(
+            eval_dataset, batch_size=batch_size, mode='eval')
+        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
+                     format(eval_dataset.num_samples, eval_dataset.num_samples))
+
+        top1s = []
+        top5s = []
+        with paddle.no_grad():
+            for step, data in enumerate(self.eval_data_loader):
+                if self.precision == 'fp16':
+                    with paddle.amp.auto_cast(
+                            level=self.amp_level,
+                            enable=True,
+                            custom_white_list=self.custom_white_list,
+                            custom_black_list=self.custom_black_list):
+                        outputs = self.run(net, data, 'eval')
+                else:
+                    outputs = self.run(net, data, 'eval')
+                if nranks > 1:
+                    t1 = outputs["top1"]
+                    t5 = outputs["top5"]
+                    t1s = []
+                    t5s = []
+                    paddle.distributed.all_gather(t1s, t1)
+                    paddle.distributed.all_gather(t5s, t5)
+                    for rank_id in range(nranks):
+                        top1 = t1s[rank_id]
+                        top5 = t5s[rank_id]
+                        for i in range(data['image'].shape[0]):
+                            top1s.append(top1)
+                            top5s.append(top5)
+                else:
+                    for i in range(data['image'].shape[0]):
+                        top1s.append(outputs["top1"])
+                        top5s.append(outputs["top5"])
+
+        top1 = np.mean(top1s)
+        top5 = np.mean(top5s)
+        eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
+
+        if return_details:
+            # TODO: Add details
+            return eval_metrics, None
 
-        if batch_size > 1:
-            logging.warning(
-                "Classifier only supports single card evaluation with batch_size=1 "
-                "during evaluation, so batch_size is forcibly set to 1.")
-            batch_size = 1
-
-        if nranks < 2 or local_rank == 0:
-            self.eval_data_loader = self.build_data_loader(
-                eval_dataset, batch_size=batch_size, mode='eval')
-            logging.info(
-                "Start to evaluate (total_samples={}, total_steps={})...".
-                format(eval_dataset.num_samples, eval_dataset.num_samples))
-
-            top1s = []
-            top5s = []
-            with paddle.no_grad():
-                for step, data in enumerate(self.eval_data_loader):
-                    if self.precision == 'fp16':
-                        with paddle.amp.auto_cast(
-                                level=self.amp_level,
-                                enable=True,
-                                custom_white_list=self.custom_white_list,
-                                custom_black_list=self.custom_black_list):
-                            outputs = self.run(self.net, data, 'eval')
-                    else:
-                        outputs = self.run(self.net, data, 'eval')
-                    top1s.append(outputs["top1"])
-                    top5s.append(outputs["top5"])
-
-            top1 = np.mean(top1s)
-            top5 = np.mean(top5s)
-            eval_metrics = OrderedDict(zip(['top1', 'top5'], [top1, top5]))
-
-            if return_details:
-                # TODO: Add details
-                return eval_metrics, None
-
-            return eval_metrics
+        return eval_metrics
 
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):

+ 210 - 57
paddlers/tasks/object_detector.py

@@ -31,6 +31,7 @@ from paddlers.transforms.operators import _NormalizeBox, _PadBox, _BboxXYXY2XYWH
 from paddlers.transforms.batch_operators import BatchCompose, _BatchPad, _Gt2YoloTarget, BatchPadRGT, BatchNormalizeImage
 from paddlers.models.ppdet.optimizer import ModelEMA
 import paddlers.utils.logging as logging
+from paddlers.utils import to_data_parallel
 from paddlers.utils.checkpoint import det_pretrain_weights_dict
 from .base import BaseModel
 from .utils.det_metrics import VOCMetric, COCOMetric, RBoxMetric
@@ -629,71 +630,223 @@ class BaseDetector(BaseModel):
                                       self._default_collate_fn)
 
         self._check_transforms(eval_dataset.transforms)
+        net = self.net
+        net.eval()
 
-        self.net.eval()
+        # XXX: Hard-coding
         nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
-
-        if batch_size > 1:
-            logging.warning(
-                "Detector only supports single card evaluation with batch_size=1 "
-                "during evaluation, so batch_size is forcibly set to 1.")
-            batch_size = 1
-
-        if nranks < 2 or local_rank == 0:
-            self.eval_data_loader = self.build_data_loader(
-                eval_dataset,
-                batch_size=batch_size,
-                mode='eval',
-                collate_fn=eval_dataset.collate_fn)
-            is_bbox_normalized = False
-            if hasattr(eval_dataset, 'batch_transforms'):
-                is_bbox_normalized = any(
-                    isinstance(t, _NormalizeBox)
-                    for t in eval_dataset.batch_transforms.batch_transforms)
-            if self.metric == 'voc':
-                eval_metric = VOCMetric(
-                    labels=eval_dataset.labels,
-                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                    is_bbox_normalized=is_bbox_normalized,
-                    classwise=False)
-            elif self.metric == 'coco':
-                eval_metric = COCOMetric(
-                    coco_gt=copy.deepcopy(eval_dataset.coco_gt),
-                    classwise=False)
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
             else:
-                assert hasattr(eval_dataset, 'get_anno_path')
-                eval_metric = RBoxMetric(
-                    anno_file=eval_dataset.get_anno_path(), classwise=False)
-            scores = collections.OrderedDict()
-            logging.info(
-                "Start to evaluate (total_samples={}, total_steps={})...".
-                format(eval_dataset.num_samples, eval_dataset.num_samples))
-            with paddle.no_grad():
-                for step, data in enumerate(self.eval_data_loader):
-                    if self.precision == 'fp16':
-                        with paddle.amp.auto_cast(
-                                level=self.amp_level,
-                                enable=True,
-                                custom_white_list=self.custom_white_list,
-                                custom_black_list=self.custom_black_list):
-                            outputs = self.run(self.net, data, 'eval')
-                    else:
-                        outputs = self.run(self.net, data, 'eval')
-                    eval_metric.update(data, outputs)
-                eval_metric.accumulate()
-                self.eval_details = eval_metric.details
-                scores.update(eval_metric.get())
-                eval_metric.reset()
-
-            if return_details:
-                return scores, self.eval_details
-            return scores
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+
+        self.eval_data_loader = self.build_data_loader(
+            eval_dataset,
+            batch_size=batch_size,
+            mode='eval',
+            collate_fn=eval_dataset.collate_fn)
+        is_bbox_normalized = False
+        if hasattr(eval_dataset, 'batch_transforms'):
+            is_bbox_normalized = any(
+                isinstance(t, _NormalizeBox)
+                for t in eval_dataset.batch_transforms.batch_transforms)
+        if self.metric == 'voc':
+            eval_metric = VOCMetric(
+                labels=eval_dataset.labels,
+                coco_gt=copy.deepcopy(eval_dataset.coco_gt),
+                is_bbox_normalized=is_bbox_normalized,
+                classwise=False)
+        elif self.metric == 'coco':
+            eval_metric = COCOMetric(
+                coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False)
+        else:
+            assert hasattr(eval_dataset, 'get_anno_path')
+            eval_metric = RBoxMetric(
+                anno_file=eval_dataset.get_anno_path(), classwise=False)
+        scores = collections.OrderedDict()
+        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
+                     format(eval_dataset.num_samples, eval_dataset.num_samples))
+        with paddle.no_grad():
+            for step, data in enumerate(self.eval_data_loader):
+                if self.precision == 'fp16':
+                    with paddle.amp.auto_cast(
+                            level=self.amp_level,
+                            enable=True,
+                            custom_white_list=self.custom_white_list,
+                            custom_black_list=self.custom_black_list):
+                        outputs = self.run(net, data, 'eval')
+                else:
+                    outputs = self.run(net, data, 'eval')
+
+                sum_num = 0
+                if nranks > 1:
+                    for i in range(outputs['bbox_num'].shape[0]):
+                        output_bbox_num = outputs['bbox_num'][i:i + 1].cuda()
+
+                        start_id = sum_num
+                        sum_num += int(output_bbox_num)
+                        end_id = sum_num
+                        output_bbox = outputs['bbox'][start_id:end_id + 1].cuda(
+                        )
+
+                        data_single_im_id = data['im_id'][i].unsqueeze(0)
+                        data_single_image = data['image'][i].unsqueeze(0)
+                        data_single_image_shape = data['image_shape'][
+                            i].unsqueeze(0)
+                        data_single_im_shape = data['im_shape'][i].unsqueeze(0)
+                        data_single_scale_factor = data['scale_factor'][
+                            i].unsqueeze(0)
+                        data_single_permuted = data['permuted'][i]
+                        data_single_gt_bbox = data['gt_bbox'][i]
+                        data_single_gt_bbox_num = paddle.to_tensor(
+                            data_single_gt_bbox.shape[0])
+                        data_single_difficult = data['difficult'][i]
+                        data_single_gt_class = data['gt_class'][i]
+
+                        output_bbox_num_list = []
+                        paddle.distributed.all_gather(output_bbox_num_list,
+                                                      output_bbox_num)
+                        max_num = paddle.max(
+                            paddle.concat(output_bbox_num_list))
+                        if len(output_bbox) < max_num:
+                            tp_box = output_bbox[0:1].clone()
+                            pad_box = tp_box.tile(
+                                (max_num - len(output_bbox), 1))
+                            output_bbox_pad = paddle.concat(
+                                [output_bbox, pad_box], axis=0)
+                        else:
+                            output_bbox_pad = output_bbox
+                        output_bbox_list = []
+                        paddle.distributed.all_gather(output_bbox_list,
+                                                      output_bbox_pad)
+                        data_single_im_id_list = []
+                        paddle.distributed.all_gather(data_single_im_id_list,
+                                                      data_single_im_id)
+                        data_single_image_list = []
+                        paddle.distributed.all_gather(data_single_image_list,
+                                                      data_single_image)
+                        data_single_image_shape_list = []
+                        paddle.distributed.all_gather(
+                            data_single_image_shape_list,
+                            data_single_image_shape)
+                        data_single_im_shape_list = []
+                        paddle.distributed.all_gather(data_single_im_shape_list,
+                                                      data_single_im_shape)
+                        data_single_scale_factor_list = []
+                        paddle.distributed.all_gather(
+                            data_single_scale_factor_list,
+                            data_single_scale_factor)
+                        data_single_permuted_list = []
+                        paddle.distributed.all_gather(data_single_permuted_list,
+                                                      data_single_permuted)
+                        data_single_gt_bbox_num_list = []
+                        paddle.distributed.all_gather(
+                            data_single_gt_bbox_num_list,
+                            data_single_gt_bbox_num)
+                        max_num = paddle.max(
+                            paddle.concat(output_bbox_num_list))
+                        if data_single_gt_bbox.shape[0] < max_num:
+                            tp_box = data_single_gt_bbox[0:1].clone()
+                            pad_box = tp_box.tile(
+                                (max_num - data_single_gt_bbox.shape[0], 1))
+                            data_single_gt_bbox_pad = paddle.concat(
+                                [data_single_gt_bbox, pad_box], axis=0)
+                            tp_diff = data_single_difficult[0:1].clone()
+                            pad_diff = tp_diff.tile(
+                                (max_num - data_single_gt_bbox.shape[0], 1))
+                            data_single_difficult_pad = paddle.concat(
+                                [data_single_difficult, pad_diff], axis=0)
+                            tp_glass = data_single_gt_class[0:1].clone()
+                            pad_glass = tp_glass.tile(
+                                (max_num - data_single_gt_bbox.shape[0], 1))
+                            data_single_gt_class_pad = paddle.concat(
+                                [data_single_gt_class, pad_glass], axis=0)
+                        else:
+                            data_single_gt_bbox_pad = data_single_gt_bbox_pad
+                        data_single_gt_bbox_list = []
+                        paddle.distributed.all_gather(data_single_gt_bbox_list,
+                                                      data_single_gt_bbox_pad)
+                        data_single_difficult_list = []
+                        paddle.distributed.all_gather(
+                            data_single_difficult_list,
+                            data_single_difficult_pad)
+                        data_single_gt_class_list = []
+                        paddle.distributed.all_gather(data_single_gt_class_list,
+                                                      data_single_gt_class_pad)
+
+                        for rank_id in range(nranks):
+                            output = {}
+                            data_single = {}
+                            output['bbox_num'] = output_bbox_num_list[rank_id]
+                            output['bbox'] = output_bbox_list[
+                                rank_id][:output_bbox_num_list[rank_id]]
+
+                            data_single['im_id'] = data_single_im_id_list[
+                                rank_id]
+                            data_single['image'] = data_single_image_list[
+                                rank_id]
+                            data_single[
+                                'image_shape'] = data_single_image_shape_list[
+                                    rank_id]
+                            data_single['im_shape'] = data_single_im_shape_list[
+                                rank_id]
+                            data_single[
+                                'scale_factor'] = data_single_scale_factor_list[
+                                    rank_id]
+                            data_single['permuted'] = data_single_permuted_list[
+                                rank_id]
+                            box_num = data_single_gt_bbox_num_list[rank_id]
+                            data_single['gt_bbox'] = [
+                                data_single_gt_bbox_list[rank_id][:box_num]
+                            ]
+                            data_single['difficult'] = [
+                                data_single_difficult_list[rank_id][:box_num]
+                            ]
+                            data_single['gt_class'] = [
+                                data_single_gt_class_list[rank_id][:box_num]
+                            ]
+
+                            eval_metric.update(data_single, output)
+                else:
+                    for i in range(outputs['bbox_num'].shape[0]):
+                        output = {}
+                        output['bbox_num'] = outputs['bbox_num'][i:i + 1]
+
+                        start_id = sum_num
+                        sum_num += int(output['bbox_num'])
+                        end_id = sum_num
+                        output['bbox'] = outputs['bbox'][start_id:end_id + 1]
+
+                        data_single = {}
+                        data_single['im_id'] = data['im_id'][i].unsqueeze(0)
+                        data_single['image'] = data['image'][i].unsqueeze(0)
+                        data_single['image_shape'] = data['image_shape'][
+                            i].unsqueeze(0)
+                        data_single['im_shape'] = data['im_shape'][i].unsqueeze(
+                            0)
+                        data_single['scale_factor'] = data['scale_factor'][
+                            i].unsqueeze(0)
+                        data_single['permuted'] = data['permuted'][i]
+                        data_single['gt_bbox'] = [data['gt_bbox'][i]]
+                        data_single['difficult'] = [data['difficult'][i]]
+                        data_single['gt_class'] = [data['gt_class'][i]]
+
+                        eval_metric.update(data_single, output)
+
+            eval_metric.accumulate()
+            self.eval_details = eval_metric.details
+            scores.update(eval_metric.get())
+            eval_metric.reset()
+
+        if return_details:
+            return scores, self.eval_details
+        return scores
 
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):

+ 45 - 41
paddlers/tasks/restorer.py

@@ -30,6 +30,7 @@ from paddlers.models import res_losses
 from paddlers.models.ppgan.modules.init import init_weights
 from paddlers.transforms import Resize, decode_image, construct_sample
 from paddlers.transforms.functions import calc_hr_shape
+from paddlers.utils import to_data_parallel
 from paddlers.utils.checkpoint import res_pretrain_weights_dict
 from .base import BaseModel
 from .utils.res_adapters import GANAdapter, OptimizerAdapter
@@ -414,58 +415,59 @@ class BaseRestorer(BaseModel):
         """
 
         self._check_transforms(eval_dataset.transforms)
+        net = self.net
+        net.eval()
 
-        self.net.eval()
+        # XXX: Hard-coding
         nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
-
-        # TODO: Distributed evaluation
-        if batch_size > 1:
-            logging.warning(
-                "Restorer only supports single card evaluation with batch_size=1 "
-                "during evaluation, so batch_size is forcibly set to 1.")
-            batch_size = 1
-
-        if nranks < 2 or local_rank == 0:
-            self.eval_data_loader = self.build_data_loader(
-                eval_dataset, batch_size=batch_size, mode='eval')
-            # XXX: Hard-code crop_border and test_y_channel
-            psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
-            ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
-            logging.info(
-                "Start to evaluate (total_samples={}, total_steps={})...".
-                format(eval_dataset.num_samples, eval_dataset.num_samples))
-            with paddle.no_grad():
-                for step, data in enumerate(self.eval_data_loader):
-                    if self.precision == 'fp16':
-                        with paddle.amp.auto_cast(
-                                level=self.amp_level,
-                                enable=True,
-                                custom_white_list=self.custom_white_list,
-                                custom_black_list=self.custom_black_list):
-                            outputs = self.run(self.net, data, 'eval')
-                    else:
-                        outputs = self.run(self.net, data, 'eval')
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+            else:
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+
+        self.eval_data_loader = self.build_data_loader(
+            eval_dataset, batch_size=batch_size, mode='eval')
+        # XXX: Hard-code crop_border and test_y_channel
+        psnr = metrics.PSNR(crop_border=4, test_y_channel=True)
+        ssim = metrics.SSIM(crop_border=4, test_y_channel=True)
+        logging.info("Start to evaluate (total_samples={}, total_steps={})...".
+                     format(eval_dataset.num_samples, eval_dataset.num_samples))
+        with paddle.no_grad():
+            for step, data in enumerate(self.eval_data_loader):
+                if self.precision == 'fp16':
+                    with paddle.amp.auto_cast(
+                            level=self.amp_level,
+                            enable=True,
+                            custom_white_list=self.custom_white_list,
+                            custom_black_list=self.custom_black_list):
+                        outputs = self.run(net, data, 'eval')
+                else:
+                    outputs = self.run(net, data, 'eval')
+                if len(outputs['pred'].shape) > 3:
+                    for i in range(batch_size):
+                        psnr.update(outputs['pred'][i], outputs['tar'][i])
+                        ssim.update(outputs['pred'][i], outputs['tar'][i])
+                else:
                     psnr.update(outputs['pred'], outputs['tar'])
                     ssim.update(outputs['pred'], outputs['tar'])
+        # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
+        assert len(psnr.results) > 0
+        assert len(ssim.results) > 0
+        eval_metrics = OrderedDict(
+            zip(['psnr', 'ssim'],
+                [np.mean(psnr.results), np.mean(ssim.results)]))
 
-            # DO NOT use psnr.accumulate() here, otherwise the program hangs in multi-card training.
-            assert len(psnr.results) > 0
-            assert len(ssim.results) > 0
-            eval_metrics = OrderedDict(
-                zip(['psnr', 'ssim'],
-                    [np.mean(psnr.results), np.mean(ssim.results)]))
-
-            if return_details:
-                # TODO: Add details
-                return eval_metrics, None
+        if return_details:
+            # TODO: Add details
+            return eval_metrics, None
 
-            return eval_metrics
+        return eval_metrics
 
     @paddle.no_grad()
     def predict(self, img_file, transforms=None):
@@ -553,6 +555,8 @@ class BaseRestorer(BaseModel):
                 else:
                     pass
             results.append(pred)
+        if len(results) > 1:
+            results = [paddle.concat(results, axis=0)]
         return results
 
     def _infer_postprocess(self, batch_res_map, batch_restore_list):

+ 14 - 13
paddlers/tasks/segmenter.py

@@ -28,7 +28,7 @@ import paddlers.rs_models.seg as cmseg
 import paddlers.utils.logging as logging
 from paddlers.models import seg_losses
 from paddlers.transforms import Resize, decode_image, construct_sample
-from paddlers.utils import get_single_card_bs, DisablePrint
+from paddlers.utils import DisablePrint, to_data_parallel
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from .base import BaseModel
 from .utils import seg_metrics as metrics
@@ -429,23 +429,22 @@ class BaseSegmenter(BaseModel):
         """
 
         self._check_transforms(eval_dataset.transforms)
-        self.net.eval()
+        net = self.net
+        net.eval()
+
+        # XXX: Hard-coding
         nranks = paddle.distributed.get_world_size()
-        local_rank = paddle.distributed.get_rank()
         if nranks > 1:
             # Initialize parallel environment if not done.
             if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
             ):
                 paddle.distributed.init_parallel_env()
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
+            else:
+                net = to_data_parallel(
+                    net, find_unused_parameters=self.find_unused_parameters)
 
-        batch_size_each_card = get_single_card_bs(batch_size)
-        if batch_size_each_card > 1:
-            batch_size_each_card = 1
-            batch_size = batch_size_each_card * paddlers.env_info['num']
-            logging.warning(
-                "Segmenter only supports batch_size=1 for each gpu/cpu card " \
-                "during evaluation, so batch_size " \
-                "is forcibly set to {}.".format(batch_size))
         self.eval_data_loader = self.build_data_loader(
             eval_dataset, batch_size=batch_size, mode='eval')
 
@@ -465,9 +464,9 @@ class BaseSegmenter(BaseModel):
                             enable=True,
                             custom_white_list=self.custom_white_list,
                             custom_black_list=self.custom_black_list):
-                        outputs = self.run(self.net, data, 'eval')
+                        outputs = self.run(net, data, 'eval')
                 else:
-                    outputs = self.run(self.net, data, 'eval')
+                    outputs = self.run(net, data, 'eval')
                 pred_area = outputs['pred_area']
                 label_area = outputs['label_area']
                 intersect_area = outputs['intersect_area']
@@ -658,6 +657,8 @@ class BaseSegmenter(BaseModel):
                 else:
                     raise RuntimeError
             results.append(pred)
+        if len(results) > 1:
+            results = [paddle.concat(results, axis=0)]
         return results
 
     def _infer_postprocess(self, batch_label_map, batch_score_map,

+ 15 - 4
paddlers/tasks/utils/slider_predict.py

@@ -512,10 +512,21 @@ def slider_predict(predict_func,
                 batch_out = predict_func(batch_data, transforms=transforms)
 
                 for out, (xoff_, yoff_) in zip(batch_out, batch_offsets):
-                    # Get processed result
-                    pred = overlap_processor.process_pred(out, xoff_, yoff_)
-                    # Write to file
-                    band.WriteArray(pred, xoff_, yoff_)
+                    if len(out['label_map'].shape) == 3:
+                        for i in range(out['label_map'].shape[0]):
+                            out_single = {}
+                            out_single['label_map'] = out['label_map'][i]
+                            out_single['score_map'] = out['score_map'][i]
+                            # Get processed result
+                            pred = overlap_processor.process_pred(out_single,
+                                                                  xoff_, yoff_)
+                            # Write to file
+                            band.WriteArray(pred, xoff_, yoff_)
+                    else:
+                        # Get processed result
+                        pred = overlap_processor.process_pred(out, xoff_, yoff_)
+                        # Write to file
+                        band.WriteArray(pred, xoff_, yoff_)
 
                 batch_data.clear()
                 batch_offsets.clear()