Skip to content

Commit

Permalink
Add 'equivalent' method to Type (facebookincubator#1740)
Browse files Browse the repository at this point in the history
Summary:
The `kindEquals` method only checks if the `TypeKind` of two types matches recursively.
However, certain types like DecimalType and FixedSizedArrayType require their parameters
to match as well. This type of equivalence is needed by the SignatureBinder.
A new `equivalent` method has been added to Type for this requirement.
Example: Two FixedSizedArrayTypes are equivalent only if their lengths are equal.
Two DecimalTypes are equivalent only if their precision/scale are the same.

Pull Request resolved: facebookincubator#1740

Reviewed By: pedroerp

Differential Revision: D36986040

Pulled By: mbasmanova

fbshipit-source-id: 495cbe1c01aa87aea47c1c9d4f891e3cc784f4b2
  • Loading branch information
majetideepak authored and liushengxuan committed Jul 1, 2022
1 parent a5729d8 commit ab9a45c
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 141 deletions.
4 changes: 2 additions & 2 deletions velox/expression/SignatureBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ bool SignatureBinder::tryBind() {
if (actualTypes_.size() > formalArgsCnt) {
auto& type = actualTypes_[formalArgsCnt - 1];
for (auto i = formalArgsCnt; i < actualTypes_.size(); i++) {
if (!type->kindEquals(actualTypes_[i]) &&
if (!type->equivalent(*actualTypes_[i]) &&
actualTypes_[i]->kind() != TypeKind::UNKNOWN) {
return false;
}
Expand Down Expand Up @@ -98,7 +98,7 @@ bool SignatureBinder::tryBind(
return true;
}

return it->second->kindEquals(actualType);
return it->second->equivalent(*actualType);
}

TypePtr SignatureBinder::tryResolveType(
Expand Down
25 changes: 18 additions & 7 deletions velox/expression/tests/SignatureBinderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ void testSignatureBinder(
ASSERT_TRUE(expectedReturnType->kindEquals(returnType));
}

void assertCannotResolve(
const std::shared_ptr<exec::FunctionSignature>& signature,
const std::vector<TypePtr>& actualTypes) {
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}

TEST(SignatureBinderTest, generics) {
// array(T), T -> boolean
{
Expand All @@ -41,6 +48,17 @@ TEST(SignatureBinderTest, generics) {
.build();

testSignatureBinder(signature, {ARRAY(BIGINT()), BIGINT()}, BOOLEAN());
testSignatureBinder(
signature, {ARRAY(DECIMAL(20, 3)), DECIMAL(20, 3)}, BOOLEAN());
assertCannotResolve(signature, {ARRAY(DECIMAL(20, 3)), DECIMAL(20, 4)});
testSignatureBinder(
signature,
{ARRAY(FIXED_SIZE_ARRAY(20, BIGINT())), FIXED_SIZE_ARRAY(20, BIGINT())},
BOOLEAN());
assertCannotResolve(
signature,
{ARRAY(FIXED_SIZE_ARRAY(20, BIGINT())),
FIXED_SIZE_ARRAY(10, BIGINT())});
}

// array(array(T)), array(T) -> boolean
Expand Down Expand Up @@ -146,13 +164,6 @@ TEST(SignatureBinderTest, variableArity) {
}
}

void assertCannotResolve(
const std::shared_ptr<exec::FunctionSignature>& signature,
const std::vector<TypePtr>& actualTypes) {
exec::SignatureBinder binder(*signature, actualTypes);
ASSERT_FALSE(binder.tryBind());
}

TEST(SignatureBinderTest, unresolvable) {
// integer -> varchar
{
Expand Down
47 changes: 39 additions & 8 deletions velox/type/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ const std::shared_ptr<const Type>& ArrayType::childAt(uint32_t idx) const {
ArrayType::ArrayType(std::shared_ptr<const Type> child)
: child_{std::move(child)} {}

bool ArrayType::operator==(const Type& other) const {
bool ArrayType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand Down Expand Up @@ -230,6 +230,20 @@ std::string FixedSizeArrayType::toString() const {
return ss.str();
}

bool FixedSizeArrayType::equivalent(const Type& other) const {
if (!ArrayType::equivalent(other)) {
return false;
}
auto otherFixedSizeArray = dynamic_cast<const FixedSizeArrayType*>(&other);
if (!otherFixedSizeArray) {
return false;
}
if (fixedElementsWidth() != otherFixedSizeArray->fixedElementsWidth()) {
return false;
}
return true;
}

const std::shared_ptr<const Type>& MapType::childAt(uint32_t idx) const {
if (idx == 0) {
return keyType();
Expand Down Expand Up @@ -329,7 +343,7 @@ std::optional<uint32_t> RowType::getChildIdxIfExists(
return std::nullopt;
}

bool RowType::operator==(const Type& other) const {
bool RowType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand All @@ -341,11 +355,21 @@ bool RowType::operator==(const Type& other) const {
return false;
}
for (size_t i = 0; i < size(); ++i) {
// todo: case sensitivity
if (nameOf(i) != otherTyped.nameOf(i)) {
if (*childAt(i) != *otherTyped.childAt(i)) {
return false;
}
if (*childAt(i) != *otherTyped.childAt(i)) {
}
return true;
}

bool RowType::operator==(const Type& other) const {
if (!this->equivalent(other)) {
return false;
}
auto& otherTyped = other.asRow();
for (size_t i = 0; i < size(); ++i) {
// todo: case sensitivity
if (nameOf(i) != otherTyped.nameOf(i)) {
return false;
}
}
Expand Down Expand Up @@ -411,7 +435,7 @@ bool Type::kindEquals(const std::shared_ptr<const Type>& other) const {
return true;
}

bool MapType::operator==(const Type& other) const {
bool MapType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand All @@ -422,7 +446,7 @@ bool MapType::operator==(const Type& other) const {
return *keyType_ == *otherMap.keyType_ && *valueType_ == *otherMap.valueType_;
}

bool FunctionType::operator==(const Type& other) const {
bool FunctionType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
Expand Down Expand Up @@ -450,13 +474,20 @@ folly::dynamic FunctionType::serialize() const {
OpaqueType::OpaqueType(const std::type_index& typeIndex)
: typeIndex_(typeIndex) {}

bool OpaqueType::operator==(const Type& other) const {
bool OpaqueType::equivalent(const Type& other) const {
if (&other == this) {
return true;
}
if (other.kind() != TypeKind::OPAQUE) {
return false;
}
return true;
}

bool OpaqueType::operator==(const Type& other) const {
if (!this->equivalent(other)) {
return false;
}
auto& otherTyped = *reinterpret_cast<const OpaqueType*>(&other);
return typeIndex_ == otherTyped.typeIndex_;
}
Expand Down
Loading

0 comments on commit ab9a45c

Please sign in to comment.