Эх сурвалжийг харах

Replace x.shape with paddle.shape(x)

Bobholamovic 3 жил өмнө
parent
commit
ab59d5cacb

+ 5 - 5
paddlers/models/cd/models/bit.py

@@ -122,7 +122,7 @@ class BIT(nn.Layer):
         )
 
     def _get_semantic_tokens(self, x):
-        b, c = x.shape[:2]
+        b, c = paddle.shape(x)[:2]
         att_map = self.conv_att(x)
         att_map = att_map.reshape((b,self.token_len,1,-1))
         att_map = F.softmax(att_map, axis=-1)
@@ -147,7 +147,7 @@ class BIT(nn.Layer):
         return x
 
     def decode(self, x, m):
-        b, c, h, w = x.shape
+        b, c, h, w = paddle.shape(x)
         x = x.transpose((0,2,3,1)).flatten(1,2)
         x = self.decoder(x, m)
         x = x.transpose((0,2,1)).reshape((b,c,h,w))
@@ -257,7 +257,7 @@ class CrossAttention(nn.Layer):
         )
 
     def forward(self, x, ref):
-        b, n = x.shape[:2]
+        b, n = paddle.shape(x)[:2]
         h = self.n_heads
 
         q = self.fc_q(x)
@@ -265,8 +265,8 @@ class CrossAttention(nn.Layer):
         v = self.fc_v(ref)
 
         q = q.reshape((b,n,h,-1)).transpose((0,2,1,3))
-        k = k.reshape((b,ref.shape[1],h,-1)).transpose((0,2,1,3))
-        v = v.reshape((b,ref.shape[1],h,-1)).transpose((0,2,1,3))
+        k = k.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
+        v = v.reshape((b,paddle.shape(ref)[1],h,-1)).transpose((0,2,1,3))
 
         mult = paddle.matmul(q, k, transpose_y=True) * self.scale
 

+ 7 - 7
paddlers/models/cd/models/stanet.py

@@ -72,7 +72,7 @@ class STANet(nn.Layer):
         f1, f2 = self.attend(f1, f2)
 
         y = paddle.abs(f1- f2)
-        y = F.interpolate(y, size=t1.shape[2:], mode='bilinear', align_corners=True)
+        y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
 
         pred = self.conv_out(y)
         return pred,
@@ -166,9 +166,9 @@ class Decoder(nn.Layer, KaimingInitMixin):
         f3 = self.dr3(feats[2])
         f4 = self.dr4(feats[3])
 
-        f2 = F.interpolate(f2, size=f1.shape[2:], mode='bilinear', align_corners=True)
-        f3 = F.interpolate(f3, size=f1.shape[2:], mode='bilinear', align_corners=True)
-        f4 = F.interpolate(f4, size=f1.shape[2:], mode='bilinear', align_corners=True)
+        f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
+        f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
+        f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
 
         x = paddle.concat([f1, f2, f3, f4], axis=1)
         y = self.conv_out(x)
@@ -195,7 +195,7 @@ class BAM(nn.Layer):
         x = x.flatten(-2)
         x_rs = self.pool(x)
         
-        b, c, h, w = x_rs.shape
+        b, c, h, w = paddle.shape(x_rs)
         query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
         key = self.conv_k(x_rs).reshape((b,-1,h*w))
         energy = paddle.bmm(query, key)
@@ -236,7 +236,7 @@ class PAMBlock(nn.Layer):
         value = self.conv_v(x_rs)
         
         # Split the whole image into subregions.
-        b, c, h, w = x_rs.shape
+        b, c, h, w = paddle.shape(x_rs)
         query = self._split_subregions(query)
         key = self._split_subregions(key)
         value = self._split_subregions(value)
@@ -257,7 +257,7 @@ class PAMBlock(nn.Layer):
         return out
 
     def _split_subregions(self, x):
-        b, c, h, w = x.shape
+        b, c, h, w = paddle.shape(x)
         assert h % self.scale == 0 and w % self.scale == 0
         x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
         x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))