Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pt): calculate stat during compression if --skip-neighbor-stat #4330

Merged
merged 1 commit into from
Nov 9, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 9, 2024

If --skip-neighbor-stat is set during training, when calling dp compress, first calculate the neighbor stat.

Summary by CodeRabbit

  • New Features

    • Enhanced enable_compression function to accept a training_script parameter for improved error handling and functionality.
    • Updated the compress command to allow specification of a training script during execution.
    • Introduced a new testing framework for models using the --skip-neighbor-stat flag, validating their functionality.
  • Bug Fixes

    • Improved error handling for cases where the model's minimum neighbor distance is not saved.
  • Tests

    • Added a new test class and methods to validate the functionality of models initialized with skip neighbor statistics.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Copy link
Contributor

coderabbitai bot commented Nov 9, 2024

📝 Walkthrough

Walkthrough

The changes in this pull request primarily focus on enhancing the enable_compression function in compress.py by adding a new optional parameter, training_script, which allows users to specify a training script if the model's minimum neighbor distance is not available. Additionally, the main.py file is updated to accommodate this new parameter in the compress command. The testing framework is also expanded with a new function and class to validate the behavior of models initialized with a specific training scenario involving the --skip-neighbor-stat flag.

Changes

File Change Summary
deepmd/pt/entrypoints/compress.py Updated enable_compression function to include a new optional parameter training_script. Added logic for handling cases where the model's minimum neighbor distance is not saved, including logging and raising a ValueError. New imports for utilities added.
deepmd/pt/entrypoints/main.py Modified the main function to include the training_script parameter in the enable_compression function call within the compress command block. No other significant changes.
source/tests/pt/test_model_compression_se_a.py Introduced a new function _init_models_skip_neighbor_stat() and a new test class TestSkipNeighborStat to validate model functionality with the --skip-neighbor-stat flag. Updated setUpModule and added new assertions in test methods.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Main
    participant Compress

    User->>Main: Execute compress command with training_script
    Main->>Compress: Call enable_compression(input_file, output, ..., training_script)
    Compress->>Compress: Check minimum neighbor distance
    alt Min neighbor distance not saved
        Compress->>Compress: Log informational message
        Compress->>Compress: Raise ValueError if training_script not provided
    else Min neighbor distance saved
        Compress->>Compress: Load training data
        Compress->>Compress: Compute minimum neighbor distance
        Compress->>Compress: Update model's min_nbor_dist
    end
Loading

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions github-actions bot added the Python label Nov 9, 2024
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (2)
deepmd/pt/entrypoints/compress.py (1)

45-54: Enhance error handling for training script loading

The error handling could be more robust for the training script scenario.

Consider wrapping the training script loading in a try-except block:

     if training_script is None:
         raise ValueError(
             "The model does not have a minimum neighbor distance, "
             "so the training script and data must be provided "
             "(via -t,--training-script)."
         )
+    try:
+        jdata = j_loader(training_script)
+    except (json.JSONDecodeError, FileNotFoundError) as e:
+        raise ValueError(f"Failed to load training script: {str(e)}")
source/tests/pt/test_model_compression_se_a.py (1)

608-730: Consider reducing code duplication through test inheritance.

While the implementation is correct and thorough, there's significant code duplication with other test classes (e.g., TestDeepPotAPBC). Consider refactoring to extract common test logic into a base test class.

Example refactoring approach:

+class BaseDeepPotTest(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        cls.coords = np.array([
+            12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
+            00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
+            3.51, 2.51, 2.60, 4.27, 3.22, 1.56,
+        ])
+        cls.atype = [0, 1, 1, 0, 1, 1]
+        cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])
+
+    def _test_attrs(self, dp_original, dp_compressed):
+        # Common attribute tests
+        ...
+
+    def _test_1frame(self, dp_original, dp_compressed):
+        # Common single frame tests
+        ...
+
-class TestSkipNeighborStat(unittest.TestCase):
+class TestSkipNeighborStat(BaseDeepPotTest):
     @classmethod
     def setUpClass(cls):
+        super().setUpClass()
         cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT)
         cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT)
-        cls.coords = np.array([...])  # Remove duplicated setup
-        cls.atype = [0, 1, 1, 0, 1, 1]
-        cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

     def test_attrs(self):
-        # Replace with call to base class method
+        self._test_attrs(self.dp_original, self.dp_compressed)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 22123aa and 5014740.

📒 Files selected for processing (3)
  • deepmd/pt/entrypoints/compress.py (2 hunks)
  • deepmd/pt/entrypoints/main.py (1 hunks)
  • source/tests/pt/test_model_compression_se_a.py (2 hunks)
🔇 Additional comments (7)
deepmd/pt/entrypoints/compress.py (4)

3-27: LGTM: Well-organized imports

The new imports are properly organized and all are utilized in the implementation.


29-29: LGTM: Proper logger setup

Good practice using module-level logger setup.


36-38: LGTM: Backward-compatible signature change

The new optional parameter maintains backward compatibility while adding the needed functionality.


45-74: Verify handling of --skip-neighbor-stat scenario

Let's verify that this implementation correctly handles the case when --skip-neighbor-stat is used during training.

✅ Verification successful

Based on the search results, I can now generate a final response as we have sufficient context about the skip-neighbor-stat functionality:

Implementation correctly handles --skip-neighbor-stat scenario

