Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce VRAM footprint of CutlassMLP #21

Merged
merged 1 commit into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tiny-cuda-nn/networks/cutlass_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class CutlassMLP : public Network<T> {
Activation m_activation;
Activation m_output_activation;

bool m_can_fuse_activation;

static const uint32_t tensorcore_width = 8;

// Streams and events
Expand Down
39 changes: 27 additions & 12 deletions src/cutlass_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ m_network_width{network_width},
m_output_width{output_width},
m_n_hidden_layers{n_hidden_layers},
m_activation{activation},
m_output_activation{output_activation}
m_output_activation{output_activation},
m_can_fuse_activation{activation != Activation::Sine}
{
m_padded_output_width = next_multiple(m_output_width, tensorcore_width);

Expand Down Expand Up @@ -92,8 +93,8 @@ m_output_activation{output_activation}
}

// Buffers to keep data from the forward and backward pass
m_forward_tmp.resize(m_n_hidden_layers * 2);
m_backward_tmp.resize(m_n_hidden_layers * 2);
m_forward_tmp.resize(m_can_fuse_activation ? m_n_hidden_layers : (m_n_hidden_layers * 2));
m_backward_tmp.resize(m_can_fuse_activation ? m_n_hidden_layers : (m_n_hidden_layers * 2));

// 1 stream per matrix.
m_training_splitk_streams.resize(m_n_hidden_layers + 1);
Expand Down Expand Up @@ -241,12 +242,28 @@ void CutlassMLP<T>::forward(cudaStream_t stream, const GPUMatrix<T>& input, GPUM
// Run the actual network
uint32_t tmp_idx = 0;

bool fused = compute_layer<FullLayer>(stream, false, m_activation, input_weight_matrix(use_inference_matrices), input, m_forward_tmp.at(tmp_idx), m_forward_tmp.at(tmp_idx+1));
bool fused = compute_layer<FullLayer>(
stream,
false,
m_activation,
input_weight_matrix(use_inference_matrices),
input,
m_forward_tmp.at(tmp_idx),
m_can_fuse_activation ? m_forward_tmp.at(tmp_idx) : m_forward_tmp.at(tmp_idx+1)
);
tmp_idx += fused ? 1 : 2;

// layers
for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) {
fused = compute_layer<FullLayer>(stream, false, m_activation, weight_matrix_at(use_inference_matrices, i), m_forward_tmp.at(tmp_idx-1), m_forward_tmp.at(tmp_idx), m_forward_tmp.at(tmp_idx+1));
fused = compute_layer<FullLayer>(
stream,
false,
m_activation,
weight_matrix_at(use_inference_matrices, i),
m_forward_tmp.at(tmp_idx-1),
m_forward_tmp.at(tmp_idx),
m_can_fuse_activation ? m_forward_tmp.at(tmp_idx) : m_forward_tmp.at(tmp_idx+1)
);
tmp_idx += fused ? 1 : 2;
}

Expand Down Expand Up @@ -275,8 +292,6 @@ void CutlassMLP<T>::backward(
allocate_backward_buffers(batch_size);
}

bool can_fuse_activation = m_activation != Activation::Sine;

// Compute transfer of output activation in-place... it's treated specially for performance reasons
if (m_output_activation != Activation::None) {
activation_backward_output_gpu(stream, dL_doutput.n_elements(), m_output_activation, output.data(), dL_doutput.data(), m_backward_output_tmp.data());
Expand Down Expand Up @@ -315,7 +330,7 @@ void CutlassMLP<T>::backward(
return;
}

uint32_t tmp_idx = (can_fuse_activation ? (m_n_hidden_matmuls+1) : ((m_n_hidden_matmuls+1) * 2)) - 1;
uint32_t tmp_idx = (m_can_fuse_activation ? (m_n_hidden_matmuls+1) : ((m_n_hidden_matmuls+1) * 2)) - 1;
uint32_t backward_tmp_idx = 0;

if (compute_param_gradients) {
Expand All @@ -329,14 +344,14 @@ void CutlassMLP<T>::backward(
cudaEventRecord(m_training_splitk_events.at(backward_tmp_idx), m_training_splitk_streams.at(backward_tmp_idx));
}

if (!can_fuse_activation) {
if (!m_can_fuse_activation) {
fc_multiply<FullLayer>(stream, output_weight_matrix(use_inference_matrices).transposed(), tmp_dL_doutput, m_backward_tmp.at(backward_tmp_idx));
activation_backward_gpu(stream, m_activation, m_forward_tmp.at(tmp_idx-1), m_backward_tmp.at(backward_tmp_idx));
} else {
fc_multiply<FullLayer>(stream, output_weight_matrix(use_inference_matrices).transposed(), tmp_dL_doutput, m_forward_tmp.at(tmp_idx), m_backward_tmp.at(backward_tmp_idx), m_activation, true);
}

tmp_idx -= can_fuse_activation ? 1 : 2;
tmp_idx -= m_can_fuse_activation ? 1 : 2;
++backward_tmp_idx;

// layers
Expand All @@ -350,14 +365,14 @@ void CutlassMLP<T>::backward(
cudaEventRecord(m_training_splitk_events.at(backward_tmp_idx), m_training_splitk_streams.at(backward_tmp_idx));
}

if (!can_fuse_activation) {
if (!m_can_fuse_activation) {
fc_multiply<FullLayer>(stream, weight_matrix_at(use_inference_matrices, matrix_idx).transposed(), m_backward_tmp.at(backward_tmp_idx-1), m_backward_tmp.at(backward_tmp_idx));
activation_backward_gpu(stream, m_activation, m_forward_tmp.at(tmp_idx-1), m_backward_tmp.at(backward_tmp_idx));
} else {
fc_multiply<FullLayer>(stream, weight_matrix_at(use_inference_matrices, matrix_idx).transposed(), m_backward_tmp.at(backward_tmp_idx-1), m_forward_tmp.at(tmp_idx), m_backward_tmp.at(backward_tmp_idx), m_activation, true);
}

tmp_idx -= can_fuse_activation ? 1 : 2;
tmp_idx -= m_can_fuse_activation ? 1 : 2;
++backward_tmp_idx;
}

Expand Down