Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python DU refinements: provide access to DU case constructors to python code #3558

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions src/Fable.Cli/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

* Remove support for Python 3.9. Add GH testing for Python 3.12 (by @dbrattli)
* Support (un)curry up to 20 arguments (by @MangelMaxime)
* Expose discriminated union constructors to Python code (by @smoothdeveloper).

#### Dart

Expand Down
66 changes: 52 additions & 14 deletions src/Fable.Transforms/Python/Fable2Python.fs
Original file line number Diff line number Diff line change
Expand Up @@ -3688,6 +3688,16 @@ module Util =
let transformUnion (com: IPythonCompiler) ctx (ent: Fable.Entity) (entName: string) classMembers =
let fieldIds = getUnionFieldsAsIdents com ctx ent

let genTypeArgument =
let gen =
getGenericTypeParams [ fieldIds[1].Type ]
|> Set.toList
|> List.tryHead

let ta = Expression.name (gen |> Option.defaultValue "Any")
let id = ident com ctx fieldIds[1]
Arg.arg (id, annotation = ta)

let args, isOptional =
let args =
fieldIds[0]
Expand All @@ -3697,20 +3707,8 @@ module Util =
Arg.arg (id, annotation = ta))
|> List.singleton

let varargs =
fieldIds[1]
|> ident com ctx
|> (fun id ->
let gen =
getGenericTypeParams [ fieldIds[1].Type ]
|> Set.toList
|> List.tryHead

let ta = Expression.name (gen |> Option.defaultValue "Any")
Arg.arg (id, annotation = ta))

let isOptional = Helpers.isOptional fieldIds
Arguments.arguments (args = args, vararg = varargs), isOptional
Arguments.arguments (args = args, vararg = genTypeArgument), isOptional

let body =
[ yield callSuperAsStatement []
Expand Down Expand Up @@ -3751,8 +3749,48 @@ module Util =

Statement.functionDef (name, Arguments.arguments (), body = body, returns = returnType, decoratorList = decorators)

let constructors =
[
for tag, case in ent.UnionCases |> Seq.indexed do
let name = Identifier case.Name
let args =
Arguments.arguments
[
for field in case.UnionCaseFields do
let ta, _ = typeAnnotation com ctx None field.FieldType
Arg.arg(com.GetIdentifier(ctx, field.Name), ta)
]
let decorators = [Expression.name "staticmethod"]
let values =
[
for field in case.UnionCaseFields do
let identifier : Fable.Ident =
{ Name = field.Name
Type = field.FieldType
IsMutable = false
IsThisArgument = true
IsCompilerGenerated = false
Range = None }
Fable.Expr.IdentExpr identifier
]
let unionExpr,_ =
Fable.Value(Fable.ValueKind.NewUnion(values, tag, ent.Ref, []),None)
|> transformAsExpr com ctx
let body =
[Statement.return' unionExpr]
let returnType =
match args.VarArg with
| None -> Expression.name entName
| Some _ ->
Expression.subscript(Expression.name entName, Expression.name genTypeArgument.Arg)
Statement.functionDef(name, args, body = body, returns = returnType, decoratorList = decorators)
]
let baseExpr = libValue com ctx "types" "Union" |> Some
let classMembers = List.append [ cases ] classMembers
let classMembers = [
cases
yield! constructors
yield! classMembers
]
declareType com ctx ent entName args isOptional body baseExpr classMembers

let transformClassWithCompilerGeneratedConstructor (com: IPythonCompiler) ctx (ent: Fable.Entity) (entName: string) classMembers =
Expand Down
15 changes: 15 additions & 0 deletions tests/Python/TestUnionType.fs
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,18 @@ let ``test Equality works in filter`` () =
|> Array.filter (fun r -> r.Case = MyUnion3.Case1)
|> Array.length
|> equal 2

#if FABLE_COMPILER
open Fable.Core
[<Fact>]
let ``test constructor exposed to python code`` () =
let u0 = MyUnion.Case0
let u1 = MyUnion.Case1 "a"
let u2 = MyUnion.Case2 ("a","b")
let u3 = MyUnion.Case3 ("a","b","c")
let v0 : MyUnion = PyInterop.emitPyExpr () "MyUnion.Case0()"
let v1 : MyUnion = PyInterop.emitPyExpr "a" "MyUnion.Case1($0)"
let v2 : MyUnion = PyInterop.emitPyExpr ("a","b") "MyUnion.Case2($0,$1)"
let v3 : MyUnion = PyInterop.emitPyExpr ("a","b","c") "MyUnion.Case3($0,$1,$2)"
equal [u0;u1;u2;u3] [v0;v1;v2;v3]
#endif
Loading