Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using users provided weight matrix to build LSTM layers #2130

Closed
Chengfeng-Jia opened this issue Dec 2, 2022 · 2 comments
Closed

Using users provided weight matrix to build LSTM layers #2130

Chengfeng-Jia opened this issue Dec 2, 2022 · 2 comments

Comments

@Chengfeng-Jia
Copy link

Chengfeng-Jia commented Dec 2, 2022

Describe the potential feature

In this demo for Dense layer, user can use provided weight matrix to build Dense layer. d1 = Dense(ones(2, 5), false, tanh). When I want to build a RNN/LSTM layer, it failed Lstm = Flux.LSTMCell(W2_1,W2_2,b2,s2_1) Lstm(data[1]).

The above W2_1,W2_2,b2,s2_1 are pre-defined random weight matrix. The error is MethodError: objects of type Flux.LSTMCell{Matrix{Float64}, Vector{Float64}, Vector{Float64}} are not callable

Motivation

No response

Possible Implementation

No response

@ToucheSir
Copy link
Member

The state you initialize the LSTMCell with needs to be a 2-tuple of AbstractMatrix (e.g. (ones(1, 1), zeros(1, 1)) and not a single array as you described. Otherwise

function (m::LSTMCell{I,H,V,<:NTuple{2,AbstractMatrix{T}}})((h, c), x::Union{AbstractVecOrMat{T},OneHotArray}) where {I,H,V,T}
won't work. You can see how the main constructor does this initialization in
init_state = zeros32)
cell = LSTMCell(init(out * 4, in), init(out * 4, out), initb(out * 4), (init_state(out,1), init_state(out,1)))
.

@Chengfeng-Jia
Copy link
Author

Thank you so much! It works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants