forked from hjk41/MXNet.cpp
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
#pragma once | ||
|
||
#include <functional> | ||
#include <vector> | ||
#include <cinttypes> | ||
#include <string> | ||
#include <list> | ||
|
||
#define ALLREDUCE_IN_PLACE ((void *)-1) | ||
|
||
#define ALLREDUCE_OP_SUM 0 | ||
#define ALLREDUCE_OP_UDR 1 | ||
#define ALLREDUCE_OP_NUM 2 | ||
|
||
#define ALLREDUCE_TYPE_INT32 0 | ||
#define ALLREDUCE_TYPE_UINT32 1 | ||
#define ALLREDUCE_TYPE_INT64 2 | ||
#define ALLREDUCE_TYPE_UINT64 3 | ||
#define ALLREDUCE_TYPE_FLOAT 4 | ||
#define ALLREDUCE_TYPE_DOUBLE 5 | ||
#define ALLREDUCE_TYPE_NUM 6 | ||
|
||
#define PS_ROLE_ALL (-1) | ||
#define PS_ROLE_COORDINATOR (0) | ||
|
||
#define PS_CMD_WAIT (-1) | ||
#define PS_CMD_WAIT_ACK (-2) | ||
|
||
# if defined(EXPORTDLL) | ||
# if defined(_WIN32) | ||
# define ExportDll __declspec(dllexport) | ||
# else | ||
# define ExportDll __attribute__((visibility("default"))) | ||
# endif | ||
# else | ||
# if defined(_WIN32) | ||
# define ExportDll __declspec(dllimport) | ||
# else | ||
# define ExportDll | ||
# endif | ||
# endif | ||
|
||
typedef void(*reducer_t)(void *state, void *opnd, size_t state_size); | ||
typedef void(*broadcast_udf_t)(void *bcastbuf, size_t size, void *args); | ||
|
||
typedef void(*ps_push_callback_t)(size_t count, uint64_t *keys, void **vals, size_t *val_sizes, void *args); | ||
typedef void(*val_deallocator_t)(void *args); | ||
|
||
class ExportDll ChaNaPSBase | ||
{ | ||
public: | ||
ChaNaPSBase() { } | ||
|
||
// user defined interface | ||
//static ChaNaPSBase* Create(); | ||
|
||
virtual void control(int cmd_id, void *data, const size_t len) { } | ||
|
||
virtual void server_process_pull( | ||
uint64_t *keys, | ||
size_t key_count, | ||
void **vals, | ||
size_t *val_sizes, | ||
val_deallocator_t *dealloc, | ||
void **args, | ||
bool *fixed_val_size | ||
) = 0; | ||
|
||
virtual void server_process_push(size_t key_count, uint64_t *keys, void **vals, size_t *valsizes) = 0; | ||
virtual void worker_apply_pull(void *args) = 0; | ||
}; | ||
|
||
typedef ChaNaPSBase*(*ps_create_function_t)(void *args); | ||
|
||
ExportDll std::list<std::string> chana_config_get_string_value_list(const char* section, const char* key, char splitter, const char* dsptr); | ||
|
||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif /* __cplusplus */ | ||
|
||
ExportDll void CreateSyncEngine(const char *machine_list_filename, const int iothread_per_machine_count, bool use_rdma); | ||
ExportDll int GetMyRank(); | ||
ExportDll int GetMachineCount(); | ||
ExportDll size_t GetPSThreadNumPerMachine(); | ||
ExportDll void BarrierEnter(); | ||
ExportDll void ChaNa_AllReduce(const void *sendbuf, void *recvbuf, size_t elemcount, size_t elemsize, int elemtype, int op); | ||
ExportDll void RegisterUserDefinedReducer(reducer_t reducer, int elemtype); | ||
|
||
ExportDll void AsyncBroadCast(void *bcastbuf, size_t size); | ||
|
||
// must perform barrier between this function and the following AsyncBroadCast() | ||
ExportDll void RegisterAsyncBroadCastHandler(broadcast_udf_t bcast_cb, void *args); | ||
|
||
|
||
/////////////////////// Parameter Server Interfaces //////////////////////////////////////// | ||
|
||
ExportDll void CreateParameterServer( | ||
std::string machine_list_file, | ||
const int ps_per_machine_count, | ||
bool use_rdma, | ||
ps_create_function_t create_function, | ||
void *args | ||
); | ||
|
||
ExportDll void CreateParameterServerWithPort( | ||
std::string machine_list_file, | ||
const int ps_per_machine_count, | ||
const int port, | ||
bool use_rdma, | ||
ps_create_function_t create_function, | ||
void *args | ||
); | ||
|
||
ExportDll ChaNaPSBase * ChaNaPSGetInstance(int inst_id); | ||
ExportDll void ChaNaPSWait(); | ||
ExportDll uint64_t ChaNaPSControl(int roles, void *data, size_t len, val_deallocator_t cb, void *args); | ||
ExportDll uint64_t ChaNaPSPull(size_t count, uint64_t *keys, void **vals, size_t *val_sizes, void *args); | ||
ExportDll uint64_t ChaNaPSPush(size_t count, uint64_t *keys, void **vals, size_t *val_sizes, ps_push_callback_t cb, void *args); | ||
|
||
// init | ||
ExportDll bool chana_initialize(int argc, const char *argv[]); | ||
ExportDll bool chana_is_initialized(void); | ||
|
||
// config | ||
ExportDll const char* chana_config_get_value_string(const char* section, const char* key, const char* default_value, const char* dsptr); | ||
ExportDll void chana_config_set(const char* section, const char* key, const char* value, const char* dsptr); | ||
ExportDll bool chana_config_get_value_list(const char* section, const char* key, char splitter, const char* dsptr); | ||
ExportDll bool chana_config_get_value_bool(const char* section, const char* key, bool default_value, const char* dsptr); | ||
ExportDll uint64_t chana_config_get_value_uint64(const char* section, const char* key, uint64_t default_value, const char* dsptr); | ||
ExportDll double chana_config_get_value_double(const char* section, const char* key, double default_value, const char* dsptr); | ||
ExportDll int chana_config_get_all_keys(const char* section, const char** buffers, int buffer_count); | ||
ExportDll bool chana_config_has_section(const char* section); | ||
ExportDll bool chana_config_has_key(const char* section, const char* key); | ||
ExportDll void chana_config_dump(const char* file); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif /* __cplusplus */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
|
||
Microsoft Visual Studio Solution File, Format Version 12.00 | ||
# Visual Studio 2013 | ||
VisualStudioVersion = 12.0.40629.0 | ||
MinimumVisualStudioVersion = 10.0.40219.1 | ||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "dmlc-core", "..\dmlc-core\dmlc-core.vcxproj", "{D22EA6D1-CE76-4C85-9075-D7172816E93F}" | ||
EndProject | ||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "MxNetCppCommon", "MxNetCppCommon\MxNetCppCommon.vcxproj", "{19D5EC88-F7F3-40E1-83D8-693FF451C4B6}" | ||
EndProject | ||
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Mnist", "Mnist\Mnist.vcxproj", "{D8FC80FA-0565-467E-8E62-BA405A93B358}" | ||
EndProject | ||
Global | ||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||
Debug|x64 = Debug|x64 | ||
Release|x64 = Release|x64 | ||
EndGlobalSection | ||
GlobalSection(ProjectConfigurationPlatforms) = postSolution | ||
{D22EA6D1-CE76-4C85-9075-D7172816E93F}.Debug|x64.ActiveCfg = Debug|x64 | ||
{D22EA6D1-CE76-4C85-9075-D7172816E93F}.Debug|x64.Build.0 = Debug|x64 | ||
{D22EA6D1-CE76-4C85-9075-D7172816E93F}.Release|x64.ActiveCfg = Release|x64 | ||
{D22EA6D1-CE76-4C85-9075-D7172816E93F}.Release|x64.Build.0 = Release|x64 | ||
{19D5EC88-F7F3-40E1-83D8-693FF451C4B6}.Debug|x64.ActiveCfg = Debug|x64 | ||
{19D5EC88-F7F3-40E1-83D8-693FF451C4B6}.Debug|x64.Build.0 = Debug|x64 | ||
{19D5EC88-F7F3-40E1-83D8-693FF451C4B6}.Release|x64.ActiveCfg = Release|x64 | ||
{19D5EC88-F7F3-40E1-83D8-693FF451C4B6}.Release|x64.Build.0 = Release|x64 | ||
{D8FC80FA-0565-467E-8E62-BA405A93B358}.Debug|x64.ActiveCfg = Debug|x64 | ||
{D8FC80FA-0565-467E-8E62-BA405A93B358}.Debug|x64.Build.0 = Debug|x64 | ||
{D8FC80FA-0565-467E-8E62-BA405A93B358}.Release|x64.ActiveCfg = Release|x64 | ||
{D8FC80FA-0565-467E-8E62-BA405A93B358}.Release|x64.Build.0 = Release|x64 | ||
EndGlobalSection | ||
GlobalSection(SolutionProperties) = preSolution | ||
HideSolutionNode = FALSE | ||
EndGlobalSection | ||
EndGlobal |