|
@@ -215,7 +215,7 @@ class BAM(nn.Layer):
|
|
|
|
|
|
out = F.interpolate(out, scale_factor=self.ds)
|
|
|
out = out + x
|
|
|
- return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
|
|
|
+ return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
|
|
|
|
|
|
|
|
|
class PAMBlock(nn.Layer):
|
|
@@ -241,7 +241,8 @@ class PAMBlock(nn.Layer):
|
|
|
value = self.conv_v(x_rs)
|
|
|
|
|
|
# Split the whole image into subregions.
|
|
|
- b, c, h, w = paddle.shape(x_rs)
|
|
|
+ b, c, h, w = x_rs.shape
|
|
|
+
|
|
|
query = self._split_subregions(query)
|
|
|
key = self._split_subregions(key)
|
|
|
value = self._split_subregions(value)
|
|
@@ -263,12 +264,14 @@ class PAMBlock(nn.Layer):
|
|
|
return out
|
|
|
|
|
|
def _split_subregions(self, x):
|
|
|
- b, c, h, w = paddle.shape(x)
|
|
|
+ b, c, h, w = x.shape
|
|
|
assert h % self.scale == 0 and w % self.scale == 0
|
|
|
x = x.reshape(
|
|
|
(b, c, self.scale, h // self.scale, self.scale, w // self.scale))
|
|
|
- x = x.transpose((0, 2, 4, 1, 3, 5)).reshape(
|
|
|
- (b * self.scale * self.scale, c, -1))
|
|
|
+
|
|
|
+ x = x.transpose((0, 2, 4, 1, 3, 5))
|
|
|
+
|
|
|
+ x = x.reshape((b * self.scale * self.scale, c, -1))
|
|
|
return x
|
|
|
|
|
|
def _recons_whole(self, x, b, c, h, w):
|
|
@@ -290,8 +293,10 @@ class PAM(nn.Layer):
|
|
|
def forward(self, x):
|
|
|
x = x.flatten(-2)
|
|
|
res = [stage(x) for stage in self.stages]
|
|
|
+
|
|
|
out = self.conv_out(paddle.concat(res, axis=1))
|
|
|
- return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
|
|
|
+
|
|
|
+ return out.reshape(tuple(out.shape[:-1]) + (out.shape[-1] // 2, 2))
|
|
|
|
|
|
|
|
|
class Attention(nn.Layer):
|