Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'equivalent' method to Type #1740

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions velox/expression/SignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool SignatureBinder::tryBind() {
if (actualTypes_.size() > formalArgsCnt) {
auto& type = actualTypes_[formalArgsCnt - 1];
for (auto i = formalArgsCnt; i < actualTypes_.size(); i++) {
if (!type->kindEquals(actualTypes_[i]) &&
if (!type->equivalent(*actualTypes_[i]) &&
actualTypes_[i]->kind() != TypeKind::UNKNOWN) {
return false;
}
Expand Down Expand Up @@ -98,7 +98,7 @@ bool SignatureBinder::tryBind(
return true;
}

return it->second->kindEquals(actualType);
return it->second->equivalent(*actualType);
}

TypePtr SignatureBinder::tryResolveType(
Expand Down
25 changes: 18 additions & 7 deletions velox/expression/tests/SignatureBinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ void testSignatureBinder(
ASSERT_TRUE(expectedReturnType->kindEquals(returnType));
}

void assertCannotResolve(
const std::shared_ptr<exec::FunctionSignature>& signature,
const std::vector<TypePtr>& actualTypes) {
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}

TEST(SignatureBinderTest, generics) {
// array(T), T -> boolean
{
Expand All @@ -41,6 +48,17 @@ TEST(SignatureBinderTest, generics) {
.build();

testSignatureBinder(signature, {ARRAY(BIGINT()), BIGINT()}, BOOLEAN());
testSignatureBinder(
signature, {ARRAY(DECIMAL(20, 3)), DECIMAL(20, 3)}, BOOLEAN());
assertCannotResolve(signature, {ARRAY(DECIMAL(20, 3)), DECIMAL(20, 4)});
testSignatureBinder(
signature,
{ARRAY(FIXED_SIZE_ARRAY(20, BIGINT())), FIXED_SIZE_ARRAY(20, BIGINT())},
BOOLEAN());
assertCannotResolve(
signature,
{ARRAY(FIXED_SIZE_ARRAY(20, BIGINT())),
FIXED_SIZE_ARRAY(10, BIGINT())});
}

// array(array(T)), array(T) -> boolean
Expand Down Expand Up @@ -146,13 +164,6 @@ TEST(SignatureBinderTest, variableArity) {
}
}

void assertCannotResolve(
const std::shared_ptr<exec::FunctionSignature>& signature,
const std::vector<TypePtr>& actualTypes) {
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}

TEST(SignatureBinderTest, unresolvable) {
// integer -> varchar
{
Expand Down
47 changes: 39 additions & 8 deletions velox/type/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ const std::shared_ptr<const Type>& ArrayType::childAt(uint32_t idx) const {
ArrayType::ArrayType(std::shared_ptr<const Type> child)
: child_{std::move(child)} {}

bool ArrayType::operator==(const Type& other) const {
bool ArrayType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand Down Expand Up @@ -230,6 +230,20 @@ std::string FixedSizeArrayType::toString() const {
return ss.str();
}

bool FixedSizeArrayType::equivalent(const Type& other) const {
if (!ArrayType::equivalent(other)) {
return false;
}
auto otherFixedSizeArray = dynamic_cast<const FixedSizeArrayType*>(&other);
if (!otherFixedSizeArray) {
return false;
}
if (fixedElementsWidth() != otherFixedSizeArray->fixedElementsWidth()) {
return false;
}
return true;
}

const std::shared_ptr<const Type>& MapType::childAt(uint32_t idx) const {
if (idx == 0) {
return keyType();
Expand Down Expand Up @@ -329,7 +343,7 @@ std::optional<uint32_t> RowType::getChildIdxIfExists(
return std::nullopt;
}

bool RowType::operator==(const Type& other) const {
bool RowType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand All @@ -341,11 +355,21 @@ bool RowType::operator==(const Type& other) const {
return false;
}
for (size_t i = 0; i < size(); ++i) {
// todo: case sensitivity
if (nameOf(i) != otherTyped.nameOf(i)) {
if (*childAt(i) != *otherTyped.childAt(i)) {
return false;
}
if (*childAt(i) != *otherTyped.childAt(i)) {
}
return true;
}

bool RowType::operator==(const Type& other) const {
if (!this->equivalent(other)) {
return false;
}
auto& otherTyped = other.asRow();
for (size_t i = 0; i < size(); ++i) {
// todo: case sensitivity
if (nameOf(i) != otherTyped.nameOf(i)) {
return false;
}
}
Expand Down Expand Up @@ -411,7 +435,7 @@ bool Type::kindEquals(const std::shared_ptr<const Type>& other) const {
return true;
}

bool MapType::operator==(const Type& other) const {
bool MapType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand All @@ -422,7 +446,7 @@ bool MapType::operator==(const Type& other) const {
return *keyType_ == *otherMap.keyType_ && *valueType_ == *otherMap.valueType_;
}

bool FunctionType::operator==(const Type& other) const {
bool FunctionType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand Down Expand Up @@ -450,13 +474,20 @@ folly::dynamic FunctionType::serialize() const {
OpaqueType::OpaqueType(const std::type_index& typeIndex)
: typeIndex_(typeIndex) {}

bool OpaqueType::operator==(const Type& other) const {
bool OpaqueType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
if (other.kind() != TypeKind::OPAQUE) {
return false;
}
return true;
}

bool OpaqueType::operator==(const Type& other) const {
if (!this->equivalent(other)) {
return false;
}
auto& otherTyped = *reinterpret_cast<const OpaqueType*>(&other);
return typeIndex_ == otherTyped.typeIndex_;
}
Expand Down
124 changes: 71 additions & 53 deletions velox/type/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,19 @@ class Type : public Tree<const std::shared_ptr<const Type>>,

virtual std::string toString() const = 0;

virtual bool operator==(const Type& other) const = 0;
// Types are weakly matched.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use /// for method and class-level comments.

// Examples: Two RowTypes are equivalent if the children types are equivalent,
// but the children names could be different. Two OpaqueTypes are equivalent
// if the typeKind matches, but the typeIndex could be different.
virtual bool equivalent(const Type& other) const = 0;

// Types are strongly matched.
// Examples: Two RowTypes are == if the children types and the children names
// are same. Two OpaqueTypes are == if the typeKind and the typeIndex are
// same. Same as equivalent for most types except for Row, Opaque types.
virtual bool operator==(const Type& other) const {
return this->equivalent(other);
}

inline bool operator!=(const Type& other) const {
return !(*this == other);
Expand All @@ -492,10 +504,10 @@ class Type : public Tree<const std::shared_ptr<const Type>>,

static std::shared_ptr<const Type> create(const folly::dynamic& obj);

// recursive kind hashing (ignores names)
// recursive kind hashing (uses only typeKind)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: perhaps, fix the comment to start with a capital letter and end with a period. Same below.

size_t hashKind() const;

// recursive kind match (ignores names)
// recursive kind match (uses only typeKind)
bool kindEquals(const std::shared_ptr<const Type>& other) const;

template <TypeKind KIND, typename... CHILDREN>
Expand Down Expand Up @@ -556,8 +568,49 @@ class TypeBase : public Type {
}
};

using ShortDecimalType = DecimalType<TypeKind::SHORT_DECIMAL>;
using LongDecimalType = DecimalType<TypeKind::LONG_DECIMAL>;
template <TypeKind KIND>
class ScalarType : public TypeBase<KIND> {
public:
uint32_t size() const override {
return 0;
}

const std::shared_ptr<const Type>& childAt(uint32_t) const override {
throw std::invalid_argument{"scalar type has no children"};
}

std::string toString() const override {
return TypeTraits<KIND>::name;
}

size_t cppSizeInBytes() const override {
if (TypeTraits<KIND>::isFixedWidth) {
return sizeof(typename TypeTraits<KIND>::NativeType);
}
// TODO: velox throws here for non fixed width types.
return Type::cppSizeInBytes();
}

FOLLY_NOINLINE static const std::shared_ptr<const ScalarType<KIND>> create();

bool equivalent(const Type& other) const override {
return KIND == other.kind();
}

// TODO: velox implementation is in cpp
folly::dynamic serialize() const override {
folly::dynamic obj = folly::dynamic::object;
obj["name"] = "Type";
obj["type"] = TypeTraits<KIND>::name;
return obj;
}
};

template <TypeKind KIND>
const std::shared_ptr<const ScalarType<KIND>> ScalarType<KIND>::create() {
static const auto instance = std::make_shared<const ScalarType<KIND>>();
return instance;
}

/// This class represents the fixed-point numbers.
/// The parameter "precision" represents the number of digits the
Expand All @@ -577,7 +630,7 @@ class DecimalType : public ScalarType<KIND> {
VELOX_CHECK_LE(precision, kMaxPrecision);
}

inline bool operator==(const Type& otherDecimal) const override {
inline bool equivalent(const Type& otherDecimal) const override {
if (this->kind() != otherDecimal.kind()) {
return false;
}
Expand Down Expand Up @@ -611,49 +664,8 @@ class DecimalType : public ScalarType<KIND> {
const uint8_t scale_;
};

template <TypeKind KIND>
class ScalarType : public TypeBase<KIND> {
public:
uint32_t size() const override {
return 0;
}

const std::shared_ptr<const Type>& childAt(uint32_t) const override {
throw std::invalid_argument{"scalar type has no children"};
}

std::string toString() const override {
return TypeTraits<KIND>::name;
}

size_t cppSizeInBytes() const override {
if (TypeTraits<KIND>::isFixedWidth) {
return sizeof(typename TypeTraits<KIND>::NativeType);
}
// TODO: velox throws here for non fixed width types.
return Type::cppSizeInBytes();
}

FOLLY_NOINLINE static const std::shared_ptr<const ScalarType<KIND>> create();

bool operator==(const Type& other) const override {
return KIND == other.kind();
}

// TODO: velox implementation is in cpp
folly::dynamic serialize() const override {
folly::dynamic obj = folly::dynamic::object;
obj["name"] = "Type";
obj["type"] = TypeTraits<KIND>::name;
return obj;
}
};

template <TypeKind KIND>
const std::shared_ptr<const ScalarType<KIND>> ScalarType<KIND>::create() {
static const auto instance = std::make_shared<const ScalarType<KIND>>();
return instance;
}
using ShortDecimalType = DecimalType<TypeKind::SHORT_DECIMAL>;
using LongDecimalType = DecimalType<TypeKind::LONG_DECIMAL>;

class UnknownType : public TypeBase<TypeKind::UNKNOWN> {
public:
Expand All @@ -675,7 +687,7 @@ class UnknownType : public TypeBase<TypeKind::UNKNOWN> {
return 0;
}

bool operator==(const Type& other) const override {
bool equivalent(const Type& other) const override {
return TypeKind::UNKNOWN == other.kind();
}

Expand Down Expand Up @@ -703,7 +715,7 @@ class ArrayType : public TypeBase<TypeKind::ARRAY> {

std::string toString() const override;

bool operator==(const Type& other) const override;
bool equivalent(const Type& other) const override;

folly::dynamic serialize() const override;

Expand Down Expand Up @@ -734,6 +746,8 @@ class FixedSizeArrayType : public ArrayType {
return "FIXED_SIZE_ARRAY";
}

bool equivalent(const Type& other) const override;

std::string toString() const override;

private:
Expand Down Expand Up @@ -762,7 +776,7 @@ class MapType : public TypeBase<TypeKind::MAP> {

const std::shared_ptr<const Type>& childAt(uint32_t idx) const override;

bool operator==(const Type& other) const override;
bool equivalent(const Type& other) const override;

folly::dynamic serialize() const override;

Expand Down Expand Up @@ -798,6 +812,8 @@ class RowType : public TypeBase<TypeKind::ROW> {
return names_.at(idx);
}

bool equivalent(const Type& other) const override;

bool operator==(const Type& other) const override;

std::string toString() const override;
Expand Down Expand Up @@ -847,7 +863,7 @@ class FunctionType : public TypeBase<TypeKind::FUNCTION> {
return children_;
}

bool operator==(const Type& other) const override;
bool equivalent(const Type& other) const override;

std::string toString() const override;

Expand Down Expand Up @@ -884,6 +900,8 @@ class OpaqueType : public TypeBase<TypeKind::OPAQUE> {

std::string toString() const override;

bool equivalent(const Type& other) const override;

bool operator==(const Type& other) const override;

const std::type_index& typeIndex() const {
Expand Down
Loading