Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

添加stn的问题 issue about adding stn #95

Open
Powerfulidot opened this issue Sep 13, 2024 · 4 comments
Open

添加stn的问题 issue about adding stn #95

Powerfulidot opened this issue Sep 13, 2024 · 4 comments

Comments

@Powerfulidot
Copy link

我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?

i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?

@risangbaskoro
Copy link

My implementation is to create another class that extends nn.Module. This class will work as STN.

The purpose of this module is to output the grid_sample that the backbone has learned in a few epochs.

In my implementations, I create two classes:

  • LocNet, which is the localization module
  • STN, which is the module that will take any localization (in my case, LocNet) when initialized, and returns a grid sample when calling forward pass, thus doing an affine transformation to the input.

Following lines are my implementations:

class LocNet(nn.Module):
    """LocNet architecture for Spatial Transformer Layer"""

    def __init__(self):
        super().__init__()

        # self.avg_pool = nn.AvgPool2d(kernel_size=(3, 3), stride=(2, 2))
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(24, 94))
        self.conv_l = nn.Conv2d(
            in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
        )
        self.conv_r = nn.Conv2d(
            in_channels=3, out_channels=32, kernel_size=(5, 5), stride=(3, 3)
        )

        self.dropout = nn.Dropout2d()

        self.fc_1 = nn.Linear(in_features=64 * 7 * 30, out_features=32)
        self.fc_2 = nn.Linear(in_features=32, out_features=6)

        self.fc_2.weight.data.zero_()
        self.fc_2.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input (torch.Tensor): Tensor of shape (N, C, H, W), where:
                - N: the number of batch
                - C: channel
                - H: height (pixel) of the image
                - W: width (pixel) of the image
        Return:
            torch.Tensor of affine matrices shape (N, 2, 3).
        """
        x_l = self.avg_pool(input)
        x_l = self.conv_l(x_l)

        x_r = self.conv_r(input)

        xs = torch.cat([x_l, x_r], dim=1)
        xs = self.dropout(xs)

        xs = xs.flatten(start_dim=1)  # Flatten for fully-connected layer
        xs = self.fc_1(xs)
        xs = torch.tanh(xs)  # activation
        xs = self.fc_2(xs)
        xs = torch.tanh(xs)  # activation
        theta = xs.view(-1, 2, 3)  # transform the shape to (N, 2, 3)
        return theta
class SpatialTransformerLayer(nn.Module):
    """Spatial Transformer Layer module

    Args:
        localization (torch.nn.Module): Module to generate localization.
        align_corners (bool):
            Whether to align_corners. See https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
              and https://pytorch.org/docs/stable/generated/torch.nn.functional.affine_grid.html
    """

    def __init__(self, localization: nn.Module, align_corners: bool = False):
        super().__init__()
        self.localization = localization
        self.align_corners = align_corners

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input (torch.Tensor): Tensor of shape (N, C, H, W), where:
                - N: the number of batch
                - C: channel
                - H: height (pixel) of the image
                - W: width (pixel) of the image

        Return:
            torch.Tensor of grid sample.
        """
        theta = self.localization(input)
        grid = F.affine_grid(
            theta=theta, size=input.shape, align_corners=self.align_corners
        )
        return F.grid_sample(input, grid=grid, align_corners=self.align_corners)

Then, you can do model subclassing from the LPRNet. Probably as follows:

class LPRNet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        # Initialization steps
        localization = LocNet()
        self.stn = STN(localization=localization)
        # Another initialization steps...

    def forward(self, x):
        xs = self.stn(x)
        # do the rest with the backbone and global context
        return xs

Keep in mind that in the original paper, STN is initially turned off and then enabled at 5k epochs. You might want to have a simple conditional logic and property in your model class.

If you train your model enough, the result may be similar to the following: (first four rows is the input after augmentation, the rest is transformed by the model)

grid_combined 2

@Powerfulidot
Copy link
Author

@risangbaskoro thank you! ill give it a try!

@zjykzj
Copy link

zjykzj commented Sep 21, 2024

我往lprnet网络结构的顶端添加了一个stn,那个空间变换网络,但是似乎这个新结构训练难度很大,loss不下降。有人知道怎么做吗?

i added a STN network(the spacial transforming network) to the top of LPRNet, but i find training this new structure quite difficult and the loss doesnt go down. anybody knows how to deal with it?

@Powerfulidot From the training results, adding STNet should significantly improve the performance of LPRNet for license plate recognition, but attention should be paid to the model training process. More info you can see #96

Model ARCH Input Shape GFLOPs Model Size (MB) ChineseLicensePlate Accuracy (%) Training Data Testing Data
LPRNet CONV (3, 24, 94) 0.3 1.9 60.105 269,621 149,002
LPRNet+STNet CONV (3, 24, 94) 0.3 2.2 72.261 269,621 149,002

@Powerfulidot
Copy link
Author

@risangbaskoro @zjykzj
well, thank both of you for your help, but after applying your methods the issue remains. the loss just refused to go down or dropped really slowly. on the other hand, when using the original lprnet structure with the same training parameters there aint no problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants