Skip to content

Commit

Permalink
Revert "move c++ lstm to also use size_t, since Rust uses usize"
Browse files Browse the repository at this point in the history
This reverts commit 3b174ea.
  • Loading branch information
ZuseZ4 committed Nov 9, 2024
1 parent 8945897 commit a73a1ad
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 103 deletions.
92 changes: 46 additions & 46 deletions enzyme/benchmarks/ReverseMode/adbench/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ using json = nlohmann::json;

struct LSTMInput
{
size_t l;
size_t c;
size_t b;
int l;
int c;
int b;
std::vector<double> main_params;
std::vector<double> extra_params;
std::vector<double> state;
Expand All @@ -34,60 +34,60 @@ struct LSTMOutput {
};

extern "C" {
void rust_unsafe_dlstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void rust_unsafe_dlstm_objective(int l, int c, int b, double const *main_params,
double *dmain_params,
double const *extra_params,
double *dextra_params, double *state,
double const *sequence, double *loss,
double *dloss);

void rust_unsafe_lstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void rust_unsafe_lstm_objective(int l, int c, int b, double const *main_params,
double const *extra_params, double *state,
double const *sequence, double *loss);

void rust_safe_lstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void rust_safe_lstm_objective(int l, int c, int b, double const *main_params,
double const *extra_params, double *state,
double const *sequence, double *loss);

void cxx_restrict_lstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void cxx_restrict_lstm_objective(int l, int c, int b, double const *main_params,
double const *extra_params, double *state,
double const *sequence, double *loss);

void cxx_mayalias_lstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void cxx_mayalias_lstm_objective(int l, int c, int b, double const *main_params,
double const *extra_params, double *state,
double const *sequence, double *loss);

void rust_safe_dlstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void rust_safe_dlstm_objective(int l, int c, int b, double const *main_params,
double *dmain_params, double const *extra_params,
double *dextra_params, double *state,
double const *sequence, double *loss,
double *dloss);

void dlstm_objective_mayalias(size_t l, size_t c, size_t b, double const *main_params,
void dlstm_objective_mayalias(int l, int c, int b, double const *main_params,
double *dmain_params, double const *extra_params,
double *dextra_params, double *state,
double const *sequence, double *loss,
double *dloss);

void dlstm_objective_restrict(size_t l, size_t c, size_t b, double const *main_params,
void dlstm_objective_restrict(int l, int c, int b, double const *main_params,
double *dmain_params, double const *extra_params,
double *dextra_params, double *state,
double const *sequence, double *loss,
double *dloss);

void lstm_objective_b(size_t l, size_t c, size_t b, const double *main_params,
void lstm_objective_b(int l, int c, int b, const double *main_params,
double *main_paramsb, const double *extra_params,
double *extra_paramsb, double *state,
const double *sequence, double *loss, double *lossb);

void adept_dlstm_objective(size_t l, size_t c, size_t b, double const *main_params,
void adept_dlstm_objective(int l, int c, int b, double const *main_params,
double *dmain_params, double const *extra_params,
double *dextra_params, double *state,
double const *sequence, double *loss, double *dloss);
}

void read_lstm_instance(const string& fn,
size_t* l, size_t* c, size_t* b,
int* l, int* c, int* b,
vector<double>& main_params,
vector<double>& extra_params,
vector<double>& state,
Expand All @@ -100,46 +100,46 @@ void read_lstm_instance(const string& fn,
exit(1);
}

fscanf(fid, "%zu %zu %zu", l, c, b);
fscanf(fid, "%i %i %i", l, c, b);

size_t l_ = *l, c_ = *c, b_ = *b;
int l_ = *l, c_ = *c, b_ = *b;

size_t main_sz = 2 * l_ * 4 * b_;
size_t extra_sz = 3 * b_;
size_t state_sz = 2 * l_ * b_;
size_t seq_sz = c_ * b_;
int main_sz = 2 * l_ * 4 * b_;
int extra_sz = 3 * b_;
int state_sz = 2 * l_ * b_;
int seq_sz = c_ * b_;

main_params.resize(main_sz);
extra_params.resize(extra_sz);
state.resize(state_sz);
sequence.resize(seq_sz);

for (size_t i = 0; i < main_sz; i++) {
for (int i = 0; i < main_sz; i++) {
fscanf(fid, "%lf", &main_params[i]);
}

for (size_t i = 0; i < extra_sz; i++) {
for (int i = 0; i < extra_sz; i++) {
fscanf(fid, "%lf", &extra_params[i]);
}

for (size_t i = 0; i < state_sz; i++) {
for (int i = 0; i < state_sz; i++) {
fscanf(fid, "%lf", &state[i]);
}

for (size_t i = 0; i < c_ * b_; i++) {
for (int i = 0; i < c_ * b_; i++) {
fscanf(fid, "%lf", &sequence[i]);
}

/*char ch;
fscanf(fid, "%c", &ch);
fscanf(fid, "%c", &ch);
for (size_t i = 0; i < c_; i++) {
for (int i = 0; i < c_; i++) {
unsigned char ch;
fscanf(fid, "%c", &ch);
size_t cb = ch;
for (size_t j = b_ - 1; j >= 0; j--) {
size_t p = pow(2, j);
int cb = ch;
for (int j = b_ - 1; j >= 0; j--) {
int p = pow(2, j);
if (cb >= p) {
sequence[(i + 1) * b_ - j - 1] = 1;
cb -= p;
Expand All @@ -154,9 +154,9 @@ void read_lstm_instance(const string& fn,
}

typedef void(*deriv_t)(
size_t l,
size_t c,
size_t b,
int l,
int c,
int b,
double const* main_params,
double* dmain_params,
double const* extra_params,
Expand All @@ -170,7 +170,7 @@ typedef void(*deriv_t)(
template<deriv_t deriv>
void calculate_jacobian(struct LSTMInput &input, struct LSTMOutput &result)
{
for(size_t i=0; i<100; i++) {
for(int i=0; i<100; i++) {

double* main_params_gradient_part = result.gradient.data();
double* extra_params_gradient_part = result.gradient.data() + input.main_params.size();
Expand Down Expand Up @@ -198,7 +198,7 @@ void calculate_jacobian(struct LSTMInput &input, struct LSTMOutput &result)

double calculate_mayalias_primal(struct LSTMInput &input) {
double loss = 0.0;
for (size_t i = 0; i < 100; i++) {
for (int i = 0; i < 100; i++) {
cxx_mayalias_lstm_objective(
input.l, input.c, input.b, input.main_params.data(),
input.extra_params.data(), input.state.data(),
Expand All @@ -209,7 +209,7 @@ double calculate_mayalias_primal(struct LSTMInput &input) {

double calculate_restrict_primal(struct LSTMInput &input) {
double loss = 0.0;
for (size_t i = 0; i < 100; i++) {
for (int i = 0; i < 100; i++) {
cxx_restrict_lstm_objective(
input.l, input.c, input.b, input.main_params.data(),
input.extra_params.data(), input.state.data(),
Expand All @@ -220,7 +220,7 @@ double calculate_restrict_primal(struct LSTMInput &input) {

double calculate_unsafe_primal(struct LSTMInput &input) {
double loss = 0.0;
for (size_t i = 0; i < 100; i++) {
for (int i = 0; i < 100; i++) {
rust_unsafe_lstm_objective(
input.l, input.c, input.b, input.main_params.data(),
input.extra_params.data(), input.state.data(),
Expand All @@ -231,7 +231,7 @@ double calculate_unsafe_primal(struct LSTMInput &input) {

double calculate_safe_primal(struct LSTMInput &input) {
double loss = 0.0;
for (size_t i = 0; i < 100; i++) {
for (int i = 0; i < 100; i++) {
rust_safe_lstm_objective(input.l, input.c, input.b,
input.main_params.data(),
input.extra_params.data(), input.state.data(),
Expand Down Expand Up @@ -265,7 +265,7 @@ int main(const int argc, const char* argv[]) {
}
printf("\n");

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = { 0, std::vector<double>(Jcols) };

{
Expand Down Expand Up @@ -299,7 +299,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = { 0, std::vector<double>(Jcols) };

{
Expand Down Expand Up @@ -332,7 +332,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = { 0, std::vector<double>(Jcols) };

{
Expand Down Expand Up @@ -366,7 +366,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down Expand Up @@ -399,7 +399,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = { 0, std::vector<double>(Jcols) };

{
Expand Down Expand Up @@ -434,7 +434,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down Expand Up @@ -467,7 +467,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down Expand Up @@ -500,7 +500,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down Expand Up @@ -533,7 +533,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down Expand Up @@ -566,7 +566,7 @@ int main(const int argc, const char* argv[]) {

std::vector<double> state = std::vector<double>(input.state.size());

size_t Jcols = 8 * input.l * input.b + 3 * input.b;
int Jcols = 8 * input.l * input.b + 3 * input.b;
struct LSTMOutput result = {0, std::vector<double>(Jcols)};

{
Expand Down
Loading

0 comments on commit a73a1ad

Please sign in to comment.