Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): Unions include runtime validation #5403

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions generators/go/cmd/fern-go-model/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,6 @@ func TestTime(t *testing.T) {
})

t.Run("union (optional)", func(t *testing.T) {
empty := union.NewUnionWithOptionalTimeFromDate(nil)

emptyBytes, err := json.Marshal(empty)
require.NoError(t, err)
assert.Equal(t, `{"type":"date"}`, string(emptyBytes))

value := union.NewUnionWithOptionalTimeFromDate(&date)

bytes, err := json.Marshal(value)
Expand Down
2 changes: 1 addition & 1 deletion generators/go/internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,7 @@ func zeroValueForPrimitive(primitive fernir.PrimitiveTypeV1) string {
case fernir.PrimitiveTypeV1DateTime, fernir.PrimitiveTypeV1Date:
return "time.Time{}"
case fernir.PrimitiveTypeV1Uuid:
return "uuid.UUID{}"
return "uuid.Nil"
case fernir.PrimitiveTypeV1Base64:
return "nil"
}
Expand Down
60 changes: 60 additions & 0 deletions generators/go/internal/generator/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ func (t *typeVisitor) VisitUnion(union *ir.UnionTypeDeclaration) error {

// Implement the json.Marshaler interface.
t.writer.P("func (", receiver, " ", t.typeName, ") MarshalJSON() ([]byte, error) {")
t.writer.P("if err := ", receiver, ".validate(); err != nil {")
t.writer.P("return nil, err")
t.writer.P("}")
if t.unionVersion != UnionVersionV1 {
t.writer.P("switch ", receiver, ".", discriminantName, " {")
}
Expand Down Expand Up @@ -659,6 +662,63 @@ func (t *typeVisitor) VisitUnion(union *ir.UnionTypeDeclaration) error {
t.writer.P("}")
t.writer.P()

// Generate the validate method.
t.writer.P("func (", receiver, " *", t.typeName, ") validate() error {")
t.writer.P("if ", receiver, " == nil {")
t.writer.P(`return fmt.Errorf("type %T is nil", `, receiver, ")")
t.writer.P("}")
t.writer.P("var fields []string")
for _, unionType := range union.Types {
var (
isLiteral bool
isOptional bool
date *date
)
if unionType.Shape.SingleProperty != nil {
isLiteral = isLiteralType(unionType.Shape.SingleProperty.Type, t.writer.types)
isOptional = isOptionalType(unionType.Shape.SingleProperty.Type, t.writer.types)
date = maybeDate(unionType.Shape.SingleProperty.Type, isOptional)
}
zeroValue := "nil"
if unionType.Shape.PropertiesType == "singleProperty" {
zeroValue = zeroValueForTypeReference(unionType.Shape.SingleProperty.Type, t.writer.types)
}
unionTypeValue := receiver + "." + unionType.DiscriminantValue.Name.PascalCase.UnsafeName
if isLiteral {
unionTypeValue = receiver + "." + unionType.DiscriminantValue.Name.CamelCase.SafeName
}
if date != nil && !isOptional {
t.writer.P("if !", unionTypeValue, ".IsZero() {")
} else {
t.writer.P("if ", unionTypeValue, " != ", zeroValue, " {")
}
t.writer.P(`fields = append(fields, "`, unionType.DiscriminantValue.WireValue, `")`)
t.writer.P("}")
}
t.writer.P("if len(fields) == 0 {")
t.writer.P("if ", receiver, ".", discriminantName, ` != "" {`)
t.writer.P(`return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", `, receiver, ", ", receiver, ".", discriminantName, ")")
t.writer.P("}")
t.writer.P(`return fmt.Errorf("type %T is empty", `, receiver, ")")
t.writer.P("}")
t.writer.P("if len(fields) > 1 {")
t.writer.P(`return fmt.Errorf("type %T defines values for %s, but only one value is allowed", `, receiver, ", fields)")
t.writer.P("}")
t.writer.P("if ", receiver, ".", discriminantName, ` != "" {`)
t.writer.P("field := fields[0]")
t.writer.P("if ", receiver, ".", discriminantName, " != field {")
t.writer.P("return fmt.Errorf(")
t.writer.P(`"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",`)
t.writer.P(receiver, ", ")
t.writer.P(receiver, ".", discriminantName, ", ")
t.writer.P(receiver, ", ")
t.writer.P(")")
t.writer.P("}")
t.writer.P("}")
t.writer.P("return nil")
t.writer.P("}")
t.writer.P()

return nil
}

Expand Down
42 changes: 41 additions & 1 deletion generators/go/internal/testdata/model/alias/fixtures/imdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type Foo struct {

func (f *Foo) GetId() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Id
}
Expand Down Expand Up @@ -224,6 +224,9 @@ func (u *Union) UnmarshalJSON(data []byte) error {
}

func (u Union) MarshalJSON() ([]byte, error) {
if err := u.validate(); err != nil {
return nil, err
}
switch u.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", u.Type, u)
Expand Down Expand Up @@ -269,6 +272,43 @@ func (u *Union) Accept(visitor UnionVisitor) error {
}
}

func (u *Union) validate() error {
if u == nil {
return fmt.Errorf("type %T is nil", u)
}
var fields []string
if u.FooAlias != nil {
fields = append(fields, "fooAlias")
}
if u.BarAlias != nil {
fields = append(fields, "barAlias")
}
if u.DoubleAlias != 0 {
fields = append(fields, "doubleAlias")
}
if len(fields) == 0 {
if u.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", u, u.Type)
}
return fmt.Errorf("type %T is empty", u)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", u, fields)
}
if u.Type != "" {
field := fields[0]
if u.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
u,
u.Type,
u,
)
}
}
return nil
}

