-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(pt): calculate stat during compression if --skip-neighbor-stat
#4330
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 WalkthroughWalkthroughThe changes in this pull request primarily focus on enhancing the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Main
participant Compress
User->>Main: Execute compress command with training_script
Main->>Compress: Call enable_compression(input_file, output, ..., training_script)
Compress->>Compress: Check minimum neighbor distance
alt Min neighbor distance not saved
Compress->>Compress: Log informational message
Compress->>Compress: Raise ValueError if training_script not provided
else Min neighbor distance saved
Compress->>Compress: Load training data
Compress->>Compress: Compute minimum neighbor distance
Compress->>Compress: Update model's min_nbor_dist
end
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (2)
deepmd/pt/entrypoints/compress.py (1)
45-54
: Enhance error handling for training script loadingThe error handling could be more robust for the training script scenario.
Consider wrapping the training script loading in a try-except block:
if training_script is None: raise ValueError( "The model does not have a minimum neighbor distance, " "so the training script and data must be provided " "(via -t,--training-script)." ) + try: + jdata = j_loader(training_script) + except (json.JSONDecodeError, FileNotFoundError) as e: + raise ValueError(f"Failed to load training script: {str(e)}")source/tests/pt/test_model_compression_se_a.py (1)
608-730
: Consider reducing code duplication through test inheritance.While the implementation is correct and thorough, there's significant code duplication with other test classes (e.g.,
TestDeepPotAPBC
). Consider refactoring to extract common test logic into a base test class.Example refactoring approach:
+class BaseDeepPotTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.coords = np.array([ + 12.83, 2.56, 2.18, 12.09, 2.87, 2.74, + 00.25, 3.32, 1.68, 3.36, 3.00, 1.81, + 3.51, 2.51, 2.60, 4.27, 3.22, 1.56, + ]) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def _test_attrs(self, dp_original, dp_compressed): + # Common attribute tests + ... + + def _test_1frame(self, dp_original, dp_compressed): + # Common single frame tests + ... + -class TestSkipNeighborStat(unittest.TestCase): +class TestSkipNeighborStat(BaseDeepPotTest): @classmethod def setUpClass(cls): + super().setUpClass() cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) - cls.coords = np.array([...]) # Remove duplicated setup - cls.atype = [0, 1, 1, 0, 1, 1] - cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) def test_attrs(self): - # Replace with call to base class method + self._test_attrs(self.dp_original, self.dp_compressed)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
deepmd/pt/entrypoints/compress.py
(2 hunks)deepmd/pt/entrypoints/main.py
(1 hunks)source/tests/pt/test_model_compression_se_a.py
(2 hunks)
🔇 Additional comments (7)
deepmd/pt/entrypoints/compress.py (4)
3-27
: LGTM: Well-organized imports
The new imports are properly organized and all are utilized in the implementation.
29-29
: LGTM: Proper logger setup
Good practice using module-level logger setup.
36-38
: LGTM: Backward-compatible signature change
The new optional parameter maintains backward compatibility while adding the needed functionality.
45-74
: Verify handling of --skip-neighbor-stat scenario
Let's verify that this implementation correctly handles the case when --skip-neighbor-stat
is used during training.
✅ Verification successful
Based on the search results, I can now generate a final response as we have sufficient context about the skip-neighbor-stat functionality:
Implementation correctly handles --skip-neighbor-stat scenario
The current implementation in compress.py
is correct as it:
- Only computes minimum neighbor distance when it's not available in the model
- Requires training script in such cases to compute it from training data
- Aligns with test cases in
test_model_compression_se_a.py
which explicitly test compression after training with --skip-neighbor-stat
The test file demonstrates the expected workflow:
- Train with --skip-neighbor-stat
- Freeze the model
- Compress with training script provided
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for skip-neighbor-stat usage in training code and tests
rg -l "skip-neighbor-stat" --type python
# Search for related test cases
rg -l "test.*compress.*skip.*neighbor" --type python
Length of output: 157
Script:
#!/bin/bash
# Search for skip-neighbor-stat usage without file type restriction
rg -l "skip.?neighbor.?stat"
# Search for related test cases
rg -l "test.*compress.*skip.*neighbor"
# Look for neighbor stat related code in training
rg "neighbor.*stat" -A 5 deepmd/pt/entrypoints/train.py
Length of output: 538
Script:
#!/bin/bash
# Check content of relevant test files for compression with skip-neighbor-stat
rg -A 10 "skip.?neighbor.?stat" source/tests/pt/test_model_compression_se_a.py source/tests/pt/test_init_model.py
# Look for neighbor stat related code in training
rg "neighbor.*stat" -A 5 deepmd/pt/entrypoints/train.py deepmd/pt/entrypoints/compress.py
# Check main.py for skip-neighbor-stat argument handling
rg "skip.?neighbor.?stat" -A 5 deepmd/pt/entrypoints/main.py
Length of output: 4676
deepmd/pt/entrypoints/main.py (1)
568-568
: LGTM! Please verify argument parser and error handling.
The addition of training_script
parameter aligns with the PR objective to calculate neighbor statistics during compression when they were skipped during training.
Let's verify the argument parser setup and error handling:
✅ Verification successful
Let me search for the argument parser definition and the usage of training_script
in the compression flow.
Let me check the argument parser definition in the main file since it wasn't found in the previous search.
Let me try one final search to locate the argument parser setup.
Based on the gathered information, I can now provide the final response:
LGTM! The training_script parameter is properly handled.
The code shows proper implementation of the training_script parameter:
- It's correctly passed from FLAGS to enable_compression function
- The enable_compression function has proper error handling that raises a clear error message when training_script is required but not provided
- The error message explains why the parameter is needed: "The model does not have a minimum neighbor distance, so the training script and data must be provided"
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if training_script argument is properly defined and handled
# Test 1: Search for training_script in argument parser definition
echo "Checking argument parser definition..."
rg -A 5 "training_script.*add_argument"
# Test 2: Search for error handling in enable_compression
echo "Checking error handling in enable_compression..."
rg -A 10 "def enable_compression"
Length of output: 23386
Script:
#!/bin/bash
# Search for argument parser definition for training_script
echo "Checking argument parser definition..."
rg -B 5 -A 2 "training_script.*add_argument|training_script.*help" deepmd/pt/entrypoints/main.py
# Search for training_script usage in compression flow
echo "Checking training_script usage in compression..."
rg -B 5 -A 5 "training_script" deepmd/pt/entrypoints/compress.py
Length of output: 1406
Script:
#!/bin/bash
# Search for compress subcommand parser setup
echo "Checking compress subcommand parser setup..."
rg -B 10 -A 10 "compress.*add_parser" deepmd/pt/entrypoints/main.py
# Search for training script related argument
echo "Checking training script argument..."
rg -B 2 -A 2 "\-t|--training.*script" deepmd/pt/entrypoints/main.py
Length of output: 548
Script:
#!/bin/bash
# Search for the entire compress subcommand setup
echo "Checking compress subcommand setup..."
rg -B 20 -A 20 "def compress\(|compress_parser" deepmd/pt/entrypoints/main.py
# Also check if training_script is used in the compress function
echo "Checking training_script usage in main..."
rg -B 5 -A 5 "FLAGS\.training_script" deepmd/pt/entrypoints/main.py
Length of output: 652
source/tests/pt/test_model_compression_se_a.py (2)
77-102
: LGTM! Well-structured initialization function.
The implementation follows the established pattern of other initialization functions while correctly incorporating the new --skip-neighbor-stat
functionality. Good use of consistent naming conventions with the suffix approach.
104-118
: LGTM! Clean module setup changes.
The changes properly integrate the new skip neighbor stat model variables while maintaining the existing initialization pattern.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4330 +/- ##
==========================================
- Coverage 84.61% 84.53% -0.08%
==========================================
Files 571 571
Lines 53163 53182 +19
Branches 3059 3055 -4
==========================================
- Hits 44982 44956 -26
- Misses 7218 7262 +44
- Partials 963 964 +1 ☔ View full report in Codecov by Sentry. |
If
--skip-neighbor-stat
is set during training, when callingdp compress
, first calculate the neighbor stat.Summary by CodeRabbit
New Features
enable_compression
function to accept atraining_script
parameter for improved error handling and functionality.compress
command to allow specification of a training script during execution.--skip-neighbor-stat
flag, validating their functionality.Bug Fixes
Tests