Skip to content

Commit

Permalink
feat(go): Unions include runtime validation (#5403)
Browse files Browse the repository at this point in the history
  • Loading branch information
amckinney authored Dec 12, 2024
1 parent 7c73af2 commit 643ba46
Show file tree
Hide file tree
Showing 60 changed files with 4,514 additions and 146 deletions.
6 changes: 0 additions & 6 deletions generators/go/cmd/fern-go-model/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,6 @@ func TestTime(t *testing.T) {
})

t.Run("union (optional)", func(t *testing.T) {
empty := union.NewUnionWithOptionalTimeFromDate(nil)

emptyBytes, err := json.Marshal(empty)
require.NoError(t, err)
assert.Equal(t, `{"type":"date"}`, string(emptyBytes))

value := union.NewUnionWithOptionalTimeFromDate(&date)

bytes, err := json.Marshal(value)
Expand Down
2 changes: 1 addition & 1 deletion generators/go/internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ func zeroValueForPrimitive(primitive fernir.PrimitiveTypeV1) string {
case fernir.PrimitiveTypeV1DateTime, fernir.PrimitiveTypeV1Date:
return "time.Time{}"
case fernir.PrimitiveTypeV1Uuid:
return "uuid.UUID{}"
return "uuid.Nil"
case fernir.PrimitiveTypeV1Base64:
return "nil"
}
Expand Down
60 changes: 60 additions & 0 deletions generators/go/internal/generator/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ func (t *typeVisitor) VisitUnion(union *ir.UnionTypeDeclaration) error {

// Implement the json.Marshaler interface.
t.writer.P("func (", receiver, " ", t.typeName, ") MarshalJSON() ([]byte, error) {")
t.writer.P("if err := ", receiver, ".validate(); err != nil {")
t.writer.P("return nil, err")
t.writer.P("}")
if t.unionVersion != UnionVersionV1 {
t.writer.P("switch ", receiver, ".", discriminantName, " {")
}
Expand Down Expand Up @@ -659,6 +662,63 @@ func (t *typeVisitor) VisitUnion(union *ir.UnionTypeDeclaration) error {
t.writer.P("}")
t.writer.P()

// Generate the validate method.
t.writer.P("func (", receiver, " *", t.typeName, ") validate() error {")
t.writer.P("if ", receiver, " == nil {")
t.writer.P(`return fmt.Errorf("type %T is nil", `, receiver, ")")
t.writer.P("}")
t.writer.P("var fields []string")
for _, unionType := range union.Types {
var (
isLiteral bool
isOptional bool
date *date
)
if unionType.Shape.SingleProperty != nil {
isLiteral = isLiteralType(unionType.Shape.SingleProperty.Type, t.writer.types)
isOptional = isOptionalType(unionType.Shape.SingleProperty.Type, t.writer.types)
date = maybeDate(unionType.Shape.SingleProperty.Type, isOptional)
}
zeroValue := "nil"
if unionType.Shape.PropertiesType == "singleProperty" {
zeroValue = zeroValueForTypeReference(unionType.Shape.SingleProperty.Type, t.writer.types)
}
unionTypeValue := receiver + "." + unionType.DiscriminantValue.Name.PascalCase.UnsafeName
if isLiteral {
unionTypeValue = receiver + "." + unionType.DiscriminantValue.Name.CamelCase.SafeName
}
if date != nil && !isOptional {
t.writer.P("if !", unionTypeValue, ".IsZero() {")
} else {
t.writer.P("if ", unionTypeValue, " != ", zeroValue, " {")
}
t.writer.P(`fields = append(fields, "`, unionType.DiscriminantValue.WireValue, `")`)
t.writer.P("}")
}
t.writer.P("if len(fields) == 0 {")
t.writer.P("if ", receiver, ".", discriminantName, ` != "" {`)
t.writer.P(`return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", `, receiver, ", ", receiver, ".", discriminantName, ")")
t.writer.P("}")
t.writer.P(`return fmt.Errorf("type %T is empty", `, receiver, ")")
t.writer.P("}")
t.writer.P("if len(fields) > 1 {")
t.writer.P(`return fmt.Errorf("type %T defines values for %s, but only one value is allowed", `, receiver, ", fields)")
t.writer.P("}")
t.writer.P("if ", receiver, ".", discriminantName, ` != "" {`)
t.writer.P("field := fields[0]")
t.writer.P("if ", receiver, ".", discriminantName, " != field {")
t.writer.P("return fmt.Errorf(")
t.writer.P(`"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",`)
t.writer.P(receiver, ", ")
t.writer.P(receiver, ".", discriminantName, ", ")
t.writer.P(receiver, ", ")
t.writer.P(")")
t.writer.P("}")
t.writer.P("}")
t.writer.P("return nil")
t.writer.P("}")
t.writer.P()

return nil
}

Expand Down
42 changes: 41 additions & 1 deletion generators/go/internal/testdata/model/alias/fixtures/imdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type Foo struct {

func (f *Foo) GetId() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Id
}
Expand Down Expand Up @@ -224,6 +224,9 @@ func (u *Union) UnmarshalJSON(data []byte) error {
}

func (u Union) MarshalJSON() ([]byte, error) {
if err := u.validate(); err != nil {
return nil, err
}
switch u.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", u.Type, u)
Expand Down Expand Up @@ -269,6 +272,43 @@ func (u *Union) Accept(visitor UnionVisitor) error {
}
}

func (u *Union) validate() error {
if u == nil {
return fmt.Errorf("type %T is nil", u)
}
var fields []string
if u.FooAlias != nil {
fields = append(fields, "fooAlias")
}
if u.BarAlias != nil {
fields = append(fields, "barAlias")
}
if u.DoubleAlias != 0 {
fields = append(fields, "doubleAlias")
}
if len(fields) == 0 {
if u.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", u, u.Type)
}
return fmt.Errorf("type %T is empty", u)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", u, fields)
}
if u.Type != "" {
field := fields[0]
if u.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
u,
u.Type,
u,
)
}
}
return nil
}