type Unknown = interface{}

type Uuid = uuid.UUID
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (t *Type) GetSeven() time.Time {

func (t *Type) GetEight() uuid.UUID {
if t == nil {
return uuid.UUID{}
return uuid.Nil
}
return t.Eight
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ type Foo struct {

func (f *Foo) GetId() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Id
}
Expand Down
68 changes: 68 additions & 0 deletions generators/go/internal/testdata/model/extends/fixtures/imdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,9 @@ func (n *NestedUnion) UnmarshalJSON(data []byte) error {
}

func (n NestedUnion) MarshalJSON() ([]byte, error) {
if err := n.validate(); err != nil {
return nil, err
}
switch n.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", n.Type, n)
Expand All @@ -283,6 +286,37 @@ func (n *NestedUnion) Accept(visitor NestedUnionVisitor) error {
}
}

func (n *NestedUnion) validate() error {
if n == nil {
return fmt.Errorf("type %T is nil", n)
}
var fields []string
if n.One != nil {
fields = append(fields, "one")
}
if len(fields) == 0 {
if n.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", n, n.Type)
}
return fmt.Errorf("type %T is empty", n)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", n, fields)
}
if n.Type != "" {
field := fields[0]
if n.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
n,
n.Type,
n,
)
}
}
return nil
}

type Union struct {
Type string
Docs string
Expand Down Expand Up @@ -339,6 +373,9 @@ func (u *Union) UnmarshalJSON(data []byte) error {
}

func (u Union) MarshalJSON() ([]byte, error) {
if err := u.validate(); err != nil {
return nil, err
}
switch u.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", u.Type, u)
Expand All @@ -359,3 +396,34 @@ func (u *Union) Accept(visitor UnionVisitor) error {
return visitor.VisitOne(u.One)
}
}

func (u *Union) validate() error {
if u == nil {
return fmt.Errorf("type %T is nil", u)
}
var fields []string
if u.One != nil {
fields = append(fields, "one")
}
if len(fields) == 0 {
if u.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", u, u.Type)
}
return fmt.Errorf("type %T is empty", u)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", u, fields)
}
if u.Type != "" {
field := fields[0]
if u.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
u,
u.Type,
u,
)
}
}
return nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (f *Foo) GetBar() *bar.Bar {

func (f *Foo) GetUuid() uuid.UUID {
if f == nil {
return uuid.UUID{}
return uuid.Nil
}
return f.Uuid
}
Expand Down
40 changes: 40 additions & 0 deletions generators/go/internal/testdata/model/ir/fixtures/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ func (a *AuthScheme) UnmarshalJSON(data []byte) error {
}

func (a AuthScheme) MarshalJSON() ([]byte, error) {
if err := a.validate(); err != nil {
return nil, err
}
switch a.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", a.Type, a)
Expand Down Expand Up @@ -176,6 +179,43 @@ func (a *AuthScheme) Accept(visitor AuthSchemeVisitor) error {
}
}

func (a *AuthScheme) validate() error {
if a == nil {
return fmt.Errorf("type %T is nil", a)
}
var fields []string
if a.Bearer != nil {
fields = append(fields, "bearer")
}
if a.Basic != nil {
fields = append(fields, "basic")
}
if a.Header != nil {
fields = append(fields, "header")
}
if len(fields) == 0 {
if a.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", a, a.Type)
}
return fmt.Errorf("type %T is empty", a)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", a, fields)
}
if a.Type != "" {
field := fields[0]
if a.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
a,
a.Type,
a,
)
}
}
return nil
}

type AuthSchemesRequirement string

const (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ func (e *Environments) UnmarshalJSON(data []byte) error {
}

func (e Environments) MarshalJSON() ([]byte, error) {
if err := e.validate(); err != nil {
return nil, err
}
switch e.Type {
default:
return nil, fmt.Errorf("invalid type %s in %T", e.Type, e)
Expand All @@ -151,6 +154,40 @@ func (e *Environments) Accept(visitor EnvironmentsVisitor) error {
}
}

func (e *Environments) validate() error {
if e == nil {
return fmt.Errorf("type %T is nil", e)
}
var fields []string
if e.SingleBaseUrl != nil {
fields = append(fields, "singleBaseUrl")
}
if e.MultipleBaseUrls != nil {
fields = append(fields, "multipleBaseUrls")
}
if len(fields) == 0 {
if e.Type != "" {
return fmt.Errorf("type %T defines a discriminant set to %q but the field is not set", e, e.Type)
}
return fmt.Errorf("type %T is empty", e)
}
if len(fields) > 1 {
return fmt.Errorf("type %T defines values for %s, but only one value is allowed", e, fields)
}
if e.Type != "" {
field := fields[0]
if e.Type != field {
return fmt.Errorf(
"type %T defines a discriminant set to %q, but it does not match the %T field; either remove or update the discriminant to match",
e,
e.Type,
e,
)
}
}
return nil
}

type EnvironmentsConfig struct {
DefaultEnvironment *EnvironmentId `json:"defaultEnvironment,omitempty" url:"defaultEnvironment,omitempty"`
Environments *Environments `json:"environments,omitempty" url:"environments,omitempty"`
Expand Down
Loading
Loading