Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Weight decay causes loaded model to not match saved one #1201

Closed
danielhers opened this issue Jan 24, 2018 · 9 comments
Closed

Weight decay causes loaded model to not match saved one #1201

danielhers opened this issue Jan 24, 2018 · 9 comments
Assignees
Labels
major bug Issues that silently cause incorrect results, break installation on common environments, etc.

Comments

@danielhers
Copy link
Collaborator

danielhers commented Jan 24, 2018

This has caused me a lot of frustration until I finally figured out why my saved models' results don't match when I load them.
After training a model and saving it, I expect it to produce exactly the same results as just before it was saved (assuming no updates were done in between, of course). However, this is not the case when using weight decay. Looks like the weight decay does not apply to the loaded model, even though it is set globally in dynet_config.

Minimal working example:

import dynet_config
dynet_config.set(weight_decay=1e-5)
import dynet as dy

m1 = dy.ParameterCollection()
p1 = m1.add_parameters(1)
t = dy.SimpleSGDTrainer(m1)
p1.expr().forward()
p1.expr().backward()
t.update()
dy.renew_cg()
v1 = p1.expr().value()
dy.save("test", [p1])
m2 = dy.ParameterCollection()
[p2] = dy.load("test", m2)
v2 = p2.expr().value()
assert v1 == v2, "%s != %s" % (v1, v2)
# >>> AssertionError: -1.5506035089492798 != -1.5506190061569214

Changing weight_decay to 0 fixes the problem.

Related to #917.

@danielhers
Copy link
Collaborator Author

Is there a way I can force my existing trained models to use the same weight decay as when they were saved, so I can reproduce their results?

@neubig neubig added the major bug Issues that silently cause incorrect results, break installation on common environments, etc. label Jan 24, 2018
@neubig
Copy link
Contributor

neubig commented Jan 24, 2018

Ouch, this was probably broken when we switched to the new model loading format. There are two ways to fix this (cc @xunzhang ):

  1. Save the current amount of weight decay along with the model and read it in at model load time
  2. When saving parameters, trigger a reset of the weight decay by calling reset_and_rescale_weight_decay()

Before we did number 1 because we didn't want to trigger a heavy operation unrelated to saving every time we save the model, but I think 2 is probably conceptually simpler and wouldn't require changing the model saving format.

@danielhers
Copy link
Collaborator Author

danielhers commented Jan 24, 2018

But since I'm setting the weight decay globally to the same value, shouldn't it apply at model load time too?
Edit: oh, I think I get it - the parameter only determines lambda, but the current weight decay depends on the number of updates performed. Correct?

@xunzhang xunzhang self-assigned this Jan 24, 2018
@xunzhang
Copy link
Collaborator

@neubig will take a look.

@danielhers
Copy link
Collaborator Author

danielhers commented Jan 24, 2018

Thanks. So if I know how many updates have been done to a model and what was the weight decay rate, can I use this information to recover the exact model (as a workaround for now)?

@neubig
Copy link
Contributor

neubig commented Jan 24, 2018

Yes. The current weight decay value will be (1-weight_decay)^num_updates. You can write a script to modify the model file appropriately and it will recover the original model.

@danielhers
Copy link
Collaborator Author

Thanks, it worked!

@xunzhang
Copy link
Collaborator

Addressed the issue in above pull request.
I think both 1 and 2 @neubig mentioned above are not perfect. For 1, I think weight_decay is not the logic of the model file itself, so it should not be included in the saved model files. For 2, we shouldn't reset the weight decay before save. For example the checkpointing case, users may continue training after saving model files.

@neubig
Copy link
Contributor

neubig commented Jan 29, 2018

Fixed by #1206

@neubig neubig closed this as completed Jan 29, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
major bug Issues that silently cause incorrect results, break installation on common environments, etc.
Projects
None yet
Development

No branches or pull requests

3 participants