Skip to content

A pythonic extension for pytorch that adds limited support for complex valued algebra in neural networks.

License

Notifications You must be signed in to change notification settings

joerenner/cplxmodule

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CplxModule

A lightweight extension for pytorch.nn that adds layers and activations, which respect algebraic operations over the field of complex numbers.

The implementation is based on the ICLR 2018 parer on Deep Complex Networks [1]_ and borrows ideas from their implementation.

Installation

Just run to install with pip from git

pip install --upgrade git+https://github.com/ivannz/cplxmodule.git

or a developer install (editable) from the root of the local repo

pip install -e .

.

Example

Basically the module is designed in such a way as to be ready for plugging into the existing torch.nn sequential models.

Importing the building blocks.

import torch
import torch.nn

# complex valeud tensor class
from cplxmodule import cplx

# converters
from cplxmodule.nn import RealToCplx, CplxToReal

# layers of encapsulating other complex valued layers
from cplxmodule.nn.sequential import CplxSequential

# common layers
from cplxmodule.nn.layers import CplxConv1d, CplxLinear

# activation layers
from cplxmodule.nn.activation import CplxModReLU, CplxActivation

After RealToCplx layer the intermediate inputs are Cplx objects, which are abstractions for complex valued tensors, represented by real and imaginary parts, and which obey complex arithmetic (currently no support for mixed-type arithmetic like torch.Tensor +/-* Cplx).

n_features, n_channels = 16, 4
z = torch.randn(3, n_features*2)

cplx = RealToCplx()(z)
print(cplx)

Stacking and constructing linear pipelines:

n_features, n_channels = 16, 4
z = torch.randn(256, n_features*2)

# gain network works on the modulus of the complex input
modulus_gain = torch.nn.Sequential(
    torch.nn.Linear(n_features, n_channels * n_features),
    torch.nn.Sigmoid(),
)

# purely complex-to-complex sequential container
complex_model = CplxSequential(
    CplxLinear(n_features, n_features, bias=True),

    # complex: batch x n_channels x n_features
    CplxConv1d(n_channels, 3 * n_channels, kernel_size=4, stride=1, bias=False),

    # complex: batch x (3 * n_channels) x (n_features - (4-1))
    CplxModReLU(threshold=0.15),

    # complex: batch x (3 * n_channels) x (n_features - (4-1))
    CplxActivation(torch.flatten, start_dim=-2),
)

# branching into complex within a real-to-real model
real_input_model = torch.nn.Sequential(
    # real: batch x (n_features * 2)
    torch.nn.Linear(n_features * 2, n_features * 2),

    # real: batch x (n_features * 2)
    RealToCplx(),

    # complex: batch x n_features
    complex_model,

    # complex: batch x (3 * n_channels * (n_features - (4-1)))
    CplxToReal(),

    # real: batch x ((3 * n_channels * (n_features - (4-1))) * 2)
)

print(real_input_model(z).shape)
# >>> torch.Size([256, 312])

References

.. [1] Trabelsi, C., Bilaniuk, O., Zhang, Y., Serdyuk, D., Subramanian, S., Santos, J. F., ... & Pal, C. J. (2017). Deep complex networks. arXiv preprint arXiv:1705.09792

About

A pythonic extension for pytorch that adds limited support for complex valued algebra in neural networks.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%