Skip to content

Commit

Permalink
Implement copy and move constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
inakleinbottle committed Nov 3, 2023
1 parent 349d4fd commit 7ebdd30
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 205 deletions.
2 changes: 2 additions & 0 deletions scalars/include/roughpy/scalars/scalar_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class RPY_EXPORT ScalarArray

constexpr dimn_t size() const noexcept { return m_size; }

private:
const void* raw_pointer() const noexcept;
};

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion scalars/src/scalar/arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ inline void scalar_inplace_arithmetic(
} else {
Scalar tmp(dst_type);
void* tmp_ptr = tmp.mut_pointer();
scalars::dtl::scalar_convert_copy(tmp_ptr, dst_info, src, src_info);
scalars::dtl::scalar_convert_copy(tmp_ptr, dst_info, src, src_info, 1);
do_op(dst, tmp_ptr, dst_info, std::forward<Op>(op));
}
}
Expand Down
252 changes: 58 additions & 194 deletions scalars/src/scalar/casts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@

#include <roughpy/scalars/scalar.h>

#include "casts.h"
#include "do_macro.h"
#include <roughpy/core/alloc.h>
#include <roughpy/scalars/scalar_types.h>
#include "casts.h"

using namespace rpy;
using namespace scalars;
Expand All @@ -42,9 +43,9 @@ static inline enable_if_t<
is_constructible<D, const T&>::value
&& is_trivially_default_constructible<D>::value,
bool>
write_result(D* dst, const T& src) noexcept
write_result(D* dst, const T* src, dimn_t count) noexcept
{
construct_inplace(dst, src);
for (dimn_t i = 0; i < count; ++i) { construct_inplace(dst++, *(src++)); }
return true;
}

