Skip to content

Commit

Permalink
modify for pytorch 0.5.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jinserk committed Jul 19, 2018
1 parent 2321716 commit d1dd79f
Showing 1 changed file with 67 additions and 63 deletions.
130 changes: 67 additions & 63 deletions pytorch_binding/src/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,28 @@
#endif

extern "C" int cpu_ctc(THFloatTensor *probs,
THFloatTensor *grads,
THIntTensor *labels,
THIntTensor *label_sizes,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label) {

float *probs_ptr = probs->storage->data + probs->storageOffset;
THFloatTensor *grads,
THIntTensor *labels,
THIntTensor *label_sizes,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label)
{
float *probs_ptr = THFloatTensor_data(probs);
float *grads_ptr;
if (grads->storage) {
grads_ptr = grads->storage->data + grads->storageOffset;
if (THFloatTensor_storage(grads)) {
grads_ptr = THFloatTensor_data(grads);
} else {
grads_ptr = NULL; // this will trigger the score forward code path
grads_ptr = NULL; // this will trigger the score forward code path
}

int *sizes_ptr = sizes->storage->data + sizes->storageOffset;
int *labels_ptr = labels->storage->data + labels->storageOffset;
int *label_sizes_ptr = label_sizes->storage->data + label_sizes->storageOffset;
float *costs_ptr = costs->storage->data + costs->storageOffset;
int *sizes_ptr = THIntTensor_data(sizes);
int *labels_ptr = THIntTensor_data(labels);
int *label_sizes_ptr = THIntTensor_data(label_sizes);
float *costs_ptr = THFloatTensor_data(costs);

int probs_size = THFloatTensor_size(probs, 2);

ctcOptions options;
memset(&options, 0, sizeof(options));
Expand All @@ -49,64 +51,66 @@ extern "C" int cpu_ctc(THFloatTensor *probs,

size_t cpu_size_bytes;
get_workspace_size(label_sizes_ptr, sizes_ptr,
(int) probs->size[2], minibatch_size,
probs_size, minibatch_size,
options, &cpu_size_bytes);

float* cpu_workspace = (float*) new unsigned char[cpu_size_bytes];

compute_ctc_loss(probs_ptr, grads_ptr,
labels_ptr, label_sizes_ptr,
sizes_ptr, probs->size[2],
sizes_ptr, probs_size,
minibatch_size, costs_ptr,
cpu_workspace, options);

delete cpu_workspace;
return 1;
}

#ifdef WARPCTC_ENABLE_GPU
extern "C" int gpu_ctc(THCudaTensor *probs,
THCudaTensor *grads,
THIntTensor *labels,
THIntTensor *label_sizes,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label) {

float *probs_ptr = probs->storage->data + probs->storageOffset;
float *grads_ptr;
if (grads->storage) {
grads_ptr = grads->storage->data + grads->storageOffset;
} else {
grads_ptr = NULL; // this will trigger the score forward code path
}

int *sizes_ptr = sizes->storage->data + sizes->storageOffset;
int *labels_ptr = labels->storage->data + labels->storageOffset;
int *label_sizes_ptr = label_sizes->storage->data + label_sizes->storageOffset;
float *costs_ptr = costs->storage->data + costs->storageOffset;

ctcOptions options;
memset(&options, 0, sizeof(options));
options.loc = CTC_GPU;
options.blank_label = blank_label;
options.stream = THCState_getCurrentStream(state);

size_t gpu_size_bytes;
get_workspace_size(label_sizes_ptr, sizes_ptr,
(int) probs->size[2], minibatch_size,
options, &gpu_size_bytes);

float* gpu_workspace;
THCudaMalloc(state, (void **) &gpu_workspace, gpu_size_bytes);

compute_ctc_loss(probs_ptr, grads_ptr,
labels_ptr, label_sizes_ptr,
sizes_ptr, probs->size[2],
minibatch_size, costs_ptr,
gpu_workspace, options);

THCudaFree(state, (void *) gpu_workspace);
return 1;
}
extern "C" int gpu_ctc(THCudaTensor *probs,
THCudaTensor *grads,
THIntTensor *labels,
THIntTensor *label_sizes,
THIntTensor *sizes,
int minibatch_size,
THFloatTensor *costs,
int blank_label)
{
float *probs_ptr = THCudaTensor_data(state, probs);
float *grads_ptr;
if (THCudaTensor_storage(state, grads)) {
grads_ptr = THCudaTensor_data(state, grads);
} else {
grads_ptr = NULL; // this will trigger the score forward code path
}

int *sizes_ptr = THIntTensor_data(sizes);
int *labels_ptr = THIntTensor_data(labels);
int *label_sizes_ptr = THIntTensor_data(label_sizes);
float *costs_ptr = THFloatTensor_data(costs);

int probs_size = THFloatTensor_size(probs, 2);

ctcOptions options;
memset(&options, 0, sizeof(options));
options.loc = CTC_GPU;
options.blank_label = blank_label;
options.stream = THCState_getCurrentStream(state);

size_t gpu_size_bytes;
get_workspace_size(label_sizes_ptr, sizes_ptr,
probs_size, minibatch_size,
options, &gpu_size_bytes);

void* gpu_workspace = THCudaMalloc(state, gpu_size_bytes);

compute_ctc_loss(probs_ptr, grads_ptr,
labels_ptr, label_sizes_ptr,
sizes_ptr, probs_size,
minibatch_size, costs_ptr,
gpu_workspace, options);

THCudaFree(state, (void *) gpu_workspace);
return 1;
}
#endif

0 comments on commit d1dd79f

Please sign in to comment.