forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request apache#92 from sxjscience/pr
Add new operator `MatFillRowElement`
- Loading branch information
Showing
2 changed files
with
104 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/*! | ||
* Copyright (c) 2015 by Contributors | ||
* \file fill.h | ||
* \brief support for implicit array filling operation | ||
* \author Xingjian Shi | ||
*/ | ||
#ifndef MSHADOW_EXTENSION_FILL_H_ | ||
#define MSHADOW_EXTENSION_FILL_H_ | ||
|
||
#include "../extension.h" | ||
|
||
|
||
namespace mshadow { | ||
namespace expr { | ||
/*! | ||
* \brief Set value of a specific element in each line of the data matrix. | ||
* \tparam SrcExp type of src expression | ||
* \tparam ValExp type of val expression | ||
* \tparam IndexExp type of index expression | ||
* \tparam DType the type of ret expression | ||
*/ | ||
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | ||
struct MatFillRowElementExp: | ||
public Exp<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, | ||
DType, type::kChainer> { | ||
/*! \brief src operand */ | ||
const SrcExp &src_; | ||
const ValExp &val_; | ||
/*! \brief index operand */ | ||
const IndexExp &index_; | ||
/*! \brief constructor */ | ||
MatFillRowElementExp(const SrcExp &src, const ValExp &val, const IndexExp &index) | ||
: src_(src), val_(val), index_(index) {} | ||
}; | ||
|
||
template<typename SrcExp, typename ValExp, typename IndexExp, | ||
typename SDType, typename VDType, typename IDType, int e1, int e2, int e3> | ||
inline MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType> | ||
mat_fill_row_element(const Exp<SrcExp, SDType, e1> &src, | ||
const Exp<ValExp, VDType, e2> &val, | ||
const Exp<IndexExp, IDType, e3> &index) { | ||
TypeCheckPass<ExpInfo<SrcExp>::kDim == 2 && ExpInfo<ValExp>::kDim == 1 | ||
&& ExpInfo<IndexExp>::kDim == 1>::Error_Expression_Does_Not_Meet_Dimension_Req(); | ||
return MatFillRowElementExp<SrcExp, ValExp, IndexExp, SDType>(src.self(), | ||
val.self(), index.self()); | ||
} | ||
|
||
//---------------------- | ||
// Execution plan | ||
//---------------------- | ||
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | ||
struct Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> { | ||
public: | ||
explicit Plan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &e) | ||
: src_(MakePlan(e.src_)), | ||
val_(MakePlan(e.val_)), | ||
index_(MakePlan(e.index_)) { | ||
} | ||
MSHADOW_XINLINE DType Eval(index_t y, index_t x) const { | ||
index_t idx = static_cast<index_t>(index_.Eval(0, y)); | ||
if (idx == x) { | ||
return static_cast<DType>(val_.Eval(0, y)); | ||
} else { | ||
return static_cast<DType>(src_.Eval(y, x)); | ||
} | ||
} | ||
|
||
private: | ||
expr::Plan<SrcExp, DType> src_; | ||
expr::Plan<ValExp, DType> val_; | ||
expr::Plan<IndexExp, DType> index_; | ||
}; | ||
|
||
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | ||
inline Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType> | ||
MakePlan(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &exp) { | ||
return Plan<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType>, DType>(exp); | ||
} | ||
|
||
template<int dim, typename SrcExp, typename ValExp, typename IndexExp, typename DType> | ||
struct ShapeCheck<dim, MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > { | ||
inline static Shape<dim> | ||
Check(const MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> &t) { | ||
CHECK(dim == 2) | ||
<< "MatFillRowElementExp only support 2 dimension output"; | ||
Shape<2> shape_src = ShapeCheck<2, SrcExp>::Check(t.src_); | ||
Shape<1> shape_val = ShapeCheck<1, ValExp>::Check(t.val_); | ||
Shape<1> shape_index = ShapeCheck<1, IndexExp>::Check(t.index_); | ||
CHECK((shape_src[0] == shape_index[0]) && (shape_index[0] == shape_val[0])) | ||
<< "mat_fill_row_element index length, val length and number of rows in matrix"; | ||
return shape_src; | ||
} | ||
}; | ||
|
||
template<typename SrcExp, typename ValExp, typename IndexExp, typename DType> | ||
struct ExpInfo<MatFillRowElementExp<SrcExp, ValExp, IndexExp, DType> > { | ||
static const int kDim = 2; | ||
static const int kDevMask = | ||
ExpInfo<SrcExp>::kDevMask & ExpInfo<ValExp>::kDevMask & ExpInfo<IndexExp>::kDevMask; | ||
}; | ||
} // namespace expr | ||
} // namespace mshadow | ||
#endif // MSHADOW_EXTENSION_FILL_H_ |