|
@@ -129,7 +129,7 @@ class BIT(nn.Layer):
|
|
|
Conv3x3(EBD_DIM, num_classes))
|
|
|
|
|
|
def _get_semantic_tokens(self, x):
|
|
|
- b, c = paddle.shape(x)[:2]
|
|
|
+ b, c = x.shape[:2]
|
|
|
att_map = self.conv_att(x)
|
|
|
att_map = att_map.reshape((b, self.token_len, 1, -1))
|
|
|
att_map = F.softmax(att_map, axis=-1)
|
|
@@ -154,7 +154,7 @@ class BIT(nn.Layer):
|
|
|
return x
|
|
|
|
|
|
def decode(self, x, m):
|
|
|
- b, c, h, w = paddle.shape(x)
|
|
|
+ b, c, h, w = x.shape
|
|
|
x = x.transpose((0, 2, 3, 1)).flatten(1, 2)
|
|
|
x = self.decoder(x, m)
|
|
|
x = x.transpose((0, 2, 1)).reshape((b, c, h, w))
|
|
@@ -172,7 +172,7 @@ class BIT(nn.Layer):
|
|
|
else:
|
|
|
token1 = self._get_reshaped_tokens(x1)
|
|
|
token2 = self._get_reshaped_tokens(x2)
|
|
|
-
|
|
|
+
|
|
|
# Transformer encoder forward
|
|
|
token = paddle.concat([token1, token2], axis=1)
|
|
|
token = self.encode(token)
|
|
@@ -265,7 +265,7 @@ class CrossAttention(nn.Layer):
|
|
|
nn.Linear(inner_dim, dim), nn.Dropout(dropout_rate))
|
|
|
|
|
|
def forward(self, x, ref):
|
|
|
- b, n = paddle.shape(x)[:2]
|
|
|
+ b, n = x.shape[:2]
|
|
|
h = self.n_heads
|
|
|
|
|
|
q = self.fc_q(x)
|