Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize MaxAndArgmax to all Commutative Operations and Datatypes and all Destination Tensor Sizes #334

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c3ae76c
Current status of reduction generalization and small-destination
obilaniu Jan 24, 2017
939a115
Add strb_init() function.
obilaniu Jan 26, 2017
a21bcb5
Moved the reduction API to reduction.h.
obilaniu Jan 26, 2017
6ed2534
Feedback Applied.
obilaniu Jan 26, 2017
a0654c2
More style fixes on switches.
obilaniu Jan 26, 2017
67e163e
Refactoring of all non-code-gen-related functions.
obilaniu Mar 3, 2017
b88ae57
Added variadic string append function strb_appendv().
obilaniu Mar 3, 2017
0949626
Massive refactor of kernel codegen.
obilaniu Mar 5, 2017
8fe9083
Added testcases for all reductions.
obilaniu Mar 5, 2017
32bd11d
Muzzle incorrect GCC maybe-uninitialized diagnostic.
obilaniu Mar 5, 2017
19bd939
Current State
obilaniu May 15, 2017
1a2df8d
Current State
obilaniu Jun 13, 2017
fffd323
Remove warp axis select.
obilaniu Jun 14, 2017
1cfe552
Massive cleanup.
obilaniu Jun 14, 2017
2317ca1
More planning for 2-stage reduction.
obilaniu Jun 14, 2017
c3977d8
Near-complete rewrite based on 1/2-phase code model with workspace.
obilaniu Jun 27, 2017
8fc792b
More fixes.
obilaniu Jul 4, 2017
8debf2d
Really dumb division bug fixed.
obilaniu Jul 4, 2017
5f4ec4e
Fix summation tests:
obilaniu Jul 4, 2017
eb108be
Add huge sum-reduction and pepper kernel with `restrict` keyword, it
obilaniu Jul 4, 2017
ce9c067
Massive Refactor into effectively a lattice engine.
obilaniu Jul 12, 2017
c9a0389
More refactoring.
obilaniu Jul 14, 2017
6fb0793
Delete an "initialization" that should not be there.
obilaniu Jul 14, 2017
4a17f48
Added an initialization that WAS needed.
obilaniu Jul 14, 2017
328c957
Add a bunch of local_barrier()'s.
obilaniu Jul 14, 2017
925688c
Style fixes.
obilaniu Jul 14, 2017
8f5250e
Muzzle -Wdeclaration-after-statement in check_reduction.c.
obilaniu Jul 23, 2017
fac52b6
Easy feedback fixes applied.
obilaniu Jul 25, 2017
f129c69
Add stdargs support to the error API.
obilaniu Aug 4, 2017
76fd38c
Deleted recently-removed properties.
obilaniu Aug 26, 2017
0832fa1
Added missing header
obilaniu Aug 26, 2017
c679474
For test purposes, create buffer of ULONG rather than unsupported SIZE.
obilaniu Aug 27, 2017
ecde75c
Bugfix in GpuReduction_new().
obilaniu Aug 27, 2017
79d3649
Bugfixes in check_reduction.c
obilaniu Aug 27, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -143,6 +143,7 @@ set(headers
gpuarray/extension.h
gpuarray/ext_cuda.h
gpuarray/kernel.h
gpuarray/reduction.h
gpuarray/types.h
gpuarray/util.h
)
1 change: 1 addition & 0 deletions src/cluda_cuda.h
Original file line number Diff line number Diff line change
@@ -60,6 +60,7 @@
#define GA_DECL_SHARED_PARAM(type, name)
#define GA_DECL_SHARED_BODY(type, name) extern __shared__ type name[];
#define GA_WARP_SIZE warpSize
#define restrict __restrict__

struct ga_half {
ga_ushort data;
551 changes: 277 additions & 274 deletions src/cluda_cuda.h.c

Large diffs are not rendered by default.

40 changes: 2 additions & 38 deletions src/gpuarray/array.h
Original file line number Diff line number Diff line change
@@ -604,44 +604,8 @@ GPUARRAY_PUBLIC void GpuArray_fprintf(FILE *fd, const GpuArray *a);

GPUARRAY_PUBLIC int GpuArray_fdump(FILE *fd, const GpuArray *a);

/**
* @brief Computes simultaneously the maxima and the arguments of maxima over
* specified axes of the tensor.
*
* Returns two tensors of identical shape. Both tensors' axes are a subset of
* the axes of the original tensor. The axes to be reduced are specified by
* the caller, and the maxima and arguments of maxima are computed over them.
*
* @param [out] dstMax The resulting tensor of maxima
* @param [out] dstArgmax the resulting tensor of arguments at maxima
* @param [in] src The source tensor.
* @param [in] reduxLen The number of axes reduced. Must be >= 1 and
* <= src->nd.
* @param [in] reduxList A list of integers of length reduxLen, indicating
* the axes to be reduced. The order of the axes
* matters for dstArgmax index calculations. All
* entries in the list must be unique, >= 0 and
* < src->nd.
*
* For example, if a 5D-tensor is reduced with an axis
* list of [3,4,1], then reduxLen shall be 3, and the
* index calculation in every point shall take the form
*
* dstArgmax[i0,i2] = i3 * src.shape[4] * src.shape[1] +
* i4 * src.shape[1] +
* i1
*
* where (i3,i4,i1) are the coordinates of the maximum-
* valued element within subtensor [i0,:,i2,:,:] of src.
* @return GA_NO_ERROR if the operation was successful, or a non-zero error
* code otherwise.
*/

GPUARRAY_PUBLIC int GpuArray_maxandargmax(GpuArray* dstMax,
GpuArray* dstArgmax,
const GpuArray* src,
unsigned reduxLen,
const unsigned* reduxList);



#ifdef __cplusplus
}
169 changes: 169 additions & 0 deletions src/gpuarray/reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#ifndef GPUARRAY_REDUCTION_H
#define GPUARRAY_REDUCTION_H
/**
* \file reduction.h
* \brief Reduction functions.
*/

#include <gpuarray/array.h>

#ifdef _MSC_VER
#ifndef inline
#define inline __inline
#endif
#endif

#ifdef __cplusplus
extern "C" {
#endif
#ifdef CONFUSE_EMACS
}
#endif


/* Data Structures */
struct GpuReductionAttr;
struct GpuReduction;
typedef struct GpuReductionAttr GpuReductionAttr;
typedef struct GpuReduction GpuReduction;


/**
* Supported array reduction operations.
*/

typedef enum _ga_reduce_op {
/* d0 , d1 */
GA_ELEMWISE,
GA_REDUCE_COPY=GA_ELEMWISE, /* (copy) */
GA_REDUCE_SUM, /* + */
GA_REDUCE_PROD, /* * */
GA_REDUCE_PRODNZ, /* * (!=0) */
GA_REDUCE_MIN, /* min() */
GA_REDUCE_MAX, /* max() */
GA_REDUCE_ARGMIN, /* argmin() */
GA_REDUCE_ARGMAX, /* argmax() */
GA_REDUCE_MINANDARGMIN, /* min() , argmin() */
GA_REDUCE_MAXANDARGMAX, /* max() , argmax() */
GA_REDUCE_AND, /* & */
GA_REDUCE_OR, /* | */
GA_REDUCE_XOR, /* ^ */
GA_REDUCE_ALL, /* &&/all() */
GA_REDUCE_ANY, /* ||/any() */

GA_REDUCE_ENDSUPPORTED /* Must be last element in enum */
} ga_reduce_op;


/* External Functions */

/**
* @brief Create, modify and free the attributes of a reduction operator.
*
* @param [out] grAttr The reduction operator attributes object.
* @param [in] op The reduction operation.
* @param [in] maxSrcDims The maximum number of supported source dimensions.
* @param [in] maxDstDims The maximum number of supported destination dimensions.
* @param [in] s0Typecode The typecode of the source tensor.
* @param [in] d0Typecode The typecode of the first destination tensor.
* @param [in] d1Typecode The typecode of the second destination tensor.
* @param [in] i0Typecode The typecode of the indices.
*/

GPUARRAY_PUBLIC int GpuReductionAttr_new (GpuReductionAttr** grAttr,
gpucontext* gpuCtx);
GPUARRAY_PUBLIC int GpuReductionAttr_setop (GpuReductionAttr* grAttr,
ga_reduce_op op);
GPUARRAY_PUBLIC int GpuReductionAttr_setdims (GpuReductionAttr* grAttr,
unsigned maxSrcDims,
unsigned maxDstDims);
GPUARRAY_PUBLIC int GpuReductionAttr_sets0type (GpuReductionAttr* grAttr,
int s0Typecode);
GPUARRAY_PUBLIC int GpuReductionAttr_setd0type (GpuReductionAttr* grAttr,
int d0Typecode);
GPUARRAY_PUBLIC int GpuReductionAttr_setd1type (GpuReductionAttr* grAttr,
int d1Typecode);
GPUARRAY_PUBLIC int GpuReductionAttr_seti0type (GpuReductionAttr* grAttr,
int i0Typecode);
GPUARRAY_PUBLIC int GpuReductionAttr_appendopname (GpuReductionAttr* grAttr,
size_t n,
char* name);
GPUARRAY_PUBLIC int GpuReductionAttr_issensitive (const GpuReductionAttr* grAttr);
GPUARRAY_PUBLIC int GpuReductionAttr_requiresS0 (const GpuReductionAttr* grAttr);
GPUARRAY_PUBLIC int GpuReductionAttr_requiresD0 (const GpuReductionAttr* grAttr);
GPUARRAY_PUBLIC int GpuReductionAttr_requiresD1 (const GpuReductionAttr* grAttr);
GPUARRAY_PUBLIC void GpuReductionAttr_free (GpuReductionAttr* grAttr);

/**
* @brief Create a new GPU reduction operator with the given attributes.
*
* @param [out] gr The reduction operator.
* @param [in] grAttr The GPU context.
*
* @return GA_NO_ERROR if the operator was created successfully
* GA_INVALID_ERROR if some argument was invalid
* GA_NO_MEMORY if memory allocation failed anytime during creation
* or other non-zero error codes otherwise.
*/

GPUARRAY_PUBLIC int GpuReduction_new (GpuReduction** gr,
const GpuReductionAttr* grAttr);

/**
* @brief Deallocate an operator allocated by GpuReduction_new().
*/

GPUARRAY_PUBLIC void GpuReduction_free (GpuReduction* gr);

/**
* @brief Invoke an operator allocated by GpuReduction_new() on a source tensor.
*
* Returns one (in the case of min-and-argmin/max-and-argmax, two) destination
* tensors. The destination tensor(s)' axes are a strict subset of the axes of the
* source tensor. The axes to be reduced are specified by the caller, and the
* reduction is performed over these axes, which are then removed in the
* destination.
*
* @param [in] gr The reduction operator.
* @param [out] d0 The destination tensor.
* @param [out] d1 The second destination tensor, for argmin/argmax operations.
* @param [in] s0 The source tensor.
* @param [in] reduxLen The number of axes reduced. Must be >= 1 and
* <= s0->nd.
* @param [in] reduxList A list of integers of length reduxLen, indicating
* the axes to be reduced. The order of the axes
* matters for dstArg index calculations (argmin, argmax,
* minandargmin, maxandargmax). All entries in the list must be
* unique, >= 0 and < src->nd.
*
* For example, if a 5D-tensor is maxandargmax-reduced with an
* axis list of [3,4,1], then reduxLen shall be 3, and the
* index calculation in every point shall take the form
*
* d1[i0,i2] = i3 * s0.shape[4] * s0.shape[1] +
* i4 * s0.shape[1] +
* i1
*
* where (i3,i4,i1) are the coordinates of the maximum-
* valued element within subtensor [i0,:,i2,:,:] of s0.
* @param [in] flags Reduction operator invocation flags. Currently must be
* set to 0.
*
* @return GA_NO_ERROR if the operator was invoked successfully, or a non-zero
* error code otherwise.
*/

GPUARRAY_PUBLIC int GpuReduction_call (const GpuReduction* gr,
GpuArray* d0,
GpuArray* d1,
const GpuArray* s0,
unsigned reduxLen,
const int* reduxList,
int flags);


#ifdef __cplusplus
}
#endif

#endif
Loading