-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added validation to the training script and exposed more of the settngs for network #113
base: main
Are you sure you want to change the base?
Conversation
5b543a5
to
1a36344
Compare
…ngs for network Signed-off-by: LimitingFactor <aswift0n3@gmail.com>
Updated the wandb initialisation Signed-off-by: LimitingFactor <aswift0n3@gmail.com>
f715d7e
to
8ad3efc
Compare
Hi @cfd1 , thanks for the PR! Is this ready for review? |
@mnabian yes, it's ready for a review please |
do_concat_trick: bool = False | ||
num_processor_checkpoint_segments: int = 0 | ||
# activation_fn: str = "relu" | ||
|
||
# performance configs | ||
amp: bool = False | ||
jit: bool = False | ||
|
||
# test & visualization configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to "visualization configs"
def get_options(): | ||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--entity", "-e", type=str, default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add these to the config file and not use argparse here?
# initialize distributed manager | ||
DistributedManager.initialize() | ||
dist = DistributedManager() | ||
|
||
# initialize loggers | ||
if wandb: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this if statement necessary? This can be done by changing the mode
argument in initialize_wandb
|
||
if __name__ == "__main__": | ||
@torch.no_grad() | ||
def validation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very similar to the predict
method in `inference.py. Can we make the prediction code more modular to avoid code duplication?
# Train the model | ||
tmp_start = time.time() | ||
loss_train_agg = 0 | ||
for graph in tqdm(trainer.dataloader): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does tqdm
play nicely in the multi-gpu runs?
Modulus Pull Request
Description
Checklist
Dependencies