-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
添加stn的问题 issue about adding stn #95
Comments
My implementation is to create another class that extends The purpose of this module is to output the In my implementations, I create two classes:
Following lines are my implementations: class LocNet(nn.Module):
"""LocNet architecture for Spatial Transformer Layer"""
def __init__(self):
super().__init__()
# self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2))
self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(24, 94))
self.conv_l = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
)
self.conv_r = nn.Conv2d(
in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
)
self.dropout = nn.Dropout2d()
self.fc_1 = nn.Linear(in_features=64 * 7 * 30, out_features=32)
self.fc_2 = nn.Linear(in_features=32, out_features=6)
self.fc_2.weight.data.zero_()
self.fc_2.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): Tensor of shape (N, C, H, W), where:
- N: the number of batch
- C: channel
- H: height (pixel) of the image
- W: width (pixel) of the image
Return:
torch.Tensor of affine matrices shape (N, 2, 3).
"""
x_l = self.avg_pool(input)
x_l = self.conv_l(x_l)
x_r = self.conv_r(input)
xs = torch.cat([x_l, x_r], dim=1)
xs = self.dropout(xs)
xs = xs.flatten(start_dim=1) # Flatten for fully-connected layer
xs = self.fc_1(xs)
xs = torch.tanh(xs) # activation
xs = self.fc_2(xs)
xs = torch.tanh(xs) # activation
theta = xs.view(-1, 2, 3) # transform the shape to (N, 2, 3)
return theta class SpatialTransformerLayer(nn.Module):
"""Spatial Transformer Layer module
Args:
localization (torch.nn.Module): Module to generate localization.
align_corners (bool):
Whether to align_corners. See https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
and https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
"""
def __init__(self, localization: nn.Module, align_corners: bool = False):
super().__init__()
self.localization = localization
self.align_corners = align_corners
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): Tensor of shape (N, C, H, W), where:
- N: the number of batch
- C: channel
- H: height (pixel) of the image
- W: width (pixel) of the image
Return:
torch.Tensor of grid sample.
"""
theta = self.localization(input)
grid = F.affine_grid(
theta=theta, size=input.shape, align_corners=self.align_corners
)
return F.grid_sample(input, grid=grid, align_corners=self.align_corners) Then, you can do model subclassing from the LPRNet. Probably as follows: class LPRNet(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# Initialization steps
localization = LocNet()
self.stn = STN(localization=localization)
# Another initialization steps...
def forward(self, x):
xs = self.stn(x)
# do the rest with the backbone and global context
return xs Keep in mind that in the original paper, STN is initially turned off and then enabled at 5k epochs. You might want to have a simple conditional logic and property in your model class. If you train your model enough, the result may be similar to the following: (first four rows is the input after augmentation, the rest is transformed by the model) |
@risangbaskoro thank you! ill give it a try! |
@Powerfulidot From the training results, adding STNet should significantly improve the performance of LPRNet for license plate recognition, but attention should be paid to the model training process. More info you can see #96
|
@risangbaskoro @zjykzj |
我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?
i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?
The text was updated successfully, but these errors were encountered: