Skip to content
Snippets Groups Projects
Commit 15ba130f authored by StrangerZhang's avatar StrangerZhang
Browse files

fix mobile_v2 single output and neck hard coding

parent b5da3c98
No related branches found
No related tags found
No related merge requests found
......@@ -128,6 +128,8 @@ class MobileNetV2(nn.Sequential):
outputs.append(x)
p0, p1, p2, p3, p4 = [outputs[i] for i in [1, 2, 3, 5, 7]]
out = [outputs[i] for i in self.used_layers]
if len(out) == 1:
return out[0]
return out
......
......@@ -9,32 +9,37 @@ import torch.nn as nn
class AdjustLayer(nn.Module):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustLayer, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.center_size = center_size
def forward(self, x):
x = self.downsample(x)
if x.size(3) < 20:
l = 4
r = l + 7
l = (x.size(3) - self.center_size) // 2
r = l + self.center_size
x = x[:, :, l:r, l:r]
return x
class AdjustAllLayer(nn.Module):
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels, out_channels, center_size=7):
super(AdjustAllLayer, self).__init__()
self.num = len(out_channels)
if self.num == 1:
self.downsample = AdjustLayer(in_channels[0], out_channels[0])
self.downsample = AdjustLayer(in_channels[0],
out_channels[0],
center_size)
else:
for i in range(self.num):
self.add_module('downsample'+str(i+2),
AdjustLayer(in_channels[i], out_channels[i]))
AdjustLayer(in_channels[i],
out_channels[i],
center_size))
def forward(self, features):
if self.num == 1:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment