Skip to content

Commit

Permalink
Added [convolutional] activation=normalize_channels_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexeyAB committed Dec 6, 2019
1 parent 5d13aad commit c9c745c
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 1 deletion.
2 changes: 1 addition & 1 deletion include/darknet.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ typedef struct tree {

// activations.h
typedef enum {
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH, NORM_CHAN
LOGISTIC, RELU, RELIE, LINEAR, RAMP, TANH, PLSE, LEAKY, ELU, LOGGY, STAIR, HARDTAN, LHTAN, SELU, SWISH, MISH, NORM_CHAN, NORM_CHAN_SOFTMAX
}ACTIVATION;

// parser.h
Expand Down
79 changes: 79 additions & 0 deletions src/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,10 @@ extern "C" void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta
else if (a == HARDTAN) gradient_array_hardtan_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == RELU) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == NORM_CHAN) gradient_array_relu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else if (a == NORM_CHAN_SOFTMAX) {
printf(" Error: should be used custom NORM_CHAN_SOFTMAX-function for gradient \n");
exit(0);
}
else if (a == SELU) gradient_array_selu_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> >(x, n, delta);
else
gradient_array_kernel << <cuda_gridsize(n), BLOCK, 0, get_cuda_stream() >> > (x, n, a, delta);
Expand Down Expand Up @@ -456,4 +460,79 @@ extern "C" void activate_array_normalize_channels_ongpu(float *x, int n, int bat

activate_array_normalize_channels_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}



__global__ void activate_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *output_gpu)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;

int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float sum = eps;
int k;
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
sum += expf(val);
}
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
val = expf(val) / sum;
output_gpu[wh_i + k * wh_step + b*wh_step*channels] = val;
}
}
}

extern "C" void activate_array_normalize_channels_softmax_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu)
{
// n = w*h*c*batch
// size = w*h*batch
int size = n / channels;

const int num_blocks = get_number_of_blocks(size, BLOCK);

activate_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (x, size, batch, channels, wh_step, output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}



__global__ void gradient_array_normalize_channels_softmax_kernel(float *x, int size, int batch, int channels, int wh_step, float *delta_gpu)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;

int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float grad = eps;
int k;
for (k = 0; k < channels; ++k) {
float out = x[wh_i + k * wh_step + b*wh_step*channels];
float delta = delta_gpu[wh_i + k * wh_step + b*wh_step*channels];
grad += out*delta;
}
for (k = 0; k < channels; ++k) {
float delta = delta_gpu[wh_i + k * wh_step + b*wh_step*channels];
delta = delta * grad;
delta_gpu[wh_i + k * wh_step + b*wh_step*channels] = delta;
}
}
}

extern "C" void gradient_array_normalize_channels_softmax_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu)
{
// n = w*h*c*batch
// size = w*h*batch
int size = n / channels;

const int num_blocks = get_number_of_blocks(size, BLOCK);

gradient_array_normalize_channels_softmax_kernel << <num_blocks, BLOCK, 0, get_cuda_stream() >> > (output_gpu, size, batch, channels, wh_step, delta_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
60 changes: 60 additions & 0 deletions src/activations.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ ACTIVATION get_activation(char *s)
if (strcmp(s, "swish") == 0) return SWISH;
if (strcmp(s, "mish") == 0) return MISH;
if (strcmp(s, "normalize_channels") == 0) return NORM_CHAN;
if (strcmp(s, "normalize_channels_softmax") == 0) return NORM_CHAN_SOFTMAX;
if (strcmp(s, "loggy")==0) return LOGGY;
if (strcmp(s, "relu")==0) return RELU;
if (strcmp(s, "elu")==0) return ELU;
Expand Down Expand Up @@ -176,6 +177,61 @@ void activate_array_normalize_channels(float *x, const int n, int batch, int cha
}
}

void activate_array_normalize_channels_softmax(float *x, const int n, int batch, int channels, int wh_step, float *output)
{
int size = n / channels;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float sum = eps;
int k;
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
sum += expf(val);
}
for (k = 0; k < channels; ++k) {
float val = x[wh_i + k * wh_step + b*wh_step*channels];
val = expf(val) / sum;
output[wh_i + k * wh_step + b*wh_step*channels] = val;
}
}
}
}

void gradient_array_normalize_channels_softmax(float *x, const int n, int batch, int channels, int wh_step, float *delta)
{
int size = n / channels;

int i;
#pragma omp parallel for
for (i = 0; i < size; ++i) {
int wh_i = i % wh_step;
int b = i / wh_step;

const float eps = 0.0001;
if (i < size) {
float grad = eps;
int k;
for (k = 0; k < channels; ++k) {
float out = x[wh_i + k * wh_step + b*wh_step*channels];
float d = delta[wh_i + k * wh_step + b*wh_step*channels];
grad += out*d;
}
for (k = 0; k < channels; ++k) {
float d = delta[wh_i + k * wh_step + b*wh_step*channels];
d = d * grad;
delta[wh_i + k * wh_step + b*wh_step*channels] = d;
}
}
}
}

float gradient(float x, ACTIVATION a)
{
switch(a){
Expand All @@ -189,6 +245,10 @@ float gradient(float x, ACTIVATION a)
return relu_gradient(x);
case NORM_CHAN:
return relu_gradient(x);
case NORM_CHAN_SOFTMAX:
printf(" Error: should be used custom NORM_CHAN_SOFTMAX-function for gradient \n");
exit(0);
return 0;
case ELU:
return elu_gradient(x);
case SELU:
Expand Down
4 changes: 4 additions & 0 deletions src/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void activate_array(float *x, const int n, const ACTIVATION a);
void activate_array_swish(float *x, const int n, float * output_sigmoid, float * output);
void activate_array_mish(float *x, const int n, float * activation_input, float * output);
void activate_array_normalize_channels(float *x, const int n, int batch, int channels, int wh_step, float *output);
void activate_array_normalize_channels_softmax(float *x, const int n, int batch, int channels, int wh_step, float *output);
#ifdef GPU
void activate_array_ongpu(float *x, int n, ACTIVATION a);
void activate_array_swish_ongpu(float *x, int n, float *output_sigmoid_gpu, float *output_gpu);
Expand All @@ -31,6 +32,9 @@ void gradient_array_ongpu(float *x, int n, ACTIVATION a, float *delta);
void gradient_array_swish_ongpu(float *x, int n, float *sigmoid_gpu, float *delta);
void gradient_array_mish_ongpu(int n, float *activation_input_gpu, float *delta);
void activate_array_normalize_channels_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu);
void activate_array_normalize_channels_softmax_ongpu(float *x, int n, int batch, int channels, int wh_step, float *output_gpu);
void gradient_array_normalize_channels_softmax_ongpu(float *output_gpu, int n, int batch, int channels, int wh_step, float *delta_gpu);

#endif

static inline float stair_activate(float x)
Expand Down
3 changes: 3 additions & 0 deletions src/convolutional_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output_gpu);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output_gpu);
else if (l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.activation != LINEAR && l.activation != LEAKY) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if (l.binary || l.xnor) swap_binary(&l);
Expand Down Expand Up @@ -601,6 +602,7 @@ void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
if (l.activation == SWISH) activate_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == MISH) activate_array_mish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.output_gpu);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output_gpu);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output_gpu);
else if (l.activation != LINEAR) activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
//if(l.dot > 0) dot_error_gpu(l);
if(l.binary || l.xnor) swap_binary(&l);
Expand Down Expand Up @@ -645,6 +647,7 @@ void backward_convolutional_layer_gpu(convolutional_layer l, network_state state

if (l.activation == SWISH) gradient_array_swish_ongpu(l.output_gpu, l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
else if (l.activation == MISH) gradient_array_mish_ongpu(l.outputs*l.batch, l.activation_input_gpu, l.delta_gpu);
else if (l.activation == NORM_CHAN_SOFTMAX) gradient_array_normalize_channels_softmax_ongpu(l.output_gpu, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
else gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);

if (!l.batch_normalize)
Expand Down
3 changes: 3 additions & 0 deletions src/convolutional_layer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1204,6 +1204,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
return;

Expand Down Expand Up @@ -1245,6 +1246,7 @@ void forward_convolutional_layer(convolutional_layer l, network_state state)
if (l.activation == SWISH) activate_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == MISH) activate_array_mish(l.output, l.outputs*l.batch, l.activation_input, l.output);
else if (l.activation == NORM_CHAN) activate_array_normalize_channels(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else if (l.activation == NORM_CHAN_SOFTMAX) activate_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.output);
else activate_array_cpu_custom(l.output, l.outputs*l.batch, l.activation);

if(l.binary || l.xnor) swap_binary(&l);
Expand Down Expand Up @@ -1383,6 +1385,7 @@ void backward_convolutional_layer(convolutional_layer l, network_state state)

if (l.activation == SWISH) gradient_array_swish(l.output, l.outputs*l.batch, l.activation_input, l.delta);
else if (l.activation == MISH) gradient_array_mish(l.outputs*l.batch, l.activation_input, l.delta);
else if (l.activation == NORM_CHAN_SOFTMAX) gradient_array_normalize_channels_softmax(l.output, l.outputs*l.batch, l.batch, l.out_c, l.out_w*l.out_h, l.delta);
else gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);

if (l.batch_normalize) {
Expand Down

0 comments on commit c9c745c

Please sign in to comment.