|
@@ -44,24 +44,19 @@ class STANet(nn.Layer):
|
|
|
Raises:
|
|
|
ValueError: When `att_type` has an illeagal value (unsupported attention type).
|
|
|
"""
|
|
|
-
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- in_channels,
|
|
|
- num_classes,
|
|
|
- att_type='BAM',
|
|
|
- ds_factor=1
|
|
|
- ):
|
|
|
+
|
|
|
+ def __init__(self, in_channels, num_classes, att_type='BAM', ds_factor=1):
|
|
|
super().__init__()
|
|
|
|
|
|
WIDTH = 64
|
|
|
|
|
|
self.extract = build_feat_extractor(in_ch=in_channels, width=WIDTH)
|
|
|
- self.attend = build_sta_module(in_ch=WIDTH, att_type=att_type, ds=ds_factor)
|
|
|
+ self.attend = build_sta_module(
|
|
|
+ in_ch=WIDTH, att_type=att_type, ds=ds_factor)
|
|
|
self.conv_out = nn.Sequential(
|
|
|
- Conv3x3(WIDTH, WIDTH, norm=True, act=True),
|
|
|
- Conv3x3(WIDTH, num_classes)
|
|
|
- )
|
|
|
+ Conv3x3(
|
|
|
+ WIDTH, WIDTH, norm=True, act=True),
|
|
|
+ Conv3x3(WIDTH, num_classes))
|
|
|
|
|
|
self.init_weight()
|
|
|
|
|
@@ -71,8 +66,9 @@ class STANet(nn.Layer):
|
|
|
|
|
|
f1, f2 = self.attend(f1, f2)
|
|
|
|
|
|
- y = paddle.abs(f1- f2)
|
|
|
- y = F.interpolate(y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ y = paddle.abs(f1 - f2)
|
|
|
+ y = F.interpolate(
|
|
|
+ y, size=paddle.shape(t1)[2:], mode='bilinear', align_corners=True)
|
|
|
|
|
|
pred = self.conv_out(y)
|
|
|
return [pred]
|
|
@@ -84,10 +80,7 @@ class STANet(nn.Layer):
|
|
|
|
|
|
|
|
|
def build_feat_extractor(in_ch, width):
|
|
|
- return nn.Sequential(
|
|
|
- Backbone(in_ch, 'resnet18'),
|
|
|
- Decoder(width)
|
|
|
- )
|
|
|
+ return nn.Sequential(Backbone(in_ch, 'resnet18'), Decoder(width))
|
|
|
|
|
|
|
|
|
def build_sta_module(in_ch, att_type, ds):
|
|
@@ -100,15 +93,24 @@ def build_sta_module(in_ch, att_type, ds):
|
|
|
|
|
|
|
|
|
class Backbone(nn.Layer, KaimingInitMixin):
|
|
|
- def __init__(self, in_ch, arch, pretrained=True, strides=(2,1,2,2,2)):
|
|
|
+ def __init__(self, in_ch, arch, pretrained=True, strides=(2, 1, 2, 2, 2)):
|
|
|
super().__init__()
|
|
|
|
|
|
if arch == 'resnet18':
|
|
|
- self.resnet = resnet.resnet18(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
|
|
|
+ self.resnet = resnet.resnet18(
|
|
|
+ pretrained=pretrained,
|
|
|
+ strides=strides,
|
|
|
+ norm_layer=get_norm_layer())
|
|
|
elif arch == 'resnet34':
|
|
|
- self.resnet = resnet.resnet34(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
|
|
|
+ self.resnet = resnet.resnet34(
|
|
|
+ pretrained=pretrained,
|
|
|
+ strides=strides,
|
|
|
+ norm_layer=get_norm_layer())
|
|
|
elif arch == 'resnet50':
|
|
|
- self.resnet = resnet.resnet50(pretrained=pretrained, strides=strides, norm_layer=get_norm_layer())
|
|
|
+ self.resnet = resnet.resnet50(
|
|
|
+ pretrained=pretrained,
|
|
|
+ strides=strides,
|
|
|
+ norm_layer=get_norm_layer())
|
|
|
else:
|
|
|
raise ValueError
|
|
|
|
|
@@ -116,13 +118,12 @@ class Backbone(nn.Layer, KaimingInitMixin):
|
|
|
|
|
|
if in_ch != 3:
|
|
|
self.resnet.conv1 = nn.Conv2D(
|
|
|
- in_ch,
|
|
|
+ in_ch,
|
|
|
64,
|
|
|
kernel_size=7,
|
|
|
stride=strides[0],
|
|
|
padding=3,
|
|
|
- bias_attr=False
|
|
|
- )
|
|
|
+ bias_attr=False)
|
|
|
|
|
|
if not pretrained:
|
|
|
self.init_weight()
|
|
@@ -153,10 +154,11 @@ class Decoder(nn.Layer, KaimingInitMixin):
|
|
|
self.dr3 = Conv1x1(256, 96, norm=True, act=True)
|
|
|
self.dr4 = Conv1x1(512, 96, norm=True, act=True)
|
|
|
self.conv_out = nn.Sequential(
|
|
|
- Conv3x3(384, 256, norm=True, act=True),
|
|
|
+ Conv3x3(
|
|
|
+ 384, 256, norm=True, act=True),
|
|
|
nn.Dropout(0.5),
|
|
|
- Conv1x1(256, f_ch, norm=True, act=True)
|
|
|
- )
|
|
|
+ Conv1x1(
|
|
|
+ 256, f_ch, norm=True, act=True))
|
|
|
|
|
|
self.init_weight()
|
|
|
|
|
@@ -166,9 +168,12 @@ class Decoder(nn.Layer, KaimingInitMixin):
|
|
|
f3 = self.dr3(feats[2])
|
|
|
f4 = self.dr4(feats[3])
|
|
|
|
|
|
- f2 = F.interpolate(f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
- f3 = F.interpolate(f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
- f4 = F.interpolate(f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ f2 = F.interpolate(
|
|
|
+ f2, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ f3 = F.interpolate(
|
|
|
+ f3, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
+ f4 = F.interpolate(
|
|
|
+ f4, size=paddle.shape(f1)[2:], mode='bilinear', align_corners=True)
|
|
|
|
|
|
x = paddle.concat([f1, f2, f3, f4], axis=1)
|
|
|
y = self.conv_out(x)
|
|
@@ -194,23 +199,23 @@ class BAM(nn.Layer):
|
|
|
def forward(self, x):
|
|
|
x = x.flatten(-2)
|
|
|
x_rs = self.pool(x)
|
|
|
-
|
|
|
+
|
|
|
b, c, h, w = paddle.shape(x_rs)
|
|
|
- query = self.conv_q(x_rs).reshape((b,-1,h*w)).transpose((0,2,1))
|
|
|
- key = self.conv_k(x_rs).reshape((b,-1,h*w))
|
|
|
+ query = self.conv_q(x_rs).reshape((b, -1, h * w)).transpose((0, 2, 1))
|
|
|
+ key = self.conv_k(x_rs).reshape((b, -1, h * w))
|
|
|
energy = paddle.bmm(query, key)
|
|
|
energy = (self.key_ch**(-0.5)) * energy
|
|
|
|
|
|
attention = self.softmax(energy)
|
|
|
|
|
|
- value = self.conv_v(x_rs).reshape((b,-1,w*h))
|
|
|
+ value = self.conv_v(x_rs).reshape((b, -1, w * h))
|
|
|
|
|
|
- out = paddle.bmm(value, attention.transpose((0,2,1)))
|
|
|
- out = out.reshape((b,c,h,w))
|
|
|
+ out = paddle.bmm(value, attention.transpose((0, 2, 1)))
|
|
|
+ out = out.reshape((b, c, h, w))
|
|
|
|
|
|
out = F.interpolate(out, scale_factor=self.ds)
|
|
|
out = out + x
|
|
|
- return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
|
|
|
+ return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
|
|
|
|
|
|
|
|
|
class PAMBlock(nn.Layer):
|
|
@@ -234,13 +239,13 @@ class PAMBlock(nn.Layer):
|
|
|
query = self.conv_q(x_rs)
|
|
|
key = self.conv_k(x_rs)
|
|
|
value = self.conv_v(x_rs)
|
|
|
-
|
|
|
+
|
|
|
# Split the whole image into subregions.
|
|
|
b, c, h, w = paddle.shape(x_rs)
|
|
|
query = self._split_subregions(query)
|
|
|
key = self._split_subregions(key)
|
|
|
value = self._split_subregions(value)
|
|
|
-
|
|
|
+
|
|
|
# Perform subregion-wise attention.
|
|
|
out = self._attend(query, key, value)
|
|
|
|
|
@@ -250,40 +255,43 @@ class PAMBlock(nn.Layer):
|
|
|
return out
|
|
|
|
|
|
def _attend(self, query, key, value):
|
|
|
- energy = paddle.bmm(query.transpose((0,2,1)), key) # batch matrix multiplication
|
|
|
+ energy = paddle.bmm(query.transpose((0, 2, 1)),
|
|
|
+ key) # batch matrix multiplication
|
|
|
energy = (self.key_ch**(-0.5)) * energy
|
|
|
attention = F.softmax(energy, axis=-1)
|
|
|
- out = paddle.bmm(value, attention.transpose((0,2,1)))
|
|
|
+ out = paddle.bmm(value, attention.transpose((0, 2, 1)))
|
|
|
return out
|
|
|
|
|
|
def _split_subregions(self, x):
|
|
|
b, c, h, w = paddle.shape(x)
|
|
|
assert h % self.scale == 0 and w % self.scale == 0
|
|
|
- x = x.reshape((b, c, self.scale, h//self.scale, self.scale, w//self.scale))
|
|
|
- x = x.transpose((0,2,4,1,3,5)).reshape((b*self.scale*self.scale, c, -1))
|
|
|
+ x = x.reshape(
|
|
|
+ (b, c, self.scale, h // self.scale, self.scale, w // self.scale))
|
|
|
+ x = x.transpose((0, 2, 4, 1, 3, 5)).reshape(
|
|
|
+ (b * self.scale * self.scale, c, -1))
|
|
|
return x
|
|
|
|
|
|
def _recons_whole(self, x, b, c, h, w):
|
|
|
- x = x.reshape((b, self.scale, self.scale, c, h//self.scale, w//self.scale))
|
|
|
- x = x.transpose((0,3,1,4,2,5)).reshape((b, c, h, w))
|
|
|
+ x = x.reshape(
|
|
|
+ (b, self.scale, self.scale, c, h // self.scale, w // self.scale))
|
|
|
+ x = x.transpose((0, 3, 1, 4, 2, 5)).reshape((b, c, h, w))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class PAM(nn.Layer):
|
|
|
- def __init__(self, in_ch, ds, scales=(1,2,4,8)):
|
|
|
+ def __init__(self, in_ch, ds, scales=(1, 2, 4, 8)):
|
|
|
super().__init__()
|
|
|
|
|
|
- self.stages = nn.LayerList([
|
|
|
- PAMBlock(in_ch, scale=s, ds=ds)
|
|
|
- for s in scales
|
|
|
- ])
|
|
|
- self.conv_out = Conv1x1(in_ch*len(scales), in_ch, bias=False)
|
|
|
+ self.stages = nn.LayerList(
|
|
|
+ [PAMBlock(
|
|
|
+ in_ch, scale=s, ds=ds) for s in scales])
|
|
|
+ self.conv_out = Conv1x1(in_ch * len(scales), in_ch, bias=False)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x.flatten(-2)
|
|
|
res = [stage(x) for stage in self.stages]
|
|
|
out = self.conv_out(paddle.concat(res, axis=1))
|
|
|
- return out.reshape(out.shape[:-1]+[out.shape[-1]//2, 2])
|
|
|
+ return out.reshape(out.shape[:-1] + [out.shape[-1] // 2, 2])
|
|
|
|
|
|
|
|
|
class Attention(nn.Layer):
|
|
@@ -294,4 +302,4 @@ class Attention(nn.Layer):
|
|
|
def forward(self, x1, x2):
|
|
|
x = paddle.stack([x1, x2], axis=-1)
|
|
|
y = self.att(x)
|
|
|
- return y[...,0], y[...,1]
|
|
|
+ return y[..., 0], y[..., 1]
|