-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mixed precision #40
Changes from all commits
b470df0
7698f28
aedc1ec
8d50953
3fb715e
fa9c4b0
9877f14
4bd562f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import torchtext | ||
from torchtext.data.utils import get_tokenizer | ||
|
||
import fairscale.nn.pipe.pipe as pipe | ||
from fairscale.nn import Pipe | ||
|
||
try: | ||
from fairscale.optim.adam import Adam # type: ignore | ||
|
@@ -129,13 +129,15 @@ def make_model(device, ntokens): | |
dropout = 0 | ||
initrange = 0.1 | ||
|
||
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).to(device) | ||
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device) | ||
balance = generate_balance(min(num_devices, 4), len(model)) | ||
p = Pipe(model, balance) | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
lr = 0.01 # learning rate | ||
optimizer = Adam(model.parameters(), lr=lr) | ||
optimizer = Adam(p.parameters(), lr=lr, mixed_precision=True) | ||
|
||
return model, criterion, optimizer | ||
return p, criterion, optimizer | ||
|
||
|
||
def train(train_data, model, criterion, optimizer, bptt, ntokens): | ||
|
@@ -221,7 +223,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, | |
if can_benchmark and len(model.balance) == 4: | ||
# Assert that words per second is within 3 standard deviations of the average | ||
# of six golden runs | ||
assert wps > 20052.1 - (3 * 359) | ||
assert wps > 27799.2 - (3 * 522.145) | ||
|
||
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"])) | ||
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"])) | ||
|
@@ -230,10 +232,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, | |
|
||
# Assert that memory usage on each GPU is within 10% of golden run | ||
# Right-hand-side is golden run bytes * 110% | ||
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 365916160 * 1.1 | ||
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 1281024 * 1.1 | ||
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 2788864 * 1.1 | ||
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 190724608 * 1.1 | ||
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 210479616 * 1.1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, i'm curious what tool did you use get these exact numbers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I used the values printed by the above four lines of code |
||
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1 | ||
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1605120 * 1.1 | ||
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 113801216 * 1.1 | ||
print("No regression detected") | ||
|
||
|
||
|
@@ -259,7 +261,5 @@ def generate_balance(num_devices, num_layers): | |
device = torch.device("cuda") | ||
ntokens, train_data, val_data, test_data = get_data(device) | ||
model, criterion, optimizer = make_model(device, ntokens) | ||
balance = generate_balance(min(num_devices, 4), len(model)) | ||
p = pipe.Pipe(model, balance) | ||
benchmark_language_model(train_data, val_data, test_data, p, criterion, optimizer, ntokens) | ||
del p | ||
benchmark_language_model(train_data, val_data, test_data, model, criterion, optimizer, ntokens) | ||
del model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
#include <stdio.h> | ||
#include <assert.h> | ||
#include <cmath> | ||
#include "ATen/TensorUtils.h" | ||
// #include "ATen/Type.h" | ||
|
@@ -19,9 +20,7 @@ typedef enum{ | |
ADAM_MODE_1 =1 // eps outside square root | ||
} adamMode_t; | ||
|
||
|
||
|
||
template <int DEPTH, typename T, typename GRAD_T> | ||
template <int DEPTH, typename PARAM_T, typename GRAD_T> | ||
struct AdamFunctor | ||
{ | ||
__device__ __forceinline__ void operator()( | ||
|
@@ -40,26 +39,26 @@ struct AdamFunctor | |
int chunk_idx = tl.block_to_chunk[blockIdx.x]; | ||
int n = tl.sizes[tensor_loc]; | ||
|
||
GRAD_T* p = (GRAD_T *)tl.addresses[0][tensor_loc]; | ||
PARAM_T* p = (PARAM_T *)tl.addresses[0][tensor_loc]; | ||
p += chunk_idx*chunk_size; | ||
T* m = (T *)tl.addresses[1][tensor_loc]; | ||
float* m = (float *)tl.addresses[1][tensor_loc]; | ||
m += chunk_idx*chunk_size; | ||
T* v = (T *)tl.addresses[2][tensor_loc]; | ||
float* v = (float *)tl.addresses[2][tensor_loc]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. naive question: why do we have types of m and v as float? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are momentum and velocity! And right now, we require them to be floats; in the Python code, when they are instantiated, they are always dtype=torch.float32. Next pull request will add in the option for them to be fp16 |
||
v += chunk_idx*chunk_size; | ||
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; | ||
g += chunk_idx*chunk_size; | ||
GRAD_T* p_copy = NULL; | ||
at::Half* p_copy = NULL; | ||
if (DEPTH == 5) { | ||
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; | ||
p_copy = (at::Half*)tl.addresses[4][tensor_loc]; | ||
p_copy += chunk_idx*chunk_size; | ||
} | ||
|
||
n -= chunk_idx*chunk_size; | ||
|
||
T incoming_p[ILP]; | ||
T incoming_m[ILP]; | ||
T incoming_v[ILP]; | ||
T incoming_g[ILP]; | ||
PARAM_T incoming_p[ILP]; | ||
float incoming_m[ILP]; | ||
float incoming_v[ILP]; | ||
GRAD_T incoming_g[ILP]; | ||
|
||
for(int i_start = 0; | ||
i_start < n && i_start < chunk_size; | ||
|
@@ -74,10 +73,10 @@ struct AdamFunctor | |
|
||
int i = i_start + threadIdx.x + ii*blockDim.x; | ||
if (i < n && i < chunk_size) { | ||
incoming_p[ii] = static_cast<T>(p[i]); | ||
incoming_p[ii] = static_cast<PARAM_T>(p[i]); | ||
incoming_m[ii] = m[i]; | ||
incoming_v[ii] = v[i]; | ||
incoming_g[ii] = static_cast<T>(g[i]); | ||
incoming_g[ii] = static_cast<GRAD_T>(g[i]); | ||
} | ||
} | ||
|
||
|
@@ -91,7 +90,7 @@ struct AdamFunctor | |
int j = i_start + threadIdx.x + ii*blockDim.x; | ||
|
||
if(j < n && j < chunk_size) { | ||
T scaled_grad = incoming_g[ii]/grad_scale; | ||
float scaled_grad = incoming_g[ii]/grad_scale; | ||
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; | ||
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; | ||
float denom; | ||
|
@@ -100,8 +99,8 @@ struct AdamFunctor | |
else // Mode 1 | ||
denom = sqrtf(v[j]) + eps; | ||
float update = (m[j]/denom) + (decay*incoming_p[ii]); | ||
p[j] = (GRAD_T)(incoming_p[ii] - (step_size*update)); | ||
if (DEPTH == 5) p_copy[j] = p[j]; | ||
p[j] = (PARAM_T)(incoming_p[ii] - (step_size*update)); | ||
if (DEPTH == 5) p_copy[j] = (at::Half) p[j]; | ||
} | ||
} | ||
} | ||
|
@@ -135,24 +134,65 @@ void fused_adam_cuda( | |
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
size_t tl_sz = tensor_lists.size(); | ||
AT_ASSERTM(tl_sz == 4, "expected tensor lists of size 4"); | ||
|
||
// check that the model and gradients are FP32 | ||
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float); | ||
AT_ASSERTM(tensor_lists[3][0].scalar_type() == at::ScalarType::Float); | ||
multi_tensor_apply<4>( | ||
BLOCK_SIZE, | ||
chunk_size, | ||
noop_flag, | ||
tensor_lists, | ||
AdamFunctor<4, float, float>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
(adamMode_t) mode, | ||
decay | ||
); | ||
assert(tl_sz == 4 || tl_sz == 5); | ||
|
||
if(tl_sz == 5) { | ||
// Mixed precision case | ||
assert(tensor_lists[0][0].scalar_type() == at::ScalarType::Float); | ||
assert(tensor_lists[3][0].scalar_type() == at::ScalarType::Half); | ||
assert(tensor_lists[4][0].scalar_type() == at::ScalarType::Half); | ||
multi_tensor_apply<5>( | ||
BLOCK_SIZE, | ||
chunk_size, | ||
noop_flag, | ||
tensor_lists, | ||
AdamFunctor<5, float, at::Half>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
(adamMode_t) mode, | ||
decay | ||
); | ||
} else { | ||
// tl_sz == 4 | ||
assert(tensor_lists[0][0].scalar_type() == tensor_lists[3][0].scalar_type()); | ||
if(tensor_lists[0][0].scalar_type() == at::ScalarType::Float) { | ||
// Full precision case | ||
multi_tensor_apply<4>( | ||
BLOCK_SIZE, | ||
chunk_size, | ||
noop_flag, | ||
tensor_lists, | ||
AdamFunctor<4, float, float>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
(adamMode_t) mode, | ||
decay | ||
); | ||
} else if (tensor_lists[0][0].scalar_type() == at::ScalarType::Half) { | ||
// "Memory Efficient Training" case | ||
multi_tensor_apply<4>( | ||
BLOCK_SIZE, | ||
chunk_size, | ||
noop_flag, | ||
tensor_lists, | ||
AdamFunctor<4, at::Half, at::Half>(), | ||
beta1, | ||
beta2, | ||
eps, | ||
grad_scale, | ||
step_size, | ||
(adamMode_t) mode, | ||
decay | ||
); | ||
} else { | ||
throw "Parameters must be of type float or half"; | ||
} | ||
} | ||
THCudaCheck(cudaGetLastError()); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice speedups and memory reduction (below) !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!