소스 검색

[Feature] pass change detection run test

geoyee 3 년 전
부모
커밋
26ec44f1d9
6개의 변경된 파일115개의 추가작업 그리고 24개의 파일을 삭제
  1. 3 0
      .gitignore
  2. 1 0
      paddlers/tasks/__init__.py
  3. 5 12
      paddlers/tasks/changedetector.py
  4. 1 1
      paddlers/transforms/__init__.py
  5. 47 11
      paddlers/transforms/operators.py
  6. 58 0
      tutorials/train/change_detection/cdnet_build.py

+ 3 - 0
.gitignore

@@ -128,3 +128,6 @@ dmypy.json
 
 # Pyre type checker
 .pyre/
+
+# testdata
+tutorials/train/change_detection/DataSet/

+ 1 - 0
paddlers/tasks/__init__.py

@@ -14,4 +14,5 @@
 
 from . import det
 from .segmenter import *
+from .changedetector import *
 from .load_model import load_model

+ 5 - 12
paddlers/tasks/changedetector.py

@@ -29,7 +29,7 @@ from .base import BaseModel
 from .utils import seg_metrics as metrics
 from paddlers.utils.checkpoint import seg_pretrain_weights_dict
 from paddlers.transforms import Decode, Resize
-from paddlers.models.ppcd import CDNet
+from paddlers.models.ppcd import CDNet as _CDNet
 
 __all__ = ["CDNet"]
 
@@ -59,7 +59,7 @@ class BaseChangeDetector(BaseModel):
 
     def build_net(self, **params):
         # TODO: add other model
-        net = CDNet(num_classes=self.num_classes, **params)
+        net = _CDNet(num_classes=self.num_classes, **params)
         return net
 
     def _fix_transforms_shape(self, image_shape):
@@ -174,14 +174,7 @@ class BaseChangeDetector(BaseModel):
                 paddleseg.models.MixedLoss(
                     losses=losses, coef=list(coef))
             ]
-        if self.model_name == 'FastSCNN':
-            loss_type *= 2
-            loss_coef = [1.0, 0.4]
-        elif self.model_name == 'BiSeNetV2':
-            loss_type *= 5
-            loss_coef = [1.0] * 5
-        else:
-            loss_coef = [1.0]
+        loss_coef = [1.0]
         losses = {'types': loss_type, 'coef': loss_coef}
         return losses
 
@@ -584,7 +577,7 @@ class BaseChangeDetector(BaseModel):
         return batch_restore_list
 
     def _postprocess(self, batch_pred, batch_origin_shape, transforms):
