Skip to content

Commit

Permalink
Added Multi Scale Fourier Feature Network as per Issue SciML#498
Browse files Browse the repository at this point in the history
Considering the amount of code rewritten, it makes more sense to define it as a struct or purely define one ff layer that encapsulates a fourier feature encoding and dense layers and then have a different file where users can create msff networks using FastChains or a new file for Networks that are defined using general FastChains. 

Addtionally, bias is not included because of my lack of understanding of including it in a solitary function. 
It has not been rigirously tested, although it seems to compile.  I did get an inexact error that I couldn't debug when trying to run some operations, I would need some help to tackle the issue now that I know the general way that the network can be implemented.
  • Loading branch information
Parvfect authored Mar 29, 2022
1 parent 29b0dd5 commit 01fbd3f
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/fast_layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,40 @@ end
paramlength(f::StaticDense{out,in,bias}) where {out,in,bias} = out*(in + bias)
initial_params(f::StaticDense) = f.initial_params()


"""
Multi Scale Fourier Feature Network implemented as per the paper - https://arxiv.org/abs/2012.10047
At a high level, it performs the following:
1. Seperates the network into Spatial and Time Domains
2. Encodes Fourier Features for each Domain based on the scale parameters
3. Propogates through the Hidden Layers
4. Concatenates the Fourier Features within each domain and element-wise multiplies them
5. Propogates through the Final Layer to get the output
``julia
MSFFSpacetimeNetwork(in, out, nHidden, nspace, ntime, nFourierSpace, XscaleParameters, nFourierTime, TscaleParameters, σ= identity, widthHiddenInt(((2/3)*in + out)))
```
The activation function defaults to identity and the width of the hidden layer is set to 2/3 of the input size + output size as per general practice.
nspace and ntime define the number of spatial and time input dimensions respectively.
nFourierSpace and nFourierTime define the number of Fourier Features to be used for each domain.
XscaleParameters and TscaleParameters define the scale parameters for each Fourier Feature respective to each domain.
"""

function MSFFSpacetimeNetwork(in::Int,out::Int, nHidden::Int, nspace::Int, ntime::Int, nFourierSpace::Int, XscaleParameters::Vector{Float64}, nFourierTime::Int, TscaleParameters::Vector{Float64}, σ = identity, widthHidden=Int(((2/3)*in + out)))
Wx, Wt = [randn(Float64, [1, in//2])*i for i in XscaleParameters], [randn(Float32, [1, in//2])*i for i in TscaleParameters]
return FastChain(MSFF((nspace,out), hidden, Wx, activationFunction, widthHidden) * MSFF((ntime,out), hidden, Wt, activationFunction, widthHidden) , FastDense(widthHidden, out))
end

function MSFFNetwork(in::Int, out::Int, nHidden::Int, W::Vector{Float32}, activationFunction=tanh, widthHidden=Int((((2/3)*in + out))))
return [FourierFeatureNetwork((in,out), hidden, activationFunction, widthHidden, i) for i in W]
end

function FourierFeatureNetwork(in::Int ,out::Int, nHidden::Int, widthHidden::Int, σ = identity, W=randn(Float32, [1, in//2]))
return FastChain(x-> [sin(x*W) cos(x*W)], Maxout(() -> FastDense(in, widthHidden , activationFunction), nHidden))
end


# Override FastDense to exclude the branch from the check
function Cassette.overdub(ctx::DiffEqSensitivity.HasBranchingCtx, f::FastDense, x, p)
y = reshape(p[1:(f.out*f.in)],f.out,f.in)*x
Expand Down

0 comments on commit 01fbd3f

Please sign in to comment.