-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(jax): calculate virial in call_lower
#4304
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
📝 Walkthrough📝 WalkthroughWalkthroughThe changes involve a modification to the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Model
participant ExtendedVirial
User->>Model: Call forward_common_atomic()
Model->>ExtendedVirial: Calculate extended_virial
ExtendedVirial-->>Model: Return extended_virial
Model->>Model: Compute model_predict (sum of extended_virial)
Model-->>User: Return model_predict
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## devel #4304 +/- ##
==========================================
- Coverage 84.38% 84.23% -0.16%
==========================================
Files 563 570 +7
Lines 52810 53072 +262
Branches 3054 3054
==========================================
+ Hits 44564 44704 +140
- Misses 7287 7410 +123
+ Partials 959 958 -1 ☔ View full report in Codecov by Sentry. |
Could you please add a UT for detecting the bug? |
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
added |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (2)
source/tests/consistent/model/test_ener.py (2)
272-304
: Add class docstring to document the purpose of TestEnerLower.The class should have a docstring explaining its purpose and how it differs from TestEner.
262-463
: Well-structured test implementation addressing the PR objectives.This test class effectively:
- Implements the requested unit tests for virial calculation
- Provides consistent testing across multiple backends
- Handles the lower-level interface appropriately
Consider adding more specific test cases that would have caught the original bug.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
source/tests/consistent/model/test_ener.py
(4 hunks)
🔇 Additional comments (5)
source/tests/consistent/model/test_ener.py (5)
9-20
: LGTM: Necessary imports added for neighbor list and coordinate handling.
The new imports provide essential utilities for the lower-level testing implementation.
53-55
: LGTM: JAX array conversion utilities properly imported.
The imports align with the pattern used for other backends (PyTorch, TensorFlow).
345-400
: LGTM: Comprehensive setup with proper coordinate preprocessing.
The setup correctly:
- Initializes base coordinates and types
- Normalizes coordinates
- Extends coordinates with ghost atoms
- Builds neighbor lists
437-463
: LGTM: Proper extraction of virial-related outputs for each backend.
The extract_ret method correctly handles:
- Extended virial for PyTorch
- Energy derivative for JAX
- Consistent shape normalization via ravel
404-435
: LGTM: Backend-specific implementations properly handle virial calculations.
Each backend's implementation correctly:
- Converts tensors to/from the backend's format
- Sets
do_atomic_virial=True
for virial calculation - Handles the extended coordinates and mapping
✅ Verification successful
Backend implementations correctly handle virial calculations across the codebase
The verification confirms that:
- All backend implementations (
dp
,pt
,jax
) consistently setdo_atomic_virial=True
- The virial calculations are thoroughly tested across different test files:
- Unit tests in
test_dp_model.py
verify virial consistency between forward and lower-level implementations - Integration tests in
test_io.py
validate virial calculations across different backends - Spin-aware virial calculations are tested in
test_ener_spin_model.py
- Unit tests in
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Verify that virial calculation is consistent across backends
# Look for other virial-related tests
rg -l "do_atomic_virial.*=.*True" "source/tests"
Length of output: 286
Script:
#!/bin/bash
# Check the implementation of virial calculations in other test files
rg -A 10 "do_atomic_virial.*=.*True" "source/tests" --type py
Length of output: 12666
Summary by CodeRabbit
TestEnerLower
, to evaluate lower-level energy models, excluding TensorFlow functionality.