-
Notifications
You must be signed in to change notification settings - Fork 524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat(pt): Support fitting_net statistics. #4504
base: devel
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThe pull request introduces a new method Changes
Sequence DiagramsequenceDiagram
participant AM as Atomic Model
participant FN as Fitting Network
participant Sampler as Data Sampler
AM->>Sampler: Prepare sampled data
AM->>FN: compute_input_stats(sampled_data)
FN-->>FN: Calculate mean and std dev
FN-->>AM: Statistics computed
Possibly related PRs
Suggested reviewers
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/model/task/fitting.py (2)
437-440
: Use a ternary operator for compactness.
Ruff suggests replacing theif callable(...)
block with a ternary operator. This is a minor readability enhancement.- if callable(merged): - sampled = merged() - else: - sampled = merged + sampled = merged() if callable(merged) else merged🧰 Tools
🪛 Ruff (0.8.2)
437-440: Use ternary operator
sampled = merged() if callable(merged) else merged
instead ofif
-else
-blockReplace
if
-else
-block withsampled = merged() if callable(merged) else merged
(SIM108)
457-457
: Implementaparam
statistics.
The TODO suggests you plan to handleaparam
similarly tofparam
. Let me know if you’d like help implementing that.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/model/atomic_model/dp_atomic_model.py
(1 hunks)deepmd/pt/model/task/fitting.py
(3 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/model/task/fitting.py
437-440: Use ternary operator sampled = merged() if callable(merged) else merged
instead of if
-else
-block
Replace if
-else
-block with sampled = merged() if callable(merged) else merged
(SIM108)
🔇 Additional comments (3)
deepmd/pt/model/atomic_model/dp_atomic_model.py (1)
301-301
: Consider verifying that fitting_net
is defined before usage.
If fitting_net
were ever None
, invoking compute_input_stats
would raise an exception. A quick check ensures safe execution.
deepmd/pt/model/task/fitting.py (2)
7-7
: New import statements look good.
Thanks for adding Callable
, Union
, and DPPath
; these additions enable clearer type hints and better modularity.
Also applies to: 43-45
416-436
: Comprehensive documentation.
The docstring clearly explains the purpose and usage of compute_input_stats
. This addition aligns with the PR objective to compute input statistics for fitting parameters.
if self.numb_fparam > 0: | ||
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0) | ||
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam]) | ||
fparam_avg = torch.mean(cat_data, axis=0) | ||
fparam_std = torch.std(cat_data, axis=0) | ||
fparam_inv_std = 1.0 / fparam_std | ||
self.fparam_avg.copy_( | ||
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype) | ||
) | ||
self.fparam_inv_std.copy_( | ||
torch.tensor( | ||
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype | ||
) | ||
) | ||
# TODO: stat aparam |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Handle potential zero or near-zero standard deviation.
Currently, the code divides by fparam_std
, potentially leading to inf
or NaN
values if std == 0
. Consider adding a small epsilon or performing a check to avoid division by zero.
fparam_std = torch.std(cat_data, axis=0)
+epsilon = 1e-12
+fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std)
fparam_inv_std = 1.0 / fparam_std
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if self.numb_fparam > 0: | |
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0) | |
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam]) | |
fparam_avg = torch.mean(cat_data, axis=0) | |
fparam_std = torch.std(cat_data, axis=0) | |
fparam_inv_std = 1.0 / fparam_std | |
self.fparam_avg.copy_( | |
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype) | |
) | |
self.fparam_inv_std.copy_( | |
torch.tensor( | |
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype | |
) | |
) | |
# TODO: stat aparam | |
if self.numb_fparam > 0: | |
cat_data = torch.cat([frame["fparam"] for frame in sampled], dim=0) | |
cat_data = torch.reshape(cat_data, [-1, self.numb_fparam]) | |
fparam_avg = torch.mean(cat_data, axis=0) | |
fparam_std = torch.std(cat_data, axis=0) | |
epsilon = 1e-12 | |
fparam_std = torch.where(fparam_std < epsilon, torch.tensor(epsilon, dtype=fparam_std.dtype, device=fparam_std.device), fparam_std) | |
fparam_inv_std = 1.0 / fparam_std | |
self.fparam_avg.copy_( | |
torch.tensor(fparam_avg, device=env.DEVICE, dtype=self.fparam_avg.dtype) | |
) | |
self.fparam_inv_std.copy_( | |
torch.tensor( | |
fparam_inv_std, device=env.DEVICE, dtype=self.fparam_inv_std.dtype | |
) | |
) | |
# TODO: stat aparam |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## devel #4504 +/- ##
=======================================
Coverage 84.59% 84.59%
=======================================
Files 675 675
Lines 63574 63589 +15
Branches 3490 3486 -4
=======================================
+ Hits 53778 53791 +13
- Misses 8670 8672 +2
Partials 1126 1126 ☔ View full report in Codecov by Sentry. |
Support fitting_net statistics to calculate the mean value and standard deviation of
fparam
/aparam
. So thatfparam
/aparam
can be normalized automatically before concatenating to descriptor.Summary by CodeRabbit
New Features
Bug Fixes