Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(cc): merge get backend codes #4355

Merged
merged 1 commit into from
Nov 14, 2024

Conversation

njzjz
Copy link
Member

@njzjz njzjz commented Nov 13, 2024

Fix #4308.

Summary by CodeRabbit

  • New Features

    • Introduced a new function to dynamically determine the backend framework based on the model file type.
  • Improvements

    • Enhanced backend detection logic in multiple classes, allowing for more flexible model initialization.
    • Simplified control flow in the initialization methods of various components.
  • Bug Fixes

    • Improved error handling for unsupported backends and model formats during initialization processes.

Fix deepmodeling#4308.

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@njzjz njzjz linked an issue Nov 13, 2024 that may be closed by this pull request
Copy link
Contributor

coderabbitai bot commented Nov 13, 2024

📝 Walkthrough

Walkthrough

The changes in this pull request introduce a new function, get_backend, to dynamically determine the backend framework based on the model name provided. This function is utilized in multiple classes, including DipoleChargeModifier, DeepPot, DeepSpin, and DeepTensor, to replace hardcoded backend logic. The modifications enhance backend detection by allowing the system to adapt to various model formats, including JAX, while maintaining existing functionalities and signatures of methods.

Changes

File Path Change Summary
source/api_cc/include/common.h Added function declaration: DPBackend get_backend(const std::string& model) in deepmd namespace.
source/api_cc/src/DataModifier.cc Updated DipoleChargeModifier to use get_backend(model) for backend detection; logic for TensorFlow backend modified.
source/api_cc/src/DeepPot.cc Refactored init method in DeepPot to use get_backend(model) for backend determination; similar changes in DeepPotModelDevi.
source/api_cc/src/DeepSpin.cc Refactored init method in DeepSpin to use get_backend(model); expanded error handling for PaddlePaddle backend.
source/api_cc/src/DeepTensor.cc Updated init method in DeepTensor to use get_backend(model); refined error handling for unsupported backends.
source/api_cc/src/common.cc Added implementation of get_backend function to determine backend based on model file extension.

Assessment against linked issues

Objective Addressed Explanation
Update JAX backend handling for consistency across the codebase (#4308)
Update backend type checks to include JAX in DeepTensor (#4308)
Update backend type checks to include JAX in DeepPot (#4308)
Update backend handling in DataModifier to include JAX (#4308)

Possibly related PRs

Suggested labels

Python

Suggested reviewers

  • wanghan-iapcm

Warning

Rate limit exceeded

@njzjz has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 2 minutes and 38 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 320c7fd and bc57b50.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Outside diff range and nitpick comments (3)
source/api_cc/src/DataModifier.cc (1)

Line range hint 31-45: Add missing JAX backend handling

According to the PR objectives (issue #4308), JAX backend support should be added. However, the current implementation doesn't include a case for JAX in the backend switch statement.

Add JAX handling similar to other backends. Here's the suggested change:

  const DPBackend backend = get_backend(model);
  if (deepmd::DPBackend::TensorFlow == backend) {
#ifdef BUILD_TENSORFLOW
    dcm = std::make_shared<deepmd::DipoleChargeModifierTF>(model, gpu_rank,
                                                           name_scope_);
#else
    throw deepmd::deepmd_exception("TensorFlow backend is not built");
#endif
  } else if (deepmd::DPBackend::PyTorch == backend) {
    throw deepmd::deepmd_exception("PyTorch backend is not supported yet");
  } else if (deepmd::DPBackend::Paddle == backend) {
    throw deepmd::deepmd_exception("PaddlePaddle backend is not supported yet");
+ } else if (deepmd::DPBackend::JAX == backend) {
+   throw deepmd::deepmd_exception("JAX backend is not supported yet");
  } else {
    throw deepmd::deepmd_exception("Unknown file type");
  }
source/api_cc/src/DeepSpin.cc (1)

Based on the search results and the previously examined code, I can now provide a final response:

JAX backend implementation missing for DeepSpin

The codebase has a JAX backend implementation for DeepPot (DeepPotJAX) but lacks a corresponding JAX backend implementation for DeepSpin. This is a potential issue since:

  • The DeepSpin class is designed to support multiple backends (as seen in the header file)
  • A JAX backend implementation (DeepSpinJAX) is missing, unlike the PyTorch case referenced in the review comment
  • The JAX backend would need to implement the DeepSpinBackend interface with spin-specific computations
🔗 Analysis chain

Line range hint 47-48: Verify JAX backend implementation

While the code correctly handles the JAX backend case, let's verify if there are any JAX-specific implementations that need to be added similar to the TensorFlow and PyTorch cases.


Let me check the actual JAX implementation to verify if it's properly implemented.


Let me check the JAX backend implementation for DeepSpin specifically.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for JAX backend implementation files

# Look for JAX-related header files
echo "Searching for JAX header files..."
fd -e h -e hpp ".*JAX.*"

# Look for JAX-related source files
echo "Searching for JAX source files..."
fd -e cc -e cpp ".*JAX.*"

# Search for JAX-related class definitions
echo "Searching for JAX class implementations..."
rg -l "class.*JAX" 

Length of output: 1300


Script:

#!/bin/bash
# Check JAX implementation details
echo "=== DeepPotJAX header ==="
cat source/api_cc/include/DeepPotJAX.h

echo -e "\n=== DeepPotJAX implementation ==="
cat source/api_cc/src/DeepPotJAX.cc

echo -e "\n=== DeepSpin header (for context) ==="
cat source/api_cc/include/DeepSpin.h

Length of output: 71320


Script:

#!/bin/bash
# Check for JAX-specific DeepSpin implementation
echo "=== Searching for DeepSpinJAX files ==="
fd -e h -e hpp -e cc -e cpp ".*DeepSpinJAX.*"

echo -e "\n=== Searching for JAX backend references in DeepSpin implementation ==="
rg -l "DeepSpinJAX" source/api_cc/src/DeepSpin.cc

Length of output: 345

source/api_cc/src/common.cc (1)

1403-1413: LGTM! Consider adding file validation and case-insensitive extension checks.

The implementation correctly determines the backend based on file extensions and aligns with the PR objectives for JAX backend support.

Consider these improvements:

  1. Add file existence validation before checking extensions
  2. Make extension checks case-insensitive for better robustness

Example implementation:

 deepmd::DPBackend deepmd::get_backend(const std::string& model) {
+  // Convert to lowercase for case-insensitive comparison
+  std::string model_lower = model;
+  std::transform(model_lower.begin(), model_lower.end(), model_lower.begin(), ::tolower);
+
+  // Validate file exists
+  std::ifstream f(model.c_str());
+  if (!f.good()) {
+    throw deepmd::deepmd_exception("Model file does not exist: " + model);
+  }
+
-  if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
+  if (model_lower.length() >= 4 && model_lower.substr(model_lower.length() - 4) == ".pth") {
     return deepmd::DPBackend::PyTorch;
-  } else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
+  } else if (model_lower.length() >= 3 && model_lower.substr(model_lower.length() - 3) == ".pb") {
     return deepmd::DPBackend::TensorFlow;
-  } else if (model.length() >= 11 && model.substr(model.length() - 11) == ".savedmodel") {
+  } else if (model_lower.length() >= 11 && model_lower.substr(model_lower.length() - 11) == ".savedmodel") {
     return deepmd::DPBackend::JAX;
   }
   throw deepmd::deepmd_exception("Unsupported model file format");
 }
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 320c7fd and bc57b50.

📒 Files selected for processing (6)
  • source/api_cc/include/common.h (1 hunks)
  • source/api_cc/src/DataModifier.cc (1 hunks)
  • source/api_cc/src/DeepPot.cc (1 hunks)
  • source/api_cc/src/DeepSpin.cc (1 hunks)
  • source/api_cc/src/DeepTensor.cc (1 hunks)
  • source/api_cc/src/common.cc (1 hunks)
🔇 Additional comments (7)
source/api_cc/src/DataModifier.cc (1)

31-31: Verify get_backend implementation for JAX support

Let's verify that the get_backend function in common.h properly handles JAX model detection.

✅ Verification successful

JAX support is properly implemented in get_backend function

The get_backend implementation in source/api_cc/src/common.cc correctly handles JAX models by checking for the .savedmodel extension and returning DPBackend::JAX. The function properly differentiates between PyTorch (.pth), TensorFlow (.pb), and JAX (.savedmodel) model formats.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for get_backend implementation and its handling of JAX
rg -A 10 "get_backend.*\(" source/api_cc/include/common.h source/api_cc/src/common.cc

Length of output: 1704

source/api_cc/include/common.h (1)

18-23: Documentation could be enhanced with error handling details.

The function declaration and basic documentation look good. However, consider enhancing the documentation with:

  • Error handling behavior for unknown/invalid model names
  • Expected format of the model name parameter

Example enhancement:

 /**
  * @brief Get the backend of the model.
  * @param[in] model The model name.
  * @return The backend of the model.
+ * @details The function analyzes the model name/path to determine the backend.
+ * Returns DPBackend::Unknown if the backend cannot be determined.
+ * @example
+ * get_backend("model.pb") // Returns DPBackend::TensorFlow
+ * get_backend("model.pt") // Returns DPBackend::PyTorch
  **/

Let's verify the implementation exists and matches this declaration:

✅ Verification successful

Implementation found and matches declaration with clear behavior

The implementation in source/api_cc/src/common.cc shows that:

  • Function determines backend based on file extension:
    • .pth -> PyTorch
    • .pb -> TensorFlow
    • .jax -> JAX (implied by the partial match)

The function is used correctly in multiple places (DeepPot, DeepTensor, DataModifier, DeepSpin) to determine the appropriate backend implementation.

Note: The original review comment suggesting documentation improvements is still valid but not critical, as the implementation clearly shows the expected behavior.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if get_backend implementation exists and matches declaration

# Search for the implementation
echo "Searching for get_backend implementation..."
rg -l "DPBackend\s+get_backend\s*\(.*string.*model"

# Verify the implementation matches the declaration
ast-grep --pattern 'DPBackend get_backend(const std::string& model)'

Length of output: 335


Script:

#!/bin/bash
# Let's search for the implementation in source files
echo "Searching for get_backend implementation in source files..."
rg -A 10 "DPBackend\s+get_backend\s*\(.*string.*model" --type cpp

# Also check if there are any test files that might help understand the behavior
echo -e "\nSearching for related test files..."
rg -l "get_backend.*model" --type cpp "*test*"

Length of output: 1170


Script:

#!/bin/bash
# Let's try a different approach to find the implementation and tests
# First, let's find all cpp files that might contain the implementation
echo "Searching for cpp source files..."
fd -e cpp -e cc

# Then search for get_backend implementation in these files
echo -e "\nSearching for get_backend implementation..."
rg -A 5 "get_backend.*model" --type cpp

# Look for test files with a better pattern
echo -e "\nSearching for test files..."
fd -t f "test" -e cpp -e cc

Length of output: 12669

source/api_cc/src/DeepSpin.cc (1)

Line range hint 39-54: LGTM: Backend handling implementation aligns with PR objectives

The changes effectively implement the backend selection logic by:

  1. Using the new get_backend function to determine the backend type
  2. Adding explicit support for JAX backend in the conditional checks
  3. Maintaining proper error handling for unsupported backends

The implementation is consistent with the PR objectives of updating JAX backend handling.

source/api_cc/src/DeepPot.cc (3)

Line range hint 40-65: LGTM! Backend initialization changes look good.

The implementation correctly:

  • Uses the new get_backend function for backend determination
  • Handles JAX backend appropriately through TensorFlow build infrastructure
  • Maintains proper error handling and build flags

Line range hint 66-600: No changes in compute methods.

The compute methods and their template instantiations remain unchanged and are not affected by the backend initialization changes.


Line range hint 601-900: Verify JAX model handling in DeepPotModelDevi.

The class should correctly handle JAX models through its DeepPot instances. Let's verify the implementation:

✅ Verification successful

Let me gather more information about how DeepPotModelDevi interacts with DeepPot instances to verify the JAX handling.


Let me try one more search with corrected paths to find the DeepPotModelDevi implementation.


Let me check one final aspect about how DeepPotModelDevi initializes the DeepPot instances.


Based on the evidence gathered from all the searches, I can now provide the final response:

DeepPotModelDevi correctly handles JAX models through delegation

The implementation shows that DeepPotModelDevi acts as a pure delegation layer, passing all computation requests to its underlying DeepPot instances without any backend-specific handling. This design ensures that JAX models are handled correctly since:

  1. All model-specific logic is encapsulated in the DeepPot instances
  2. No backend-specific code exists in DeepPotModelDevi
  3. The compute methods are pure pass-through calls to DeepPot's compute methods
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if DeepPotModelDevi properly initializes JAX models

# Search for DeepPotModelDevi usage with JAX models
rg -l "DeepPotModelDevi.*\.jax" || echo "No direct JAX model usage found"

# Check if there are any special cases for JAX in DeepPotModelDevi
ast-grep --pattern 'class DeepPotModelDevi {
  $$$
  if $COND {
    $$$
    jax
    $$$
  }
  $$$
}'

Length of output: 209


Script:

#!/bin/bash
# Check how DeepPotModelDevi initializes and uses DeepPot instances

# Look for DeepPotModelDevi constructor and initialization
ast-grep --pattern 'class DeepPotModelDevi {
  $$$
  DeepPotModelDevi($$$) {
    $$$
  }
  $$$
}'

# Check how DeepPot instances are created/stored
rg "DeepPot.*dps\[" -A 2

# Check if there's any backend-specific handling
rg "backend|jax|tensorflow|torch" source/api_cc/src/DeepPotModelDevi

Length of output: 272


Script:

#!/bin/bash
# Find the correct path for DeepPotModelDevi files
fd DeepPotModelDevi

# Look for constructor implementation
rg "DeepPotModelDevi::" -A 5

# Check how dps vector is used
rg "dps\[.*\].*compute" -A 3

Length of output: 11184


Script:

#!/bin/bash
# Check DeepPotModelDevi initialization of DeepPot instances
rg "void.*init.*DeepPotModelDevi" -A 15 source/api_cc/src/DeepPot.cc

# Look for DeepPot creation in DeepPotModelDevi
rg "new DeepPot" source/api_cc/src/DeepPot.cc

Length of output: 118

source/api_cc/src/DeepTensor.cc (1)

33-33: LGTM!

The backend detection is now correctly updated to use get_backend(model), allowing for dynamic backend support.

@github-actions github-actions bot added the C++ label Nov 13, 2024
@iProzd iProzd added this pull request to the merge queue Nov 14, 2024
Merged via the queue into deepmodeling:devel with commit d3095cf Nov 14, 2024
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update JAX backend handling for consistency across the codebase
3 participants