-        batch_restore_list = BaseSegmenter.get_transforms_shape_info(
+        batch_restore_list = BaseChangeDetector.get_transforms_shape_info(
             batch_origin_shape, transforms)
         if isinstance(batch_pred, (tuple, list)) and self.status == 'Infer':
             return self._infer_postprocess(
@@ -665,7 +658,7 @@ class CDNet(BaseChangeDetector):
                  **params):
         params.update({'in_channels': in_channels})
         super(CDNet, self).__init__(
-            model_name='UNet',
+            model_name='CDNet',
             num_classes=num_classes,
             use_mixed_loss=use_mixed_loss,
             **params)

+ 1 - 1
paddlers/transforms/__init__.py

@@ -25,7 +25,7 @@ def arrange_transforms(model_type, transforms, mode='train'):
         else:
             transforms.apply_im_only = False
         arrange_transform = ArrangeSegmenter(mode)
-    elif model_type == 'changedetctor':
+    elif model_type == 'changedetector':
         if mode == 'eval':
             transforms.apply_im_only = True
         else:

+ 47 - 11
paddlers/transforms/operators.py

@@ -35,8 +35,8 @@ __all__ = [
     "RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
     "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
     "RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
-    "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeClassifier",
-    "ArrangeDetector"
+    "RandomDistort", "RandomBlur", "ArrangeSegmenter", "ArrangeChangeDetector", 
+    "ArrangeClassifier", "ArrangeDetector"
 ]
 
 interp_dict = {
@@ -69,7 +69,11 @@ class Transform(object):
         pass
 
     def apply(self, sample):
-        sample['image'] = self.apply_im(sample['image'])
+        if 'image' in sample:
+            sample['image'] = self.apply_im(sample['image'])
+        else:  # image_tx
+            sample['image'] = self.apply_im(sample['image_t1'])
+            sample['image2'] = self.apply_im(sample['image_t2'])
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
         if 'gt_bbox' in sample:
@@ -175,7 +179,7 @@ class Decode(Transform):
                 return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                   cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
             else:
-                return cv2.imread(im_file, cv2.IMREAD_ANYDEPTH |
+                return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
                                   cv2.IMREAD_ANYCOLOR)
         elif ext == '.npy':
             return np.load(img_path)
@@ -218,7 +222,11 @@ class Decode(Transform):
             dict: Decoded sample.
 
         """
-        sample['image'] = self.apply_im(sample['image'])
+        if 'image' in sample:
+            sample['image'] = self.apply_im(sample['image'])
+        else:  # image_tx
+            sample['image'] = self.apply_im(sample['image_t1'])
+            sample['image2'] = self.apply_im(sample['image_t2'])
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
             im_height, im_width, _ = sample['image'].shape
@@ -323,6 +331,8 @@ class Resize(Transform):
             im_scale_x = target_w / im_w
 
         sample['image'] = self.apply_im(sample['image'], interp, target_size)
+        if 'image2' in sample:
+            sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
 
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], target_size)
@@ -523,6 +533,8 @@ class RandomHorizontalFlip(Transform):
         if random.random() < self.prob:
             im_h, im_w = sample['image'].shape[:2]
             sample['image'] = self.apply_im(sample['image'])
+            if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'])
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@@ -576,6 +588,8 @@ class RandomVerticalFlip(Transform):
         if random.random() < self.prob:
             im_h, im_w = sample['image'].shape[:2]
             sample['image'] = self.apply_im(sample['image'])
+            if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'])
             if 'mask' in sample:
                 sample['mask'] = self.apply_mask(sample['mask'])
             if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@@ -636,6 +650,8 @@ class Normalize(Transform):
 
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'])
 
         return sample
 
@@ -665,6 +681,8 @@ class CenterCrop(Transform):
 
     def apply(self, sample):
         sample['image'] = self.apply_im(sample['image'])
+        if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'])
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'])
         return sample
@@ -819,6 +837,8 @@ class RandomCrop(Transform):
             crop_box, cropped_box, valid_ids = crop_info
             im_h, im_w = sample['image'].shape[:2]
             sample['image'] = self.apply_im(sample['image'], crop_box)
+            if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'], crop_box)
             if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
                 crop_polys = self._crop_segm(
                     sample['gt_poly'],
@@ -1045,6 +1065,8 @@ class Padding(Transform):
             offsets = [w - im_w, h - im_h]
 
         sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
+        if 'image2' in sample:
+                sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
         if 'mask' in sample:
             sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
         if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
@@ -1239,22 +1261,33 @@ class RandomDistort(Transform):
             distortions = np.random.permutation(functions)[:self.count]
             for func in distortions:
                 sample['image'] = func(sample['image'])
+                if 'image2' in sample:
+                    sample['image2'] = func(sample['image2'])
             return sample
 
         sample['image'] = self.apply_brightness(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_brightness(sample['image2'])
         mode = np.random.randint(0, 2)
         if mode:
             sample['image'] = self.apply_contrast(sample['image'])
+            if 'image2' in sample:
+                sample['image2'] = self.apply_contrast(sample['image2'])
         sample['image'] = self.apply_saturation(sample['image'])
         sample['image'] = self.apply_hue(sample['image'])
+        if 'image2' in sample:
+            sample['image2'] = self.apply_saturation(sample['image2'])
+            sample['image2'] = self.apply_hue(sample['image2'])
         if not mode:
             sample['image'] = self.apply_contrast(sample['image'])
+            if 'image2' in sample:
+                sample['image2'] = self.apply_contrast(sample['image2'])
 
         if self.shuffle_channel:
             if np.random.randint(0, 2):
-                sample['image'] = sample['image'][..., np.random.permutation(
-                    3)]
-
+                sample['image'] = sample['image'][..., np.random.permutation(3)]
+                if 'image2' in sample:
+                    sample['image2'] = sample['image2'][..., np.random.permutation(3)]
         return sample
 
 
@@ -1289,7 +1322,8 @@ class RandomBlur(Transform):
                 if radius > 9:
                     radius = 9
                 sample['image'] = self.apply_im(sample['image'], radius)
-
+                if 'image2' in sample:
+                    sample['image2'] = self.apply_im(sample['image2'], radius)
         return sample
 
 
@@ -1374,6 +1408,8 @@ class _Permute(Transform):
 
     def apply(self, sample):
         sample['image'] = permute(sample['image'], False)
+        if 'image2' in sample:
+            sample['image2'] = permute(sample['image2'], False)
         return sample
 
 
@@ -1415,8 +1451,8 @@ class ArrangeChangeDetector(Transform):
         if 'mask' in sample:
             mask = sample['mask']
 
-        image_t1 = permute(sample['image_t1'], False)
-        image_t2 = permute(sample['image_t2'], False)
+        image_t1 = permute(sample['image'], False)
+        image_t2 = permute(sample['image2'], False)
         if self.mode == 'train':
             mask = mask.astype('int64')
             return image_t1, image_t2, mask

+ 58 - 0
tutorials/train/change_detection/cdnet_build.py

@@ -0,0 +1,58 @@
+import sys
+
+sys.path.append("E:/dataFiles/github/PaddleRS")
+
+import paddlers as pdrs
+from paddlers import transforms as T
+
+# 下载aistudio的数据到当前文件夹并解压、整理
+# https://aistudio.baidu.com/aistudio/datasetdetail/53795
+
+# 定义训练和验证时的transforms
+# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/transforms/transforms.md
+train_transforms = T.Compose([
+    T.Resize(target_size=512),
+    T.RandomHorizontalFlip(),
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+])
+
+eval_transforms = T.Compose([
+    T.Resize(target_size=512),
+    T.Normalize(
+        mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/datasets.md
+train_dataset = pdrs.datasets.CDDataset(
+    data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet',
+    file_list='tutorials/train/change_detection/DataSet/train.txt',
+    label_list='tutorials/train/change_detection/DataSet/labels.txt',
+    transforms=train_transforms,
+    num_workers=0,
+    shuffle=True)
+
+eval_dataset = pdrs.datasets.CDDataset(
+    data_dir='E:/dataFiles/github/PaddleRS/tutorials/train/change_detection/DataSet',
+    file_list='tutorials/train/change_detection/DataSet/val.txt',
+    label_list='tutorials/train/change_detection/DataSet/labels.txt',
+    transforms=eval_transforms,
+    num_workers=0,
+    shuffle=False)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/paddlers/blob/develop/docs/visualdl.md
+num_classes = len(train_dataset.labels)
+model = pdrs.tasks.CDNet(num_classes=num_classes, in_channels=6)
+
+# API说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/apis/models/semantic_segmentation.md
+# 各参数介绍与调整说明:https://github.com/PaddlePaddle/paddlers/blob/develop/docs/parameters.md
+model.train(
+    num_epochs=1,
+    train_dataset=train_dataset,
+    train_batch_size=4,
+    eval_dataset=eval_dataset,
+    learning_rate=0.01,
+    pretrain_weights=None,
+    save_dir='output/cdnet')