The current implementation in compress.py is correct as it:

  • Only computes minimum neighbor distance when it's not available in the model
  • Requires training script in such cases to compute it from training data
  • Aligns with test cases in test_model_compression_se_a.py which explicitly test compression after training with --skip-neighbor-stat

The test file demonstrates the expected workflow:

  1. Train with --skip-neighbor-stat
  2. Freeze the model
  3. Compress with training script provided
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for skip-neighbor-stat usage in training code and tests
rg -l "skip-neighbor-stat" --type python

# Search for related test cases
rg -l "test.*compress.*skip.*neighbor" --type python

Length of output: 157


Script:

#!/bin/bash
# Search for skip-neighbor-stat usage without file type restriction
rg -l "skip.?neighbor.?stat"

# Search for related test cases
rg -l "test.*compress.*skip.*neighbor"

# Look for neighbor stat related code in training
rg "neighbor.*stat" -A 5 deepmd/pt/entrypoints/train.py

Length of output: 538


Script:

#!/bin/bash
# Check content of relevant test files for compression with skip-neighbor-stat
rg -A 10 "skip.?neighbor.?stat" source/tests/pt/test_model_compression_se_a.py source/tests/pt/test_init_model.py

# Look for neighbor stat related code in training
rg "neighbor.*stat" -A 5 deepmd/pt/entrypoints/train.py deepmd/pt/entrypoints/compress.py

# Check main.py for skip-neighbor-stat argument handling
rg "skip.?neighbor.?stat" -A 5 deepmd/pt/entrypoints/main.py

Length of output: 4676

deepmd/pt/entrypoints/main.py (1)

568-568: LGTM! Please verify argument parser and error handling.

The addition of training_script parameter aligns with the PR objective to calculate neighbor statistics during compression when they were skipped during training.

Let's verify the argument parser setup and error handling:

✅ Verification successful

Let me search for the argument parser definition and the usage of training_script in the compression flow.


Let me check the argument parser definition in the main file since it wasn't found in the previous search.


Let me try one final search to locate the argument parser setup.


Based on the gathered information, I can now provide the final response:

LGTM! The training_script parameter is properly handled.

The code shows proper implementation of the training_script parameter:

  • It's correctly passed from FLAGS to enable_compression function
  • The enable_compression function has proper error handling that raises a clear error message when training_script is required but not provided
  • The error message explains why the parameter is needed: "The model does not have a minimum neighbor distance, so the training script and data must be provided"
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if training_script argument is properly defined and handled

# Test 1: Search for training_script in argument parser definition
echo "Checking argument parser definition..."
rg -A 5 "training_script.*add_argument" 

# Test 2: Search for error handling in enable_compression
echo "Checking error handling in enable_compression..."
rg -A 10 "def enable_compression" 

Length of output: 23386


Script:

#!/bin/bash
# Search for argument parser definition for training_script
echo "Checking argument parser definition..."
rg -B 5 -A 2 "training_script.*add_argument|training_script.*help" deepmd/pt/entrypoints/main.py

# Search for training_script usage in compression flow
echo "Checking training_script usage in compression..."
rg -B 5 -A 5 "training_script" deepmd/pt/entrypoints/compress.py

Length of output: 1406


Script:

#!/bin/bash
# Search for compress subcommand parser setup
echo "Checking compress subcommand parser setup..."
rg -B 10 -A 10 "compress.*add_parser" deepmd/pt/entrypoints/main.py

# Search for training script related argument
echo "Checking training script argument..."
rg -B 2 -A 2 "\-t|--training.*script" deepmd/pt/entrypoints/main.py

Length of output: 548


Script:

#!/bin/bash
# Search for the entire compress subcommand setup
echo "Checking compress subcommand setup..."
rg -B 20 -A 20 "def compress\(|compress_parser" deepmd/pt/entrypoints/main.py

# Also check if training_script is used in the compress function
echo "Checking training_script usage in main..."
rg -B 5 -A 5 "FLAGS\.training_script" deepmd/pt/entrypoints/main.py

Length of output: 652

source/tests/pt/test_model_compression_se_a.py (2)

77-102: LGTM! Well-structured initialization function.

The implementation follows the established pattern of other initialization functions while correctly incorporating the new --skip-neighbor-stat functionality. Good use of consistent naming conventions with the suffix approach.


104-118: LGTM! Clean module setup changes.

The changes properly integrate the new skip neighbor stat model variables while maintaining the existing initialization pattern.

deepmd/pt/entrypoints/compress.py Show resolved Hide resolved
Copy link

codecov bot commented Nov 9, 2024

Codecov Report

Attention: Patch coverage is 94.73684% with 1 line in your changes missing coverage. Please review.

Project coverage is 84.53%. Comparing base (22123aa) to head (5014740).
Report is 125 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/pt/entrypoints/compress.py 94.73% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4330      +/-   ##
==========================================
- Coverage   84.61%   84.53%   -0.08%     
==========================================
  Files         571      571              
  Lines       53163    53182      +19     
  Branches     3059     3055       -4     
==========================================
- Hits        44982    44956      -26     
- Misses       7218     7262      +44     
- Partials      963      964       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@njzjz njzjz added this pull request to the merge queue Nov 9, 2024
Merged via the queue into deepmodeling:devel with commit c12bc01 Nov 9, 2024
60 checks passed
@njzjz njzjz deleted the pt-compress-skip-neighbor-stat branch November 9, 2024 10:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants