Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test review assistant - DO NOT MERGE!!! #1

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

cuongvng
Copy link
Owner

No description provided.

src/gen_sr.py Show resolved Hide resolved
@cuongvng
Copy link
Owner Author

src/eval_bilinear.py

  1. Bug and Functionality Issues:

    • Line 4: Removing the SSIM import while the SSIM function is still being used will result in a NameError.
    -from dippykit.metrics import PSNR, SSIM
    +from dippykit.metrics import PSNR
    • Suggestion: Re-add the SSIM import.
    from dippykit.metrics import PSNR, SSIM
  2. Unintended Comment:

    • Line 48: Typo ssnr seems to be unintended. Should be ssim.
    - # Remove ssnr from metrics
    + # Remove ssim from metrics

src/gen_sr.py

  1. Performance and Maintainability:
    • Lines 8-9: Removal of argparse and Pathlib might reduce the flexibility of the script if these were used previously for handling paths and command-line arguments.

    • Suggestion: If you are certain that these modules aren't needed, ensure that the related functionality is covered elsewhere, otherwise consider retaining them for better flexibility.

import argparse
from pathlib import Path

src/generator.py

  1. Functionality Issues:
    • Lines 15-20: Removal of pixel_shufflers sequence, which could be integral to the upscaling process. Without this block, the upscaling factor might not be correctly handled.

    • Suggestion: If the pixel_shufflers blocks are crucial for upscaling, they should be reintroduced. Otherwise, ensure any equivalent upscaling process is in place.

self.pixel_shufflers = nn.Sequential()
for i in range(int(np.log2(upscale_factor))):
    self.pixel_shufflers.add_module(f"pixel_shuffle_blk_{i}", PixelShufflerBlock(in_channels=64, upscale_factor=2))

src/train.py

  1. Documentation:

    • Lines 34-36: Re-commenting the detailed docstring will reduce code readability and maintainability.
    -	'''
    -	:param `resume_training`: whether to continue training from previous checkpoint or not.
    -	If checkpoints cannot be found, train from beginning, regardless of `resume_training`.
    -	'''
    +	# '''
    +	# :param `resume_training`: whether to continue training from previous checkpoint or not.
    +	# If checkpoints cannot be found, train from beginning, regardless of `resume_training`.
    +	# '''
    • Suggestion: Restore the docstring to provide clear documentation.
    '''
    :param `resume_training`: whether to continue training from previous checkpoint or not.
    If checkpoints cannot be found, train from beginning, regardless of `resume_training`.
    '''
  2. Functionality Issues:

    • Lines 60-65: Removing the print statement might be fine, but ensure that there is equivalent logging to track training starts.
    - print("Training from start ...")
  3. Performance Concerns:

    • Lines 72-93: Removing the warmup phase can impact the performance of training, especially if warmup directly affects model convergence.

    • Suggestion: Reintroduce the warmup phase, particularly if it is crucial for model training.

    ## Warm up G
    if warmup:
        for w in range(WARMUP_EPOCHS):
            print(f"\nWarmup: {w+1}")
            for (batch, hr_batch), lr_batch in zip(enumerate(hr_train_loader), lr_train_loader):
                hr_img, lr_img = hr_batch[0].to(device), lr_batch[0].to(device)
                optimizer_G.zero_grad()
    
                sr_img = G(lr_img)
                err_G = warmup_loss(sr_img, hr_img)
                err_G.backward()
    
                optimizer_G.step()
                if batch % 10 == 0:
                    print(f"\tBatch: {batch + 1}/{len(data_train_hr) // BATCH_SIZE}")
                    print(f"\tMAE G: {err_G.item():.4f}")

By addressing these suggestions, the modified code should maintain both its functionality and performance integrity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant