From 6a1ea09a3bc28fcfe6a53041c64f3ecfc2b6abb9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20M=C3=BCller?= Date: Wed, 19 Jan 2022 12:05:56 +0100 Subject: [PATCH] Reduce VRAM footprint of CutlassMLP --- include/tiny-cuda-nn/networks/cutlass_mlp.h | 2 ++ src/cutlass_mlp.cu | 39 ++++++++++++++------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/include/tiny-cuda-nn/networks/cutlass_mlp.h b/include/tiny-cuda-nn/networks/cutlass_mlp.h index 5a60f59a..24067142 100644 --- a/include/tiny-cuda-nn/networks/cutlass_mlp.h +++ b/include/tiny-cuda-nn/networks/cutlass_mlp.h @@ -151,6 +151,8 @@ class CutlassMLP : public Network { Activation m_activation; Activation m_output_activation; + bool m_can_fuse_activation; + static const uint32_t tensorcore_width = 8; // Streams and events diff --git a/src/cutlass_mlp.cu b/src/cutlass_mlp.cu index eba80ff6..0bc5eb16 100644 --- a/src/cutlass_mlp.cu +++ b/src/cutlass_mlp.cu @@ -50,7 +50,8 @@ m_network_width{network_width}, m_output_width{output_width}, m_n_hidden_layers{n_hidden_layers}, m_activation{activation}, -m_output_activation{output_activation} +m_output_activation{output_activation}, +m_can_fuse_activation{activation != Activation::Sine} { m_padded_output_width = next_multiple(m_output_width, tensorcore_width); @@ -92,8 +93,8 @@ m_output_activation{output_activation} } // Buffers to keep data from the forward and backward pass - m_forward_tmp.resize(m_n_hidden_layers * 2); - m_backward_tmp.resize(m_n_hidden_layers * 2); + m_forward_tmp.resize(m_can_fuse_activation ? m_n_hidden_layers : (m_n_hidden_layers * 2)); + m_backward_tmp.resize(m_can_fuse_activation ? m_n_hidden_layers : (m_n_hidden_layers * 2)); // 1 stream per matrix. m_training_splitk_streams.resize(m_n_hidden_layers + 1); @@ -241,12 +242,28 @@ void CutlassMLP::forward(cudaStream_t stream, const GPUMatrix& input, GPUM // Run the actual network uint32_t tmp_idx = 0; - bool fused = compute_layer(stream, false, m_activation, input_weight_matrix(use_inference_matrices), input, m_forward_tmp.at(tmp_idx), m_forward_tmp.at(tmp_idx+1)); + bool fused = compute_layer( + stream, + false, + m_activation, + input_weight_matrix(use_inference_matrices), + input, + m_forward_tmp.at(tmp_idx), + m_can_fuse_activation ? m_forward_tmp.at(tmp_idx) : m_forward_tmp.at(tmp_idx+1) + ); tmp_idx += fused ? 1 : 2; // layers for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { - fused = compute_layer(stream, false, m_activation, weight_matrix_at(use_inference_matrices, i), m_forward_tmp.at(tmp_idx-1), m_forward_tmp.at(tmp_idx), m_forward_tmp.at(tmp_idx+1)); + fused = compute_layer( + stream, + false, + m_activation, + weight_matrix_at(use_inference_matrices, i), + m_forward_tmp.at(tmp_idx-1), + m_forward_tmp.at(tmp_idx), + m_can_fuse_activation ? m_forward_tmp.at(tmp_idx) : m_forward_tmp.at(tmp_idx+1) + ); tmp_idx += fused ? 1 : 2; } @@ -275,8 +292,6 @@ void CutlassMLP::backward( allocate_backward_buffers(batch_size); } - bool can_fuse_activation = m_activation != Activation::Sine; - // Compute transfer of output activation in-place... it's treated specially for performance reasons if (m_output_activation != Activation::None) { activation_backward_output_gpu(stream, dL_doutput.n_elements(), m_output_activation, output.data(), dL_doutput.data(), m_backward_output_tmp.data()); @@ -315,7 +330,7 @@ void CutlassMLP::backward( return; } - uint32_t tmp_idx = (can_fuse_activation ? (m_n_hidden_matmuls+1) : ((m_n_hidden_matmuls+1) * 2)) - 1; + uint32_t tmp_idx = (m_can_fuse_activation ? (m_n_hidden_matmuls+1) : ((m_n_hidden_matmuls+1) * 2)) - 1; uint32_t backward_tmp_idx = 0; if (compute_param_gradients) { @@ -329,14 +344,14 @@ void CutlassMLP::backward( cudaEventRecord(m_training_splitk_events.at(backward_tmp_idx), m_training_splitk_streams.at(backward_tmp_idx)); } - if (!can_fuse_activation) { + if (!m_can_fuse_activation) { fc_multiply(stream, output_weight_matrix(use_inference_matrices).transposed(), tmp_dL_doutput, m_backward_tmp.at(backward_tmp_idx)); activation_backward_gpu(stream, m_activation, m_forward_tmp.at(tmp_idx-1), m_backward_tmp.at(backward_tmp_idx)); } else { fc_multiply(stream, output_weight_matrix(use_inference_matrices).transposed(), tmp_dL_doutput, m_forward_tmp.at(tmp_idx), m_backward_tmp.at(backward_tmp_idx), m_activation, true); } - tmp_idx -= can_fuse_activation ? 1 : 2; + tmp_idx -= m_can_fuse_activation ? 1 : 2; ++backward_tmp_idx; // layers @@ -350,14 +365,14 @@ void CutlassMLP::backward( cudaEventRecord(m_training_splitk_events.at(backward_tmp_idx), m_training_splitk_streams.at(backward_tmp_idx)); } - if (!can_fuse_activation) { + if (!m_can_fuse_activation) { fc_multiply(stream, weight_matrix_at(use_inference_matrices, matrix_idx).transposed(), m_backward_tmp.at(backward_tmp_idx-1), m_backward_tmp.at(backward_tmp_idx)); activation_backward_gpu(stream, m_activation, m_forward_tmp.at(tmp_idx-1), m_backward_tmp.at(backward_tmp_idx)); } else { fc_multiply(stream, weight_matrix_at(use_inference_matrices, matrix_idx).transposed(), m_backward_tmp.at(backward_tmp_idx-1), m_forward_tmp.at(tmp_idx), m_backward_tmp.at(backward_tmp_idx), m_activation, true); } - tmp_idx -= can_fuse_activation ? 1 : 2; + tmp_idx -= m_can_fuse_activation ? 1 : 2; ++backward_tmp_idx; }