Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prodigy optimizer #895

Closed
wants to merge 8 commits into from
Closed

Add Prodigy optimizer #895

wants to merge 8 commits into from

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Dec 2, 2023

Closes #894.

This is a tentative to implement this optimizer. It's based on a pytorch implementation from here.
That implementation is similar to adam, and so were the files where I started at. After the pytorch->dfdx translation, I gave it a try at cuda kernel - there's probably a lot of naivety and some incorrectness.
For the fp16 side of things, I couldn't compile dfdx even without the prodigy changes, so I'm not sure if they are correct.

With the reasons above and some below, I'm leaving this as a draft.
If I happen to find it good use in later experiments or find bugs in it, I hope to update it here.

Testing

I've added some very basic rust tests and compared with the equivalent pytorch test.

For every pytorch test, they are somewhat like this:

import torch
import numpy as np
from prodigyopt import Prodigy

x = torch.tensor([[0.1, 0.2]], requires_grad=False).float()
m = torch.nn.Linear(2, 2, bias=False)
with torch.no_grad():
    w = torch.tensor([[3., 4.], [5., 6.]], requires_grad=True).float()
    m.weight = torch.nn.Parameter(w)
y = torch.tensor([[7e2, 8e2]], requires_grad=False).float()
loss_fn = torch.nn.MSELoss()
opt = Prodigy(m.parameters(), lr=1.) # this is the optimizer settings
preds = []
grads = []
weights = []
for i in range(0, 10):
    pred = m(x)
    preds.append(pred.detach().numpy().tolist())
    loss = loss_fn(pred, y)
    loss.backward()
    grads.append(m.weight.grad.detach().numpy().tolist())
    opt.step()
    weights.append(m.weight.detach().numpy().tolist())
    opt.zero_grad()
print(f"preds: {preds}")
print(f"grads: {grads}")
print(f"weights: {weights}")

Where the only change between each test happens on the # this is the optimizer settings line.

  1. The first test compares against that default optimizer settings:
    fn test_default_prodigy_params() {
    let (dev, x, y, m) = init();
    let opt = Prodigy::new(&m, Default::default());
  2. The second compares with the settings:
    fn test_custom_prodigy_params() {
    let (dev, x, y, m) = init();
    let opt = Prodigy::new(
    &m,
    ProdigyConfig {
    lr: 2e1,
    betas: [0.5, 0.25],
    beta3: Some(0.4),
    eps: 1e-8,
    weight_decay: None,
    use_bias_correction: true,
    safeguard_warmup: true,
    d0: 1e-5,
    d_coef: 0.5,
    growth_rate: 1.02,
    },
    );
opt = Prodigy(m.parameters(), lr=2e1, betas=(0.5, 0.25),beta3=0.4,eps=1e-8,use_bias_correction=True,safeguard_warmup=True,d0=1e-5,d_coef=0.5,growth_rate=1.02)
  1. The third compares with the settings:
    fn test_prodigy_l2_decay() {
    let (dev, x, y, m) = init();
    let opt = Prodigy::new(
    &m,
    ProdigyConfig {
    betas: [0.5, 0.25],
    beta3: Some(0.4),
    weight_decay: Some(WeightDecay::L2(1.0)),
    ..Default::default()
    },
    );
opt = Prodigy(m.parameters(), betas=(0.5, 0.25),beta3=0.4,weight_decay=1.0,decouple=False)
  1. The fourth compares with the settings:
    fn test_prodigy_decoupled_decay() {
    let (dev, x, y, m) = init();
    let opt = Prodigy::new(
    &m,
    ProdigyConfig {
    betas: [0.5, 0.25],
    beta3: Some(0.4),
    weight_decay: Some(WeightDecay::Decoupled(1e3)),
    ..Default::default()
    },
    );
opt = Prodigy(m.parameters(), betas=(0.5, 0.25),beta3=0.4,weight_decay=1.0,decouple=True)

Besides this I've tried making a comparison with a unet experiment, and adam seemed to be much better, so I may have implemented something incorrectly. So this is another reason why this PR is still a draft.

rainiwu and others added 8 commits January 26, 2024 00:29

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Remove ftz

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Avoid ci errors
@swfsql
Copy link
Contributor Author

swfsql commented Mar 1, 2024

I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion.
Note that when I tried using this model I didn't had any success!

@swfsql swfsql closed this Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Prodigy optimizer
2 participants