diff --git a/include/Bitmask.h b/include/Bitmask.h index e9d379ae..119db008 100644 --- a/include/Bitmask.h +++ b/include/Bitmask.h @@ -25,13 +25,15 @@ namespace sperr { class Bitmask { public: // Constructor + // Bitmask(size_t nbits = 0); // How many bits does it hold initially? // Functions for both read and write // - auto size() const -> size_t; - void resize(size_t nbits); // Resize to hold n bits. - void reset(); // Set the current bitmask to be all 0s. + auto size() const -> size_t; // Num. of useful bits in this mask. + void resize(size_t nbits); // Resize to hold n bits. + void reset(); // Set the current bitmask to be all 0's. + void reset_true(); // Set the current bitmask to be all 1's. // Functions for read // @@ -43,6 +45,19 @@ class Bitmask { void write_long(size_t idx, uint64_t value); void write_bit(size_t idx, bool bit); + // Functions for direct access of the underlying data buffer + // Note: `use_bitstream()` reads the number of values (uint64_t type) that provide + // enough bits for the specified size of this mask. + // + auto view_buffer() const -> const std::vector&; + void use_bitstream(const void* p); + + // Compare if two Bitmasks are identical. + // +#if __cplusplus >= 202002L + auto operator==(const Bitmask& rhs) const -> bool = default; +#endif + private: std::vector m_buf; size_t m_num_bits = 0; diff --git a/include/Bitstream.h b/include/Bitstream.h index 49f8e7ca..f76d547e 100644 --- a/include/Bitstream.h +++ b/include/Bitstream.h @@ -19,10 +19,13 @@ * 3. Because of 2, true random writes is not possible; it's only possible at the end of * each word, e.g., positions of 63, 127, 191. * 4. A function call of flush() will align the writing position to the beginning of the - * next word. + * next word, i.e., the number of truly useful bits is lost! + * One wants to call wtell() to retrieve and keep that info. * 5. Functions write_bitstream() and parse_bitstream() take in a raw pointer and the * number of bits to write/read. The memory pointed to by the raw pointer needs to * be big enough to hold the number of bits specified. + * 6. get_bitstream() and write_bitstream() needs to be supplied a number of bits because + * a Bitstream itself will lose track of how many useful bits are there after flush(). */ #include diff --git a/include/SPECK1D_INT.h b/include/SPECK1D_INT.h index a2a878de..898c8337 100644 --- a/include/SPECK1D_INT.h +++ b/include/SPECK1D_INT.h @@ -16,12 +16,20 @@ class Set1D { // // Main SPECK1D_INT class; intended to be the base class of both encoder and decoder. // -class SPECK1D_INT : public SPECK_INT { +template +class SPECK1D_INT : public SPECK_INT { public: // Virtual destructor virtual ~SPECK1D_INT() = default; protected: + // + // Bring members from the base class to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_dims; + using SPECK_INT::m_u64_garbage_val; + virtual void m_clean_LIS() override; virtual void m_initialize_lists() override; auto m_partition_set(const Set1D&) const -> std::array; diff --git a/include/SPECK1D_INT_DEC.h b/include/SPECK1D_INT_DEC.h index 1f310dc8..b5fa5289 100644 --- a/include/SPECK1D_INT_DEC.h +++ b/include/SPECK1D_INT_DEC.h @@ -8,8 +8,23 @@ namespace sperr { // // Main SPECK1D_INT_DEC class // -class SPECK1D_INT_DEC : public SPECK1D_INT { +template +class SPECK1D_INT_DEC : public SPECK1D_INT { private: + // + // Bring members from parent classes to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_LSP_new; + using SPECK_INT::m_bit_idx; + using SPECK_INT::m_threshold; + using SPECK_INT::m_coeff_buf; + using SPECK_INT::m_bit_buffer; + using SPECK_INT::m_sign_array; + using SPECK_INT::m_u64_garbage_val; + using SPECK1D_INT::m_LIS; + using SPECK1D_INT::m_partition_set; + virtual void m_sorting_pass() override; void m_process_S(size_t idx1, size_t idx2, size_t& counter, bool read); diff --git a/include/SPECK1D_INT_ENC.h b/include/SPECK1D_INT_ENC.h index 66f2d4ac..206e0542 100644 --- a/include/SPECK1D_INT_ENC.h +++ b/include/SPECK1D_INT_ENC.h @@ -8,8 +8,22 @@ namespace sperr { // // Main SPECK1D_INT_ENC class // -class SPECK1D_INT_ENC : public SPECK1D_INT { +template +class SPECK1D_INT_ENC : public SPECK1D_INT { private: + // + // Bring members from parent classes to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_LSP_new; + using SPECK_INT::m_threshold; + using SPECK_INT::m_coeff_buf; + using SPECK_INT::m_bit_buffer; + using SPECK_INT::m_sign_array; + using SPECK_INT::m_u64_garbage_val; + using SPECK1D_INT::m_LIS; + using SPECK1D_INT::m_partition_set; + virtual void m_sorting_pass() override; void m_process_S(size_t idx1, size_t idx2, SigType, size_t& counter, bool); diff --git a/include/SPECK3D_INT.h b/include/SPECK3D_INT.h index e9986e0b..905618e1 100644 --- a/include/SPECK3D_INT.h +++ b/include/SPECK3D_INT.h @@ -32,12 +32,23 @@ class Set3D { // // Main SPECK3D_INT class; intended to be the base class of both encoder and decoder. // -class SPECK3D_INT : public SPECK_INT { +template +class SPECK3D_INT : public SPECK_INT { public: // Virtual destructor virtual ~SPECK3D_INT() = default; protected: + // + // Bring members from the base class to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_dims; + using SPECK_INT::m_LSP_new; + using SPECK_INT::m_coeff_buf; + using SPECK_INT::m_bit_buffer; + using SPECK_INT::m_u64_garbage_val; + virtual void m_clean_LIS() override; virtual void m_initialize_lists() override; diff --git a/include/SPECK3D_INT_DEC.h b/include/SPECK3D_INT_DEC.h index 04a756ad..3df8501c 100644 --- a/include/SPECK3D_INT_DEC.h +++ b/include/SPECK3D_INT_DEC.h @@ -8,8 +8,24 @@ namespace sperr { // // Main SPECK3D_INT_DEC class // -class SPECK3D_INT_DEC : public SPECK3D_INT { +template +class SPECK3D_INT_DEC : public SPECK3D_INT { private: + // + // Bring members from parent classes to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_dims; + using SPECK_INT::m_LSP_new; + using SPECK_INT::m_bit_idx; + using SPECK_INT::m_threshold; + using SPECK_INT::m_coeff_buf; + using SPECK_INT::m_bit_buffer; + using SPECK_INT::m_sign_array; + using SPECK_INT::m_u64_garbage_val; + using SPECK3D_INT::m_LIS; + using SPECK3D_INT::m_partition_S_XYZ; + virtual void m_sorting_pass() override; void m_process_S(size_t idx1, size_t idx2, size_t& counter, bool); diff --git a/include/SPECK3D_INT_ENC.h b/include/SPECK3D_INT_ENC.h index d24e7adb..286dd4b3 100644 --- a/include/SPECK3D_INT_ENC.h +++ b/include/SPECK3D_INT_ENC.h @@ -8,8 +8,23 @@ namespace sperr { // // Main SPECK3D_INT_ENC class // -class SPECK3D_INT_ENC : public SPECK3D_INT { +template +class SPECK3D_INT_ENC : public SPECK3D_INT { private: + // + // Bring members from parent classes to this derived class. + // + using SPECK_INT::m_LIP; + using SPECK_INT::m_dims; + using SPECK_INT::m_LSP_new; + using SPECK_INT::m_threshold; + using SPECK_INT::m_coeff_buf; + using SPECK_INT::m_bit_buffer; + using SPECK_INT::m_sign_array; + using SPECK_INT::m_u64_garbage_val; + using SPECK3D_INT::m_LIS; + using SPECK3D_INT::m_partition_S_XYZ; + virtual void m_sorting_pass() override; void m_process_S(size_t idx1, size_t idx2, SigType, size_t& counter, bool); diff --git a/include/SPECK_INT.h b/include/SPECK_INT.h index b38bcf0b..b15a07ad 100644 --- a/include/SPECK_INT.h +++ b/include/SPECK_INT.h @@ -1,7 +1,9 @@ #ifndef SPECK_INT_H #define SPECK_INT_H +// // This is the base class of 1D, 2D, and 3D integer SPECK implementations. +// #include "sperr_helper.h" @@ -12,20 +14,23 @@ namespace sperr { template class SPECK_INT { - -using uint_type = T; -using vecui_type = std::vector; + using uint_type = T; + using vecui_type = std::vector; public: - // Constructor + // Constructor and destructor SPECK_INT(); - // Virtual destructor virtual ~SPECK_INT() = default; + // The length (1, 2, 4, 8) of the integer type in use + auto integer_len() const -> size_t; + void set_dims(dims_type); - // Retrieve the full length of a SPECK bitstream from its header - auto get_speck_full_len(const void*) const -> uint64_t; + // Retrieve info of a SPECK bitstream from its header + auto get_num_bitplanes(const void*) const -> uint8_t; + auto get_speck_bits(const void*) const -> uint64_t; + auto get_stream_full_len(const void*) const -> uint64_t; // Actions virtual void encode(); @@ -45,21 +50,21 @@ using vecui_type = std::vector; protected: // Core SPECK procedures virtual void m_clean_LIS() = 0; - virtual void m_initialize_lists() = 0; virtual void m_sorting_pass() = 0; + virtual void m_initialize_lists() = 0; virtual void m_refinement_pass_encode(); virtual void m_refinement_pass_decode(); // Data members dims_type m_dims = {0, 0, 0}; uint_type m_threshold = 0; + Bitmask m_LSP_mask; vecui_type m_coeff_buf; vecb_type m_sign_array; Bitstream m_bit_buffer; - Bitmask m_LSP_mask; std::vector m_LIP, m_LSP_new; - const size_t m_u64_garbage_val = std::numeric_limits::max(); + const uint64_t m_u64_garbage_val = std::numeric_limits::max(); const size_t m_header_size = 9; // 9 bytes uint64_t m_bit_idx = 0; // current bit idx when decoding diff --git a/include/SPERR3D.h b/include/SPERR3D.h index f751bfda..84c4a454 100644 --- a/include/SPERR3D.h +++ b/include/SPERR3D.h @@ -6,13 +6,11 @@ namespace sperr { class SPERR3D : public SPERR_Driver { - public: - // - // Constructor - // - SPERR3D(); protected: + virtual void m_instantiate_encoder() override; + virtual void m_instantiate_decoder() override; + virtual void m_wavelet_xform() override; virtual void m_inverse_wavelet_xform() override; diff --git a/include/SPERR_Driver.h b/include/SPERR_Driver.h index 6b9c6862..16a61850 100644 --- a/include/SPERR_Driver.h +++ b/include/SPERR_Driver.h @@ -9,6 +9,8 @@ #include "Conditioner.h" #include "SPECK_INT.h" +#include + namespace sperr { class SPERR_Driver { @@ -21,10 +23,10 @@ class SPERR_Driver { // // Input // - // Accept incoming data: copy from a raw memory block + // Accept incoming data: copy from a raw memory block. + // `len` is the number of values. template - void copy_data(const T* p, // Input: pointer to the memory block - size_t len); // Input: number of values + void copy_data(const T* p, size_t len); // Accept incoming data: take ownership of a memory block void take_data(std::vector&&); @@ -52,11 +54,24 @@ class SPERR_Driver { virtual auto decompress() -> RTNType; protected: + // Default to use 64-bit integers, but can also use smaller sizes. + // + UINTType m_uint_flag = UINTType::UINT64; + std::variant, + std::vector, + std::vector, + std::vector> + m_vals_ui; + std::variant>, + std::unique_ptr>, + std::unique_ptr>, + std::unique_ptr>> + m_encoder, m_decoder; + dims_type m_dims = {0, 0, 0}; double m_q = 1.0; // 1.0 is a better initial value than 0.0 vecd_type m_vals_d; - vecui_t m_vals_ui; // Unsigned integers to be passed to the encoder - vecb_type m_sign_array; // Signs to be passed to the encoder + vecb_type m_sign_array; // Signs to be passed to the encoder Conditioner::settings_type m_conditioning_settings = {true, false, false, false}; Conditioner::meta_type m_condi_bitstream; @@ -64,11 +79,14 @@ class SPERR_Driver { CDF97 m_cdf; Conditioner m_conditioner; - std::unique_ptr m_encoder = nullptr; - std::unique_ptr m_decoder = nullptr; - auto m_midtread_f2i() -> RTNType; - void m_midtread_i2f(); + // Derived classes instantiate the correct `m_encoder` and `m_decoder` depending on + // 3D/2D/1D classes, and the integer length in use. + virtual void m_instantiate_encoder() = 0; + virtual void m_instantiate_decoder() = 0; + + // Instantiate `m_vals_ui` based on the chosen integer length. + virtual void m_instantiate_int_vec(); // Both wavelet transforms operate on `m_vals_d`. virtual void m_wavelet_xform() = 0; @@ -79,9 +97,10 @@ class SPERR_Driver { virtual auto m_quantize() -> RTNType = 0; virtual auto m_inverse_quantize() -> RTNType = 0; - // Optional procedures for flexibility - virtual auto m_proc_1() -> RTNType; - virtual auto m_proc_2() -> RTNType; + // This base class provides two midtread quantization implementations, but derived classes + // can have other quantization methods. + auto m_midtread_f2i() -> RTNType; + void m_midtread_i2f(); }; }; // namespace sperr diff --git a/include/sperr_helper.h b/include/sperr_helper.h index 02a180d9..c044fcba 100644 --- a/include/sperr_helper.h +++ b/include/sperr_helper.h @@ -43,6 +43,8 @@ enum class SigType : unsigned char { Insig, Sig, NewlySig, Dunno, Garbage }; enum class SetType : unsigned char { TypeS, TypeI, Garbage }; +enum class UINTType : unsigned char { UINT64, UINT32, UINT16, UINT8 }; + // Return Type enum class RTNType { Good = 0, @@ -50,8 +52,7 @@ enum class RTNType { BitstreamWrongLen, IOError, InvalidParam, - QzLevelTooBig, // a very specific type of invalid param - EmptyStream, // a condition but not sure if it's an error + EmptyStream, // a condition but not sure if it's an error BitBudgetMet, VersionMismatch, ZSTDMismatch, diff --git a/src/Bitmask.cpp b/src/Bitmask.cpp index 56c71593..1f40ee12 100644 --- a/src/Bitmask.cpp +++ b/src/Bitmask.cpp @@ -1,6 +1,7 @@ #include "Bitmask.h" #include +#include sperr::Bitmask::Bitmask(size_t nbits) { @@ -32,6 +33,11 @@ void sperr::Bitmask::reset() std::fill(m_buf.begin(), m_buf.end(), 0); } +void sperr::Bitmask::reset_true() +{ + std::fill(m_buf.begin(), m_buf.end(), std::numeric_limits::max()); +} + auto sperr::Bitmask::read_long(size_t idx) const -> uint64_t { return m_buf[idx / 64]; @@ -61,3 +67,14 @@ void sperr::Bitmask::write_bit(size_t idx, bool bit) word &= ~mask; m_buf[wstart] = word; } + +auto sperr::Bitmask::view_buffer() const -> const std::vector& +{ + return m_buf; +} + +void sperr::Bitmask::use_bitstream(const void* p) +{ + const auto* pu64 = static_cast(p); + std::copy(pu64, pu64 + m_buf.size(), m_buf.begin()); +} diff --git a/src/Bitstream.cpp b/src/Bitstream.cpp index ab691b9e..ed025f07 100644 --- a/src/Bitstream.cpp +++ b/src/Bitstream.cpp @@ -69,7 +69,7 @@ auto sperr::Bitstream::rbit() -> bool m_bits = 64; } --m_bits; - bool bit = m_buffer & uint64_t{1}; + bool bit = m_buffer& uint64_t{1}; m_buffer >>= 1; return bit; } diff --git a/src/SPECK1D_INT.cpp b/src/SPECK1D_INT.cpp index d3c1a06c..27d6937c 100644 --- a/src/SPECK1D_INT.cpp +++ b/src/SPECK1D_INT.cpp @@ -5,7 +5,8 @@ #include #include -void sperr::SPECK1D_INT::m_clean_LIS() +template +void sperr::SPECK1D_INT::m_clean_LIS() { for (auto& list : m_LIS) { auto it = std::remove_if(list.begin(), list.end(), @@ -18,7 +19,8 @@ void sperr::SPECK1D_INT::m_clean_LIS() m_LIP.erase(it, m_LIP.end()); } -void sperr::SPECK1D_INT::m_initialize_lists() +template +void sperr::SPECK1D_INT::m_initialize_lists() { const auto total_len = m_dims[0]; auto num_of_parts = sperr::num_of_partitions(total_len); @@ -36,7 +38,8 @@ void sperr::SPECK1D_INT::m_initialize_lists() m_LIS[sets[1].part_level].emplace_back(sets[1]); } -auto sperr::SPECK1D_INT::m_partition_set(const Set1D& set) const -> std::array +template +auto sperr::SPECK1D_INT::m_partition_set(const Set1D& set) const -> std::array { std::array subsets; // Prepare the 1st set @@ -52,3 +55,8 @@ auto sperr::SPECK1D_INT::m_partition_set(const Set1D& set) const -> std::array; +template class sperr::SPECK1D_INT; +template class sperr::SPECK1D_INT; +template class sperr::SPECK1D_INT; diff --git a/src/SPECK1D_INT_DEC.cpp b/src/SPECK1D_INT_DEC.cpp index 5338be95..1fd443a7 100644 --- a/src/SPECK1D_INT_DEC.cpp +++ b/src/SPECK1D_INT_DEC.cpp @@ -5,7 +5,8 @@ #include // std::memcpy() #include -void sperr::SPECK1D_INT_DEC::m_sorting_pass() +template +void sperr::SPECK1D_INT_DEC::m_sorting_pass() { // Since we have a separate representation of LIP, let's process that list first // @@ -23,7 +24,8 @@ void sperr::SPECK1D_INT_DEC::m_sorting_pass() } } -void sperr::SPECK1D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& counter, bool read) +template +void sperr::SPECK1D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& counter, bool read) { auto& set = m_LIS[idx1][idx2]; bool is_sig = true; @@ -40,7 +42,8 @@ void sperr::SPECK1D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& count } } -void sperr::SPECK1D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) +template +void sperr::SPECK1D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) { bool is_sig = true; const auto pixel_idx = m_LIP[loc]; @@ -62,7 +65,8 @@ void sperr::SPECK1D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) } } -void sperr::SPECK1D_INT_DEC::m_code_S(size_t idx1, size_t idx2) +template +void sperr::SPECK1D_INT_DEC::m_code_S(size_t idx1, size_t idx2) { auto subsets = m_partition_set(m_LIS[idx1][idx2]); auto sig_counter = size_t{0}; @@ -96,3 +100,8 @@ void sperr::SPECK1D_INT_DEC::m_code_S(size_t idx1, size_t idx2) m_process_S(newidx1, m_LIS[newidx1].size() - 1, sig_counter, read); } } + +template class sperr::SPECK1D_INT_DEC; +template class sperr::SPECK1D_INT_DEC; +template class sperr::SPECK1D_INT_DEC; +template class sperr::SPECK1D_INT_DEC; diff --git a/src/SPECK1D_INT_ENC.cpp b/src/SPECK1D_INT_ENC.cpp index 4a628702..20ca6058 100644 --- a/src/SPECK1D_INT_ENC.cpp +++ b/src/SPECK1D_INT_ENC.cpp @@ -5,7 +5,8 @@ #include // std::memcpy() #include -void sperr::SPECK1D_INT_ENC::m_sorting_pass() +template +void sperr::SPECK1D_INT_ENC::m_sorting_pass() { // Since we have a separate representation of LIP, let's process that list first! // @@ -23,11 +24,12 @@ void sperr::SPECK1D_INT_ENC::m_sorting_pass() } } -void sperr::SPECK1D_INT_ENC::m_process_S(size_t idx1, - size_t idx2, - SigType sig, - size_t& counter, - bool output) +template +void sperr::SPECK1D_INT_ENC::m_process_S(size_t idx1, + size_t idx2, + SigType sig, + size_t& counter, + bool output) { // Significance type cannot be NewlySig! assert(sig != SigType::NewlySig); @@ -65,7 +67,8 @@ void sperr::SPECK1D_INT_ENC::m_process_S(size_t idx1, } } -void sperr::SPECK1D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counter, bool output) +template +void sperr::SPECK1D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counter, bool output) { const auto pixel_idx = m_LIP[loc]; @@ -91,7 +94,10 @@ void sperr::SPECK1D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counte } } -void sperr::SPECK1D_INT_ENC::m_code_S(size_t idx1, size_t idx2, std::array subset_sigs) +template +void sperr::SPECK1D_INT_ENC::m_code_S(size_t idx1, + size_t idx2, + std::array subset_sigs) { auto subsets = m_partition_set(m_LIS[idx1][idx2]); auto sig_counter = size_t{0}; @@ -130,7 +136,8 @@ void sperr::SPECK1D_INT_ENC::m_code_S(size_t idx1, size_t idx2, std::array +auto sperr::SPECK1D_INT_ENC::m_decide_significance(const Set1D& set) const -> std::pair { assert(set.length != 0); @@ -145,3 +152,8 @@ auto sperr::SPECK1D_INT_ENC::m_decide_significance(const Set1D& set) const else return {SigType::Insig, 0}; } + +template class sperr::SPECK1D_INT_ENC; +template class sperr::SPECK1D_INT_ENC; +template class sperr::SPECK1D_INT_ENC; +template class sperr::SPECK1D_INT_ENC; diff --git a/src/SPECK3D_INT.cpp b/src/SPECK3D_INT.cpp index c3a585ad..c633d182 100644 --- a/src/SPECK3D_INT.cpp +++ b/src/SPECK3D_INT.cpp @@ -15,7 +15,8 @@ auto sperr::Set3D::is_empty() const -> bool return (length_z == 0 || length_y == 0 || length_x == 0); } -void sperr::SPECK3D_INT::m_clean_LIS() +template +void sperr::SPECK3D_INT::m_clean_LIS() { for (auto& list : m_LIS) { auto it = std::remove_if(list.begin(), list.end(), @@ -28,7 +29,8 @@ void sperr::SPECK3D_INT::m_clean_LIS() m_LIP.erase(it, m_LIP.end()); } -void sperr::SPECK3D_INT::m_initialize_lists() +template +void sperr::SPECK3D_INT::m_initialize_lists() { std::array num_of_parts; // how many times each dimension could be partitioned? num_of_parts[0] = sperr::num_of_partitions(m_dims[0]); @@ -97,7 +99,8 @@ void sperr::SPECK3D_INT::m_initialize_lists() m_bit_buffer.reserve(m_coeff_buf.size()); // a reasonable starting point } -auto sperr::SPECK3D_INT::m_partition_S_XYZ(const Set3D& set) -> std::array +template +auto sperr::SPECK3D_INT::m_partition_S_XYZ(const Set3D& set) -> std::array { const auto split_x = std::array{set.length_x - set.length_x / 2, set.length_x / 2}; const auto split_y = std::array{set.length_y - set.length_y / 2, set.length_y / 2}; @@ -202,7 +205,8 @@ auto sperr::SPECK3D_INT::m_partition_S_XYZ(const Set3D& set) -> std::array std::array +template +auto sperr::SPECK3D_INT::m_partition_S_XY(const Set3D& set) -> std::array { std::array subsets; @@ -265,7 +269,8 @@ auto sperr::SPECK3D_INT::m_partition_S_XY(const Set3D& set) -> std::array std::array +template +auto sperr::SPECK3D_INT::m_partition_S_Z(const Set3D& set) -> std::array { std::array subsets; @@ -300,3 +305,8 @@ auto sperr::SPECK3D_INT::m_partition_S_Z(const Set3D& set) -> std::array; +template class sperr::SPECK3D_INT; +template class sperr::SPECK3D_INT; +template class sperr::SPECK3D_INT; diff --git a/src/SPECK3D_INT_DEC.cpp b/src/SPECK3D_INT_DEC.cpp index 53e37f10..6a0fb47a 100644 --- a/src/SPECK3D_INT_DEC.cpp +++ b/src/SPECK3D_INT_DEC.cpp @@ -5,7 +5,8 @@ #include // std::memcpy() #include -void sperr::SPECK3D_INT_DEC::m_sorting_pass() +template +void sperr::SPECK3D_INT_DEC::m_sorting_pass() { // Since we have a separate representation of LIP, let's process that list first // @@ -23,7 +24,8 @@ void sperr::SPECK3D_INT_DEC::m_sorting_pass() } } -void sperr::SPECK3D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& counter, bool read) +template +void sperr::SPECK3D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& counter, bool read) { auto& set = m_LIS[idx1][idx2]; @@ -41,7 +43,8 @@ void sperr::SPECK3D_INT_DEC::m_process_S(size_t idx1, size_t idx2, size_t& count } } -void sperr::SPECK3D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) +template +void sperr::SPECK3D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) { bool is_sig = true; const auto pixel_idx = m_LIP[loc]; @@ -63,7 +66,8 @@ void sperr::SPECK3D_INT_DEC::m_process_P(size_t loc, size_t& counter, bool read) } } -void sperr::SPECK3D_INT_DEC::m_code_S(size_t idx1, size_t idx2) +template +void sperr::SPECK3D_INT_DEC::m_code_S(size_t idx1, size_t idx2) { auto subsets = m_partition_S_XYZ(m_LIS[idx1][idx2]); const auto set_end = @@ -90,3 +94,8 @@ void sperr::SPECK3D_INT_DEC::m_code_S(size_t idx1, size_t idx2) } } } + +template class sperr::SPECK3D_INT_DEC; +template class sperr::SPECK3D_INT_DEC; +template class sperr::SPECK3D_INT_DEC; +template class sperr::SPECK3D_INT_DEC; diff --git a/src/SPECK3D_INT_ENC.cpp b/src/SPECK3D_INT_ENC.cpp index fc4f25f9..cea638c3 100644 --- a/src/SPECK3D_INT_ENC.cpp +++ b/src/SPECK3D_INT_ENC.cpp @@ -5,7 +5,8 @@ #include // std::memcpy() #include -void sperr::SPECK3D_INT_ENC::m_sorting_pass() +template +void sperr::SPECK3D_INT_ENC::m_sorting_pass() { // Since we have a separate representation of LIP, let's process that list first! // @@ -23,11 +24,12 @@ void sperr::SPECK3D_INT_ENC::m_sorting_pass() } } -void sperr::SPECK3D_INT_ENC::m_process_S(size_t idx1, - size_t idx2, - SigType sig, - size_t& counter, - bool output) +template +void sperr::SPECK3D_INT_ENC::m_process_S(size_t idx1, + size_t idx2, + SigType sig, + size_t& counter, + bool output) { // Significance type cannot be NewlySig! assert(sig != SigType::NewlySig); @@ -79,7 +81,8 @@ void sperr::SPECK3D_INT_ENC::m_process_S(size_t idx1, } } -void sperr::SPECK3D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counter, bool output) +template +void sperr::SPECK3D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counter, bool output) { const auto pixel_idx = m_LIP[loc]; @@ -105,7 +108,10 @@ void sperr::SPECK3D_INT_ENC::m_process_P(size_t loc, SigType sig, size_t& counte } } -void sperr::SPECK3D_INT_ENC::m_code_S(size_t idx1, size_t idx2, std::array subset_sigs) +template +void sperr::SPECK3D_INT_ENC::m_code_S(size_t idx1, + size_t idx2, + std::array subset_sigs) { auto subsets = m_partition_S_XYZ(m_LIS[idx1][idx2]); @@ -145,7 +151,8 @@ void sperr::SPECK3D_INT_ENC::m_code_S(size_t idx1, size_t idx2, std::array +auto sperr::SPECK3D_INT_ENC::m_decide_significance(const Set3D& set) const -> std::pair> { assert(!set.is_empty()); @@ -169,3 +176,8 @@ auto sperr::SPECK3D_INT_ENC::m_decide_significance(const Set3D& set) const return {SigType::Insig, {0, 0, 0}}; } + +template class sperr::SPECK3D_INT_ENC; +template class sperr::SPECK3D_INT_ENC; +template class sperr::SPECK3D_INT_ENC; +template class sperr::SPECK3D_INT_ENC; diff --git a/src/SPECK_INT.cpp b/src/SPECK_INT.cpp index 2c29c761..e8cc297e 100644 --- a/src/SPECK_INT.cpp +++ b/src/SPECK_INT.cpp @@ -9,6 +9,20 @@ template sperr::SPECK_INT::SPECK_INT() { static_assert(std::is_integral_v); + static_assert(std::is_unsigned_v); +} + +template +auto sperr::SPECK_INT::integer_len() const -> size_t +{ + if constexpr (std::is_same_v) + return sizeof(uint64_t); + else if constexpr (std::is_same_v) + return sizeof(uint32_t); + else if constexpr (std::is_same_v) + return sizeof(uint16_t); + else + return sizeof(uint8_t); } template @@ -18,15 +32,31 @@ void sperr::SPECK_INT::set_dims(dims_type dims) } template -auto sperr::SPECK_INT::get_speck_full_len(const void* buf) const -> uint64_t +auto sperr::SPECK_INT::get_num_bitplanes(const void* buf) const -> uint8_t { - // Given the header definition, directly go retrieve the value stored in the bytes 1--9. + // Given the header definition, directly retrieve the value stored in the first byte. + const uint8_t* const ptr = static_cast(buf); + uint8_t bitplanes = 0; + std::memcpy(&bitplanes, ptr, sizeof(bitplanes)); + return bitplanes; +} + +template +auto sperr::SPECK_INT::get_speck_bits(const void* buf) const -> uint64_t +{ + // Given the header definition, directly retrieve the value stored in bytes 1--9. const uint8_t* const ptr = static_cast(buf); uint64_t num_bits = 0; std::memcpy(&num_bits, ptr + 1, sizeof(num_bits)); + return num_bits; +} + +template +auto sperr::SPECK_INT::get_stream_full_len(const void* buf) const -> uint64_t +{ + auto num_bits = get_speck_bits(buf); while (num_bits % 8 != 0) ++num_bits; - return (m_header_size + num_bits / 8); } @@ -45,7 +75,7 @@ void sperr::SPECK_INT::encode() const auto max_coeff = *std::max_element(m_coeff_buf.cbegin(), m_coeff_buf.cend()); m_num_bitplanes = 1; m_threshold = 1; - while (m_threshold * uint_type{2} <= max_coeff) { + while (m_threshold* uint_type{2} <= max_coeff) { m_threshold *= uint_type{2}; m_num_bitplanes++; } @@ -220,3 +250,8 @@ void sperr::SPECK_INT::m_refinement_pass_decode() m_LSP_mask.write_bit(idx, true); m_LSP_new.clear(); } + +template class sperr::SPECK_INT; +template class sperr::SPECK_INT; +template class sperr::SPECK_INT; +template class sperr::SPECK_INT; diff --git a/src/SPERR3D.cpp b/src/SPERR3D.cpp index 138cb30c..447472ba 100644 --- a/src/SPERR3D.cpp +++ b/src/SPERR3D.cpp @@ -2,13 +2,46 @@ #include "SPECK3D_INT_DEC.h" #include "SPECK3D_INT_ENC.h" -// -// Constructor -// -sperr::SPERR3D::SPERR3D() +void sperr::SPERR3D::m_instantiate_encoder() { - m_encoder = std::make_unique(); - m_decoder = std::make_unique(); + switch (m_uint_flag) { + case UINTType::UINT64: + if (m_encoder.index() != 0 || std::get_if<0>(&m_encoder) == nullptr) + m_encoder = std::make_unique>(); + break; + case UINTType::UINT32: + if (m_encoder.index() != 1 || std::get_if<0>(&m_encoder) == nullptr) + m_encoder = std::make_unique>(); + break; + case UINTType::UINT16: + if (m_encoder.index() != 2 || std::get_if<0>(&m_encoder) == nullptr) + m_encoder = std::make_unique>(); + break; + default: + if (m_encoder.index() != 3 || std::get_if<0>(&m_encoder) == nullptr) + m_encoder = std::make_unique>(); + } +} + +void sperr::SPERR3D::m_instantiate_decoder() +{ + switch (m_uint_flag) { + case UINTType::UINT64: + if (m_decoder.index() != 0 || std::get_if<0>(&m_decoder) == nullptr) + m_decoder = std::make_unique>(); + break; + case UINTType::UINT32: + if (m_decoder.index() != 1 || std::get_if<1>(&m_decoder) == nullptr) + m_decoder = std::make_unique>(); + break; + case UINTType::UINT16: + if (m_decoder.index() != 2 || std::get_if<2>(&m_decoder) == nullptr) + m_decoder = std::make_unique>(); + break; + default: + if (m_decoder.index() != 3 || std::get_if<3>(&m_decoder) == nullptr) + m_decoder = std::make_unique>(); + } } void sperr::SPERR3D::m_wavelet_xform() diff --git a/src/SPERR_Driver.cpp b/src/SPERR_Driver.cpp index 2ba46954..bc3aa0f0 100644 --- a/src/SPERR_Driver.cpp +++ b/src/SPERR_Driver.cpp @@ -27,14 +27,14 @@ auto sperr::SPERR_Driver::use_bitstream(const void* p, size_t len) -> RTNType { // So let's clean up everything at the very beginning of this routine. m_vals_d.clear(); - m_vals_ui.clear(); m_sign_array.clear(); m_condi_bitstream.fill(0); m_speck_bitstream.clear(); + std::visit([](auto& vec) { vec.clear(); }, m_vals_ui); const uint8_t* const ptr = static_cast(p); - // Step 1: extract conditioner stream + // Bitstream parser 1: extract conditioner stream const auto condi_size = m_condi_bitstream.size(); if (condi_size > len) return RTNType::BitstreamWrongLen; @@ -42,9 +42,9 @@ auto sperr::SPERR_Driver::use_bitstream(const void* p, size_t len) -> RTNType size_t pos = condi_size; // `m_condi_bitstream` might be indicating that the field is a constant field. - // In that case, there will be no more speck or sperr streams. - // Let's detect that case here and return early if it is true. - // It will be up to the decompress() routine to restore the actual constant field. + // In that case, there will be no more speck or sperr streams. + // Let's detect that case here and return early if it is true. + // It will be up to the decompress() routine to restore the actual constant field. auto constant = m_conditioner.parse_constant(m_condi_bitstream); if (std::get<0>(constant)) { if (condi_size == len) @@ -53,14 +53,31 @@ auto sperr::SPERR_Driver::use_bitstream(const void* p, size_t len) -> RTNType return RTNType::BitstreamWrongLen; } - // Step 2: extract SPECK stream from it + // Bitstream parser 2: extract SPECK stream from it const uint8_t* const speck_p = ptr + pos; - const auto speck_full_len = m_decoder->get_speck_full_len(speck_p); + const auto speck_full_len = std::visit( + [speck_p](const auto& decoder) { return decoder->get_stream_full_len(speck_p); }, m_decoder); if (speck_full_len != len - pos) return RTNType::BitstreamWrongLen; m_speck_bitstream.resize(speck_full_len); std::copy(speck_p, speck_p + speck_full_len, m_speck_bitstream.begin()); + // Integer length decision 1: decide the integer length to use + const uint32_t num_bitplanes = std::visit( + [speck_p](const auto& decoder) { return decoder->get_num_bitplanes(speck_p); }, m_decoder); + if (num_bitplanes <= 8) + m_uint_flag = UINTType::UINT8; + else if (num_bitplanes <= 16) + m_uint_flag = UINTType::UINT16; + else if (num_bitplanes <= 32) + m_uint_flag = UINTType::UINT32; + else + m_uint_flag = UINTType::UINT64; + + // Integer length decision 2: make sure `m_vals_ui` and `m_decoder` use the decided length + m_instantiate_int_vec(); + m_instantiate_decoder(); + return RTNType::Good; } @@ -94,51 +111,107 @@ void sperr::SPERR_Driver::set_dims(dims_type dims) m_dims = dims; } +void sperr::SPERR_Driver::m_instantiate_int_vec() +{ + switch (m_uint_flag) { + case UINTType::UINT64: + if (m_vals_ui.index() != 0) + m_vals_ui = std::vector(); + break; + case UINTType::UINT32: + if (m_vals_ui.index() != 1) + m_vals_ui = std::vector(); + break; + case UINTType::UINT16: + if (m_vals_ui.index() != 2) + m_vals_ui = std::vector(); + break; + default: + if (m_vals_ui.index() != 3) + m_vals_ui = std::vector(); + } +} + auto sperr::SPERR_Driver::m_midtread_f2i() -> RTNType { // Make sure that the rounding mode is what we wanted. // Here are two methods of querying the current rounding mode; not sure - // how they compare, so test both of them for now. + // how they compare, so test both of them for now. assert(FE_TONEAREST == FLT_ROUNDS); assert(FE_TONEAREST == std::fegetround()); const auto total_vals = m_vals_d.size(); const auto q1 = 1.0 / m_q; auto vals_ll = std::vector(total_vals); - m_vals_ui.resize(total_vals); m_sign_array.resize(total_vals); std::feclearexcept(FE_INVALID); std::transform(m_vals_d.cbegin(), m_vals_d.cend(), vals_ll.begin(), [q1](auto d) { return std::llrint(d * q1); }); if (std::fetestexcept(FE_INVALID)) return RTNType::FE_Invalid; - std::transform(vals_ll.cbegin(), vals_ll.cend(), m_vals_ui.begin(), - [](auto ll) { return static_cast(std::abs(ll)); }); + + // Extract signs from the quantized integers. std::transform(vals_ll.cbegin(), vals_ll.cend(), m_sign_array.begin(), [](auto ll) { return ll >= 0; }); + // Decide integer length + const auto maxmag = *std::max_element(vals_ll.cbegin(), vals_ll.cend(), + [](auto a, auto b) { return std::abs(a) < std::abs(b); }); + if (maxmag <= std::numeric_limits::max()) + m_uint_flag = UINTType::UINT8; + else if (maxmag <= std::numeric_limits::max()) + m_uint_flag = UINTType::UINT16; + else if (maxmag <= std::numeric_limits::max()) + m_uint_flag = UINTType::UINT32; + else + m_uint_flag = UINTType::UINT64; + + // Use the correct integer length for `m_vals_ui`, and keep the integer magnitude. + m_instantiate_int_vec(); + std::visit([total_vals](auto& vec) { vec.resize(total_vals); }, m_vals_ui); + switch (m_uint_flag) { + case UINTType::UINT64: + assert(m_vals_ui.index() == 0); + std::transform(vals_ll.cbegin(), vals_ll.cend(), std::get_if<0>(&m_vals_ui)->begin(), + [](auto ll) { return static_cast(std::abs(ll)); }); + break; + case UINTType::UINT32: + assert(m_vals_ui.index() == 1); + std::transform(vals_ll.cbegin(), vals_ll.cend(), std::get_if<1>(&m_vals_ui)->begin(), + [](auto ll) { return static_cast(std::abs(ll)); }); + break; + case UINTType::UINT16: + assert(m_vals_ui.index() == 2); + std::transform(vals_ll.cbegin(), vals_ll.cend(), std::get_if<2>(&m_vals_ui)->begin(), + [](auto ll) { return static_cast(std::abs(ll)); }); + break; + default: + assert(m_vals_ui.index() == 3); + std::transform(vals_ll.cbegin(), vals_ll.cend(), std::get_if<3>(&m_vals_ui)->begin(), + [](auto ll) { return static_cast(std::abs(ll)); }); + } + return RTNType::Good; } void sperr::SPERR_Driver::m_midtread_i2f() { - assert(m_sign_array.size() == m_vals_ui.size()); + assert(m_sign_array.size() == std::visit([](auto& vec) { return vec.size(); }, m_vals_ui)); const auto tmpd = std::array{-1.0, 1.0}; - const auto q = m_q; m_vals_d.resize(m_sign_array.size()); - std::transform(m_vals_ui.cbegin(), m_vals_ui.cend(), m_sign_array.cbegin(), m_vals_d.begin(), - [tmpd, q](auto i, auto b) { return q * static_cast(i) * tmpd[b]; }); -} - -auto sperr::SPERR_Driver::m_proc_1() -> RTNType -{ - return RTNType::Good; -} - -auto sperr::SPERR_Driver::m_proc_2() -> RTNType -{ - return RTNType::Good; + std::visit( + [&vals_d = m_vals_d, &sign_array = m_sign_array, tmpd, q = m_q](auto& vec) { + std::transform( + vec.cbegin(), vec.cend(), sign_array.cbegin(), vals_d.begin(), + [tmpd, q](auto i, auto b) { return (q * static_cast(i) * tmpd[b]); }); + }, + m_vals_ui); + + // std::visit() obscures the intention, but the task is really the same as the following: + // + // std::transform(m_vals_ui.cbegin(), m_vals_ui.cend(), m_sign_array.cbegin(), m_vals_d.begin(), + // [tmpd, q](auto i, auto b) { return q * static_cast(i) * tmpd[b]; }); } auto sperr::SPERR_Driver::compress() -> RTNType @@ -178,10 +251,37 @@ auto sperr::SPERR_Driver::compress() -> RTNType return rtn; // Step 4: Integer SPECK encoding - m_encoder->set_dims(m_dims); - m_encoder->use_coeffs(std::move(m_vals_ui), std::move(m_sign_array)); - m_encoder->encode(); - m_encoder->write_encoded_bitstream(m_speck_bitstream); + m_instantiate_encoder(); + std::visit([&dims = m_dims](auto& encoder) { encoder->set_dims(dims); }, m_encoder); + switch (m_uint_flag) { + case UINTType::UINT64: + assert(m_vals_ui.index() == 0); + assert(m_encoder.index() == 0); + std::get<0>(m_encoder)->use_coeffs(std::move(std::get<0>(m_vals_ui)), + std::move(m_sign_array)); + break; + case UINTType::UINT32: + assert(m_vals_ui.index() == 1); + assert(m_encoder.index() == 1); + std::get<1>(m_encoder)->use_coeffs(std::move(std::get<1>(m_vals_ui)), + std::move(m_sign_array)); + break; + case UINTType::UINT16: + assert(m_vals_ui.index() == 2); + assert(m_encoder.index() == 2); + std::get<2>(m_encoder)->use_coeffs(std::move(std::get<2>(m_vals_ui)), + std::move(m_sign_array)); + break; + default: + assert(m_vals_ui.index() == 3); + assert(m_encoder.index() == 3); + std::get<3>(m_encoder)->use_coeffs(std::move(std::get<3>(m_vals_ui)), + std::move(m_sign_array)); + } + std::visit([](auto& encoder) { encoder->encode(); }, m_encoder); + std::visit([&speck_bitstream = m_speck_bitstream]( + auto& encoder) { encoder->write_encoded_bitstream(speck_bitstream); }, + m_encoder); return RTNType::Good; } @@ -189,12 +289,12 @@ auto sperr::SPERR_Driver::compress() -> RTNType auto sperr::SPERR_Driver::decompress() -> RTNType { m_vals_d.clear(); - m_vals_ui.clear(); + std::visit([](auto& vec) { vec.clear(); }, m_vals_ui); m_sign_array.clear(); const auto total_vals = uint64_t(m_dims[0]) * m_dims[1] * m_dims[2]; // `m_condi_bitstream` might be indicating a constant field, so let's see if that's - // the case, and if it is, we don't need to go through wavelet and speck stuff anymore. + // the case, and if it is, we don't need to go through wavelet and speck stuff anymore. auto constant = m_conditioner.parse_constant(m_condi_bitstream); if (std::get<0>(constant)) { auto val = std::get<1>(constant); @@ -205,11 +305,30 @@ auto sperr::SPERR_Driver::decompress() -> RTNType // Step 1: Integer SPECK decode. assert(!m_speck_bitstream.empty()); - m_decoder->set_dims(m_dims); - m_decoder->use_bitstream(m_speck_bitstream); - m_decoder->decode(); - m_vals_ui = m_decoder->release_coeffs(); - m_sign_array = m_decoder->release_signs(); + std::visit([&dims = m_dims](auto& decoder) { decoder->set_dims(dims); }, m_decoder); + std::visit([&speck_bitstream = + m_speck_bitstream](auto& decoder) { decoder->use_bitstream(speck_bitstream); }, + m_decoder); + std::visit([](auto& decoder) { decoder->decode(); }, m_decoder); + switch (m_uint_flag) { // `m_uint_flag` was properly set during `use_bitstream()`. + case UINTType::UINT64: + assert(m_decoder.index() == 0); + m_vals_ui = std::get<0>(m_decoder)->release_coeffs(); + break; + case UINTType::UINT32: + assert(m_decoder.index() == 1); + m_vals_ui = std::get<1>(m_decoder)->release_coeffs(); + break; + case UINTType::UINT16: + assert(m_decoder.index() == 2); + m_vals_ui = std::get<2>(m_decoder)->release_coeffs(); + break; + default: + assert(m_decoder.index() == 3); + m_vals_ui = std::get<3>(m_decoder)->release_coeffs(); + break; + } + m_sign_array = std::visit([](auto& decoder) { return decoder->release_signs(); }, m_decoder); // Step 2: Inverse quantization auto rtn = m_inverse_quantize(); diff --git a/test_scripts/CMakeLists.txt b/test_scripts/CMakeLists.txt index 016d90af..041729a6 100644 --- a/test_scripts/CMakeLists.txt +++ b/test_scripts/CMakeLists.txt @@ -1,5 +1,5 @@ -#dd_executable( test_sperr_helper sperr_helper_unit_test.cpp ) -#arget_link_libraries( test_sperr_helper PUBLIC SPERR gtest_main ) +add_executable( test_sperr_helper sperr_helper_unit_test.cpp ) +target_link_libraries( test_sperr_helper PUBLIC SPERR gtest_main ) add_executable( bitstream bitstream_unit_test.cpp ) target_link_libraries( bitstream PUBLIC SPERR gtest_main ) @@ -14,7 +14,7 @@ target_link_libraries( test_speck_int PUBLIC SPERR gtest_main ) #target_link_libraries( test_sperr3d PUBLIC SPERR gtest_main ) include(GoogleTest) -#test_discover_tests( test_sperr_helper ) +gtest_discover_tests( test_sperr_helper ) gtest_discover_tests( bitstream ) gtest_discover_tests( test_dwt ) gtest_discover_tests( test_speck_int ) diff --git a/test_scripts/bitstream_unit_test.cpp b/test_scripts/bitstream_unit_test.cpp index 72a39f89..973cc69b 100644 --- a/test_scripts/bitstream_unit_test.cpp +++ b/test_scripts/bitstream_unit_test.cpp @@ -234,4 +234,40 @@ TEST(Bitmask, RandomReadWrite) EXPECT_EQ(m1.read_bit(i), vec[i]) << "at idx = " << i; } +#if __cplusplus >= 202002L +TEST(Bitmask, BufferTransfer) +{ + auto src = Mask(60); + src.write_long(0, 78344ul); + auto buf = src.view_buffer(); + + auto dst = Mask(60); + dst.use_bitstream(buf.data()); + EXPECT_EQ(src, dst); + + src.resize(120); + src.write_long(100, 19837ul); + buf = src.view_buffer(); + + dst.resize(120); + dst.use_bitstream(buf.data()); + EXPECT_EQ(src, dst); + + src.resize(128); + buf = src.view_buffer(); + + dst.resize(128); + dst.use_bitstream(buf.data()); + EXPECT_EQ(src, dst); + + src.resize(150); + src.write_long(130, 19837ul); + buf = src.view_buffer(); + + dst.resize(150); + dst.use_bitstream(buf.data()); + EXPECT_EQ(src, dst); +} +#endif + } // namespace diff --git a/test_scripts/speck_int_unit_test.cpp b/test_scripts/speck_int_unit_test.cpp index eaacbe94..b1b74df3 100644 --- a/test_scripts/speck_int_unit_test.cpp +++ b/test_scripts/speck_int_unit_test.cpp @@ -13,8 +13,9 @@ namespace { +template auto ProduceRandomArray(size_t len, float stddev, uint32_t seed) - -> std::pair, std::vector> + -> std::pair, std::vector> { std::mt19937 gen{seed}; std::normal_distribution d{0.0, stddev}; @@ -22,10 +23,15 @@ auto ProduceRandomArray(size_t len, float stddev, uint32_t seed) auto tmp = std::vector(len); std::generate(tmp.begin(), tmp.end(), [&gen, &d](){ return d(gen); }); - auto coeffs = std::vector(len); + auto coeffs = std::vector(len); auto signs = std::vector(len, true); for (size_t i = 0; i < len; i++) { - coeffs[i] = std::round(std::abs(tmp[i])); + auto l = std::lround(std::abs(tmp[i])); + if (l > std::numeric_limits::max()) + coeffs[i] = std::numeric_limits::max(); + else + coeffs[i] = l; + // Only specify signs for non-zero values. if (coeffs[i] != 0) signs[i] = tmp[i] > 0.f; } @@ -40,7 +46,7 @@ TEST(SPECK1D_INT, minimal) { const auto dims = sperr::dims_type{40, 1, 1}; - auto input = sperr::vecui_t(dims[0], 0); + auto input = std::vector(dims[0], 0); auto input_signs = sperr::vecb_type(input.size(), true); input[4] = 1; input[7] = 3; input_signs[7] = false; @@ -53,34 +59,120 @@ TEST(SPECK1D_INT, minimal) input[32] = 32; input[39] = 32; input_signs[39] = false; - // Encode - auto encoder = sperr::SPECK1D_INT_ENC(); + // + // Test 1-byte integer + // + { + auto encoder = sperr::SPECK1D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); auto bitstream = sperr::vec8_type(); encoder.write_encoded_bitstream(bitstream); - // Decode - auto decoder = sperr::SPECK1D_INT_DEC(); + auto decoder = sperr::SPECK1D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); auto output = decoder.release_coeffs(); auto output_signs = decoder.release_signs(); + EXPECT_EQ(encoder.integer_len(), 1); + EXPECT_EQ(decoder.integer_len(), 1); EXPECT_EQ(input, output); EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 2-byte integer + // + auto input16 = std::vector(dims[0], 0); + std::copy(input.cbegin(), input.cend(), input16.begin()); + input16[30] = 300; + { + auto encoder = sperr::SPECK1D_INT_ENC(); + encoder.use_coeffs(input16, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK1D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 2); + EXPECT_EQ(decoder.integer_len(), 2); + EXPECT_EQ(input16, output); + EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 4-byte integer + // + auto input32 = std::vector(dims[0], 0); + std::copy(input16.cbegin(), input16.cend(), input32.begin()); + input32[20] = 70'000; + { + auto encoder = sperr::SPECK1D_INT_ENC(); + encoder.use_coeffs(input32, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK1D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 4); + EXPECT_EQ(decoder.integer_len(), 4); + EXPECT_EQ(input32, output); + EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 8-byte integer + // + auto input64 = std::vector(dims[0], 0); + std::copy(input32.cbegin(), input32.cend(), input64.begin()); + input64[23] = 5'000'700'000; + { + auto encoder = sperr::SPECK1D_INT_ENC(); + encoder.use_coeffs(input64, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK1D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 8); + EXPECT_EQ(decoder.integer_len(), 8); + EXPECT_EQ(input64, output); + EXPECT_EQ(input_signs, output_signs); + } } TEST(SPECK1D_INT, Random1) { const auto dims = sperr::dims_type{2000, 1, 1}; - auto [input, input_signs] = ProduceRandomArray(dims[0], 2.9, 1); + auto [input, input_signs] = ProduceRandomArray(dims[0], 2.9, 1); // Encode - auto encoder = sperr::SPECK1D_INT_ENC(); + auto encoder = sperr::SPECK1D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -88,7 +180,7 @@ TEST(SPECK1D_INT, Random1) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK1D_INT_DEC(); + auto decoder = sperr::SPECK1D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -103,10 +195,10 @@ TEST(SPECK1D_INT, Random2) { const auto dims = sperr::dims_type{63 * 79 * 128, 1, 1}; - auto [input, input_signs] = ProduceRandomArray(dims[0], 499.0, 2); + auto [input, input_signs] = ProduceRandomArray(dims[0], 499.0, 2); // Encode - auto encoder = sperr::SPECK3D_INT_ENC(); + auto encoder = sperr::SPECK3D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -114,7 +206,7 @@ TEST(SPECK1D_INT, Random2) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK3D_INT_DEC(); + auto decoder = sperr::SPECK3D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -130,10 +222,10 @@ TEST(SPECK1D_INT, RandomRandom) const auto dims = sperr::dims_type{63 * 64 * 119, 1, 1}; auto rd = std::random_device(); - auto [input, input_signs] = ProduceRandomArray(dims[0], 8345.3, rd()); + auto [input, input_signs] = ProduceRandomArray(dims[0], 8345.3, rd()); // Encode - auto encoder = sperr::SPECK1D_INT_ENC(); + auto encoder = sperr::SPECK1D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -141,7 +233,7 @@ TEST(SPECK1D_INT, RandomRandom) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK1D_INT_DEC(); + auto decoder = sperr::SPECK1D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -159,8 +251,7 @@ TEST(SPECK3D_INT, minimal) { const auto dims = sperr::dims_type{4, 3, 8}; const auto total_vals = dims[0] * dims[1] * dims[2]; - - auto input = sperr::vecui_t(total_vals, 0); + auto input = std::vector(total_vals, 0); auto input_signs = sperr::vecb_type(input.size(), true); input[4] = 1; input[7] = 3; input_signs[7] = false; @@ -173,24 +264,114 @@ TEST(SPECK3D_INT, minimal) input[32] = 32; input[39] = 32; input_signs[39] = false; - // Encode - auto encoder = sperr::SPECK3D_INT_ENC(); + // + // Test 1-byte integers + // + { + auto encoder = sperr::SPECK3D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); auto bitstream = sperr::vec8_type(); encoder.write_encoded_bitstream(bitstream); - // Decode - auto decoder = sperr::SPECK3D_INT_DEC(); + auto decoder = sperr::SPECK3D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); auto output = decoder.release_coeffs(); auto output_signs = decoder.release_signs(); + EXPECT_EQ(encoder.integer_len(), 1); + EXPECT_EQ(decoder.integer_len(), 1); + EXPECT_EQ(input, output); EXPECT_EQ(input, output); EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 2-byte integers + // + auto input16 = std::vector(total_vals, 0); + std::copy(input.begin(), input.end(), input16.begin()); + input16[30] = 300; input_signs[30] = false; + { + auto encoder = sperr::SPECK3D_INT_ENC(); + encoder.use_coeffs(input16, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK3D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 2); + EXPECT_EQ(decoder.integer_len(), 2); + EXPECT_EQ(input16, output); + EXPECT_EQ(input16, output); + EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 4-byte integers + // + auto input32 = std::vector(total_vals, 0); + std::copy(input16.begin(), input16.end(), input32.begin()); + input32[20] = 7'0300; input_signs[20] = false; + { + auto encoder = sperr::SPECK3D_INT_ENC(); + encoder.use_coeffs(input32, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK3D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 4); + EXPECT_EQ(decoder.integer_len(), 4); + EXPECT_EQ(input32, output); + EXPECT_EQ(input32, output); + EXPECT_EQ(input_signs, output_signs); + } + + // + // Test 8-byte integers + // + auto input64 = std::vector(total_vals, 0); + std::copy(input32.begin(), input32.end(), input64.begin()); + input64[27] = 5'000'700'990; input_signs[27] = false; + { + auto encoder = sperr::SPECK3D_INT_ENC(); + encoder.use_coeffs(input64, input_signs); + encoder.set_dims(dims); + encoder.encode(); + auto bitstream = sperr::vec8_type(); + encoder.write_encoded_bitstream(bitstream); + + auto decoder = sperr::SPECK3D_INT_DEC(); + decoder.set_dims(dims); + decoder.use_bitstream(bitstream); + decoder.decode(); + auto output = decoder.release_coeffs(); + auto output_signs = decoder.release_signs(); + + EXPECT_EQ(encoder.integer_len(), 8); + EXPECT_EQ(decoder.integer_len(), 8); + EXPECT_EQ(input64, output); + EXPECT_EQ(input64, output); + EXPECT_EQ(input_signs, output_signs); + } } TEST(SPECK3D_INT, Random1) @@ -198,10 +379,10 @@ TEST(SPECK3D_INT, Random1) const auto dims = sperr::dims_type{10, 20, 30}; const auto total_vals = dims[0] * dims[1] * dims[2]; - auto [input, input_signs] = ProduceRandomArray(total_vals, 2.9, 1); + auto [input, input_signs] = ProduceRandomArray(total_vals, 2.9, 1); // Encode - auto encoder = sperr::SPECK3D_INT_ENC(); + auto encoder = sperr::SPECK3D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -209,7 +390,7 @@ TEST(SPECK3D_INT, Random1) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK3D_INT_DEC(); + auto decoder = sperr::SPECK3D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -225,10 +406,10 @@ TEST(SPECK3D_INT, Random2) const auto dims = sperr::dims_type{63, 79, 128}; const auto total_vals = dims[0] * dims[1] * dims[2]; - auto [input, input_signs] = ProduceRandomArray(total_vals, 499.0, 2); + auto [input, input_signs] = ProduceRandomArray(total_vals, 499.0, 2); // Encode - auto encoder = sperr::SPECK3D_INT_ENC(); + auto encoder = sperr::SPECK3D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -236,7 +417,7 @@ TEST(SPECK3D_INT, Random2) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK3D_INT_DEC(); + auto decoder = sperr::SPECK3D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -253,10 +434,10 @@ TEST(SPECK3D_INT, RandomRandom) const auto total_vals = dims[0] * dims[1] * dims[2]; auto rd = std::random_device(); - auto [input, input_signs] = ProduceRandomArray(total_vals, 8345.3, rd()); + auto [input, input_signs] = ProduceRandomArray(total_vals, 8345.3, rd()); // Encode - auto encoder = sperr::SPECK3D_INT_ENC(); + auto encoder = sperr::SPECK3D_INT_ENC(); encoder.use_coeffs(input, input_signs); encoder.set_dims(dims); encoder.encode(); @@ -264,7 +445,7 @@ TEST(SPECK3D_INT, RandomRandom) encoder.write_encoded_bitstream(bitstream); // Decode - auto decoder = sperr::SPECK3D_INT_DEC(); + auto decoder = sperr::SPECK3D_INT_DEC(); decoder.set_dims(dims); decoder.use_bitstream(bitstream); decoder.decode(); @@ -276,5 +457,4 @@ TEST(SPECK3D_INT, RandomRandom) } - } // namespace