-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): build nlist in the SavedModel & fix nopbc for StableHLO and SavedModel #4318
Conversation
📝 Walkthrough📝 WalkthroughWalkthroughThe pull request introduces several enhancements across multiple files, primarily focusing on the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant DeepEval
participant HLO
participant Model
User->>DeepEval: Initialize with model file
DeepEval->>HLO: Create HLO with parameters
HLO->>Model: Prepare model for evaluation
User->>DeepEval: Evaluate model with coordinates
DeepEval->>Model: Call model with parameters
Model-->>DeepEval: Return predictions
DeepEval-->>User: Return evaluation results
Possibly related PRs
Suggested labels
Suggested reviewers
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Outside diff range and nitpick comments (24)
deepmd/jax/jax2tf/region.py (3)
29-48
: Extract common shape validation logic.Similar to
phys2inter
, this function would benefit from shape validation. Consider extracting the validation into a shared utility function to avoid code duplication.+def _validate_coord_cell_shapes(coord: tnp.ndarray, cell: tnp.ndarray, name: str): + coord_shape = tf.shape(coord) + cell_shape = tf.shape(cell) + tf.debugging.assert_equal( + coord_shape[-1], + cell_shape[-1], + message=f"{name}: Last dimension of coord and cell must match" + ) + tf.debugging.assert_equal( + cell_shape[-2:], + tf.constant([3, 3]), + message=f"{name}: Cell must be a 3x3 matrix" + ) def inter2phys(...): + _validate_coord_cell_shapes(coord, cell, "inter2phys") return tnp.matmul(coord, cell)
51-72
: Optimize matrix operations for better performance.The current implementation performs two matrix multiplications by calling
phys2inter
andinter2phys
. This could be optimized by caching the inverse cell matrix and reusing it.def normalize_coord( coord: tnp.ndarray, cell: tnp.ndarray, ) -> tnp.ndarray: + rec_cell = tf.linalg.inv(cell) - icoord = phys2inter(coord, cell) + icoord = tnp.matmul(coord, rec_cell) icoord = tnp.remainder(icoord, 1.0) - return inter2phys(icoord, cell) + return tnp.matmul(icoord, cell)
75-93
: Add validation for cell tensor in to_face_distance.The function reshapes the input tensor but doesn't validate that the final dimensions are correct.
def to_face_distance( cell: tnp.ndarray, ) -> tnp.ndarray: cshape = tf.shape(cell) + tf.debugging.assert_equal( + cshape[-2:], + tf.constant([3, 3]), + message="Cell must be a 3x3 matrix" + ) dist = b_to_face_distance(tnp.reshape(cell, [-1, 3, 3])) return tnp.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0))deepmd/jax/utils/serialization.py (2)
72-76
: Simplify ghost atoms logic using ternary operatorThe logic is correct, but it can be more concise using a ternary operator.
- if has_ghost_atoms: - nghost_ = nghost - else: - nghost_ = 0 + nghost_ = nghost if has_ghost_atoms else 0🧰 Tools
🪛 Ruff
72-75: Use ternary operator
nghost_ = nghost if has_ghost_atoms else 0
instead ofif
-else
-blockReplace
if
-else
-block withnghost_ = nghost if has_ghost_atoms else 0
(SIM108)
Tests needed for new serialization configurations
The implementation introduces new serialization configurations for ghost atoms and atomic virial combinations, but no corresponding test coverage was found. Tests should be added to verify:
- Serialization with ghost atoms enabled/disabled
- Serialization with atomic virial calculation enabled/disabled
- Combined scenarios of both configurations
Specific areas needing test coverage:
exported_whether_do_atomic_virial()
function behavior- Serialization of all four combinations stored in
@variables
- Proper handling of ghost atoms in the exported models
🔗 Analysis chain
Line range hint
56-119
: Verify test coverage for all configurationsWhile the implementation looks solid, we should ensure proper test coverage for all combinations of:
- Atomic virial calculation (True/False)
- Ghost atoms presence (True/False)
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for test files covering these scenarios rg -l "exported_whether_do_atomic_virial" --type python | grep "test"Length of output: 98
Script:
#!/bin/bash # Search for test files covering these scenarios without file type filter rg -l "exported_whether_do_atomic_virial" | grep "test"Length of output: 53
Script:
#!/bin/bash # Search for test files that might cover the serialization functionality rg -l "stablehlo.*ghost|stablehlo.*atomic_virial" | grep "test"Length of output: 63
Script:
#!/bin/bash # Look for test files in the deepmd directory fd "test.*\.py$" deepmd/Length of output: 83
Script:
#!/bin/bash # Search for test files in the jax directory specifically fd "test.*\.py$" deepmd/jax/ # Look for any serialization tests rg -l "serialize|serializ" deepmd/jax/ | grep "test" # Check content of test files for related functionality rg "atomic_virial|ghost" deepmd/jax/ -g "test*.py" -C 2Length of output: 141
source/tests/consistent/io/test_io.py (2)
186-202
: Add documentation for nopbc test cases.While the implementation is correct, it would be helpful to add a comment explaining that passing
None
as the box parameter indicates no periodic boundary conditions. This would make the test's intention more explicit.+ # Test evaluation without periodic boundary conditions by passing None as box ret = deep_eval.eval( self.coords, None, self.atype, fparam=fparam, aparam=aparam, )
The review comment's suggestion for additional test coverage is valid.
While the codebase has logic to handle zero ghost atoms in the JAX serialization code, there are no explicit test cases verifying this scenario. The existing tests in
test_virtual_type.py
only cover non-zero ghost atoms (nghost=10). Adding test cases for zero ghost atoms would help ensure the polymorphic_shape fix works correctly for this edge case.
- Add test cases in
source/tests/consistent/io/test_io.py
to verify model behavior with zero ghost atoms- Consider adding similar test coverage in
source/tests/tf/test_virtual_type.py
🔗 Analysis chain
Line range hint
145-218
: Consider adding test cases for zero ghost atoms.While the test coverage for nopbc is comprehensive, consider adding specific test cases to verify the fix for zero ghost atoms mentioned in the PR objectives. This would ensure that the polymorphic_shape issue is properly addressed.
Based on the search results, I can see that there are existing tests for ghost atoms in
source/tests/tf/test_virtual_type.py
andsource/tests/tf/common.py
that cover cases with non-zero ghost atoms (e.g.,nghost = 10
andnghost = 4
). The codebase also has specific handling for zero ghost atoms in JAX serialization (deepmd/jax/jax2tf/serialization.py
anddeepmd/jax/utils/serialization.py
).Let me verify if there are specific test cases for zero ghost atoms.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check if there are any existing tests for zero ghost atoms # Test: Search for test cases involving ghost atoms rg -i "ghost.*atom|nghost" --type pyLength of output: 8281
Script:
#!/bin/bash # Check for test cases that explicitly test zero ghost atoms rg -B2 -A2 "nghost.*=.*0|has_ghost_atoms.*=.*false" --type pyLength of output: 1487
deepmd/jax/jax2tf/tfmodel.py (1)
144-147
: Consider using a ternary operator for better readabilityThe call selection can be more concise.
- if do_atomic_virial: - call = self._call_atomic_virial - else: - call = self._call + call = self._call_atomic_virial if do_atomic_virial else self._call🧰 Tools
🪛 Ruff
144-147: Use ternary operator
call = self._call_atomic_virial if do_atomic_virial else self._call
instead ofif
-else
-blockReplace
if
-else
-block withcall = self._call_atomic_virial if do_atomic_virial else self._call
(SIM108)
deepmd/jax/jax2tf/make_model.py (4)
50-72
: Document all function parameters in the docstringThe parameters
call_lower
,rcut
,sel
,mixed_types
, andmodel_output_def
are not documented in the docstring. Please add descriptions for these parameters to enhance clarity and maintainability.Example addition to the docstring:
""" Return model prediction from lower interface. Parameters ---------- + call_lower + A callable that accepts extended coordinates and other parameters, returning model predictions as a dictionary. + rcut + Cut-off radius for neighbor list construction. + sel + List of integers specifying selected atom types. + mixed_types + Boolean indicating whether to treat atom types as mixed. + model_output_def + Definition of the model outputs. coord The coordinates of the atoms. shape: nf x (nloc x 3)
75-76
: Improve variable naming for better readabilityUsing aliases like
cc
,bb
,fp
,ap
forcoord
,box
,fparam
,aparam
reduces code readability. It's clearer to use the original variable names throughout the function.Apply this diff to enhance readability:
- cc, bb, fp, ap = coord, box, fparam, aparam - del coord, box, fparam, aparam + # Use original variable names throughout the function
75-76
: Remove unnecessary variable deletionThe
del
statements forcoord
,box
,fparam
, andaparam
might be unnecessary. Python's garbage collector handles memory management, and removing these lines can simplify the code.Apply this diff to remove unnecessary code:
cc, bb, fp, ap = coord, box, fparam, aparam - del coord, box, fparam, aparam
93-93
: Clarify the logic ofdistinguish_types
parameterThe expression
distinguish_types=not mixed_types
may be confusing. Consider renamingmixed_types
toignore_atom_types
or adding a comment to clarify thatdistinguish_types
is the inverse ofmixed_types
.deepmd/jax/jax2tf/transform_output.py (3)
54-54
: Remove unused variablemldims
.The variable
mldims
is assigned but never used, which could lead to confusion. Please remove it to keep the code clean.🧰 Tools
🪛 Ruff
54-54: Local variable
mldims
is assigned to but never usedRemove assignment to unused variable
mldims
(F841)
🪛 GitHub Check: CodeQL
[notice] 54-54: Unused local variable
Variable mldims is not used.
83-83
: Replaceassert
with explicit exception handling.Using
assert
statements for control flow may not be ideal in production code because they can be disabled with optimization flags. Consider raising a specific exception with a clear error message to handle cases wherevdef.r_differentiable
isFalse
whilevdef.c_differentiable
isTrue
.
41-44
: Enhance the docstring forcommunicate_extended_output
.The current docstring is brief. Providing detailed descriptions of the parameters, return values, and any important computational steps would improve maintainability and clarity for other developers.
deepmd/jax/jax2tf/nlist.py (6)
2-4
: Simplify import statement by removing unnecessary parenthesesThe parentheses around
Union
are unnecessary when importing a single item.Apply this diff to simplify the import:
-from typing import ( - Union, -) +from typing import Union
9-11
: Simplify import statement by removing unnecessary parenthesesThe parentheses around
to_face_distance
are unnecessary when importing a single item.Apply this diff to simplify the import:
-from .region import ( - to_face_distance, -) +from .region import to_face_distance
14-14
: Remove reference to 'chatgpt' in comment for professionalismIt's better to avoid mentioning specific tools like 'chatgpt' in code comments for professionalism and maintainability.
Apply this diff to update the comment:
-## translated from torch implementation by chatgpt +## Translated from torch implementation
28-28
: Correct typo in docstring: 'exptended' should be 'extended'There's a typo in the parameter description for
coord
.Apply this diff to fix the typo:
- exptended coordinates of shape [batch_size, nall x 3] + extended coordinates of shape [batch_size, nall x 3]
154-154
: Correct typo in docstring: 'peridoc' should be 'periodic'There's a typo in the function description.
Apply this diff to fix the typo:
- """Extend the coordinates of the atoms by appending peridoc images. + """Extend the coordinates of the atoms by appending periodic images.
147-147
: Remove reference to 'chatgpt' in comment for professionalismAs before, it's better to avoid mentioning specific tools like 'chatgpt' in code comments.
Apply this diff to update the comment:
-## translated from torch implementation by chatgpt +## Translated from torch implementationdeepmd/jax/model/hlo.py (2)
183-192
: Consider adding comments to clarify the conditional logic incall_lower
.The conditional logic in
call_lower
now depends on the dimensions ofextended_coord
andnlist
, as well as thedo_atomic_virial
flag. Adding comments to explain the reasoning behind these conditions can improve code readability and maintainability.For example:
def call_lower( self, extended_coord: jnp.ndarray, extended_atype: jnp.ndarray, nlist: jnp.ndarray, mapping: Optional[jnp.ndarray] = None, fparam: Optional[jnp.ndarray] = None, aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): + # Determine if ghost atoms are present based on the shape of extended_coord and nlist if extended_coord.shape[1] > nlist.shape[1]: + # Case with ghost atoms if do_atomic_virial: call_lower = self._call_lower_atomic_virial else: call_lower = self._call_lower else: + # Case without ghost atoms if do_atomic_virial: call_lower = self._call_lower_atomic_virial_no_ghost else: call_lower = self._call_lower_no_ghost return call_lower( extended_coord, extended_atype,
183-192
: Ensure unit tests cover all branches of the new conditional logic.With the updated
call_lower
method introducing additional branches, it's important to verify that all scenarios are properly tested. This includes cases:
- With and without ghost atoms (i.e., when
extended_coord.shape[1] > nlist.shape[1]
and when it's not).- With
do_atomic_virial
set to bothTrue
andFalse
.deepmd/jax/jax2tf/serialization.py (1)
55-58
: Simplify assignment with a ternary operatorConsider replacing the
if
-else
block with a ternary operator for conciseness and readability.Apply this diff to implement the suggestion:
# nghost >= 1 is assumed if there is # other workaround does not work, such as # nall; nloc + nghost - 1 - if has_ghost_atoms: - nghost = "nghost" - else: - nghost = "0" + nghost = "nghost" if has_ghost_atoms else "0"🧰 Tools
🪛 Ruff
55-58: Use ternary operator
nghost = "nghost" if has_ghost_atoms else "0"
instead ofif
-else
-blockReplace
if
-else
-block withnghost = "nghost" if has_ghost_atoms else "0"
(SIM108)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (11)
deepmd/jax/infer/deep_eval.py
(1 hunks)deepmd/jax/jax2tf/__init__.py
(2 hunks)deepmd/jax/jax2tf/make_model.py
(1 hunks)deepmd/jax/jax2tf/nlist.py
(1 hunks)deepmd/jax/jax2tf/region.py
(1 hunks)deepmd/jax/jax2tf/serialization.py
(5 hunks)deepmd/jax/jax2tf/tfmodel.py
(2 hunks)deepmd/jax/jax2tf/transform_output.py
(1 hunks)deepmd/jax/model/hlo.py
(3 hunks)deepmd/jax/utils/serialization.py
(3 hunks)source/tests/consistent/io/test_io.py
(2 hunks)
🧰 Additional context used
🪛 Ruff
deepmd/jax/jax2tf/serialization.py
55-58: Use ternary operator nghost = "nghost" if has_ghost_atoms else "0"
instead of if
-else
-block
Replace if
-else
-block with nghost = "nghost" if has_ghost_atoms else "0"
(SIM108)
deepmd/jax/jax2tf/tfmodel.py
144-147: Use ternary operator call = self._call_atomic_virial if do_atomic_virial else self._call
instead of if
-else
-block
Replace if
-else
-block with call = self._call_atomic_virial if do_atomic_virial else self._call
(SIM108)
deepmd/jax/jax2tf/transform_output.py
54-54: Local variable mldims
is assigned to but never used
Remove assignment to unused variable mldims
(F841)
deepmd/jax/utils/serialization.py
72-75: Use ternary operator nghost_ = nghost if has_ghost_atoms else 0
instead of if
-else
-block
Replace if
-else
-block with nghost_ = nghost if has_ghost_atoms else 0
(SIM108)
🪛 GitHub Check: CodeQL
deepmd/jax/jax2tf/transform_output.py
[notice] 54-54: Unused local variable
Variable mldims is not used.
🔇 Additional comments (15)
deepmd/jax/jax2tf/__init__.py (2)
14-14
: Verify the impact of enabling numpy behavior globally.
The call to tnp.experimental_enable_numpy_behavior()
affects the entire module. While necessary for the neighbor list calculations, we should ensure it doesn't interfere with existing functionality.
Let's check for any other numpy operations that might be affected:
3-3
: LGTM, but note the experimental nature of the import.
The import of tensorflow.experimental.numpy
is necessary for the neighbor list calculations, but be aware that it's from TensorFlow's experimental namespace.
Let's check if there are any stability notices or deprecation warnings:
✅ Verification successful
Import of tensorflow.experimental.numpy is consistently used across the module
The import of tensorflow.experimental.numpy
as tnp
is well-integrated throughout the jax2tf module:
- Used consistently in 6 files within the jax2tf package
- Primarily used for array operations in neighbor list calculations, coordinate transformations, and model outputs
- The
enable_numpy_behavior()
is properly configured in__init__.py
The experimental nature of the import is acceptable here as:
- It's used consistently across the entire module
- The functionality is essential for numpy-like operations in TensorFlow
- The code follows TensorFlow's recommended patterns for numpy compatibility
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check TensorFlow documentation or issues related to experimental.numpy stability
gh search issues "tensorflow experimental.numpy stability" --repo tensorflow/tensorflow --json title,body --jq '.[] | select(.title | contains("experimental.numpy") or contains("tnp")) | {title: .title, body: .body}'
Length of output: 218
Script:
#!/bin/bash
# Let's check the actual usage of tnp in the codebase to understand its impact
rg "import tensorflow.experimental.numpy as tnp" -A 3
# Also check for any direct usage of tnp
rg "\btnp\." -B 1 -A 1
# Look for any configuration or setup related to numpy behavior
rg "enable_numpy_behavior|set_numpy_enabled"
Length of output: 17217
deepmd/jax/jax2tf/region.py (1)
2-3
: Consider the stability implications of using experimental NumPy.
The code uses tensorflow.experimental.numpy
which, being experimental, may have breaking changes in future releases. Consider:
- Adding version constraints for TensorFlow
- Documenting the minimum required TensorFlow version
- Adding error handling for potential API changes
deepmd/jax/utils/serialization.py (4)
56-58
: LGTM: Function signature properly updated
The addition of has_ghost_atoms
parameter aligns with the PR's objective to handle cases where the number of ghost atoms is zero. The type hints are correctly specified.
79-83
: LGTM: Shape definitions properly handle ghost atoms
The shape definitions correctly use nghost_
to handle both ghost and no-ghost cases, ensuring proper tensor dimensions for coordinates, atom types, neighbor lists, and mapping arrays.
92-109
: LGTM: Comprehensive export configurations
The implementation properly handles all combinations of atomic virial calculation and ghost atoms presence, which directly addresses the PR's objective. The naming convention is clear and consistent.
116-119
: LGTM: Variable storage properly implemented
The new variables for no-ghost scenarios are correctly added to the data dictionary with appropriate naming and proper use of np.void for serialized data storage.
source/tests/consistent/io/test_io.py (2)
145-145
: LGTM: Clear and consistent variable declaration.
The addition of rets_nopbc
list follows the same pattern as the existing rets
list and aligns with the PR's objective to test no periodic boundary conditions scenarios.
210-218
: LGTM: Well-implemented verification logic with improved error messages.
The assertions for nopbc results:
- Maintain consistency with existing verification patterns
- Include helpful backend information in error messages
- Properly handle unsupported cases with NaN checks
- Use appropriate tolerance values
deepmd/jax/jax2tf/tfmodel.py (2)
55-56
: LGTM: New wrapper methods for atomic virial calculations
The addition of _call
and _call_atomic_virial
attributes follows the existing pattern and maintains consistency with the codebase structure.
149-165
: LGTM: Robust parameter initialization
The initialization of None parameters with empty arrays is well-implemented:
- Preserves batch dimensions from input coordinates
- Uses correct data type (float64)
- Maintains appropriate tensor shapes for each parameter
- Prevents TensorFlow errors about None values
This change improves the robustness of the model wrapper.
deepmd/jax/infer/deep_eval.py (1)
100-105
: LGTM! The changes align with PR objectives.
The addition of stablehlo_no_ghost
and stablehlo_atomic_virial_no_ghost
parameters properly addresses the ghost atoms handling issue mentioned in the PR objectives.
Let's verify that these new variables are present in HLO files:
✅ Verification successful
The new variables are properly integrated into the codebase
The verification confirms that:
- The variables are properly defined in the
HLO
class constructor (deepmd/jax/model/hlo.py
) - They are correctly initialized and used with
jax_export.deserialize()
in the HLO class - The serialization logic in
deepmd/jax/utils/serialization.py
properly stores these variables in the model data
The implementation in deep_eval.py
correctly reads these variables from the model data, maintaining consistency with how they are stored during serialization.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check for the presence of new variables in HLO files
# Expected: Files containing these new variables
# Search for files containing the new variable names
rg -l "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" --type py
# Search for test files that verify these variables
rg -l "test.*stablehlo.*ghost" --type py
Length of output: 202
Script:
#!/bin/bash
# Let's examine the implementation and usage of these variables in the found files
# Check hlo.py implementation
echo "=== Content from deepmd/jax/model/hlo.py ==="
rg -A 5 -B 5 "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" deepmd/jax/model/hlo.py
echo -e "\n=== Content from deepmd/jax/utils/serialization.py ==="
rg -A 5 -B 5 "stablehlo_no_ghost|stablehlo_atomic_virial_no_ghost" deepmd/jax/utils/serialization.py
# Check if there are any tests that might not directly mention "stablehlo" and "ghost"
echo -e "\n=== Related test files ==="
fd "test.*\.py" | xargs rg -l "no_ghost"
Length of output: 1891
deepmd/jax/jax2tf/make_model.py (3)
96-103
: Confirm that call_lower
is called with correct arguments
Ensure that the call_lower
function accepts the provided arguments extended_coord
, extended_atype
, nlist
, mapping
, fparam
, and aparam
. Verify that the boolean parameter do_atomic_virial
is correctly handled within call_lower
.
Run the following script to verify the signature of call_lower
:
#!/bin/bash
# Description: Verify the signature of `call_lower`.
# Expected: `call_lower` accepts the specified arguments and matches the Callable type hint.
# Since `call_lower` is passed as a parameter, search for typical usages or definitions
rg -A 5 'def call_lower' | rg 'def call_lower\('
# Alternatively, check for any type hints or comments indicating the expected parameters
rg -A 5 'call_lower' | rg 'Callable\[['
105-109
: Verify that do_atomic_virial
is correctly propagated
Ensure that the do_atomic_virial
flag is correctly used inside communicate_extended_output
, and that the function handles it appropriately.
Please run the following script to check the usage of do_atomic_virial
:
#!/bin/bash
# Description: Check how `do_atomic_virial` is used in `communicate_extended_output`.
# Expected: The flag is used to conditionally process atomic virial calculations.
# Find the definition of `communicate_extended_output` and its usage of `do_atomic_virial`
rg -A 15 'def communicate_extended_output' | rg 'do_atomic_virial'
85-86
: Ensure extend_coord_with_ghosts
handles box
being None
Passing bb
to extend_coord_with_ghosts
when bb
is None
may cause issues inside that function. Verify that extend_coord_with_ghosts
can handle a None
value for box
.
Please run the following script to check if extend_coord_with_ghosts
can handle bb
being None
:
✅ Verification successful
extend_coord_with_ghosts
safely handles None
box/cell parameter
Looking at the implementations, both the JAX and dpmodel versions properly handle the case when the cell/box parameter is None or empty:
- In JAX (jax2tf/nlist.py): Checks if cell shape is empty with
if tf.shape(cell)[-1] == 0
- In dpmodel (dpmodel/utils/nlist.py): Explicitly checks
if cell is None
In both cases, it safely returns the original coordinates without any ghost atoms when the cell/box is None.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if `extend_coord_with_ghosts` handles `box` being `None`.
# Expected: No errors or exceptions when `box` is `None`.
# Find the definition of `extend_coord_with_ghosts`
rg -A 10 'def extend_coord_with_ghosts' | grep 'box'
# Search for any checks or handling of `box` being `None` within the function
rg -A 20 'def extend_coord_with_ghosts' | rg 'if box is None|box is not None'
Length of output: 537
Script:
#!/bin/bash
# Let's try a different approach to find the implementation of extend_coord_with_ghosts
# Search for the function definition with more context
rg -A 20 "def extend_coord_with_ghosts"
# Also search in case it's defined with different spacing
rg -A 20 "def\s+extend_coord_with_ghosts"
# Try finding any usage of this function to understand the context
rg "extend_coord_with_ghosts\("
Length of output: 15278
Script:
#!/bin/bash
# Let's check the implementation in jax2tf/nlist.py since that's the one being used in make_model.py
rg -A 50 "def extend_coord_with_ghosts" deepmd/jax/jax2tf/nlist.py
# Also check if there's any handling of None cell/box in the dpmodel implementation since it has Optional typing
rg -A 50 "def extend_coord_with_ghosts" deepmd/dpmodel/utils/nlist.py
Length of output: 3600
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4318 +/- ##
==========================================
+ Coverage 84.50% 84.52% +0.01%
==========================================
Files 570 575 +5
Lines 53162 53388 +226
Branches 3059 3059
==========================================
+ Hits 44927 45124 +197
- Misses 7272 7301 +29
Partials 963 963 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
🧹 Outside diff range and nitpick comments (6)
source/tests/jax/jax2tf/__init__.py (1)
1-8
: Add missing TensorFlow configurationsBased on the AI summary, this file should include TensorFlow's experimental numpy configuration and eager execution check.
Consider adding these configurations:
# SPDX-License-Identifier: LGPL-3.0-or-later import pytest +import tensorflow as tf +import tensorflow.experimental.numpy as tnp from ...utils import ( DP_TEST_TF2_ONLY, ) +if not tf.executing_eagerly(): + raise RuntimeError("This module must be run in eager execution mode") + +tnp.experimental_enable_numpy_behavior() + pytestmark = pytest.mark.skipif(not DP_TEST_TF2_ONLY, reason="TF2 conflicts with TF1")source/tests/jax/jax2tf/test_region.py (3)
4-4
: Consider the stability implications of using experimental numpy.While using
tensorflow.experimental.numpy
is acceptable in tests, be aware that its API might change in future TensorFlow versions. Consider adding a comment explaining why the experimental version is preferred over the stable numpy.
18-22
: Add documentation for cell initialization.The cell initialization involves complex shape transformations. Consider adding comments to explain:
- Why the cell is initialized with this specific 3D structure
- The purpose of reshaping to [1, 1, -1, 3]
- The significance of tiling to [4, 5, 1, 1]
25-34
: Enhance test coverage with edge cases.While the current test with random inputs is good, consider adding specific test cases for:
- Zero vectors
- Unit vectors
- Extreme values
- Negative coordinates
This would ensure the transformation handles all possible scenarios correctly.
source/tests/jax/jax2tf/test_nlist.py (2)
17-44
: Add docstrings to improve test documentation.The test class and setUp method would benefit from docstrings explaining:
- The purpose of the test suite
- The structure of the test data
- The meaning of key parameters (nf, nloc, ns, etc.)
Example addition:
class TestNeighList(tf.test.TestCase): + """Test suite for neighbor list construction and coordinate extension. + + Tests the functionality of build_neighbor_list and extend_coord_with_ghosts + with realistic molecular simulation data. + """ def setUp(self): + """Initialize test data including coordinates, cell, and reference lists. + + Sets up: + - nf (int): Number of frames (3) + - nloc (int): Number of local atoms (3) + - ns (int): Number of shifts (75) + - cell (array): Simulation cell parameters + - coordinates and atom types + """
45-65
: Extract magic numbers into named constants.The test uses magic numbers for array comparisons. Consider extracting these into named constants at the class level for better maintainability.
class TestNeighList(tf.test.TestCase): + # Constants for array comparisons + INVALID_NEIGHBOR = -1 + def test_build_notype(self): - nlist_mask = nlist[0] == -1 + nlist_mask = nlist[0] == self.INVALID_NEIGHBOR
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (6)
.github/workflows/test_python.yml
(1 hunks)pyproject.toml
(2 hunks)source/tests/jax/__init__.py
(1 hunks)source/tests/jax/jax2tf/__init__.py
(1 hunks)source/tests/jax/jax2tf/test_nlist.py
(1 hunks)source/tests/jax/jax2tf/test_region.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- source/tests/jax/init.py
🔇 Additional comments (6)
source/tests/jax/jax2tf/test_region.py (1)
1-53
: Verify the connection with ghost atoms handling.
The PR objectives mention handling ghost atoms and neighbor lists, but this test file focuses on coordinate transformations. Please clarify how these region transformations relate to the ghost atoms functionality.
✅ Verification successful
Region transformations are essential for ghost atoms and neighbor list handling
The verification shows clear connections between region transformations and ghost atoms functionality:
-
The
inter2phys
andto_face_distance
functions are directly used in neighbor list construction across all implementations (TF/PT/JAX):- In
deepmd/tf/utils/nlist.py
,deepmd/pt/utils/nlist.py
, anddeepmd/jax/jax2tf/nlist.py
- Used to calculate buffer sizes for ghost atoms based on cutoff distances
- In
-
The region transformations serve critical purposes:
inter2phys
: Converts internal coordinates to physical coordinates needed for ghost atom positioningto_face_distance
: Calculates cell face distances to determine how many ghost atom copies are needed in each direction
-
These functions are part of the core infrastructure used by the ghost atoms implementation found in files like:
deepmd/tf/utils/nlist.py
:extend_coord_with_ghosts
deepmd/pt/utils/nlist.py
:extend_input_and_build_neighbor_list
deepmd/jax/jax2tf/nlist.py
:build_neighbor_list
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for connections between region transformations and ghost atoms
rg -l "ghost.*region|region.*ghost" --type py
# Look for usage of these transformation functions in neighbor list construction
ast-grep --pattern 'inter2phys|to_face_distance'
Length of output: 98
Script:
#!/bin/bash
# Search for files containing ghost atoms implementation
rg -l "ghost.*atom" --type py
# Search for implementations using region transformations
rg -A 5 "inter2phys|to_face_distance" --type py
# Look for neighbor list implementations
rg -l "neighbor.*list|neigh.*list" --type py
# Check the actual implementation files for region.py
fd region.py
Length of output: 20719
.github/workflows/test_python.yml (1)
61-61
: LGTM! Test paths align with PR objectives.
The addition of test_io.py
and the entire jax2tf
directory to TF2 eager mode testing appropriately covers both the DeepEval class modifications and the new neighbor list functionality.
Let's verify the test coverage:
✅ Verification successful
Test paths are appropriate and contain relevant test coverage
The verification confirms that:
test_io.py
thoroughly tests the DeepEval class with both PBC and non-PBC conditions across different backends (TF, PyTorch, JAX)test_nlist.py
comprehensively tests the neighbor list functionality with:- Type-based and non-type-based neighbor list building
- Ghost atom coordination extension
- Proper handling of periodic boundary conditions
- Validation against reference neighbor lists
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Verify that the test files exist and contain relevant test cases
# Test 1: Check if the test files exist
echo "Checking test files existence..."
fd -t f "test_io.py$" "source/tests/consistent/io"
fd -t f "test_nlist.py$" "source/tests/jax/jax2tf"
# Test 2: Check test content for neighbor list and nopbc related tests
echo "Checking test content..."
rg -l "test.*nlist|test.*pbc" "source/tests/jax/jax2tf"
Length of output: 367
Script:
#!/bin/bash
# Let's check the actual content of these test files to verify coverage
echo "Content of test_io.py:"
cat source/tests/consistent/io/test_io.py
echo -e "\nContent of test_nlist.py:"
cat source/tests/jax/jax2tf/test_nlist.py
# Also check for any other test files in jax2tf directory
echo -e "\nOther test files in jax2tf directory:"
fd -t f "^test.*\.py$" "source/tests/jax/jax2tf"
Length of output: 15126
source/tests/jax/jax2tf/test_nlist.py (2)
101-109
: Add explanation for the coordinate transformation logic.
The shift vector calculation and matrix transformation logic is complex and would benefit from detailed comments explaining:
- The purpose of the transformation
- Why the inverse cell matrix is used
- The expected output format
#!/bin/bash
# Check if similar matrix transformations are documented elsewhere
rg "tf.linalg.inv.*cell" -B 2 -A 2
4-4
: Consider the stability implications of using experimental NumPy.
The use of tensorflow.experimental.numpy
might lead to stability issues as experimental features can change or be deprecated. Consider documenting this dependency clearly and monitoring for updates.
pyproject.toml (2)
408-411
: LGTM: Consistent module-level import restrictions.
The addition of "deepmd.jax" and "jax" to banned module-level imports follows the established pattern of restricting direct imports of deep learning frameworks, similar to the existing restrictions on tensorflow and torch.
424-424
: LGTM: Appropriate test directory exception.
Adding "source/tests/jax/**" to the TID253 exceptions is consistent with the existing pattern for test directories and necessary for allowing JAX imports in test files.
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (4)
source/tests/jax/jax2tf/test_region.py (2)
26-34
: Add docstrings to explain the test class and cell structure.The test setup is well-implemented, but would benefit from documentation explaining:
- The purpose of the TestRegion class
- The structure and significance of the cell array
- Why these specific dimensions (4, 5, -1, 3) are used
class TestRegion(tf.test.TestCase): + """Tests for region transformation functions. + + The test class validates coordinate transformations and distance calculations + using a 4x5 batch of 3x3 cell matrices. Each cell represents a transformation + from internal to physical coordinates. + """ def setUp(self): + """Sets up test fixtures with a batched cell array for coordinate transforms."""
35-44
: Consider adding edge cases to the coordinate transformation test.While the basic test is solid, consider adding test cases for:
- Zero coordinates
- Large coordinate values
- Negative coordinates
- Boundary values
def test_inter_to_phys(self): + """Test internal to physical coordinate transformation with various cases.""" rng = tf.random.Generator.from_seed(GLOBAL_SEED) inter = rng.normal(shape=[4, 5, 3, 3]) + # Test regular case phys = inter2phys(inter, self.cell) for ii in range(4): for jj in range(5): expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) self.assertAllClose( phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec ) + + # Test edge cases + edge_cases = [ + tnp.zeros([4, 5, 3, 3]), # Zero coordinates + tnp.ones([4, 5, 3, 3]) * 1e6, # Large values + tnp.ones([4, 5, 3, 3]) * -1.0, # Negative values + ] + for case in edge_cases: + phys = inter2phys(case, self.cell) + for ii in range(4): + for jj in range(5): + expected_phys = tnp.matmul(case[ii, jj], self.cell[ii, jj]) + self.assertAllClose( + phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec + )source/tests/jax/jax2tf/test_nlist.py (2)
28-55
: Add docstrings to explain test setup and data structure.The test setup is comprehensive but would benefit from documentation explaining:
- The purpose and structure of the test data
- The meaning of magic numbers (e.g., nloc=3, ns=553)
- The format and meaning of ref_nlist array
Example docstring:
class TestNeighList(tf.test.TestCase): """Tests for neighbor list construction and coordinate extension. Test data structure: - nf: Number of frames - nloc: Number of local atoms per frame - ns: Number of shifts (5x5x3 grid) - ref_nlist: Reference neighbor list with format [...] """
77-98
: Optimize test_build_type by reducing redundant operations.The test method could be optimized by:
- Caching the mapping operation results
- Using vectorized operations instead of the loop
def test_build_type(self): ecoord, eatype, mapping = extend_coord_with_ghosts( self.coord, self.atype, self.cell, self.rcut ) nlist = build_neighbor_list( ecoord, eatype, self.nloc, self.rcut, self.nsel, distinguish_types=True, ) self.assertAllClose(nlist[0], nlist[1]) nlist_mask = nlist[0] == -1 nlist_loc = mapping[0][nlist[0]] nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) - for ii in range(2): - self.assertAllClose( - tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), - tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), - ) + # Split and sort once, then compare + nlist_split = tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1), axis=-1) + ref_split = tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1), axis=-1) + self.assertAllClose(nlist_split, ref_split)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
source/tests/jax/jax2tf/__init__.py
(1 hunks)source/tests/jax/jax2tf/test_nlist.py
(1 hunks)source/tests/jax/jax2tf/test_region.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- source/tests/jax/jax2tf/init.py
🔇 Additional comments (4)
source/tests/jax/jax2tf/test_region.py (2)
1-24
: LGTM! Well-structured imports and test setup.
The imports are properly organized with clear separation of concerns. Good practice using conditional imports and skip markers to handle TensorFlow version compatibility.
46-63
: Add input validation and consider vectorizing calculations.
The previous review comment about input validation is still applicable. Additionally, consider vectorizing the calculations to improve test performance:
def test_to_face_dist(self):
+ """Test face distance calculations with validation."""
cell0 = self.cell[0][0]
vol = tf.linalg.det(cell0)
+ # Ensure cell is not singular
+ self.assertGreater(tf.abs(vol), 1e-10)
+
# area of surfaces xy, xz, yz
sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1]))
sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2]))
syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2]))
+ # Validate surface areas
+ self.assertGreater(sxy, 1e-10)
+ self.assertGreater(sxz, 1e-10)
+ self.assertGreater(syz, 1e-10)
+
# vol / area gives distance
dz = vol / sxy
dy = vol / sxz
dx = vol / syz
expected = tnp.array([dx, dy, dz])
dists = to_face_distance(self.cell)
- for ii in range(4):
- for jj in range(5):
- self.assertAllClose(
- dists[ii][jj], expected, rtol=self.prec, atol=self.prec
- )
+ # Vectorized comparison
+ self.assertAllClose(
+ dists, tnp.broadcast_to(expected, [4, 5, 3]),
+ rtol=self.prec, atol=self.prec
+ )
source/tests/jax/jax2tf/test_nlist.py (2)
1-26
: LGTM! Well-structured imports and test setup.
The imports are properly organized with clear conditional logic to handle TensorFlow version conflicts.
56-164
: Test coverage aligns well with PR objectives.
The test suite effectively validates the neighbor list construction in SavedModel format and handles ghost atoms appropriately, which directly addresses the PR objectives. The tests verify:
- Neighbor list construction with and without type distinction
- Coordinate extension with ghost atoms
- Grid alignment and shift vectors
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Outside diff range and nitpick comments (1)
source/tests/consistent/io/test_io.py (1)
149-149
: Add docstring explaining nopbc test cases.While the implementation of no periodic boundary condition tests is good, consider adding documentation to explain:
- The purpose of these test cases
- The expected behavior when box is None
- The relationship with the PR's nopbc fixes
Example docstring addition:
def test_deep_eval(self): + """Test deep_eval with both periodic and non-periodic boundary conditions. + + This test verifies: + 1. Regular evaluation with periodic boundary conditions + 2. Evaluation without periodic boundaries (box=None) + 3. Atomic-level outputs for both cases + """Also applies to: 190-206
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
source/tests/consistent/io/test_io.py
(4 hunks)source/tests/jax/jax2tf/test_nlist.py
(1 hunks)source/tests/jax/jax2tf/test_region.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/jax/jax2tf/test_nlist.py
- source/tests/jax/jax2tf/test_region.py
🔇 Additional comments (2)
source/tests/consistent/io/test_io.py (2)
3-3
: LGTM: Good memory management practice.
The addition of garbage collection after object deletion is a good practice to prevent memory leaks and OOM issues in CI environments.
Also applies to: 112-114, 207-208
216-223
: Verify tolerance values for nopbc comparisons.
The assertions use the same tolerance values (rtol=1e-12, atol=1e-12) as the periodic boundary tests. Please verify if these tolerances are appropriate for nopbc cases, as different numerical behaviors might occur without periodic boundaries.
✅ Verification successful
Based on the search results, I can now generate the final response:
Tolerance values for nopbc comparisons are appropriate
The tolerance values (rtol=1e-12, atol=1e-12) are appropriate for non-periodic boundary condition (nopbc) cases because:
- These values are consistent with other high-precision tests in the codebase, particularly:
- The model initialization tests use the same tolerances (1e-12) for both periodic and non-periodic cases
- The consistent IO tests use these tolerances uniformly across different backends
- Non-periodic calculations are typically simpler than periodic ones (no wrapping/image calculations), so using the same strict tolerances is reasonable
- The test already handles special cases with the
np.isnan(vv2).all()
check for unsupported features
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if there are any existing tests that use different tolerances for nopbc cases
# Search for tolerance values in test files
echo "Searching for existing tolerance patterns in tests..."
rg -A 2 'assert_allclose.*nopbc' source/tests/
Length of output: 165
Script:
#!/bin/bash
# Let's search for all assert_allclose usage in test files to understand tolerance patterns
echo "Searching for assert_allclose patterns in tests..."
rg -B 2 -A 2 'assert_allclose' source/tests/consistent/io/test_io.py
# Also search for nopbc related tests
echo -e "\nSearching for nopbc related tests..."
rg -B 2 -A 2 'nopbc' source/tests/consistent/io/test_io.py
Length of output: 1551
Script:
#!/bin/bash
# Let's search for tolerance patterns in other test files to get a broader view
echo "Searching for tolerance patterns in all test files..."
rg -B 2 -A 2 'rtol.*atol' source/tests/
# Also search specifically for non-periodic boundary condition related tests
echo -e "\nSearching for other nopbc/non-periodic tests..."
rg -B 2 -A 2 'non.?periodic|nopbc' source/tests/
Length of output: 42671
The tests may fail due to the number of CPU threads is not limited. |
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
The test failed due to out-of-memory. I don't know why JAX consumed so much memory... I add a fixture to call |
Oh, it does not work. |
I move the JAX test in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
source/tests/jax/jax2tf/test_nlist.py (2)
34-54
: Consider making test data more dynamic and maintainable.The current setup has several potential improvements:
- The cell matrix and reference neighbor list are hardcoded, making tests brittle
- The test could benefit from parameterization to cover more cases
Consider:
- Using
@pytest.mark.parametrize
to test different cell configurations- Moving test data to external fixtures
- Adding helper methods to generate reference neighbor lists dynamically
Example parameterization:
@pytest.mark.parametrize("cell_type", [ "cubic", "orthorhombic", "triclinic" ]) def test_build_notype(self, cell_type): cell = get_test_cell(cell_type) # ... rest of the test
77-98
: Expand type-specific neighbor list testing.The current test could be enhanced with:
- More diverse type combinations
- Boundary cases between different types
- Verification of type-specific cutoffs
Consider adding:
def test_build_type_diverse(self): # Test with more atom types diverse_types = tnp.array([-1, 0, 1, 2, 3], dtype=tnp.int32) nsel_diverse = [5] * 4 # neighbor selection for each type # ... rest of the test
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
source/tests/jax/jax2tf/test_nlist.py
(1 hunks)source/tests/jax/jax2tf/test_region.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- source/tests/jax/jax2tf/test_region.py
🔇 Additional comments (1)
source/tests/jax/jax2tf/test_nlist.py (1)
6-6
: Consider the stability implications of using experimental NumPy.
The use of tensorflow.experimental.numpy
could lead to potential stability issues as it's an experimental API. Consider:
- Adding error handling for experimental features
- Documenting any known limitations
- Having a fallback mechanism
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (4)
source/jax2tf_tests/test_region.py (2)
29-33
: Consider documenting the significance of tensor dimensions.The cell tensor is initialized with specific dimensions [4, 5, -1, 3] but the significance of these dimensions (4 and 5) is not documented. Consider adding comments explaining what these dimensions represent or extracting them as named constants.
47-64
: Consider enhancing the mathematical documentation.While the code includes basic comments, it would be beneficial to add more detailed documentation about:
- The mathematical formulas used for face distance calculations
- Why volume divided by area gives the correct distance
- References to relevant geometric principles
This would make the test more maintainable and educational for future contributors.
Example documentation format:
def test_to_face_dist(self): """Tests face distance calculations using volume/area method. Mathematical background: - For a parallelepiped, the distance to a face is given by V/A where V is the volume and A is the area of the parallel face - Volume is calculated using determinant of cell matrix - Surface areas are calculated using cross products of edge vectors References: - [Add relevant geometric principles or papers] """source/jax2tf_tests/test_nlist.py (2)
29-55
: Add docstrings and comments to explain test data structure.The setUp method initializes several parameters but lacks documentation explaining:
- The significance of dimensions (nf=3, nloc=3, ns=75)
- The structure of reference neighbor lists
- The coordinate transformation process
Add docstrings following this structure:
def setUp(self): """Initialize test data for neighbor list construction. Test data structure: - nf (int): Number of frames - nloc (int): Number of local atoms - ns (int): Number of shifts (5x5x3 grid) - cell (array): Simulation cell matrix - icoord (array): Initial coordinates - ref_nlist (array): Reference neighbor lists """
56-76
: Enhance test_build_notype with edge cases.While the basic neighbor list construction is tested, consider adding tests for:
- Empty neighbor lists
- Maximum number of neighbors
- Single atom systems
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
.github/workflows/test_python.yml
(1 hunks)pyproject.toml
(2 hunks)source/jax2tf_tests/__init__.py
(1 hunks)source/jax2tf_tests/test_nlist.py
(1 hunks)source/jax2tf_tests/test_region.py
(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- source/jax2tf_tests/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- .github/workflows/test_python.yml
- pyproject.toml
🔇 Additional comments (5)
source/jax2tf_tests/test_region.py (2)
36-46
: LGTM! Well-structured test with proper assertions.
The test effectively validates the coordinate transformation across batch dimensions while ensuring reproducibility through a fixed seed.
6-8
: Consider the stability implications of using experimental numpy.
The code uses tensorflow.experimental.numpy
which may have API changes in future versions. Consider documenting the minimum required TensorFlow version and tracking the module's stability status.
source/jax2tf_tests/test_nlist.py (3)
77-98
: LGTM! Comprehensive type-aware neighbor list testing.
The test properly validates type-distinguished neighbor lists and verifies the results against reference data.
5-6
: Consider the stability implications of using experimental numpy.
The code imports tensorflow.experimental.numpy
, which is an experimental module. While it provides numpy-like functionality, its API might change in future TensorFlow releases.
Let's check if there's a more stable alternative:
Consider:
- Adding a comment explaining why the experimental module is necessary
- Adding version constraints in requirements.txt
- Creating an abstraction layer to isolate experimental API usage
29-55
: 🛠️ Refactor suggestion
Add tests for zero ghost atoms case.
The PR objectives mention issues with ghost atoms when the count is zero, but this case isn't explicitly tested. Additionally, given the memory concerns mentioned in the PR comments, consider adding memory usage assertions.
Let's check if other test files cover this:
Add test cases for:
- Zero ghost atoms scenario
- Memory usage patterns using
memory_profiler
or similar tools
It finally passed after I moved the |
Per our discussion, use TF to build the neighbor list in the SavedModel format.
Also, fix a bug when the number of ghost atoms is zero. The polymorphic_shape needs to be larger than 1, and
nghost == 0
triggered the error. Previously, I also triednall
ornghost - 1
but none of them worked. Finally, I export two different functions... So now four functions are stored in the model: calculate virial or not, x nghost is zero or not. The tests for nopbc are added.Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
Chores
pyproject.toml
.