Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log wandb step using wandb native step arg in addition to the "step" key. #613

Merged
merged 4 commits into from
Oct 11, 2023

Conversation

gabrielilharco
Copy link
Collaborator

Currently we are passing step to wandb as another metric instead of using the step argument to wandb.log (see https://docs.wandb.ai/ref/python/log). This causes two "step" variables to be logged and can cause inconsistencies.

@rom1504
Copy link
Collaborator

rom1504 commented Aug 28, 2023

Making this is change now would make it difficult to compare new and old runs

Is there any benefit?

@gabrielilharco
Copy link
Collaborator Author

gabrielilharco commented Aug 28, 2023

We can keep 'step': step inside the dict for backwards compatibility. Probably still good to fix the wandb step, I see no harm in doing so. I think having a "step" and a "Step" variables that differ from each other can be confusing for new users of the codebase.

@EIFY
Copy link
Contributor

EIFY commented Sep 28, 2023

Hmm, shouldn't we also fix the logging line in evaluate()?

if args.wandb:
assert wandb is not None, 'Please install wandb.'
for name, val in metrics.items():
wandb.log({f"val/{name}": val, 'epoch': epoch})

To be consistent we probably should also use the number of training steps so far here, e.g.

if args.wandb:
    assert wandb is not None, 'Please install wandb.'
    dataloader = data['train'].dataloader
    num_batches_per_epoch = dataloader.num_batches // args.accum_freq
    step = num_batches_per_epoch * epoch
    for name, val in metrics.items():
        wandb.log({f"val/{name}": val, 'epoch': epoch}, step=step)

@EIFY
Copy link
Contributor

EIFY commented Sep 28, 2023

Upon closer look I think we can consolidate prefixing and wandb.log() calls to reduce remote calls over the network, e.g.

# In train_one_epoch():

            log_data = {"train/" + k: v for k, v in log_data.items()}
            if tb_writer is not None:
                for name, val in log_data.items():
                    tb_writer.add_scalar(name, val, step)
            log_data['step'] = step
            if args.wandb:
                assert wandb is not None, 'Please install wandb.'
                wandb.log(log_data, step=step)
# In evaluate():

    log_data = {"val/" + k: v for k, v in metrics.items()}

    if args.save_logs:
        if tb_writer is not None:
            for name, val in log_data.items():
                tb_writer.add_scalar(name, val, epoch)

        with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
            f.write(json.dumps(metrics))
            f.write("\n")

    if args.wandb:
        assert wandb is not None, 'Please install wandb.'
        dataloader = data['train'].dataloader
        num_batches_per_epoch = dataloader.num_batches // args.accum_freq
        step = num_batches_per_epoch * epoch
        log_data['epoch'] = epoch
        wandb.log(log_data, step=step)

Overall I strongly vote for fixing this. Current wandb.log({name: val, 'step': step}) call is confusing by itself and the issue is exacerbated by how we iterate over log_data.items() since by default wandb's internal "step" increments after each call. Even worse, len(log_data) depends on the loss we use (e.g. DistillClipLoss adds an extra "distill_loss"), so even comparable training runs may not be aligned along the "step" x-axis on the wandb dashboard.

@gabrielilharco
Copy link
Collaborator Author

Thanks @EIFY. Just pushed some changes. Tested on a run and wandb is looking good. We still have a lowercase step variable being logged so everything is backwards compatible.

@rom1504
Copy link
Collaborator

rom1504 commented Sep 28, 2023

the current Step should simply not be used in wandb ui

I think changing this will make it really confusing for comparing old runs where Step is meaningless with new runs when Step is meaningful

if we do change it, can we at least add a big warning in some logging section in the readme?
and indeed also keep the old step key

@gabrielilharco
Copy link
Collaborator Author

Even with the change, people can keep doing comparisons the exact same way they did before, by looking at lowercase step.

I'm also fine with keeping things as they are. I think it's cleaner and less confusing to have a Step that's meaningful, but I also don't think it's a big deal. So up to you @rom1504

@EIFY
Copy link
Contributor

EIFY commented Sep 28, 2023

@rom1504 wandb ui defaults to its own Step and honestly I didn't know it's customizable until now. I think it would be nice for the dashboard to behave as expected at the first sight, and now there is the additional benefit of aligning the validation metrics by Step.

@EIFY
Copy link
Contributor

EIFY commented Oct 10, 2023

@rom1504 Could we merge this? I think concerns have been addressed.

@rom1504 rom1504 changed the title Fix wandb step Log wandb step using wandb native step arg in addition to the "step" key. Oct 11, 2023
@rom1504 rom1504 merged commit 4ccb752 into mlfoundations:main Oct 11, 2023
5 checks passed
@rom1504
Copy link
Collaborator

rom1504 commented Oct 11, 2023

Yes merged

Interpause pushed a commit to Interpause/open_clip that referenced this pull request May 23, 2024
…key. (mlfoundations#613)

* wandb step fix

* backwards compat fix

* update wandb calls

* update readme
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants