Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(jax): calculate virial in call_lower #4304

Merged
merged 2 commits into from
Nov 5, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 4, 2024

Summary by CodeRabbit

  • New Features
    • Enhanced output of the model by providing a reduced form of the virial tensor, improving usability for further calculations and analyses.
    • Introduced a new test class, TestEnerLower, to evaluate lower-level energy models, excluding TensorFlow functionality.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz njzjz requested a review from wanghan-iapcm November 4, 2024 04:39
@njzjz njzjz marked this pull request as ready for review November 4, 2024 04:39
Copy link
Contributor

coderabbitai bot commented Nov 4, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The changes involve a modification to the forward_common_atomic function within the deepmd/jax/model/base_model.py file. A new line is added to compute a reduced version of the extended_virial output. This line calculates the sum of extended_virial along the first axis and assigns it to model_predict with a key that appends "_redu" to kk_derv_c. The overall structure and logic of the function remain unchanged.

Changes

File Path Change Summary
deepmd/jax/model/base_model.py Modified forward_common_atomic to compute a reduced version of extended_virial by summing it along the first axis and assigning it to model_predict.
source/tests/consistent/model/test_ener.py Added a new test class TestEnerLower to test energy models without TensorFlow, including specific evaluation methods and properties for lower-level interactions.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant Model
    participant ExtendedVirial

    User->>Model: Call forward_common_atomic()
    Model->>ExtendedVirial: Calculate extended_virial
    ExtendedVirial-->>Model: Return extended_virial
    Model->>Model: Compute model_predict (sum of extended_virial)
    Model-->>User: Return model_predict
Loading

Possibly related PRs

  • feat(jax): energy model (no grad support) #4226: The changes in this PR involve modifications to the forward_common_atomic function, which is directly related to the changes made in the main PR that also modifies this function to compute a reduced version of the extended_virial output.
  • feat(jax): atomic virial #4290: This PR also modifies the forward_common_atomic function to handle the do_atomic_virial parameter, which is relevant to the changes in the main PR that enhance the output of the function related to the virial tensor calculations.

Suggested reviewers

  • wanghan-iapcm
  • iProzd

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions github-actions bot added the Python label Nov 4, 2024
Copy link

codecov bot commented Nov 4, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 84.23%. Comparing base (bfbe2ed) to head (0f30e01).
Report is 5 commits behind head on devel.

Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4304      +/-   ##
==========================================
- Coverage   84.38%   84.23%   -0.16%     
==========================================
  Files         563      570       +7     
  Lines       52810    53072     +262     
  Branches     3054     3054              
==========================================
+ Hits        44564    44704     +140     
- Misses       7287     7410     +123     
+ Partials      959      958       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@wanghan-iapcm
Copy link
Collaborator

Could you please add a UT for detecting the bug?

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz
Copy link
Member Author

njzjz commented Nov 5, 2024

Could you please add a UT for detecting the bug?

added

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (2)
source/tests/consistent/model/test_ener.py (2)

272-304: Add class docstring to document the purpose of TestEnerLower.

The class should have a docstring explaining its purpose and how it differs from TestEner.


262-463: Well-structured test implementation addressing the PR objectives.

This test class effectively:

  1. Implements the requested unit tests for virial calculation
  2. Provides consistent testing across multiple backends
  3. Handles the lower-level interface appropriately

Consider adding more specific test cases that would have caught the original bug.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 71255cb and 0f30e01.

📒 Files selected for processing (1)
  • source/tests/consistent/model/test_ener.py (4 hunks)
🔇 Additional comments (5)
source/tests/consistent/model/test_ener.py (5)

9-20: LGTM: Necessary imports added for neighbor list and coordinate handling.

The new imports provide essential utilities for the lower-level testing implementation.


53-55: LGTM: JAX array conversion utilities properly imported.

The imports align with the pattern used for other backends (PyTorch, TensorFlow).


345-400: LGTM: Comprehensive setup with proper coordinate preprocessing.

The setup correctly:

  1. Initializes base coordinates and types
  2. Normalizes coordinates
  3. Extends coordinates with ghost atoms
  4. Builds neighbor lists

437-463: LGTM: Proper extraction of virial-related outputs for each backend.

The extract_ret method correctly handles:

  • Extended virial for PyTorch
  • Energy derivative for JAX
  • Consistent shape normalization via ravel

404-435: LGTM: Backend-specific implementations properly handle virial calculations.

Each backend's implementation correctly:

  • Converts tensors to/from the backend's format
  • Sets do_atomic_virial=True for virial calculation
  • Handles the extended coordinates and mapping
✅ Verification successful

Backend implementations correctly handle virial calculations across the codebase

The verification confirms that:

  • All backend implementations (dp, pt, jax) consistently set do_atomic_virial=True
  • The virial calculations are thoroughly tested across different test files:
    • Unit tests in test_dp_model.py verify virial consistency between forward and lower-level implementations
    • Integration tests in test_io.py validate virial calculations across different backends
    • Spin-aware virial calculations are tested in test_ener_spin_model.py
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Verify that virial calculation is consistent across backends
# Look for other virial-related tests
rg -l "do_atomic_virial.*=.*True" "source/tests"

Length of output: 286


Script:

#!/bin/bash
# Check the implementation of virial calculations in other test files
rg -A 10 "do_atomic_virial.*=.*True" "source/tests" --type py

Length of output: 12666

@wanghan-iapcm wanghan-iapcm added this pull request to the merge queue Nov 5, 2024
Merged via the queue into deepmodeling:devel with commit dabedd2 Nov 5, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants