Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax/array-api): se_e2_a #4217

Merged
merged 3 commits into from
Oct 16, 2024
Merged

feat(jax/array-api): se_e2_a #4217

merged 3 commits into from
Oct 16, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Oct 15, 2024

Summary by CodeRabbit

  • New Features

    • Introduced a new class DescrptSeAArrayAPI for enhanced array compatibility.
    • Added a new class DescrptSeA integrated with the Flax library for neural network modules.
    • Improved handling of atomic types and neighbor lists for better performance and clarity.
  • Tests

    • Enhanced test suite to support additional backends and configurations, including JAX and strict array API.
    • Added new evaluation methods for testing across different frameworks.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz njzjz requested review from wanghan-iapcm and iProzd and removed request for wanghan-iapcm October 15, 2024 21:05
Copy link
Contributor

coderabbitai bot commented Oct 15, 2024

📝 Walkthrough
📝 Walkthrough

Walkthrough

The changes in this pull request involve modifications to the DescrptSeA class and the introduction of new classes to enhance compatibility with array-like structures using the array_api_compat library. Key updates include improvements to the constructor and methods for handling numpy arrays, as well as the addition of a new class DescrptSeAArrayAPI. New test files have also been created to support various backends and configurations, ensuring the flexibility and robustness of the descriptor classes.

Changes

File Path Change Summary
deepmd/dpmodel/descriptor/se_e2_a.py - Updated DescrptSeA constructor and methods for array compatibility.
- Added DescrptSeAArrayAPI class extending DescrptSeA.
deepmd/jax/descriptor/se_e2_a.py - Introduced DescrptSeA class extending DescrptSeAArrayAPI, customized __setattr__ method for Flax compatibility.
source/tests/consistent/descriptor/test_se_e2_a.py - Enhanced test file to support additional backends, added properties and methods for evaluating descriptors.
deepmd/dpmodel/utils/nlist.py - Modified nlist_distinguish_types function for improved handling of neighbor lists and atomic types.

Possibly related PRs

  • fix: correct exclude_types in descriptors #3841: The changes in deepmd/dpmodel/descriptor/se_e2_a.py include modifications related to exclude_types, which aligns with the focus of this PR on correcting exclude_types in descriptors.
  • feat(jax/array-api): energy fitting #4204: The integration of array_api_compat in the GeneralFitting class and its methods relates to the enhancements made in the DescrptSeA class for compatibility with array-like structures, indicating a shared objective of improving compatibility with different array backends.

Suggested labels

Docs, Examples

Suggested reviewers

  • iProzd
  • wanghan-iapcm
  • anyangml

📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 21393f4 and 0d6d760.

📒 Files selected for processing (2)
  • deepmd/dpmodel/utils/nlist.py (1 hunks)
  • source/tests/consistent/descriptor/test_se_e2_a.py (4 hunks)
🧰 Additional context used
🔇 Additional comments (11)
source/tests/consistent/descriptor/test_se_e2_a.py (6)

15-16: LGTM: New imports added correctly.

The new imports for INSTALLED_ARRAY_API_STRICT and INSTALLED_JAX are correctly added and follow the existing import pattern in the file.


38-47: LGTM: Conditional imports added correctly.

The new conditional imports for JAX and Array API Strict follow the existing pattern in the file. The variable name DescrptSeAArrayAPIStrict is used consistently, resolving the issue mentioned in a previous review comment.


114-134: LGTM: Skip properties added correctly.

The new skip_jax and skip_array_api_strict properties are correctly implemented and follow the existing pattern for other backends. The skip_array_api_strict property now correctly uses INSTALLED_ARRAY_API_STRICT, resolving the issue mentioned in a previous review comment.


139-140: LGTM: Class attributes added correctly.

The new class attributes jax_class and array_api_strict_class are correctly added and follow the existing pattern for other backends.


217-233: LGTM: Evaluation methods added correctly.

The new eval_jax and eval_array_api_strict methods are correctly implemented and follow the existing pattern for other backends. They use the appropriate parameters and call the correct evaluation functions.


Line range hint 1-234: Summary: Excellent implementation of JAX and Array API Strict support.

The changes in this file consistently extend the existing test infrastructure to support JAX and Array API Strict backends. All additions follow the established patterns and conventions in the file. Previous issues mentioned in past review comments have been successfully addressed. These changes enhance the test coverage and flexibility of the TestSeA class, allowing it to accommodate more computational frameworks.

deepmd/dpmodel/utils/nlist.py (5)

172-173: Updated type casting and sorting

The type casting of pick_mask has been changed to use xp.astype, and the sorting now explicitly uses the stable parameter. These changes improve clarity and ensure consistent behavior.

pick_mask = xp.astype(tnlist == ii, xp.int32)
sorted_indices = xp.argsort(-pick_mask, stable=True, axis=-1)

These changes are improvements to the code clarity and consistency. The explicit use of stable=True in xp.argsort ensures consistent sorting behavior across different array_api implementations.


166-179: Overall assessment of changes in nlist_distinguish_types

The refactoring of the nlist_distinguish_types function appears to be a well-executed improvement. The changes simplify the code, potentially improve performance, and align better with array operation best practices. Key points:

  1. The 3D tiling of tmp_atype may offer better alignment with input shapes.
  2. Simplified neighbor list handling using xp.where and xp_take_along_axis improves readability.
  3. Explicit use of stable=True in sorting ensures consistent behavior.
  4. Updated conditional logic for filling the neighbor list is more concise.

While these changes seem positive, it's crucial to ensure they maintain the original functionality. Please run the suggested verification scripts and thoroughly test the function with various inputs to confirm its correctness.

To ensure overall correctness, please run comprehensive tests on the nlist_distinguish_types function with various input scenarios, including edge cases. Compare the results with the previous implementation to verify that the functionality remains unchanged.


168-170: Simplified neighbor list handling logic

The handling of the neighbor list has been refactored to use xp.where instead of the previous copy and mask approach. This change simplifies the logic and potentially improves performance. However, we should ensure that the new implementation maintains the same functionality.

tnlist_0 = xp.where(mask, xp.zeros_like(nlist), nlist)
tnlist = xp_take_along_axis(tmp_atype, tnlist_0, axis=2)
tnlist = xp.where(mask, xp.full_like(tnlist, -1), tnlist)

To verify this change, you can run the following script to check for any potential issues:

#!/bin/bash
# Description: Verify the new neighbor list handling logic

# Test: Check for usage of xp.where and xp_take_along_axis
rg --type python 'xp\.where.*xp\.zeros_like.*nlist' deepmd/dpmodel/utils/nlist.py
rg --type python 'xp_take_along_axis.*tmp_atype.*tnlist_0' deepmd/dpmodel/utils/nlist.py
rg --type python 'xp\.where.*xp\.full_like.*tnlist.*-1' deepmd/dpmodel/utils/nlist.py

# Test: Check for any remaining occurrences of the old logic
rg --type python 'tnlist.*=.*nlist\.copy\(\)' deepmd/dpmodel/utils/nlist.py
rg --type python 'tnlist\[mask\].*=.*-1' deepmd/dpmodel/utils/nlist.py

176-178: Modified conditional logic for filling the neighbor list

The conditional logic for filling the neighbor list has been updated to use xp.where with a boolean mask. This change simplifies the code and potentially improves readability.

inlist = xp.where(
    ~xp.astype(pick_mask_sorted, xp.bool), xp.full_like(inlist, -1), inlist
)

To ensure this change maintains the correct functionality, you can run the following script:

#!/bin/bash
# Description: Verify the new conditional logic for filling the neighbor list

# Test: Check for the new xp.where usage
rg --type python 'xp\.where.*~xp\.astype.*pick_mask_sorted.*xp\.bool.*xp\.full_like.*inlist.*-1.*inlist' deepmd/dpmodel/utils/nlist.py

# Test: Check for any remaining occurrences of the old logic
rg --type python 'inlist\[pick_mask_sorted.*==.*0\].*=.*-1' deepmd/dpmodel/utils/nlist.py

166-166: Verify the correctness of the 3D tiling for tmp_atype

The initialization of tmp_atype has been changed from 2D to 3D tiling. While this change may be intentional to match the shape of the input atype, please ensure that it doesn't affect the functionality of the subsequent operations.

tmp_atype = xp.tile(atype[:, None, :], (1, nloc, 1))

To verify this change, you can run the following script to check the shape and content of tmp_atype:

✅ Verification successful

Verification Successful: The 3D tiling of tmp_atype is correctly implemented and functions as intended within the nlist_distinguish_types function.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the shape and content of tmp_atype

# Test: Check for usage of xp.tile with 3D shape
rg --type python 'xp\.tile.*\(1,\s*nloc,\s*1\)' deepmd/dpmodel/utils/nlist.py

# Test: Check for any other occurrences of tmp_atype
rg --type python 'tmp_atype' deepmd/dpmodel/utils/nlist.py

Length of output: 605


Script:

#!/bin/bash
# Description: Verify the shape and content of tmp_atype

# Test: Check for usage of xp.tile with 3D shape
rg --glob "*.py" 'xp\.tile\s*\(atype\[:,\s*None,\s*:\],\s*\(1,\s*nloc,\s*1\)\)' deepmd/dpmodel/utils/nlist.py

# Test: Check for any other occurrences of tmp_atype
rg --glob "*.py" 'tmp_atype' deepmd/dpmodel/utils/nlist.py

Length of output: 345


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Outside diff range and nitpick comments (4)
source/tests/array_api_strict/descriptor/se_e2_a.py (2)

23-25: Consider using copy.deepcopy instead of serializing and deserializing

In the assignment of the embeddings attribute, you serialize value and then immediately deserialize it using NetworkCollection.deserialize(value.serialize()). If the intention is to create a deep copy of value, using copy.deepcopy(value) would be more direct and efficient.

Apply this change:

+ from copy import deepcopy

elif name in {"embeddings"}:
    if value is not None:
-       value = NetworkCollection.deserialize(value.serialize())
+       value = deepcopy(value)

26-28: Clarify the purpose of the env_mat attribute assignment

When name == "env_mat", the code executes a pass statement, indicating that no action is taken upon assignment. While there is a comment # env_mat doesn't store any value, consider expanding this comment to provide more context on why no value is stored for env_mat, enhancing code readability and maintainability.

deepmd/jax/descriptor/se_e2_a.py (2)

24-26: Optimize 'embeddings' assignment to avoid unnecessary serialization

When assigning to embeddings, the code serializes and then deserializes value. This could introduce unnecessary overhead if value is already in the correct format. Consider checking if serialization is necessary or if value can be assigned directly to improve efficiency.


27-29: Clarify handling of 'env_mat' attribute assignment

In the __setattr__ method, when the attribute name is "env_mat", the code does nothing (pass). If the intent is to prevent env_mat from being set or stored, consider explicitly documenting this behavior or using a more explicit mechanism to prevent unintended assignments.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 16172e6 and 21393f4.

