|
@@ -72,7 +72,7 @@ class STANet(nn.Layer):
|
|
|
f1, f2 = self.attend(f1, f2)
|
|
|
|
|
|
y = paddle.abs(f1- f2)
|
|
|
- y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True)
|
|
|
+ y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
|
|
|
|
|
|
pred = self.conv_out(y)
|
|
|
return pred,
|
|
@@ -166,9 +166,9 @@ class Decoder(nn.Layer, KaimingInitMixin):
|
|
|
f3 = self.dr3(feats[2])
|
|
|
f4 = self.dr4(feats[3])
|
|
|
|
|
|
- f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True)
|
|
|
- f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True)
|
|
|
- f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True)
|
|
|
+ f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
|
|
|
x = paddle.concat([f1, f2, f3, f4], axis=1)
|
|
|
y = self.conv_out(x)
|
|
@@ -195,7 +195,7 @@ class BAM(nn.Layer):
|
|
|
x = x.flatten(-2)
|
|
|
x_rs = self.pool(x)
|
|
|
|
|
|
- b, c, h, w = x_rs.shape
|
|
|
+ b, c, h, w = paddle.shape(x_rs)
|
|
|
query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
|
|
|
key = self.conv_k(x_rs).reshape((b,-1,h*w))
|
|
|
energy = paddle.bmm(query, key)
|
|
@@ -236,7 +236,7 @@ class PAMBlock(nn.Layer):
|
|
|
value = self.conv_v(x_rs)
|
|
|
|
|
|
# Split the whole image into subregions.
|
|
|
- b, c, h, w = x_rs.shape
|
|
|
+ b, c, h, w = paddle.shape(x_rs)
|
|
|
query = self._split_subregions(query)
|
|
|
key = self._split_subregions(key)
|
|
|
value = self._split_subregions(value)
|
|
@@ -257,7 +257,7 @@ class PAMBlock(nn.Layer):
|
|
|
return out
|
|
|
|
|
|
def _split_subregions(self, x):
|
|
|
- b, c, h, w = x.shape
|
|
|
+ b, c, h, w = paddle.shape(x)
|
|
|
assert h % self.scale == 0 and w % self.scale == 0
|
|
|
x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
|
|
|
x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
|