From 3ed947fef54441981b65f3cde56126d3d062cf4d Mon Sep 17 00:00:00 2001 From: Laurent Le Brun Date: Mon, 7 Jun 2021 00:48:39 +0200 Subject: [PATCH] Inlining: automatically inline trivial values This is a conservative first step. Inline variables only if there value is a constant (it doesn't depend on another variable). --- src/ast.fs | 5 +- src/rewriter.fs | 60 ++++++++++++- .../real/yx_long_way_from_home.frag.expected | 86 +++++++++---------- tests/unit/function_comma.expected | 4 +- tests/unit/inline.expected | 5 ++ tests/unit/inline.frag | 8 ++ tests/unit/inout.expected | 46 +++++----- tests/unit/macros.expected | 4 +- tests/unit/many_variables.expected | 16 ++-- 9 files changed, 149 insertions(+), 85 deletions(-) diff --git a/src/ast.fs b/src/ast.fs index 360c18f9..4aef9518 100644 --- a/src/ast.fs +++ b/src/ast.fs @@ -4,10 +4,13 @@ open Options.Globals type Ident(name: string) = let mutable newName = name + let mutable inlined = newName.StartsWith("i_") + member this.Name = newName member this.OldName = name member this.Rename(n) = newName <- n - member this.MustBeInlined = this.Name.StartsWith("i_") + member this.ToBeInlined = inlined + member this.Inline() = inlined <- true // Real identifiers cannot start with a digit, but the temporary ids of the rename pass are numbers. member this.IsUniqueId = System.Char.IsDigit this.Name.[0] diff --git a/src/rewriter.fs b/src/rewriter.fs index da87699e..45cde208 100644 --- a/src/rewriter.fs +++ b/src/rewriter.fs @@ -57,7 +57,7 @@ let private stripSpaces str = result.ToString() -let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.MustBeInlined) +let private declsNotToInline (d: Ast.DeclElt list) = d |> List.filter (fun x -> not x.name.ToBeInlined) let private bool = function | true -> Var (Ident "true") // Int (1, "") @@ -132,9 +132,9 @@ let rec private simplifyExpr env = function | Dot(e, field) when options.canonicalFieldNames <> "" -> Dot(e, renameField field) - | Var s as e when s.MustBeInlined -> + | Var s as e -> match env.vars.TryFind s.Name with - | Some (_, {init = Some init}) -> init |> mapExpr env + | Some (_, {name = id; init = Some init}) when id.ToBeInlined -> init |> mapExpr env | _ -> e // pi is acos(-1), pi/2 is acos(0) @@ -165,6 +165,56 @@ let private rwTypeSpec = function let rwType (ty: Type) = makeType (rwTypeSpec ty.name) (Option.map stripSpaces ty.typeQ) +// Return the list of variables used in the statements, with the number of references. +let collectReferences stmtList = + let count = Dictionary() + let collectLocalUses _ = function + | Var v as e -> + match count.TryGetValue(v.Name) with + | true, n -> count.[v.Name] <- n + 1 + | false, _ -> count.[v.Name] <- 1 + e + | e -> e + for expr in stmtList do + mapStmt (mapEnv collectLocalUses id) expr |> ignore + count + +// Mark variables as inlinable when possible. +// For now, only mark a variable when: +// - the variable is used only once in the current block +// - the variable is not used in a sub-block (e.g. inside a loop) +// - the init value is trivial (doesn't depend on a variable) +let findInlinable block = + // Variables that are defined in this scope. + let localDefs = Dictionary() + // List of expressions in the current block. Do not look in sub-blocks. + let mutable localExpr = [] + for stmt: Stmt in block do + match stmt with + | Decl (_, li) -> + for def in li do + // can only inline if it has a value + match def.init with + | None -> () + | Some init -> + localExpr <- init :: localExpr + // Inline only if the init value doesn't depend on other variables. + let deps = collectReferences [Expr init] + if deps.Count = 0 then + localDefs.[def.name.Name] <- def.name + | Expr e + | Jump (_, Some e) -> localExpr <- e :: localExpr + | Verbatim _ | Jump (_, None) | Block _ | If _| ForE _ | ForD _ | While _ | DoWhile _ -> () + + let localReferences = collectReferences (List.map Expr localExpr) + let allReferences = collectReferences block + + for def in localDefs do + if not def.Value.ToBeInlined then + match localReferences.TryGetValue(def.Key), allReferences.TryGetValue(def.Key) with + | (true, 1), (true, 1) -> def.Value.Inline() + | _ -> () + let private simplifyStmt = function | Block [] as e -> e | Block b -> @@ -174,6 +224,8 @@ let private simplifyStmt = function // Remove inner empty blocks let b = b |> List.filter (function Block [] | Decl (_, []) -> false | _ -> true) + + findInlinable b // Try to remove blocks by using the comma operator let returnExp = b |> Seq.tryPick (function Jump(JumpKeyword.Return, e) -> e | _ -> None) @@ -216,6 +268,8 @@ let simplify li = li |> reorderTopLevel |> mapTopLevel (mapEnv simplifyExpr simplifyStmt) + // A second pass, because some variables might now be inlinable. + |> mapTopLevel (mapEnv simplifyExpr simplifyStmt) |> List.map (function | TLDecl (ty, li) -> TLDecl (rwType ty, declsNotToInline li) | TLVerbatim s -> TLVerbatim (stripSpaces s) diff --git a/tests/real/yx_long_way_from_home.frag.expected b/tests/real/yx_long_way_from_home.frag.expected index bf92bf9d..f149851f 100644 --- a/tests/real/yx_long_way_from_home.frag.expected +++ b/tests/real/yx_long_way_from_home.frag.expected @@ -64,12 +64,11 @@ const char *yx_long_way_from_home_frag = "m.y+=sin(m.x*2.)*.05;" "m.y-=length(sin(m.xz*.5))*.1;" "m.z+=sin(m.x*.5)*.5;" - "float y=.03;" "m.z+=step(.5,mod(m.x,1.))*.3-.15;" "m.x=mod(m.x,.5)-.25;" - "float l=t(m.xz),z=smoothstep(.1,.13,l);" - "m.y+=.1-z*y;" - "m.y-=smoothstep(.05,0.,abs(l-.16))*.004;" + "float y=t(m.xz),z=smoothstep(.1,.13,y);" + "m.y+=.1-z*.03;" + "m.y-=smoothstep(.05,0.,abs(y-.16))*.004;" "m.y-=(1.-z)*.01*h(m.xz);" "}" "m.y-=smoothstep(2.,0.,length(n.xz+vec2(-1.5,3.5)))*.2;" @@ -91,15 +90,15 @@ const char *yx_long_way_from_home_frag = "vec3 f=cross(vec3(-1,-1,-1),v);" "return f;" "}" - "vec3 e(vec3 v,float m)" + "vec3 e(vec3 v,float y)" "{" "v=normalize(v);" - "vec3 y=normalize(p(v)),f=normalize(cross(v,y));" + "vec3 f=normalize(p(v)),m=normalize(cross(v,f));" "vec2 n=i;" "n.x=n.x*2.*pi;" - "n.y=pow(n.y,1./(m+1.));" + "n.y=pow(n.y,1./(y+1.));" "float x=sqrt(1.-n.y*n.y);" - "return cos(n.x)*x*y+sin(n.x)*x*f+n.y*v;" + "return cos(n.x)*x*f+sin(n.x)*x*m+n.y*v;" "}" "vec3 x(vec3 v)" "{" @@ -145,39 +144,38 @@ const char *yx_long_way_from_home_frag = "}" "vec3 h(vec3 v,vec3 m)" "{" - "float x=.65,z=.18;" - "vec3 y=normalize(vec3((x-.5)*2.,z*2.,-1));" - "const float n=.0001;" - "const vec3 c=vec3(1.,.6,.2)*2.;" - "vec3 r=vec3(1),o=vec3(0);" - "for(int g=0;g<10;++g)" + "vec3 x=normalize(vec3(.3,.36,-1));" + "const float y=.0001;" + "const vec3 n=vec3(1.,.6,.2)*2.;" + "vec3 z=vec3(1),c=vec3(0);" + "for(int r=0;r<10;++r)" "{" - "vec3 a,p;" - "float t;" - "if(d(v,m,a,p,t))" + "vec3 a,o;" + "float p;" + "if(d(v,m,a,o,p))" "{" - "float k=1.;" - "vec3 b=vec3(1);" + "float t=1.;" + "vec3 g=vec3(1);" "if(f==1)" - "b=vec3(.7);" - "k*=k;" + "g=vec3(.7);" + "t*=t;" "{" - "v=a+p*.002;" - "vec3 h=reflect(m,p),u=e(p,1.);" - "m=normalize(mix(h,u,k));" - "r*=b;" + "v=a+o*.002;" + "vec3 h=reflect(m,o),u=e(o,1.);" + "m=normalize(mix(h,u,t));" + "z*=g;" "}" - "vec3 h=d(y,n);" - "float u=dot(p,h);" - "vec3 S,R;" - "float B;" - "if(u>0.&&!d(a+p*.002,h,S,R,B))" - "o+=r*u*c;" + "vec3 h=d(x,y);" + "float u=dot(o,h);" + "vec3 b,k;" + "float S;" + "if(u>0.&&!d(a+o*.002,h,b,k,S))" + "c+=z*u*n;" "i=s(i.y);" "}" "else" - " if(abs(t)>.1)" - "return o+l(m)*r;" + " if(abs(p)>.1)" + "return c+l(m)*z;" "else" " break;" "}" @@ -194,26 +192,26 @@ const char *yx_long_way_from_home_frag = "void main()" "{" "vec2 v=gl_FragCoord.xy/iResolution.xy-.5;" - "float m=iTime+(v.x+iResolution.x*v.y)*1.51269;" - "i=s(m);" + "float y=iTime+(v.x+iResolution.x*v.y)*1.51269;" + "i=s(y);" "v+=(i-.5)/iResolution.xy;" "v.x*=iResolution.x/iResolution.y;" - "const vec3 f=vec3(-4,2,3),y=vec3(0,0,0);" - "const float x=distance(f,y);" + "const vec3 m=vec3(-4,2,3),f=vec3(0,0,0);" + "const float x=distance(m,f);" "const vec2 z=vec2(1,2)*.015;" "vec3 c=vec3(0),r=normalize(vec3(v,2.));" "vec2 t=d();" "c.xy+=t*z;" "r.xy-=t*z*r.z/x;" - "vec3 l=y-f;" - "float p=-atan(l.y,length(l.xz)),a=-atan(l.x,l.z);" + "vec3 l=f-m;" + "float p=-atan(l.y,length(l.xz)),o=-atan(l.x,l.z);" "c.yz*=n(p);" "r.yz*=n(p);" - "c.xz*=n(a);" - "r.xz*=n(a);" - "c+=f;" - "vec4 u=vec4(h(c,r),1);" - "gl_FragColor=!isnan(u.x)&&u.x>=0.?u:vec4(0);" + "c.xz*=n(o);" + "r.xz*=n(o);" + "c+=m;" + "vec4 a=vec4(h(c,r),1);" + "gl_FragColor=!isnan(a.x)&&a.x>=0.?a:vec4(0);" "}"; #endif // YX_LONG_WAY_FROM_HOME_FRAG_EXPECTED_ diff --git a/tests/unit/function_comma.expected b/tests/unit/function_comma.expected index f3806369..203c212e 100644 --- a/tests/unit/function_comma.expected +++ b/tests/unit/function_comma.expected @@ -11,8 +11,8 @@ const char *function_comma_frag = "}" "float foo()" "{" - "float a=1.2,b=2.3;" - "return min((a=1.,b+a),0.);" + "float a=1.2;" + "return min((a=1.,2.3+a),0.);" "}" "float bar()" "{" diff --git a/tests/unit/inline.expected b/tests/unit/inline.expected index 4d1f7ee3..eb3f264a 100644 --- a/tests/unit/inline.expected +++ b/tests/unit/inline.expected @@ -10,3 +10,8 @@ int vars(int arg,int arg2) { return arg*(arg+arg2); } +int arithmetic2() +{ + int a=2,c=a+3; + return 4*a*c; +} diff --git a/tests/unit/inline.frag b/tests/unit/inline.frag index 5eb279dd..2ac344fc 100644 --- a/tests/unit/inline.frag +++ b/tests/unit/inline.frag @@ -20,3 +20,11 @@ int vars(int arg, int arg2) int i_c = i_a + i_b; return i_a * i_c; } + +int arithmetic2() +{ + int a = 2; + int b = 3; + int c = a + b; + return 4 * a * c; +} diff --git a/tests/unit/inout.expected b/tests/unit/inout.expected index 3a2e68d8..e05fde7b 100644 --- a/tests/unit/inout.expected +++ b/tests/unit/inout.expected @@ -7,22 +7,20 @@ in vec3 c,v; out vec4 o; void main() { - vec3 n=normalize(v),f=normalize(c),u=vec3(.1,.2,.3),z=vec3(.5,.5,.5); - float x=1.5; - vec3 p=texture(e,reflect(-n,f)).xyz,d=texture(e,refract(-n,f,1./x)).xyz,s=mix(u*d,p,.1); - o=vec4(s,1.); + vec3 l=normalize(v),u=normalize(c),f=vec3(.1,.2,.3),z=vec3(.5,.5,.5),x=texture(e,reflect(-l,u)).xyz,p=texture(e,refract(-l,u,1./1.5)).xyz,d=mix(f*p,x,.1); + o=vec4(d,1.); } -vec3 r(vec3 z,vec3 n,vec3 C) +vec3 r(vec3 z,vec3 l,vec3 s) { - float y=1.-clamp(dot(n,C),0.,1.); - return y*y*y*y*y*(1.-z)+z; + float C=1.-clamp(dot(l,s),0.,1.); + return C*C*C*C*C*(1.-z)+z; } -vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b) +vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w) { - vec3 C=normalize(n+w); - float Z=1.+2048.*(1.-b)*(1.-b); - vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C); - return mix(Y,X,W); + vec3 s=normalize(l+y); + float b=1.+2048.*(1.-w)*(1.-w); + vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s); + return mix(Z,Y,X); } // tests/unit/inout2.frag @@ -31,25 +29,23 @@ vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b) uniform samplerCube e; uniform float t; -uniform vec3 a,m,l,i; +uniform vec3 m,i,a,n; in vec3 c,v; out vec4 o; -vec3 r(vec3 z,vec3 n,vec3 C) +vec3 r(vec3 z,vec3 l,vec3 s) { - float y=1.-clamp(dot(n,C),0.,1.); - return y*y*y*y*y*(1.-z)+z; + float C=1.-clamp(dot(l,s),0.,1.); + return C*C*C*C*C*(1.-z)+z; } void main() { - vec3 n=normalize(v),f=normalize(c),u=m,z=i; - float V=.5; - vec3 s=l+mix(u*a,a,V); - o=vec4(s,1.); + vec3 l=normalize(v),u=normalize(c),f=i,z=n,d=a+mix(f*m,m,.5); + o=vec4(d,1.); } -vec3 r(vec3 n,vec3 w,vec3 f,vec3 u,vec3 z,float b) +vec3 r(vec3 l,vec3 y,vec3 u,vec3 f,vec3 z,float w) { - vec3 C=normalize(n+w); - float Z=1.+2048.*(1.-b)*(1.-b); - vec3 Y=u,X=vec3(pow(clamp(dot(C,f),0.,1.),Z)*(Z+4.)/8.),W=r(z,n,C); - return mix(Y,X,W); + vec3 s=normalize(l+y); + float b=1.+2048.*(1.-w)*(1.-w); + vec3 Z=f,Y=vec3(pow(clamp(dot(s,u),0.,1.),b)*(b+4.)/8.),X=r(z,l,s); + return mix(Z,Y,X); } diff --git a/tests/unit/macros.expected b/tests/unit/macros.expected index 9f7ec15e..749ed5f0 100644 --- a/tests/unit/macros.expected +++ b/tests/unit/macros.expected @@ -23,8 +23,8 @@ const char *macros_frag = "#define p$\n" "int t()" "{" - "int t=1,r=2,u=3,Z=4,Y=5,X=6,W=7,V=8,U=9,T=10,S=11,R=12;" - "return t+Y+R;" + "int t=2,r=3,u=4,Z=6,Y=7,X=8,W=9,V=10,U=11;" + "return 18;" "}"; #endif // MACROS_EXPECTED_ diff --git a/tests/unit/many_variables.expected b/tests/unit/many_variables.expected index 4d21a695..4e4f6b56 100644 --- a/tests/unit/many_variables.expected +++ b/tests/unit/many_variables.expected @@ -6,12 +6,12 @@ "int t(float t,float o,float l,float f,float a,float n,float i,float r,float u,float e,float Z,float Y)" "{" "float X=t,W=o,V=l,U=f,T=a,S=n;" - "int R=1,Q=2,P=3,O=4,N=5,M=6,L=7,K=8;" - "float J=i,I=r,H=u,G=e,F=Z,E=Y;" - "int D=1,C=2,B=3,A=4,z=5,y=6,x=7,w=8;" - "float v=0.;" - "int s=1,q=2,p=3,m=4,k=5,j=6,h=7,g=8;" - "float d=0.;" - "int c=1,b=2,at=3,ab=4,ac=5,ad=6,ag=7,ah=8;" - "return K+w+g+ah;" + "int R=1,Q=2,P=3,O=4,N=5,M=6,L=7;" + "float K=i,J=r,I=u,H=e,G=Z,F=Y;" + "int E=1,D=2,C=3,B=4,A=5,z=6,y=7;" + "float x=0.;" + "int w=1,v=2,s=3,q=4,p=5,m=6,k=7;" + "float j=0.;" + "int h=1,g=2,d=3,c=4,b=5,at=6,ab=7;" + "return 32;" "}",