Expand All @@ -53,10 +54,10 @@ static inline enable_if_t<
is_constructible<D, const T&>::value
&& !is_trivially_default_constructible<D>::value,
bool>
write_result(D* dst, const T& src) noexcept
write_result(D* dst, const T* src, dimn_t count) noexcept
{
try {
*dst = D(src);
for (dimn_t i = 0; i < count; ++i) { dst[i] = D(src[i]); }
} catch (...) {
return false;
}
Expand All @@ -65,93 +66,75 @@ write_result(D* dst, const T& src) noexcept

template <typename D, typename T>
static inline enable_if_t<!is_constructible<D, const T&>::value, bool>
write_result(D*, const T&) noexcept
write_result(D*, const T*, dimn_t) noexcept
{
return false;
}

template <typename D>
static inline bool write_result(D* dst, const rational_poly_scalar& value)
noexcept {
static inline bool write_single_poly(D* dst, const rational_poly_scalar& value)
{
if (value.empty()) {
write_result(dst, 0);
auto tmp = 0;
write_result(dst, &tmp, 1);
return true;
}

if (value.size() == 1) {
auto kv = value.begin();
if (kv->key() == monomial()) {
// Try converting via rational
return write_result(dst, kv->value());
return write_result(dst, &kv->value(), 1);
}
}

return false;
}


template <typename T>
template <typename D>
static inline bool
convert_impl(void* dst, const TypeInfo& info, const T& src) noexcept
write_result(D* dst, const rational_poly_scalar* value, dimn_t count) noexcept
{
switch (info.code) {
case devices::TypeCode::Int:
switch (info.bytes) {
case 1: return write_result((int8_t*) dst, src);
case 2: return write_result((int16_t*) dst, src);
case 4: return write_result((int32_t*) dst, src);
case 8: return write_result((int64_t*) dst, src);
}
break;
case devices::TypeCode::UInt:
switch (info.bytes) {
case 1: return write_result((uint8_t*) dst, src);
case 2: return write_result((uint16_t*) dst, src);
case 4: return write_result((uint32_t*) dst, src);
case 8: return write_result((uint64_t*) dst, src);
}
break;
case devices::TypeCode::Float:
switch (info.bytes) {
case 2: return write_result((half*) dst, src);
case 4: return write_result((float*) dst, src);
case 8: return write_result((double*) dst, src);
}
break;
case devices::TypeCode::OpaqueHandle: break; // not supported
case devices::TypeCode::BFloat:
if (info.bytes == 2) {
return write_result((bfloat16*) dst, src);
}
break;
case devices::TypeCode::Complex:
// TODO: implement complex conversions
break;
case devices::TypeCode::Bool: break; // not supported
case devices::TypeCode::ArbitraryPrecisionInt: break;
case devices::TypeCode::ArbitraryPrecisionUInt: break;
case devices::TypeCode::ArbitraryPrecisionFloat: break;
case devices::TypeCode::ArbitraryPrecisionComplex: break;
case devices::TypeCode::Rational:
// later we might actually have a fixed precision rational.
case devices::TypeCode::ArbitraryPrecisionRational:
return write_result((rational_scalar_type*) dst, src);
case devices::TypeCode::APRationalPolynomial:
return write_result((rational_poly_scalar*) dst, src);
try {
for (dimn_t i = 0; i < count; ++i) {
if (!write_single_poly(dst++, value[i])) { return false; }
}
} catch (...) {
return false;
}
return false;
}

// For now, just cheat with half and bfloat16 and cast them to floats
static inline bool
convert_impl(void* dst, const TypeInfo& info, const half& src) noexcept
{
return convert_impl(dst, info, float(src));
return true;
}

static inline bool
convert_impl(void* dst, const TypeInfo& info, const bfloat16& src) noexcept
template <typename T>
static inline bool convert_impl(
void* dst,
const TypeInfo& info,
const T* src,
dimn_t count
) noexcept
{
return convert_impl(dst, info, float(src));
#define X(TP) return write_result((T*) dst, src, count)
DO_FOR_EACH_X(info)
#undef X
return false;
}
//
//// For now, just cheat with half and bfloat16 and cast them to floats
// static inline bool
// convert_impl(void* dst, const TypeInfo& info, const half* src, dimn_t count)
// noexcept
//{
// return convert_impl(dst, info, src, count);
// }
//
// static inline bool
// convert_impl(void* dst, const TypeInfo& info, const bfloat16* src, dimn_t
// count)
// noexcept
//{
// return convert_impl(dst, info, src, count);
// }

bool scalars::dtl::scalar_convert_copy(
void* dst,
Expand All @@ -160,141 +143,22 @@ bool scalars::dtl::scalar_convert_copy(
) noexcept
{
auto src_info = src.type_info();
switch (src_info.code) {
case devices::TypeCode::Int:
switch (src_info.bytes) {
case 1:
return convert_impl(dst, dst_type, src.as_type<int8_t>());
case 2:
return convert_impl(dst, dst_type, src.as_type<int16_t>());
case 4:
return convert_impl(dst, dst_type, src.as_type<int32_t>());
case 8:
return convert_impl(dst, dst_type, src.as_type<int64_t>());
}
break;
case devices::TypeCode::UInt:
switch (src_info.bytes) {
case 1:
return convert_impl(dst, dst_type, src.as_type<uint8_t>());
case 2:
return convert_impl(dst, dst_type, src.as_type<uint16_t>());
case 4:
return convert_impl(dst, dst_type, src.as_type<uint32_t>());
case 8:
return convert_impl(dst, dst_type, src.as_type<uint64_t>());
}
break;
case devices::TypeCode::Float:
switch (src_info.bytes) {
case 2: return convert_impl(dst, dst_type, src.as_type<half>());
case 4:
return convert_impl(dst, dst_type, src.as_type<float>());
case 8:
return convert_impl(dst, dst_type, src.as_type<double>());
}
break;
case devices::TypeCode::OpaqueHandle: break;
case devices::TypeCode::BFloat:
if (src_info.bytes == 2) {
return convert_impl(dst, dst_type, src.as_type<bfloat16>());
}
break;
case devices::TypeCode::Complex: break;
case devices::TypeCode::Bool: break;
case devices::TypeCode::ArbitraryPrecision: break;
case devices::TypeCode::ArbitraryPrecisionUInt: break;
case devices::TypeCode::ArbitraryPrecisionFloat: break;
case devices::TypeCode::ArbitraryPrecisionComplex: break;
case devices::TypeCode::Rational:
// Later we might have a fixed precision rational.
case devices::TypeCode::ArbitraryPrecisionRational:
return convert_impl(
dst,
dst_type,
src.as_type<rational_scalar_type>()
);
case devices::TypeCode::Polynomial:
return convert_impl(
dst,
dst_type,
src.as_type<rational_poly_scalar>()
);
}

#define X(TP) return convert_impl(dst, dst_type, (const TP*) src.pointer(), 1)
DO_FOR_EACH_X(src_info)
#undef X
return false;
}



bool rpy::scalars::dtl::scalar_convert_copy(
void* dst,
devices::TypeInfo dst_type,
const void* src,
devices::TypeInfo src_type
devices::TypeInfo src_type,
dimn_t count
) noexcept
{
switch (src_type.code) {
case devices::TypeCode::Int:
switch (src_type.bytes) {
case 1:
return convert_impl(dst, dst_type, *((int8_t*) src));
case 2:
return convert_impl(dst, dst_type, *((int16_t*) src));
case 4:
return convert_impl(dst, dst_type, *((int32_t*) src));
case 8:
return convert_impl(dst, dst_type, *((int64_t*) src));
}
break;
case devices::TypeCode::UInt:
switch (src_type.bytes) {
case 1:
return convert_impl(dst, dst_type, *((uint8_t*) src));
case 2:
return convert_impl(dst, dst_type, *((uint16_t*) src));
case 4:
return convert_impl(dst, dst_type, *((uint32_t*) src));
case 8:
return convert_impl(dst, dst_type, *((uint64_t*) src));
}
break;
case devices::TypeCode::Float:
switch (src_type.bytes) {
case 2: return convert_impl(dst, dst_type, *((half*) src));
case 4:
return convert_impl(dst, dst_type, *((float*) src));
case 8:
return convert_impl(dst, dst_type, *((double*) src));
}
break;
case devices::TypeCode::OpaqueHandle: break;
case devices::TypeCode::BFloat:
if (src_type.bytes == 2) {
return convert_impl(dst, dst_type, *((bfloat16*) src));
}
break;
case devices::TypeCode::Complex: break;
case devices::TypeCode::Bool: break;
case devices::TypeCode::ArbitraryPrecision: break;
case devices::TypeCode::ArbitraryPrecisionUInt: break;
case devices::TypeCode::ArbitraryPrecisionFloat: break;
case devices::TypeCode::ArbitraryPrecisionComplex: break;
case devices::TypeCode::Rational:
// Later we might have a fixed precision rational.
case devices::TypeCode::ArbitraryPrecisionRational:
return convert_impl(
dst,
dst_type,
*((rational_scalar_type*) src)
);
case devices::TypeCode::Polynomial:
return convert_impl(
dst,
dst_type,
*((rational_poly_scalar*) src)
);
}

#define X(TP) return convert_impl<TP>(dst, dst_type, (const TP*) src, count)
DO_FOR_EACH_X(src_type)
#undef X
return false;
}
4 changes: 2 additions & 2 deletions scalars/src/scalar/casts.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ namespace rpy {
namespace scalars {
namespace dtl {


bool scalar_convert_copy(
void* dst,
devices::TypeInfo dst_type,
const void* src,
devices::TypeInfo src_type
devices::TypeInfo src_type,
dimn_t count=1
) noexcept;


Expand Down
Loading

0 comments on commit 7ebdd30

Please sign in to comment.