Skip to content

Commit

Permalink
make singular unions castable to their underlying type (flyteorg#599)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Rammer <daniel@union.ai>
Signed-off-by: Gopal K. Vashishtha <gvashishtha@anduril.com>
  • Loading branch information
hamersaw authored and gvashishtha committed Aug 5, 2023
1 parent a1539da commit 719a290
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pkg/compiler/validators/typing.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,15 @@ func getTypeChecker(t *flyte.LiteralType) typeChecker {
}

func AreTypesCastable(upstreamType, downstreamType *flyte.LiteralType) bool {
return getTypeChecker(downstreamType).CastsFrom(upstreamType)
typeChecker := getTypeChecker(downstreamType)

// if upstream is a singular union we check if the downstream type is castable from the union variant
if upstreamType.GetUnionType() != nil && len(upstreamType.GetUnionType().GetVariants()) == 1 {
variants := upstreamType.GetUnionType().GetVariants()
if len(variants) == 1 && typeChecker.CastsFrom(variants[0]) {
return true
}
}

return typeChecker.CastsFrom(upstreamType)
}
26 changes: 26 additions & 0 deletions pkg/compiler/validators/typing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,32 @@ func TestUnionCasting(t *testing.T) {
)
assert.False(t, castable, "Union types can only be cast to a union that contains a superset of variants")
})

t.Run("SingularUnionToUnderlyingType", func(t *testing.T) {
castable := AreTypesCastable(
&core.LiteralType{
Type: &core.LiteralType_UnionType{
UnionType: &core.UnionType{
Variants: []*core.LiteralType{
{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING},
Structure: &core.TypeStructure{
Tag: "string",
},
},
},
},
},
},
&core.LiteralType{
Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING},
Structure: &core.TypeStructure{
Tag: "string",
},
},
)
assert.True(t, castable, "Singular unions should be castable to their underlying type")
})
}

func TestCollectionCasting(t *testing.T) {
Expand Down

0 comments on commit 719a290

Please sign in to comment.