Skip to content

Commit

Permalink
More work on scalar arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Nov 8, 2023
1 parent 38f840d commit 241ce4e
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 36 deletions.
9 changes: 7 additions & 2 deletions scalars/include/roughpy/scalars/key_scalar_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#include <roughpy/core/types.h>
#include <roughpy/platform/serialization.h>

RPY_WARNING_PUSH
RPY_CLANG_DISABLE_WARNING(HidingNonVirtualFunction)

namespace rpy {
namespace scalars {

Expand All @@ -60,11 +63,11 @@ class RPY_EXPORT KeyScalarArray : public ScalarArray
dimn_t count
) noexcept;

explicit operator ScalarArray() && noexcept;

RPY_NO_DISCARD KeyScalarArray copy_or_move() &&;

KeyScalarArray& operator=(const ScalarArray& other) noexcept;
KeyScalarArray& operator=(const KeyScalarArray& other);
KeyScalarArray& operator=(const ScalarArray& other);
KeyScalarArray& operator=(KeyScalarArray&& other) noexcept;
KeyScalarArray& operator=(ScalarArray&& other) noexcept;

Expand All @@ -82,4 +85,6 @@ class RPY_EXPORT KeyScalarArray : public ScalarArray
}// namespace scalars
}// namespace rpy


RPY_WARNING_POP
#endif// ROUGHPY_SCALARS_KEY_SCALAR_ARRAY_H_
46 changes: 30 additions & 16 deletions scalars/include/roughpy/scalars/scalar_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
#ifndef ROUGHPY_SCALARS_SCALAR_ARRAY_H_
#define ROUGHPY_SCALARS_SCALAR_ARRAY_H_

#include "scalars_fwd.h"
#include "packed_scalar_type_ptr.h"
#include "scalars_fwd.h"

#include <roughpy/platform/serialization.h>
#include <roughpy/device/buffer.h>
#include <roughpy/platform/serialization.h>

