Skip to content

umzi2/MoSR

Repository files navigation

Mamba Out Super Resolution

This architecture was inspired by MambaOut

def detect(state):
    # Get values from state
    n_block = get_seq_len(state, "gblocks") - 6
    in_ch = state["gblocks.0.weight"].shape[1]
    dim = state["gblocks.0.weight"].shape[0]

    # Calculate expansion ratio and convolution ratio
    expansion_ratio = (state["gblocks.1.fc1.weight"].shape[0] / state["gblocks.1.fc1.weight"].shape[1]) / 2
    conv_ratio = state["gblocks.1.conv.weight"].shape[0] / dim
    kernel_size = state["gblocks.1.conv.weight"].shape[2]
    # Determine upsampler type and calculate upscale
    if "upsampler.init_pos" in state:
        upsampler = "dys"
        out_ch = state["upsampler.end_conv.weight"].shape[0]
        upscale = math.isqrt(state["upsampler.offset.weight"].shape[0] // 8)
    elif "upsampler.in_to_k.weight" in state:
        upsampler = 'gps'
        out_ch = in_ch
        upscale = math.isqrt(state['upsampler.in_to_k.weight'].shape[0] // 8 // out_ch)
    else:
        upsampler = "ps"
        out_ch = in_ch
        upscale = math.isqrt(state["upsampler.0.weight"].shape[0] // out_ch)

    # Print results
    print(f"""    in_ch: {in_ch}
    out_ch: {out_ch}
    dim: {dim}
    n_block: {n_block}
    upsampler: {upsampler}
    upscale: {upscale}
    kernel_size: {kernel_size}
    expansion_ratio: {expansion_ratio}
    conv_ratio: {conv_ratio}""")


signature = [
    'gblocks.0.weight',
    'gblocks.0.bias',
    'gblocks.1.norm.weight',
    'gblocks.1.norm.bias',
    'gblocks.1.fc1.weight',
    'gblocks.1.fc1.bias',
    'gblocks.1.conv.weight',
    'gblocks.1.conv.bias',
    'gblocks.1.fc2.weight',
    'gblocks.1.fc2.bias',
]

References:

Training code from NeoSR

MambaOut

TODO:

  • release pretrain

About

No description or website provided.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages