Skip to content

Commit

Permalink
Merge pull request #1032 from nickdidio/complex-numbers2
Browse files Browse the repository at this point in the history
Complex numbers vals_c
  • Loading branch information
bob-carpenter authored Aug 13, 2021
2 parents d919452 + 18b9ff6 commit 120cb36
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 1 deletion.
40 changes: 40 additions & 0 deletions src/cmdstan/io/json/json_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <sstream>
#include <string>
#include <vector>
#include <complex>

namespace cmdstan {

Expand Down Expand Up @@ -119,6 +120,45 @@ class json_data : public stan::io::var_context {
return empty_vec_r_;
}

/**
* Read out the complex values for the variable with the specifed
* name and return a flat vector of complex values.
*
* @param name Name of Variable of type string.
* @return Vector of complex numbers with values equal to the read input.
*/
std::vector<std::complex<double>> vals_c(const std::string &name) const {
if (contains_r_only(name)) {
auto &&vec_r = (vars_r_.find(name)->second);
auto &&val_r = vec_r.first;
auto &&dim_r = vec_r.second;
std::vector<std::complex<double>> vec_c(val_r.size() / 2);
int offset = 1;
for (int i = 0; i < dim_r.size() - 1; ++i) {
offset *= dim_r[i];
}
for (int i = 0; i < vec_c.size(); ++i) {
vec_c[i] = std::complex<double>{val_r[i], val_r[i + offset]};
}
return vec_c;
} else if (contains_i(name)) {
auto &&vec_i = (vars_i_.find(name)->second);
auto &&val_i = vec_i.first;
auto &&dim_i = vec_i.second;
std::vector<std::complex<double>> vec_c(val_i.size() / 2);
int offset = 1;
for (int i = 0; i < dim_i.size() - 1; ++i) {
offset *= dim_i[i];
}
for (int i = 0; i < vec_c.size(); ++i) {
vec_c[i] = std::complex<double>{static_cast<double>(val_i[i]),
static_cast<double>(val_i[i + offset])};
}
return vec_c;
}
return std::vector<std::complex<double>>{};
}

/**
* Return the dimensions for the variable with the specified
* name.
Expand Down
84 changes: 84 additions & 0 deletions src/test/interface/io/json/json_data_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <boost/math/special_functions/fpclassify.hpp>
#include <gtest/gtest.h>

#include <complex>

void test_int_var(cmdstan::json::json_data &jdata, const std::string &text,
const std::string &name,
const std::vector<int> &expected_vals,
Expand Down Expand Up @@ -52,6 +54,22 @@ void test_real_var(cmdstan::json::json_data &jdata, const std::string &text,
EXPECT_EQ(expected_vals[i], vals[i]);
}

void test_complex_var(cmdstan::json::json_data &jdata, const std::string &text,
const std::string &name,
const std::vector<std::complex<double>> &expected_vals,
const std::vector<size_t> &expected_dims) {
EXPECT_EQ(true, (jdata.contains_r(name) || jdata.contains_i(name)));
std::vector<size_t> dims = jdata.dims_r(name);
dims.pop_back();
EXPECT_EQ(expected_dims.size(), dims.size());
for (size_t i = 0; i < dims.size(); i++)
EXPECT_EQ(expected_dims[i], dims[i]);
std::vector<std::complex<double>> vals = jdata.vals_c(name);
EXPECT_EQ(expected_vals.size(), vals.size());
for (size_t i = 0; i < vals.size(); i++)
EXPECT_EQ(expected_vals[i], vals[i]);
}

void test_exception(const std::string &input,
const std::string &exception_text) {
try {
Expand Down Expand Up @@ -84,6 +102,16 @@ TEST(ioJson, jsonData_scalar_real) {
test_real_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_scalar_complex) {
std::string txt = "{ \"foo\" : [1.1, 2.2] }";
std::stringstream in(txt);
cmdstan::json::json_data jdata(in);
std::vector<std::complex<double>> expected_vals;
expected_vals.push_back(std::complex<double>(1.1, 2.2));
std::vector<size_t> expected_dims;
test_complex_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_mult_vars) {
std::string txt = "{ \"foo\" : 1, \"bar\" : 0.1 }";
std::stringstream in(txt);
Expand Down Expand Up @@ -144,6 +172,18 @@ TEST(ioJson, jsonData_real_array_1D) {
test_real_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_complex_array_1D) {
std::string txt = "{ \"foo\" : [ [1.1, 2.2], [3, 4] ] }";
std::stringstream in(txt);
cmdstan::json::json_data jdata(in);
std::vector<std::complex<double>> expected_vals;
expected_vals.push_back(std::complex<double>(1.1, 2.2));
expected_vals.push_back(std::complex<double>(3, 4));
std::vector<size_t> expected_dims;
expected_dims.push_back(2);
test_complex_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_array_1D_inf) {
std::string txt = "{ \"foo\" : [ 1.1, \"Inf\" ] }";
std::stringstream in(txt);
Expand Down Expand Up @@ -198,6 +238,21 @@ TEST(ioJson, jsonData_real_array_2D) {
test_real_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_complex_array_2D) {
std::string txt = "{ \"foo\" : [ [ [1, 2], [3, 4] ], [ [5, 6], [7, 8] ] ] }";
std::stringstream in(txt);
cmdstan::json::json_data jdata(in);
std::vector<std::complex<double>> expected_vals;
expected_vals.push_back(std::complex<double>(1, 2));
expected_vals.push_back(std::complex<double>(5, 6));
expected_vals.push_back(std::complex<double>(3, 4));
expected_vals.push_back(std::complex<double>(7, 8));
std::vector<size_t> expected_dims;
expected_dims.push_back(2);
expected_dims.push_back(2);
test_complex_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_real_array_3D) {
std::string txt
= "{ \"foo\" : [ [ [ 11.1, 11.2, 11.3, 11.4 ], [ 12.1, 12.2, 12.3, 12.4 "
Expand Down Expand Up @@ -240,6 +295,35 @@ TEST(ioJson, jsonData_real_array_3D) {
test_real_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_complex_array_3D) {
std::string txt
= "{ \"foo\" : [ [ [ [11.1, 11.2], [11.3, 11.4] ], [ [12.1,"
" 12.2], [12.3, 12.4] ], "
"[ [13.1, 13.2], [13.3, 13.4]] ],"
" [ [ [21.1, 21.2], [21.3, 21.4] ], [ [22.1, 22.2], "
"[22.3, 22.4] ], [ [23.1, 23.2], [23.3, 23.4]] ] ] }";
std::stringstream in(txt);
cmdstan::json::json_data jdata(in);
std::vector<std::complex<double>> expected_vals;
expected_vals.push_back(std::complex<double>(11.1, 11.2));
expected_vals.push_back(std::complex<double>(21.1, 21.2));
expected_vals.push_back(std::complex<double>(12.1, 12.2));
expected_vals.push_back(std::complex<double>(22.1, 22.2));
expected_vals.push_back(std::complex<double>(13.1, 13.2));
expected_vals.push_back(std::complex<double>(23.1, 23.2));
expected_vals.push_back(std::complex<double>(11.3, 11.4));
expected_vals.push_back(std::complex<double>(21.3, 21.4));
expected_vals.push_back(std::complex<double>(12.3, 12.4));
expected_vals.push_back(std::complex<double>(22.3, 22.4));
expected_vals.push_back(std::complex<double>(13.3, 13.4));
expected_vals.push_back(std::complex<double>(23.3, 23.4));
std::vector<size_t> expected_dims;
expected_dims.push_back(2);
expected_dims.push_back(3);
expected_dims.push_back(2);
test_complex_var(jdata, txt, "foo", expected_vals, expected_dims);
}

TEST(ioJson, jsonData_int_array_3D) {
std::string txt
= "{ \"foo\" : [ [ [ 111, 112, 113, 114 ], [ 121, 122, 123, "
Expand Down
2 changes: 1 addition & 1 deletion stan
Submodule stan updated 1 files
+1 −1 lib/stan_math

0 comments on commit 120cb36

Please sign in to comment.