Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add op and extra opt
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Jun 19, 2015
1 parent e4e2178 commit 3674aef
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 6 deletions.
7 changes: 5 additions & 2 deletions include/mxnet/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Operator {
struct Option {
/*! \brief whether it is training phase*/
int is_train;
/*! \brief whether propagate gradient to x in backprop */
int prop_grad;
};
/*! \briref gradient request type the request can have */
enum GradReqType {
Expand All @@ -43,7 +45,7 @@ class Operator {
* \param name parameter name
* \param val string for configuration
*/
virtual void SetParam(const char *name, const char *val) {}
virtual void SetParam(const char *name, const char *val) {}
/*!
* \brief inter the shape of output given the input data
* \param in_shape the shape of input arguments of the operator
Expand Down Expand Up @@ -73,7 +75,8 @@ class Operator {
* \param req_types request types of the gradient saving operation
* \sa GradReqType
*/
virtual void Backward(RunContext ctx,
virtual void Backward(Option opt,
RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_grad,
Expand Down
9 changes: 5 additions & 4 deletions src/operator/activation_op-inl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* \brief activation operator of mxnet
*/

#ifndef ACTIVATION_OP_INL_HPP
#define ACTIVATION_OP_INL_HPP
#ifndef SRC_OPERATOR_ACTIVATION_OP_INL_HPP_
#define SRC_OPERATOR_ACTIVATION_OP_INL_HPP_
#pragma once
#include <mxnet/operator.h>
#include <vector>
Expand Down Expand Up @@ -34,7 +34,8 @@ class ActivationOp : public Operator {
mshadow::Tensor<xpu, 2> out = out_data[0].FlatTo2D(stream);
out = mshadow::expr::F<ForwardOp>(in);
}
virtual void Backward(RunContext ctx,
virtual void Backward(Option opt,
RunContext ctx,
const std::vector<TBlob> &grad_next,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_grad,
Expand All @@ -56,6 +57,6 @@ class ActivationOp : public Operator {
}; // class ActivationOp
} // namespace mxnet

#endif // ACTIVATION_OP_INL_HPP
#endif // SRC_OPERATOR_ACTIVATION_OP_INL_HPP_


109 changes: 109 additions & 0 deletions src/operator/op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*!
* Copyright (c) 2015 by Contributors
* \file op.h
* \brief extra mshadow operation for mxnet
* \author Bing Xu
*/
#ifndef SRC_OPERATOR_OP_H_

This comment has been minimized.

Copy link
@tqchen

tqchen Jun 20, 2015

Member

do MXNET_OPERATOR_OP_H_

#define SRC_OPERATOR_OP_H_
#pragma once

This comment has been minimized.

Copy link
@tqchen

tqchen Jun 20, 2015

Member

remove pragma once, no need for now


#include <algorithm>

namespace mxnet {
/*! \brief operations for ActivationLayer */
namespace op {
struct identity {
MSHADOW_XINLINE static real_t Map(real_t a) {
return a;
}
};
struct identity_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
return 1.0f;
}
};

/*! \brief sigmoid unit */
struct sigmoid {
MSHADOW_XINLINE static real_t Map(real_t a) {
return 1.0f / (1.0f + expf(-a));
}
};
struct sigmoid_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
return a * (1.0f - a);
}
};
/*! \brief Rectified Linear Operation */
struct relu {
MSHADOW_XINLINE static real_t Map(real_t a) {
return std::max(a, 0.0f);
}
};
struct relu_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
return a > 0.0f ? 1.0f : 0.0f;
}
};

/*! \brief Leaky ReLU Operation */
struct xelu {
MSHADOW_XINLINE static real_t Map(real_t a, real_t b) {
return a > 0 ? a : a / b;
}
};

struct xelu_grad {
MSHADOW_XINLINE static real_t Map(real_t a, real_t b) {
return a > 0 ? 1 : 1.0f / b;
}
};

struct tanh {
MSHADOW_XINLINE static real_t Map(real_t a) {
return tanhf( a );
}
};

struct tanh_grad {
MSHADOW_XINLINE static real_t Map(real_t a) {
return 1.0f - a * a;
}
};


struct square {
MSHADOW_XINLINE static real_t Map(real_t a) {
return a * a;
}
};

/*! \brief used for generate Bernoulli mask */
struct threshold {
MSHADOW_XINLINE static real_t Map(real_t a, real_t b) {
return a < b ? 1.0f : 0.0f;
}
};

/*! \brief used for generate element of power */
struct power {
MSHADOW_XINLINE static real_t Map(real_t a, real_t b) {
return powf( a, b );
}
};

/*!\ \brief used for generate element sqrt */
struct square_root {
MSHADOW_XINLINE static real_t Map(real_t a) {
return sqrt(a);
}
};

} // namespace op
} // namespace mxnet

#endif // SRC_OPERATOR_OP_H_



0 comments on commit 3674aef

Please sign in to comment.