type Unknown = interface{}

type Uuid = uuid.UUID
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (t *Type) GetSeven() time.Time {

func (t *Type) GetEight() uuid.UUID {
if t == nil {
return uuid.UUID{}
return uuid.Nil
}
return t.Eight
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type Foo struct {

func (f *Foo) GetId() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Id
}
Expand Down
68 changes: 68 additions & 0 deletions generators/go/internal/testdata/model/extends/fixtures/imdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ func (n *NestedUnion) UnmarshalJSON(data []byte) error {
}

func (n NestedUnion) MarshalJSON() ([]byte, error) {
if err := n.validate(); err != nil {
return nil, err
}
switch n.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", n.Type, n)
Expand All @@ -283,6 +286,37 @@ func (n *NestedUnion) Accept(visitor NestedUnionVisitor) error {
}
}

func (n *NestedUnion) validate() error {
if n == nil {
return fmt.Errorf("type %T is nil", n)
}
var fields []string
if n.One != nil {
fields = append(fields, "one")
}
if len(fields) == 0 {
if n.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", n, n.Type)
}
return fmt.Errorf("type %T is empty", n)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", n, fields)
}
if n.Type != "" {
field := fields[0]
if n.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
n,
n.Type,
n,
)
}
}
return nil
}

type Union struct {
Type string
Docs string
Expand Down Expand Up @@ -339,6 +373,9 @@ func (u *Union) UnmarshalJSON(data []byte) error {
}

func (u Union) MarshalJSON() ([]byte, error) {
if err := u.validate(); err != nil {
return nil, err
}
switch u.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", u.Type, u)
Expand All @@ -359,3 +396,34 @@ func (u *Union) Accept(visitor UnionVisitor) error {
return visitor.VisitOne(u.One)
}
}

func (u *Union) validate() error {
if u == nil {
return fmt.Errorf("type %T is nil", u)
}
var fields []string
if u.One != nil {
fields = append(fields, "one")
}
if len(fields) == 0 {
if u.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", u, u.Type)
}
return fmt.Errorf("type %T is empty", u)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", u, fields)
}
if u.Type != "" {
field := fields[0]
if u.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
u,
u.Type,
u,
)
}
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (f *Foo) GetBar() *bar.Bar {

func (f *Foo) GetUuid() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Uuid
}
Expand Down
40 changes: 40 additions & 0 deletions generators/go/internal/testdata/model/ir/fixtures/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ func (a *AuthScheme) UnmarshalJSON(data []byte) error {
}

func (a AuthScheme) MarshalJSON() ([]byte, error) {
if err := a.validate(); err != nil {
return nil, err
}
switch a.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", a.Type, a)
Expand Down Expand Up @@ -176,6 +179,43 @@ func (a *AuthScheme) Accept(visitor AuthSchemeVisitor) error {
}
}

func (a *AuthScheme) validate() error {
if a == nil {
return fmt.Errorf("type %T is nil", a)
}
var fields []string
if a.Bearer != nil {
fields = append(fields, "bearer")
}
if a.Basic != nil {
fields = append(fields, "basic")
}
if a.Header != nil {
fields = append(fields, "header")
}
if len(fields) == 0 {
if a.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", a, a.Type)
}
return fmt.Errorf("type %T is empty", a)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", a, fields)
}
if a.Type != "" {
field := fields[0]
if a.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
a,
a.Type,
a,
)
}
}
return nil
}

type AuthSchemesRequirement string

const (
Expand Down
37 changes: 37 additions & 0 deletions generators/go/internal/testdata/model/ir/fixtures/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ func (e *Environments) UnmarshalJSON(data []byte) error {
}

func (e Environments) MarshalJSON() ([]byte, error) {
if err := e.validate(); err != nil {
return nil, err
}
switch e.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", e.Type, e)
Expand All @@ -151,6 +154,40 @@ func (e *Environments) Accept(visitor EnvironmentsVisitor) error {
}
}

func (e *Environments) validate() error {
if e == nil {
return fmt.Errorf("type %T is nil", e)
}
var fields []string
if e.SingleBaseUrl != nil {
fields = append(fields, "singleBaseUrl")
}
if e.MultipleBaseUrls != nil {
fields = append(fields, "multipleBaseUrls")
}
if len(fields) == 0 {
if e.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", e, e.Type)
}
return fmt.Errorf("type %T is empty", e)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", e, fields)
}
if e.Type != "" {
field := fields[0]
if e.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
e,
e.Type,
e,
)
}
}
return nil
}

type EnvironmentsConfig struct {
DefaultEnvironment *EnvironmentId `json:"defaultEnvironment,omitempty" url:"defaultEnvironment,omitempty"`
Environments *Environments `json:"environments,omitempty" url:"environments,omitempty"`
Expand Down
Loading

0 comments on commit 643ba46

Please sign in to comment.