namespace rpy {
namespace scalars {
Expand Down Expand Up @@ -64,16 +64,25 @@ class RPY_EXPORT ScalarArray
dimn_t m_size = 0;

static bool check_pointer_and_size(const void* ptr, dimn_t size);
protected:

protected:
type_pointer packed_type() const noexcept { return p_type_and_mode; }

public:
ScalarArray();
ScalarArray(const ScalarArray& other);
ScalarArray(ScalarArray&& other) noexcept;

explicit ScalarArray(const ScalarType* type, dimn_t size = 0);
explicit ScalarArray(devices::TypeInfo info, dimn_t size = 0);

explicit ScalarArray(const ScalarType* type, const void* data, dimn_t size);
explicit ScalarArray(devices::TypeInfo info, const void* data, dimn_t size);

explicit ScalarArray(const ScalarType* type, void* data, dimn_t size);
explicit ScalarArray(devices::TypeInfo info, void* data, dimn_t size);


explicit ScalarArray(const ScalarType* type, devices::Buffer&& buffer);
explicit ScalarArray(devices::TypeInfo info, devices::Buffer&& buffer);

Expand All @@ -94,7 +103,11 @@ class RPY_EXPORT ScalarArray
ScalarArray& operator=(const ScalarArray& other);
ScalarArray& operator=(ScalarArray&& other) noexcept;

bool is_owning() const noexcept {

ScalarArray copy_or_clone() &&;

bool is_owning() const noexcept
{
return p_type_and_mode.get_enumeration() == discriminator_type::Owned;
}

Expand All @@ -105,12 +118,14 @@ class RPY_EXPORT ScalarArray
constexpr dimn_t size() const noexcept { return m_size; }
dimn_t capacity() const noexcept;
constexpr bool empty() const noexcept { return m_size == 0; }
constexpr bool is_null() const noexcept {
constexpr bool is_null() const noexcept
{
return p_type_and_mode.is_null() && empty();
}
constexpr bool is_const() const noexcept {
return p_type_and_mode.get_enumeration() ==
discriminator_type::BorrowConst;
constexpr bool is_const() const noexcept
{
return p_type_and_mode.get_enumeration()
== discriminator_type::BorrowConst;
}
devices::Device device() const noexcept;

Expand All @@ -119,33 +134,33 @@ class RPY_EXPORT ScalarArray
const devices::Buffer& buffer() const;
devices::Buffer& mut_buffer();


Scalar operator[](dimn_t i) const;
Scalar operator[](dimn_t i);


RPY_SERIAL_SAVE_FN();
RPY_SERIAL_LOAD_FN();

private:
void check_for_ptr_access(bool mut=false) const;
void check_for_ptr_access(bool mut = false) const;

public:
template <typename T>
Slice<T> as_mut_slice() {
Slice<T> as_mut_slice()
{
check_for_ptr_access(true);
return {static_cast<T*>(raw_mut_pointer()), m_size};
}

template <typename T>
Slice<const T> as_slice() {
Slice<const T> as_slice()
{
check_for_ptr_access(true);
return {static_cast<const T*>(raw_pointer()), m_size};
}

private:
const void* raw_pointer(dimn_t i=0) const noexcept;
void* raw_mut_pointer(dimn_t i=0) noexcept;
const void* raw_pointer(dimn_t i = 0) const noexcept;
void* raw_mut_pointer(dimn_t i = 0) noexcept;
};

template <typename T>
Expand Down Expand Up @@ -198,7 +213,6 @@ ScalarArray::ScalarArray(const T* data, dimn_t size)
RPY_SERIAL_EXTERN_SAVE_CLS(ScalarArray)
RPY_SERIAL_EXTERN_LOAD_CLS(ScalarArray)


}// namespace scalars
}// namespace rpy

Expand Down
111 changes: 93 additions & 18 deletions scalars/src/key_scalar_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,42 @@
// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include "key_scalar_array.h"
#include "scalar_type.h"

#include <algorithm>

using namespace rpy;
using namespace rpy::scalars;

KeyScalarArray::~KeyScalarArray() {
if (m_owns_keys) {
delete[] p_keys;
}
KeyScalarArray::~KeyScalarArray()
{
if (m_owns_keys) { delete[] p_keys; }
p_keys = nullptr;
m_owns_keys = false;
}
KeyScalarArray::KeyScalarArray(const KeyScalarArray& other)
: ScalarArray(other)
: ScalarArray(other),
p_keys()
{
if (other.p_keys != nullptr && other.m_owns_keys) {
m_owns_keys = true;
allocate_keys();
std::copy_n(other.p_keys, other.size(), const_cast<key_type*>(p_keys));
}
}
KeyScalarArray::KeyScalarArray(KeyScalarArray&& other) noexcept
: ScalarArray(std::move(other)), p_keys(other.p_keys), m_owns_keys(other.m_owns_keys)
{

}
KeyScalarArray::KeyScalarArray(KeyScalarArray&& other) noexcept {}
KeyScalarArray::KeyScalarArray(ScalarArray&& sa) noexcept {}
KeyScalarArray::KeyScalarArray(ScalarArray base, const key_type* keys) {}
KeyScalarArray::KeyScalarArray(const ScalarType* type) noexcept {}
KeyScalarArray::KeyScalarArray(ScalarArray&& sa) noexcept
: ScalarArray(std::move(sa))
{}
KeyScalarArray::KeyScalarArray(ScalarArray base, const key_type* keys)
: ScalarArray(std::move(base)), p_keys(keys), m_owns_keys(false)
{}
KeyScalarArray::KeyScalarArray(const ScalarType* type) noexcept
: ScalarArray(type)
{}
KeyScalarArray::KeyScalarArray(const ScalarType* type, dimn_t n) noexcept
: ScalarArray(type, n)
{}
Expand All @@ -52,23 +70,80 @@ KeyScalarArray::KeyScalarArray(
const void* begin,
dimn_t count
) noexcept
: ScalarArray(type, begin, count)
{}
KeyScalarArray::operator ScalarArray() && noexcept { return ScalarArray(); }
KeyScalarArray KeyScalarArray::copy_or_move() && { return KeyScalarArray(); }
KeyScalarArray& KeyScalarArray::operator=(const ScalarArray& other) noexcept
KeyScalarArray KeyScalarArray::copy_or_move() && {
if (m_owns_keys) {
return std::move(*this);
}

KeyScalarArray result(static_cast<ScalarArray&&>(*this).copy_or_clone());
result.allocate_keys();
std::copy_n(p_keys, size(), result.keys());

return KeyScalarArray();
}
KeyScalarArray& KeyScalarArray::operator=(const KeyScalarArray& other)
{
if (&other != this) {
this->~KeyScalarArray();
ScalarArray::operator=(other);
if (other.m_owns_keys) {
allocate_keys();
std::copy_n(other.p_keys, other.size(), keys());
m_owns_keys = true;
} else {
p_keys = other.p_keys;
m_owns_keys = false;
}

}
return *this;
}

KeyScalarArray& KeyScalarArray::operator=(const ScalarArray& other)
{
return ScalarArray::operator=(other);
if (&other != this) {
this->~KeyScalarArray();
ScalarArray::operator=(other);
}
return *this;
}
KeyScalarArray& KeyScalarArray::operator=(KeyScalarArray&& other) noexcept
{
if (&other != this) {
this->~KeyScalarArray();

ScalarArray::operator=(std::move(other));
p_keys = other.p_keys;
other.p_keys = nullptr;
m_owns_keys = other.m_owns_keys;

}
return *this;
}
KeyScalarArray& KeyScalarArray::operator=(ScalarArray&& other) noexcept
{
return ScalarArray::operator=(other);
if (&other != this) {
this->~KeyScalarArray();
ScalarArray::operator=(std::move(other));
}
return *this;
}
key_type* KeyScalarArray::keys() {
if (m_owns_keys) {
return const_cast<key_type *>(p_keys);
}
return nullptr;
}
key_type* KeyScalarArray::keys() { return p_keys; }
void KeyScalarArray::allocate_scalars(idimn_t count) {

auto type = this->type();
if (count >= 0 && type) {
*this = (*type)->allocate(static_cast<dimn_t>(count));
}
}
void KeyScalarArray::allocate_keys(idimn_t count) {
if (count == -1 && p_keys == nullptr) {
p_keys = new key_type[size()];
}
}
void KeyScalarArray::allocate_keys(idimn_t count) {}
58 changes: 58 additions & 0 deletions scalars/src/scalar_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,44 @@ bool ScalarArray::check_pointer_and_size(const void* ptr, dimn_t size)
if (size > 0) { RPY_CHECK(ptr != nullptr); }
return true;
}
ScalarArray::ScalarArray()
: const_borrowed(nullptr), m_size(0)
{}
ScalarArray::ScalarArray(const ScalarArray& other)
: p_type_and_mode(other.p_type_and_mode),
m_size(other.m_size)
{
switch (p_type_and_mode.get_enumeration()) {
case dtl::ScalarArrayStorageModel::BorrowConst:
const_borrowed = other.const_borrowed;
break;
case dtl::ScalarArrayStorageModel::BorrowMut:
mut_borrowed = other.mut_borrowed;
break;
case dtl::ScalarArrayStorageModel::Owned:
owned_buffer = other.owned_buffer;
break;
}
}
ScalarArray::ScalarArray(ScalarArray&& other) noexcept
: p_type_and_mode(other.p_type_and_mode),
m_size(other.m_size)
{
switch(p_type_and_mode.get_enumeration()) {
case dtl::ScalarArrayStorageModel::BorrowConst:
const_borrowed = other.const_borrowed;
other.const_borrowed = nullptr;
break;
case dtl::ScalarArrayStorageModel::BorrowMut:
mut_borrowed = other.mut_borrowed;
other.mut_borrowed = nullptr;
break;
case dtl::ScalarArrayStorageModel::Owned:
owned_buffer = std::move(other.owned_buffer);
break;
}
}

ScalarArray::ScalarArray(const ScalarType* type, dimn_t size)
{
RPY_DBG_ASSERT(type != nullptr);
Expand All @@ -57,6 +95,23 @@ ScalarArray::ScalarArray(devices::TypeInfo info, dimn_t size)
= devices::get_host_device()->raw_alloc(size, info.alignment);
}
}
ScalarArray::ScalarArray(const ScalarType* type, const void* data, dimn_t size)
: p_type_and_mode(type, dtl::ScalarArrayStorageModel::BorrowConst),
const_borrowed(data), m_size(size)
{}
ScalarArray::ScalarArray(devices::TypeInfo info, const void* data, dimn_t size)
: p_type_and_mode(info, dtl::ScalarArrayStorageModel::BorrowConst),
const_borrowed(data), m_size(size)
{}
ScalarArray::ScalarArray(const ScalarType* type, void* data, dimn_t size)
: p_type_and_mode(type, dtl::ScalarArrayStorageModel::BorrowMut),
mut_borrowed(data), m_size(size)
{}
ScalarArray::ScalarArray(devices::TypeInfo info, void* data, dimn_t size)
: p_type_and_mode(info, dtl::ScalarArrayStorageModel::BorrowMut),
mut_borrowed(data), m_size(size)
{}

ScalarArray::ScalarArray(const ScalarType* type, devices::Buffer&& buffer)
: p_type_and_mode(type, dtl::ScalarArrayStorageModel::Owned),
owned_buffer(std::move(buffer)),
Expand Down Expand Up @@ -154,6 +209,9 @@ ScalarArray& ScalarArray::operator=(ScalarArray&& other) noexcept
return *this;
}


ScalarArray ScalarArray::copy_or_clone() && { return ScalarArray(); }

optional<const ScalarType*> ScalarArray::type() const noexcept
{
if (p_type_and_mode.is_pointer()) { return p_type_and_mode.get_pointer(); }
Expand Down

0 comments on commit 241ce4e

Please sign in to comment.