📒 Files selected for processing (5)
  • deepmd/dpmodel/descriptor/se_e2_a.py (5 hunks)
  • deepmd/dpmodel/utils/nlist.py (1 hunks)
  • deepmd/jax/descriptor/se_e2_a.py (1 hunks)
  • source/tests/array_api_strict/descriptor/se_e2_a.py (1 hunks)
  • source/tests/consistent/descriptor/test_se_e2_a.py (4 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/dpmodel/utils/nlist.py

171-171: Local variable snsel is assigned to but never used

Remove assignment to unused variable snsel

(F841)

🔇 Additional comments (17)
source/tests/array_api_strict/descriptor/se_e2_a.py (1)

19-32: LGTM!

The implementation of the DescrptSeA class and the overridden __setattr__ method appropriately handle custom attribute assignments. The usage of to_array_api_strict_array, NetworkCollection, and PairExcludeMask appears correct and in line with best practices.

deepmd/jax/descriptor/se_e2_a.py (1)

21-33: Override of __setattr__ is well-structured and maintains class integrity

The custom __setattr__ method effectively handles specific attributes with necessary transformations while preserving the base class behavior through super().__setattr__(name, value). This ensures controlled attribute assignment and maintains the integrity of the class.

source/tests/consistent/descriptor/test_se_e2_a.py (4)

15-16: Imports of INSTALLED_ARRAY_API_STRICT and INSTALLED_JAX are correctly added

The inclusion of INSTALLED_ARRAY_API_STRICT and INSTALLED_JAX in the imports enhances the conditional handling of different backends.


139-140: Class attributes for JAX and Array API Strict are appropriately assigned

The jax_class and array_api_strict_class attributes are correctly set to their respective descriptor classes.


226-234: Ensure eval_array_api_strict_descriptor method exists

The method eval_array_api_strict invokes self.eval_array_api_strict_descriptor, but its definition isn't present in the provided code. Verify that this method is implemented to avoid runtime errors.

Run the following script to confirm the existence of eval_array_api_strict_descriptor:

#!/bin/bash
# Description: Verify if 'eval_array_api_strict_descriptor' method is defined in the codebase.

# Test: Search for the definition of 'eval_array_api_strict_descriptor'. Expect: At least one match indicating the method definition.
rg --type python -A 2 -B 2 'def eval_array_api_strict_descriptor\('

217-225: Ensure eval_jax_descriptor method exists

The method eval_jax calls self.eval_jax_descriptor, but it's not shown in the provided code whether eval_jax_descriptor is defined. Verify that this method exists to prevent AttributeError.

Run the following script to confirm the existence of eval_jax_descriptor:

✅ Verification successful

eval_jax_descriptor method is defined in source/tests/consistent/descriptor/common.py.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify if 'eval_jax_descriptor' method is defined in the codebase.

# Test: Search for the definition of 'eval_jax_descriptor'. Expect: At least one match indicating the method definition.
rg --type python -A 2 -B 2 'def eval_jax_descriptor\('

Length of output: 100


Script:

#!/bin/bash
# Description: Verify if 'eval_jax_descriptor' method is defined in the codebase.

# Test: Search for the definition of 'eval_jax_descriptor'. Expect: At least one match indicating the method definition.
rg -g "*.py" -A 2 -B 2 'def eval_jax_descriptor\('

Length of output: 413

deepmd/dpmodel/utils/nlist.py (1)

166-180: Changes enhance clarity and efficiency in neighbor list handling.

The modifications to the nlist_distinguish_types function improve the handling of atomic types and neighbor lists. The use of xp.where and xp_take_along_axis simplifies the logic and enhances code readability.

🧰 Tools
🪛 Ruff

171-171: Local variable snsel is assigned to but never used

Remove assignment to unused variable snsel

(F841)

deepmd/dpmodel/descriptor/se_e2_a.py (10)

10-10: Ensure array_api_compat is included in dependencies

The import statement for array_api_compat is added. Please verify that array_api_compat is installed in the environment and included in your project's dependencies, such as in requirements.txt or setup.py, to prevent import issues.


18-20: Import of to_numpy_array is appropriate

The import of to_numpy_array from deepmd.dpmodel.common is necessary for serialization purposes later in the code.


193-201: Initialization of embeddings is correctly updated

The modification initializes the embeddings using NetworkCollection with appropriate dimensions based on self.type_one_side. The loop correctly iterates over the embedding indices, and each embedding is instantiated with the given parameters.


209-219: Proper assignment and initialization of class variables

The assignments to self.embeddings, self.env_mat, and other class variables like self.nnei, self.davg, self.dstd, and self.sel_cumsum are correctly implemented. The use of .item() after np.sum(self.sel) ensures that self.nnei is a scalar, which is appropriate.


330-332: Utilization of array_api_compat for array operations

The cal_g method now uses array_api_compat to obtain the array namespace xp, enhancing compatibility with different array backends. The reshaping of ss using xp.reshape ensures that the code is compatible with the selected array API.


454-455: Serialization uses consistent data types

Converting self.davg and self.dstd to numpy arrays using to_numpy_array ensures consistent data types during serialization, which is important for data integrity when saving and loading models.


509-591: Addition of DescrptSeAArrayAPI class enhances array compatibility

The new class DescrptSeAArrayAPI extends DescrptSeA and overrides the call method to utilize the array API provided by array_api_compat. This includes:

  • Checking self.type_one_side and raising NotImplementedError if it's False, which correctly reflects the current limitations.
  • Deleting the mapping parameter as it's unused.
  • Using xp for array operations, ensuring compatibility with different array libraries.
  • Replacing np.einsum with equivalent operations using xp.sum and broadcasting, which can offer performance benefits and compatibility.

546-549: Informative error message for unsupported configuration

The check for self.type_one_side and the subsequent NotImplementedError provide a clear indication that type_one_side == False is not supported in DescrptSeAArrayAPI. This helps users understand the limitations of the new class.


551-551: Unused parameter mapping is appropriately handled

The deletion of the unused parameter mapping with del mapping prevents potential confusion and indicates that it is intentionally not used in this method.


579-587: Optimized array operations

The replacement of xp.einsum with explicit sum and multiplication operations:

  • Line 579: gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1)
  • Line 587: grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)

These changes improve compatibility with array APIs that may not support einsum and can lead to performance improvements.

deepmd/jax/descriptor/se_e2_a.py Show resolved Hide resolved
source/tests/consistent/descriptor/test_se_e2_a.py Outdated Show resolved Hide resolved
source/tests/consistent/descriptor/test_se_e2_a.py Outdated Show resolved Hide resolved
deepmd/dpmodel/utils/nlist.py Outdated Show resolved Hide resolved
Copy link

codecov bot commented Oct 15, 2024

Codecov Report

Attention: Patch coverage is 98.48485% with 1 line in your changes missing coverage. Please review.

Project coverage is 83.52%. Comparing base (5c092e6) to head (0d6d760).
Report is 2 commits behind head on devel.

Files with missing lines Patch % Lines
deepmd/dpmodel/descriptor/se_e2_a.py 97.50% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##            devel    #4217      +/-   ##
==========================================
+ Coverage   83.50%   83.52%   +0.01%     
==========================================
  Files         541      542       +1     
  Lines       52488    52538      +50     
  Branches     3047     3043       -4     
==========================================
+ Hits        43831    43882      +51     
  Misses       7709     7709              
+ Partials      948      947       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

njzjz and others added 2 commits October 15, 2024 18:12
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants