Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: memset and batch size optimization for computing splits #4001

Conversation

venkywonka
Copy link
Contributor

@venkywonka venkywonka commented Jun 22, 2021

  • optimization 1: Increase the default maximum number of nodes that can be processed per batch (the max_batch_size hyperparameter)
    • However, this causes an increase in GPU memory, but for practical workloads, this hardly exceeds 200 MB.
  • optimization 2: reduce the amount of memory accessed in the memset operations per kernel call

  • The current PR drastically reduces total number of kernel invocations (while increasing work-per-invocation) and also memsets required per kernel invocation. This can be seen in the following plot on the year dataset.
    • x-axis: (with/without optimization 1 x with/without optimization 2) , y-axis: times (s)

    • CSRK = computeSplitRegressionKernel

    • year-nsys-kernel-and-memset-times-lite_mode-max_bach_size


  • With n_estimators: 10, n_streams: 4, max_depth:32 (rest default) the following are the gbm-bench plots:
    • (main: branch-21.08 , devel: current PR, skl: scikit-learn RF)
    • scores are accuracy for classification and MSE for regression
    • Note: scikit-learn runs on n_jobs=-1 so it's leveraging all the 24 CPUs in my machine

memset-batch-opt

@venkywonka venkywonka requested review from a team as code owners June 22, 2021 09:18
@github-actions github-actions bot added CUDA/C++ Cython / Python Cython or Python issue labels Jun 22, 2021
@venkywonka venkywonka added Perf Related to runtime performance of the underlying code breaking Breaking change labels Jun 22, 2021
Copy link
Contributor

@RAMitchell RAMitchell left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice, simple changes and awesome results.

@@ -338,13 +338,16 @@ struct Builder {
raft::update_device(curr_nodes, h_nodes.data() + node_start, batchSize, s);

int total_samples_in_curr_batch = 0;
int n_large_nodes_in_curr_batch = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment what this variable is. e.g. nodes with number of training instances larger than block size. These nodes require global memory for histograms.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍🏾

@@ -446,7 +452,9 @@ struct ClsTraits {
// Pick the max of two
size_t smemSize = std::max(smemSize1, smemSize2);
dim3 grid(b.total_num_blocks, colBlks, 1);
CUDA_CHECK(cudaMemsetAsync(b.hist, 0, sizeof(int) * b.nHistBins, s));
int nHistBins = 0;
nHistBins = n_large_nodes_in_curr_batch * (1 + nbins) * colBlks * nclasses;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 initialising this variable is unnecessary.

The 1+ in (1+nbins) should disappear when you merge the objective function PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that was necessary in my initial prototyping version but missed it somehow.. changed it 👍🏾

@@ -507,7 +515,7 @@ struct RegTraits {
*/
static void computeSplit(Builder<RegTraits<DataT, IdxT>>& b, IdxT col,
IdxT batchSize, CRITERION splitType,
cudaStream_t s) {
int& n_large_nodes_in_curr_batch, cudaStream_t s) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is n_large_nodes_in_curr_batch passed by reference? Is it modified somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was in my initial version, but you're right, unnecessary here. I have changed it to const int 👍🏾

@@ -35,6 +35,7 @@ namespace DecisionTree {
template <typename IdxT>
struct WorkloadInfo {
IdxT nodeid; // Node in the batch on which the threadblock needs to work
IdxT large_nodeid; // counts only large nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure you comment what large nodes means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 👍🏾

Copy link
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

codeowner approval

@dantegd
Copy link
Member

dantegd commented Jun 23, 2021

rerun tests

@dantegd dantegd added the improvement Improvement / enhancement to an existing function label Jun 23, 2021
…enh-ext-partial-memset-and-batch-size-optimization
@RAMitchell
Copy link
Contributor

rerun tests

1 similar comment
@RAMitchell
Copy link
Contributor

rerun tests

@codecov-commenter
Copy link

Codecov Report

❗ No coverage uploaded for pull request base (branch-21.08@f71d369). Click here to learn what that means.
The diff coverage is n/a.

Impacted file tree graph

@@               Coverage Diff               @@
##             branch-21.08    #4001   +/-   ##
===============================================
  Coverage                ?   85.44%           
===============================================
  Files                   ?      230           
  Lines                   ?    18088           
  Branches                ?        0           
===============================================
  Hits                    ?    15455           
  Misses                  ?     2633           
  Partials                ?        0           
Flag Coverage Δ
dask 48.04% <0.00%> (?)
non-dask 77.79% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.


Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update f71d369...70665a3. Read the comment docs.

@dantegd
Copy link
Member

dantegd commented Jun 29, 2021

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 705e0df into rapidsai:branch-21.08 Jun 29, 2021
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
…#4001)

* **optimization 1:** Increase the default maximum number of nodes that can be processed per batch (the `max_batch_size` hyperparameter)
    * However, this causes an increase in GPU memory, but for practical workloads, this hardly exceeds 200 MB.
* **optimization 2:** reduce the amount of memory accessed in the memset operations per kernel call 

---
* The current PR drastically reduces total number of kernel invocations (while increasing work-per-invocation) and also memsets required per kernel invocation. This can be seen in the following plot on the `year` dataset. 
    * x-axis: (with/without `optimization 1` x with/without `optimization 2`) , y-axis: times (s)
    * `CSRK` = `computeSplitRegressionKernel` 
    
    *  ![year-nsys-kernel-and-memset-times-lite_mode-max_bach_size](https://user-images.githubusercontent.com/23023424/122897144-5b319380-d367-11eb-995f-9c05a086fc0f.png)

---

* With `n_estimators: 10`, `n_streams: 4`, `max_depth:32`  (rest default) the following are the gbm-bench plots: 
    * (main: branch-21.08 , devel: current PR, skl: scikit-learn RF)
    * scores are accuracy for classification and MSE for regression
    * Note: scikit-learn runs on `n_jobs=-1` so it's leveraging all the 24 CPUs in my machine


![memset-batch-opt](https://user-images.githubusercontent.com/23023424/122897816-f88cc780-d367-11eb-9b0f-6384d4ef8cbb.png)

Authors:
  - Venkat (https://github.com/venkywonka)

Approvers:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4001
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
breaking Breaking change CUDA/C++ Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function Perf Related to runtime performance of the underlying code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants