-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
explicit differentiation for RNN gives wrong results #2185
Comments
I'm surprised this works at all with the input format given. What does the PyTorch code look like and have you verified it's doing the same thing? |
what should be the format? |
pytorch is quite different, it got a shape of (batch_size, seq_len, features) |
Flux supports something very similar. This is why it's important to see the PyTorch code as well, I have a feeling this is not an apples-to-apples comparison. |
pytorch implementation: import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
# ------------------------------ DATA -----------------------------------
train_data = MNIST(train=True, root='data', transform=ToTensor())
test_data = MNIST(train=False, root='data', transform=ToTensor())
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=True)
# ---------------------------- MODEL --------------------------------------
class RNN(nn.Module):
def __init__(self, input_dim, output_dim):
super(RNN, self).__init__()
self.rnn = nn.RNN(input_dim, 128, batch_first=True)
self.fc = nn.Linear(128, output_dim)
def forward(self, x, h):
x, h = self.rnn(x, h)
x = F.relu(x)
x = self.fc(x)
# get last layer from rnn
return x[:, -1, :], h
def init_hidden(self, batch_size):
return torch.zeros([1, batch_size, 128])
# ----------------------- HELPER -----------------------------------
# seq_len = 28, input_dim=28, num_classes=10
model = RNN(input_dim=28, output_dim=10)
loss_fn = nn.CrossEntropyLoss() # includes softmax layer too so we don't need it in the model
def accuracy(X, y):
total_samples = X.shape[0]
h = model.init_hidden(batch_size=total_samples)
with torch.no_grad():
pred_values, _ = model(X, h)
return torch.sum(pred_values.max(1)[1] == y) / total_samples
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# --------------------------------- TRAIN LOOP ------------------------
for epoch in range(1, 11):
for data in train_loader:
features = data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
h = model.init_hidden(batch_size=features.shape[0]) # hidden state
labels = data[1]
predicted_values, _ = model(features, h)
loss = loss_fn(predicted_values, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# get test data
test_data = next(iter(test_loader))
test_features = test_data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
test_labels = test_data[1]
print(f"epoch : {epoch}\\10\taccuracy : {accuracy(test_features, test_labels)}") epoch : 1\10 accuracy : 0.8370000123977661 of course i used smaller batch__size in Flux like 64, 32, but still the same result |
Thanks, will try to take a look over the next couple of days. One quick observation though:
This is also true for Flux's |
yes i know, i've used logitcrossentropy without softmax, also softmax with crossentropy, but still same results |
I haven't run the code. But is there a possibility that model input is mistaken? 50% accuracy really reminds me of my one data-processing experience. |
how i should prepare my data then? |
I took a shot rewriting the model as I would have go implementing it. It results in a 91% right after first epch, batchsize=64. using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition
using Statistics: mean
using Random: shuffle
#---------------------------------- DATA -------------------------------------
DATA_TRAIN = MNIST.traindata(Float32)
DATA_TEST = MNIST.testdata(Float32)
#-------------------------------- PREPROCESS DATA ------------------------------
x_train = [x for x in eachslice(DATA_TRAIN[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 60000
x_test = [x for x in eachslice(DATA_TEST[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 10000
# create onehotbatch for train label
y_train = onehotbatch(DATA_TRAIN[2], 0:9)
y_test = DATA_TEST[2]
#------------------------------ CONSTANTS ---------------------------------------
INPUT_DIM = size(x_train[1], 1)
OUTPUT_DIM = 10 # number of classes
LR = 0.001f0 # learning rate
EPOCHS = 10
BATCH_SIZE = 64
TOTAL_SAMPLES = size(x_train[1], 2)
#--------------------------------- BUILD MODEL -----------------------------------
model = Chain(
RNN(INPUT_DIM => 128, relu),
Dense(128, OUTPUT_DIM)
)
#----------------------------- HELPER FUNCTIONS --------------------------------------
function loss_fn_2(m, x, y)
out = [m(xi) for xi in x] # generate output for each of the 28 timesteps
Flux.Losses.logitcrossentropy(out[end], y) # compute loss based on predictions of the latest timestep
end
function accuracy_eval(m, x, y)
Flux.reset!(m)
out = [m(xi) for xi in x]
mean(onecold(out[end], 0:9) .== y)
end
θ = params(model) # model parameters to be updated during training
opt = Flux.ADAM(LR) # optimizer function
#---------------------------- RUN TRAINING ----------------------------------------------
for epoch ∈ 1:EPOCHS
for idx ∈ partition(1:TOTAL_SAMPLES, BATCH_SIZE)
features = [x[:, idx] for x ∈ x_train]
labels = y_train[:, idx]
Flux.reset!(model)
gs = gradient(θ) do
loss = loss_fn_2(model, features, labels)
end
# update model
Flux.Optimise.update!(opt, θ, gs)
end
# evaluate model
@info epoch
@show accuracy_eval(model, x_test, y_test)
end I think the data preprocessing was done fine (I just dropped the TensorCast dependency as I got an issue and felt simpler not using it). I'm really unclear what went wrong with your implementation. It's really just a speculation, but perhaps the gradients didn't get propagated through the following part:
as there's no explicit passing of the of the inital computation to the second. Again, just a wild guess here. |
your code works but i really don't know why my code isn't working if the data preprocessing is the same using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition, product
using TensorCast
using Statistics: mean
using Random: shuffle
using StatsBase
using ChainRulesCore, Zygote
ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})
Zygote.refresh()
# ---------------------------------- DATA -------------------------------------
TRAIN_DATA, TRAIN_LABELS = MNIST.traindata(Float32)
TEST_DATA, TEST_LABELS = MNIST.testdata(Float32)
TRAIN_LABELS = onehotbatch(TRAIN_LABELS, 0:9)
# convert 3d arrays to vector of 2d arrays
@cast TRAIN_FEATURES[i][j, k] := TRAIN_DATA[i, j, k]
@cast TEST_FEATURES[i][j, k] := TEST_DATA[i, j, k]
INPUT_DIM = size(TRAIN_FEATURES[1], 1)
DATA = [([x[:, idx] for x in TRAIN_FEATURES], TRAIN_LABELS[:, idx]) for idx ∈ partition(shuffle(1:size(TRAIN_LABELS, 2)), 1000)]
# ----------------------------------- MODEL --------------------------------------------
model = Chain(
RNN(INPUT_DIM, 128, relu),
Dense(128, 10)
)
# --------------------------------- HELPER -----------------------------------------------
function loss_fn(X, Y)
Flux.reset!(model)
out = [model(x) for x ∈ X]
Flux.Losses.logitcrossentropy(out[end], Y)
end
function accuracy(X, Y)
Flux.reset!(model) # Only important for recurrent network
out = [model(x) for x ∈ X]
mean(onecold(out[end], 0:9) .== Y)
end
θ = params(model)
opt = Flux.ADAM()
evalcb() = @show(accuracy(TEST_FEATURES, TEST_LABELS))
# ----------------------------------- TRAIN -------------------------
Flux.@epochs 30 Flux.train!(loss_fn, θ, DATA, opt, cb = Flux.throttle(evalcb, 5)) still doesn't work |
Not on a computer right now, but I think you should remove the |
i found out if i delete model in loss and accuracy function, i get bad results else it's working as expected: can you explain why this happens because it's too weird |
Can you show the before and after code for that change? It's not immediately clear what the difference would be. |
@alerem18 if you manage to clarify what's the difference causing a bad result we can decide if we have an actual bug or not |
loss_fn(X, Y), accuracy(X, Y) ===> bad results passing model thorough loss and accuracy functions will work as expected, if you don't pass it to those functions, you'll get bad results, model doesn't improve after a while, accuracy will be around 50-60% |
What we're asking for is full code examples that show the good and bad results. Without that, |
using Flux
using Flux: gradient, logitcrossentropy, params, Momentum
using OneHotArrays: onecold, onehotbatch
using MLDatasets: MNIST
using Random: shuffle
using Statistics: mean
using Base.Iterators: partition
# ------------------- data --------------------------
train_x, train_y = MNIST(split=:train).features, MNIST(split=:train).targets
test_x, test_y = MNIST(split=:test).features, MNIST(split=:test).targets
train_y = onehotbatch(train_y, 0:9)
train_x = [x for x ∈ eachslice(train_x, dims=2)]
test_x = [x for x ∈ eachslice(test_x, dims=2)]
# ------------------ constants ---------------------
INPUT_SIZE = 28
NUM_CLASSES = 10
BATCH_SIZE = 1000
EPOCHS = 5
# ------------------ model --------------------------
model = Chain(
RNN(INPUT_SIZE, 128, relu),
RNN(128, 64, relu),
Dense(64, NUM_CLASSES)
)
# ---------------- helper --------------------------
loss_fn(m, X, y) = logitcrossentropy([m(x) for x ∈ X][end], y)
accuracy(m, X, y) = mean(onecold([m(x) for x ∈ X][end], 0:9) .== y)
opt = Momentum()
θ = params(model)
# --------------- train -----------------------------
for epoch ∈ 1:EPOCHS
for idx ∈ partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
Flux.reset!(model)
X = [x[:, idx] for x ∈ train_x]
y = train_y[:, idx]
gs = gradient(θ) do
loss_fn(model, X, y)
end
Flux.Optimise.update!(opt, θ, gs)
end
Flux.reset!(model)
test_acc = accuracy(model, test_x, test_y)
@info "Epoch : $epoch | accuracy : $test_acc"
end [ Info: Epoch : 1 | accuracy : 0.3968 edit loss and accuracy functions like below and you get this results loss_fn(X, y) = logitcrossentropy([model(x) for x ∈ X][end], y)
accuracy(X, y) = mean(onecold([model(x) for x ∈ X][end], 0:9) .== y) [ Info: Epoch : 1 | accuracy : 0.2795 |
Thanks, I can reproduce. Cause isn't obvious to me but the behavior seems to point that the In all cases, it appears safer to use the explicit reference to model for the loss and accuracy functions. It also looks like a an non obvious behavior that can lead to unexpected bad behavior, hence would be worth documenting if we could confirm the root cause. |
A quick sanity check would be moving |
Unfortunately, no luck with adding training params instantiation within the training loop. The following results in the same accuracy plateau around 60%: ps = params(model)
Flux.reset!(model)
gs = gradient(ps) do
loss_fn2(X, y)
end |
In the modified code, the loss_fn and accuracy functions do not take the params of the model as input, and they call the model directly within the function to compute the loss and accuracy. The params function is used to extract the trainable parameters of a model, which is necessary for computing gradients and updating the model parameters during training. When params is used, the optimizer is able to track the gradients of the model parameters and update them accordingly during optimization. By not using params, the optimizer is not able to track the gradients of the model parameters correctly and this can lead to incorrect optimization and lower accuracy. Therefore, not using params in the modified code is a mistake and can result in lower accuracy. any thoughts? |
Part of the "magic" of passing a The problem here is that something is causing the aforementioned tracking to not work. Ordinarily both versions of the code should behave similarly, so this is a bug. It's also why we've moving away from magical implicit params to directly passing the model/trainable params to |
For reference, this is how you could use the new explicit gradient / Optimsers.jl mode: loss_fn1(m, X, y) = logitcrossentropy([m(x) for x ∈ X][end], y)
accuracy1(m, X, y) = mean(onecold([m(x) for x ∈ X][end], 0:9) .== y)
rule = Flux.Optimisers.Adam()
opts = Flux.Optimisers.setup(rule, model);
for epoch ∈ 1:5
for idx ∈ partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
X = [x[:, idx] for x ∈ train_x]
y = train_y[:, idx]
Flux.reset!(model)
gs = gradient(model) do m
loss_fn1(m, X, y)
end
Flux.Optimisers.update!(opts, model, gs[1]);
end
Flux.reset!(model)
test_acc = accuracy1(model, test_x, test_y)
@info "Epoch : $epoch | accuracy : $test_acc"
end |
the RNN gradient with Zygote might have a bug. Here's my short test code. Keeping outputs in an array and in a scalar give me different gradients. How come? using Flux
using Random
Random.seed!(149)
layer1 = Flux.Recur(Flux.RNNCell(1 => 1, identity))
x = Float32[0.8, 0.9]
y = Float32(-0.7)
Flux.reset!(layer1)
e1, g1 = Flux.withgradient(layer1) do m
yhat = 0.0
for i in 1:2
yhat = m([x[i]])
end
loss = Flux.mse(yhat, y)
println(loss)
return loss
end
println("flux gradients: ", g1[1])
Flux.reset!(layer1)
e2, g2 = Flux.withgradient(layer1) do m
yhat = [m([x[i]]) for i in 1:2]
loss = Flux.mse(yhat[end], y)
println(loss)
return loss
end
println("flux gradients: ", g2[1]) |
There's effectively something fishy going on with the RNN gradients. using Flux
layer2 = Flux.Recur(Flux.RNNCell(1, 1, identity))
layer2.cell.Wi .= 5.0
layer2.cell.Wh .= 4.0
layer2.cell.b .= 0f0
layer2.cell.state0 .= 7.0
x = [[2f0], [3f0]]
Flux.reset!(layer2)
ps = Flux.params(layer2)
e2, g2 = Flux.withgradient(ps) do
out = [layer2(xi) for xi in x]
sum(out[2])
end
julia> g2[ps[1]]
1×1 Matrix{Float32}:
3.0
julia> g2[ps[2]]
1×1 Matrix{Float32}:
38.0
julia> g2[ps[3]]
1-element Fill{Float32}, with entry equal to 1.0
julia> g2[ps[4]] # nothing
Theoretical gradients are: julia> ∇Wi = x[1] .* layer2.cell.Wh .+ x[2]
1×1 Matrix{Float32}:
11.0
julia> ∇Wh = 2 .* layer2.cell.Wh .* layer2.cell.state0 .+ x[1] .* layer2.cell.Wi
1×1 Matrix{Float32}:
66.0
julia> ∇b = layer2.cell.Wh .+ 1
1×1 Matrix{Float32}:
5.0
julia> ∇state0 = layer2.cell.Wh .^ 2
1×1 Matrix{Float32}:
16.0 Worst, the gradients are different (yet still wrong) if using the explicit mode :\ I tested on older version of Flux and things got even more weird. I got the same bad gradients going back to v0.11.4. However, when trying out of Julia 1.6.5... correct gradients with all tested Flux versions, v0.11.4 up to v0.13.4 and latest Zygote v0.6.58 (both implicit and explicit modes)! The same bad gradients were observed on Julia 1.7.2 and 1.9.0-rc1. So, it seems like something changed btween Julia v1.6 and v1.7 that had an impact on gradient correctness. Any idea @ToucheSir? |
If I had to guess, something about lowering changed between those two versions. The more concerning part is that our test suite didn't catch this. I've always had a sinking feeling that https://github.com/FluxML/Flux.jl/blob/master/test/layers/recurrent.jl did not provide sufficient coverage, and unfortunately this only confirms that... |
I'll open a PR by tomorrow to add the above gradients tests. I'm also disappointed not to have taken the time to manually validate those RNN gradients until now. Zygote is quite a footgun :\ |
In this implementation, the |
If you consider the initial state non-trainable, then I think it's mostly equivalent since other libraries are passing all zeros as the initial state. If you have a custom initial state however or want it to be trainable (which PyTorch at least does not appear to support directly), then it is not the same as you say. I'm unsure why the original design is the way it is (cc @mkschleg for possible theories), but reworking the initial state is one of those things we're investigating for our overhaul of the RNN API. |
Regarding the initialization of initial state, although it may not be the common form encountered in PyTorch, this paper with LSTM author as co-author points to the relevance of learning the initial state (see section 5.1 at page 135). Also this blog post discussing it: https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html. I also had the vague souvenir that MXNet used to have learnable initial state as a feature, but couldn't confirm. By applying |
The reduced example in #2185 (comment) |
i tried to implement a RNN MODEL to classify Mnist Dataset but i get an accuracy around 40-50% even with running it for more than 20 epochs, while in pytorch, i'll get an accuracy upto 90% after just 4-5 epochs
here is my code:
what i'm doing wrong?
The text was updated successfully, but these errors were encountered: