Procházet zdrojové kódy

[feature] add stanet reproduction (#73)

sun222 před 2 roky
rodič
revize
574e5180cc

+ 11 - 6
paddlers/custom_models/cd/stanet.py

@@ -215,7 +215,7 @@ class BAM(nn.Layer):
 
         out = F.interpolate(out, scale_factor=self.ds)
         out = out + x
-        return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
+        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
 
 
 class PAMBlock(nn.Layer):
@@ -241,7 +241,8 @@ class PAMBlock(nn.Layer):
         value = self.conv_v(x_rs)
 
         # Split the whole image into subregions.
-        b, c, h, w = paddle.shape(x_rs)
+        b, c, h, w = x_rs.shape
+
         query = self._split_subregions(query)
         key = self._split_subregions(key)
         value = self._split_subregions(value)
@@ -263,12 +264,14 @@ class PAMBlock(nn.Layer):
         return out
 
     def _split_subregions(self, x):
-        b, c, h, w = paddle.shape(x)
+        b, c, h, w = x.shape
         assert h % self.scale == 0 and w % self.scale == 0
         x = x.reshape(
             (b, c, self.scale, h // self.scale, self.scale, w // self.scale))
-        x = x.transpose((0, 2, 4, 1, 3, 5)).reshape(
-            (b * self.scale * self.scale, c, -1))
+
+        x = x.transpose((0, 2, 4, 1, 3, 5))
+
+        x = x.reshape((b * self.scale * self.scale, c, -1))
         return x
 
     def _recons_whole(self, x, b, c, h, w):
@@ -290,8 +293,10 @@ class PAM(nn.Layer):
     def forward(self, x):
         x = x.flatten(-2)
         res = [stage(x) for stage in self.stages]
+
         out = self.conv_out(paddle.concat(res, axis=1))
-        return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
+
+        return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
 
 
 class Attention(nn.Layer):

+ 1 - 0
paddlers/tasks/base.py

@@ -166,6 +166,7 @@ class BaseModel(metaclass=ModelMeta):
              ('fixed_input_shape', self.fixed_input_shape),
              ('best_accuracy', self.best_accuracy),
              ('best_model_epoch', self.best_model_epoch)])
+
         if 'self' in init_params:
             del init_params['self']
         if '__class__' in init_params:

+ 1 - 1
tools/spliter.py

@@ -52,4 +52,4 @@ parser.add_argument("--save_folder", type=str, default="output", \
 
 if __name__ == "__main__":
     args = parser.parse_args()
-    split_data(args.image_path, args.block_size, args.save_folder)
+    split_data(args.image_path, args.block_size, args.save_folder)

+ 1 - 1
tutorials/train/change_detection/stanet.py

@@ -85,4 +85,4 @@ model.train(
     # 是否启用VisualDL日志功能
     use_vdl=True,
     # 指定从某个检查点继续训练
-    resume_checkpoint=None)
+    resume_checkpoint=None)