Skip to content

Commit

Permalink
Add Euclidean SSE loss
Browse files Browse the repository at this point in the history
  • Loading branch information
danielsuo committed Apr 8, 2016
1 parent 9559b7a commit e5fa55c
Showing 1 changed file with 84 additions and 16 deletions.
100 changes: 84 additions & 16 deletions marvin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,25 @@ __global__ void Loss_SmoothL1(size_t CUDA_NUM_LOOPS, size_t N,
}
}

__global__ void Loss_EuclideanSSE(size_t CUDA_NUM_LOOPS, size_t N,
const StorageT *pred, const StorageT *target,
const StorageT *weight, StorageT *loss) {
// diff = f( weight * (pred - target) )
// f(x) = 0.5 * x^2
const size_t idxBase = size_t(CUDA_NUM_LOOPS) *
(size_t(CUDA_NUM_THREADS) * size_t(blockIdx.x) +
size_t(threadIdx.x));
if (idxBase >= N) return;
for (size_t idx = idxBase; idx < min(N, idxBase + CUDA_NUM_LOOPS); ++idx) {

ComputeT val =
GPUStorage2ComputeT(pred[idx]) - GPUStorage2ComputeT(target[idx]);
if (weight != NULL) val *= GPUStorage2ComputeT(weight[idx]);

loss[idx] = GPUCompute2StorageT(0.5 * val * val);
}
}

__global__ void LossGrad_SmoothL1(
size_t CUDA_NUM_LOOPS, size_t N, ComputeT scale, const StorageT *pred,
const StorageT *target, const StorageT *weight, StorageT *diff) {
Expand Down Expand Up @@ -1193,6 +1212,26 @@ __global__ void LossGrad_SmoothL1(
}
}

__global__ void LossGrad_EuclideanSSE(
size_t CUDA_NUM_LOOPS, size_t N, ComputeT scale, const StorageT *pred,
const StorageT *target, const StorageT *weight, StorageT *diff) {
// diff = scale * f'( weight * (pred - target) )
// f'(x) = x

const size_t idxBase = size_t(CUDA_NUM_LOOPS) *
(size_t(CUDA_NUM_THREADS) * size_t(blockIdx.x) +
size_t(threadIdx.x));
if (idxBase >= N) return;
for (size_t idx = idxBase; idx < min(N, idxBase + CUDA_NUM_LOOPS); ++idx) {

ComputeT val =
GPUStorage2ComputeT(pred[idx]) - GPUStorage2ComputeT(target[idx]);
if (weight != NULL) val *= GPUStorage2ComputeT(weight[idx]);

diff[idx] = GPUCompute2StorageT(GPUStorage2ComputeT(diff[idx]) + scale * val);
}
}

__global__ void Loss_Contrastive(
size_t CUDA_NUM_LOOPS, size_t N, int C, ComputeT margin, const StorageT *a,
const StorageT *b, const StorageT *y, StorageT *loss) {
Expand Down Expand Up @@ -2604,8 +2643,8 @@ std::vector<Tensor<T>*> readTensors(std::string filename, size_t max_count = SIZ
tensors.push_back(new Tensor<T>(fp));
count++;
if (count>=max_count) break;
int c = getc(fp);
ungetc(c, fp);
int c = getc(fp);
ungetc(c, fp);
}
fclose(fp);
return tensors;
Expand Down Expand Up @@ -5219,24 +5258,42 @@ class LossLayer : public Layer {
std::endl;
FatalError(__LINE__);
}
if (!same_dim(in[0]->dim, in[1]->dim)) {
std::cout <<
"LossLayer: SmoothL1 should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
if (in.size() == 3 && !same_dim(in[0]->dim, in[2]->dim)) {
std::cout <<
"LossLayer: SmoothL1 should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
loss_numel = numel(in[0]->dim);
if (!same_dim(in[0]->dim, in[1]->dim)) {
std::cout <<
"LossLayer: SmoothL1 should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
if (in.size() == 3 && !same_dim(in[0]->dim, in[2]->dim)) {
std::cout <<
"LossLayer: SmoothL1 should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
loss_numel = numel(in[0]->dim);
break;
case Contrastive:
loss_numel = numExamples;
break;
case EuclideanSSE:
if (!(in.size() == 2 || in.size() == 3)) {
std::cout << "LossLayer: EuclideanSSE should have 2 or 3 ins" <<
std::endl;
FatalError(__LINE__);
}
if (!same_dim(in[0]->dim, in[1]->dim)) {
std::cout <<
"LossLayer: EuclideanSSE should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
if (in.size() == 3 && !same_dim(in[0]->dim, in[2]->dim)) {
std::cout <<
"LossLayer: EuclideanSSE should have the same dimensions" <<
std::endl;
FatalError(__LINE__);
}
loss_numel = numel(in[0]->dim);
break;
case HingeL1:
break;
Expand Down Expand Up @@ -5319,6 +5376,11 @@ class LossLayer : public Layer {
loss_values);
break;
case EuclideanSSE:
Loss_EuclideanSSE<<<CUDA_GET_BLOCKS(loss_numel),
CUDA_NUM_THREADS>>>(
CUDA_GET_LOOPS(loss_numel),
loss_numel, in[0]->dataGPU, in[1]->dataGPU,
(in.size()==3 ? in[2]->dataGPU : NULL), loss_values);
break;
case HingeL1:
break;
Expand Down Expand Up @@ -5377,6 +5439,12 @@ class LossLayer : public Layer {
in[0]->diffGPU, in[1]->diffGPU);
break;
case EuclideanSSE:
LossGrad_EuclideanSSE<<<CUDA_GET_BLOCKS(loss_numel),
CUDA_NUM_THREADS>>>(
CUDA_GET_LOOPS(loss_numel),
loss_numel, scale, in[0]->dataGPU, in[1]->dataGPU,
(in.size()==3 ? in[2]->dataGPU : NULL),
in[0]->diffGPU);
break;
case HingeL1:
break;
Expand Down Expand Up @@ -7480,4 +7548,4 @@ class Solver{
};
};

} // namespace marvin
} // namespace marvin

0 comments on commit e5fa55c

Please sign in to comment.