diff --git a/crosstl/src/translator/ast.py b/crosstl/src/translator/ast.py index a4c681ea..f8f2f0b5 100644 --- a/crosstl/src/translator/ast.py +++ b/crosstl/src/translator/ast.py @@ -97,13 +97,22 @@ def __repr__(self): class IfNode(ASTNode): - def __init__(self, condition, if_body, else_body=None): - self.condition = condition + def __init__( + self, + if_condition, + if_body, + else_if_conditions=[], + else_if_bodies=[], + else_body=None, + ): + self.if_condition = if_condition self.if_body = if_body + self.else_if_conditions = else_if_conditions + self.else_if_bodies = else_if_bodies self.else_body = else_body def __repr__(self): - return f"IfNode(condition={self.condition}, if_body={self.if_body}, else_body={self.else_body})" + return f"IfNode(if_condition={self.if_condition}, if_body={self.if_body}, else_if_conditions={self.else_if_conditions}, else_if_bodies={self.else_if_bodies}, else_body={self.else_body})" class ForNode(ASTNode): diff --git a/crosstl/src/translator/codegen/directx_codegen.py b/crosstl/src/translator/codegen/directx_codegen.py index fb6e9543..47eedd48 100644 --- a/crosstl/src/translator/codegen/directx_codegen.py +++ b/crosstl/src/translator/codegen/directx_codegen.py @@ -252,10 +252,19 @@ def generate_assignment(self, node, shader_type=None): def generate_if(self, node, indent, shader_type=None): indent_str = " " * indent - code = f"{indent_str}if ({self.generate_expression(node.condition, shader_type)}) {{\n" + code = f"{indent_str}if ({self.generate_expression(node.if_condition, shader_type)}) {{\n" for stmt in node.if_body: code += self.generate_statement(stmt, indent + 1, shader_type) code += f"{indent_str}}}" + + for else_if_condition, else_if_body in zip( + node.else_if_conditions, node.else_if_bodies + ): + code += f" else if ({self.generate_expression(else_if_condition, shader_type)}) {{\n" + for stmt in else_if_body: + code += self.generate_statement(stmt, indent + 1, shader_type) + code += f"{indent_str}}}" + if node.else_body: code += " else {\n" for stmt in node.else_body: diff --git a/crosstl/src/translator/codegen/metal_codegen.py b/crosstl/src/translator/codegen/metal_codegen.py index 3e34eaf8..2a258c92 100644 --- a/crosstl/src/translator/codegen/metal_codegen.py +++ b/crosstl/src/translator/codegen/metal_codegen.py @@ -304,10 +304,19 @@ def generate_assignment(self, node, shader_type=None): def generate_if(self, node, indent, shader_type=None): indent_str = " " * indent - code = f"{indent_str}if ({self.generate_expression(node.condition, shader_type)}) {{\n" + code = f"{indent_str}if ({self.generate_expression(node.if_condition, shader_type)}) {{\n" for stmt in node.if_body: code += self.generate_statement(stmt, indent + 1, shader_type) code += f"{indent_str}}}" + + for else_if_condition, else_if_body in zip( + node.else_if_conditions, node.else_if_bodies + ): + code += f" else if ({self.generate_expression(else_if_condition, shader_type)}) {{\n" + for stmt in else_if_body: + code += self.generate_statement(stmt, indent + 1, shader_type) + code += f"{indent_str}}}" + if node.else_body: code += " else {\n" for stmt in node.else_body: diff --git a/crosstl/src/translator/codegen/opengl_codegen.py b/crosstl/src/translator/codegen/opengl_codegen.py index 5b8651e1..0a462c1a 100644 --- a/crosstl/src/translator/codegen/opengl_codegen.py +++ b/crosstl/src/translator/codegen/opengl_codegen.py @@ -174,10 +174,19 @@ def generate_assignment(self, node, shader_type=None): def generate_if(self, node, indent, shader_type=None): indent_str = " " * indent - code = f"{indent_str}if ({self.generate_expression(node.condition, shader_type)}) {{\n" + code = f"{indent_str}if ({self.generate_expression(node.if_condition, shader_type)}) {{\n" for stmt in node.if_body: code += self.generate_statement(stmt, indent + 1, shader_type) code += f"{indent_str}}}" + + for else_if_condition, else_if_body in zip( + node.else_if_conditions, node.else_if_bodies + ): + code += f" else if ({self.generate_expression(else_if_condition, shader_type)}) {{\n" + for stmt in else_if_body: + code += self.generate_statement(stmt, indent + 1, shader_type) + code += f"{indent_str}}}" + if node.else_body: code += " else {\n" for stmt in node.else_body: diff --git a/crosstl/src/translator/parser.py b/crosstl/src/translator/parser.py index bf5862c6..82899e60 100644 --- a/crosstl/src/translator/parser.py +++ b/crosstl/src/translator/parser.py @@ -467,18 +467,31 @@ def parse_if_statement(self): """ self.eat("IF") self.eat("LPAREN") - condition = self.parse_expression() + if_condition = self.parse_expression() self.eat("RPAREN") self.eat("LBRACE") if_body = self.parse_body() self.eat("RBRACE") + else_if_condition = [] + else_if_body = [] else_body = None + + while self.current_token[0] == "ELSE" and self.peak(1)[0] == "IF": + self.eat("ELSE") + self.eat("IF") + self.eat("LPAREN") + else_if_condition.append(self.parse_expression()) + self.eat("RPAREN") + self.eat("LBRACE") + else_if_body.append(self.parse_body()) + self.eat("RBRACE") + if self.current_token[0] == "ELSE": self.eat("ELSE") self.eat("LBRACE") else_body = self.parse_body() self.eat("RBRACE") - return IfNode(condition, if_body, else_body) + return IfNode(if_condition, if_body, else_if_condition, else_if_body, else_body) def peak(self, n): """Peek ahead in the token list diff --git a/tests/test_translator/test_codegen/test_directx_codegen.py b/tests/test_translator/test_codegen/test_directx_codegen.py index 05510651..8435bd6c 100644 --- a/tests/test_translator/test_codegen/test_directx_codegen.py +++ b/tests/test_translator/test_codegen/test_directx_codegen.py @@ -189,6 +189,58 @@ def test_else_statement(): pytest.fail("Struct parsing not implemented.") +def test_else_if_statement(): + code = """ + shader PerlinNoise { + vertex { + input vec3 position; + output vec2 vUV; + + void main() { + vUV = position.xy * 10.0; + if (vUV.x < 0.5) { + vUV.x = 0.25; + } + if (vUV.x < 0.25) { + vUV.x = 0.0; + } else if (vUV.x < 0.75) { + vUV.x = 0.5; + } else if (vUV.x < 1.0) { + vUV.x = 0.75; + } else { + vUV.x = 0.0; + } + gl_Position = vec4(position, 1.0); + } + } + + // Fragment Shader + fragment { + input vec2 vUV; + output vec4 fragColor; + + void main() { + if (vUV.x > 0.75) { + fragColor = vec4(1.0, 1.0, 1.0, 1.0); + } else if (vUV.x > 0.5) { + fragColor = vec4(0.5, 0.5, 0.5, 1.0); + } else { + fragColor = vec4(0.0, 0.0, 0.0, 1.0); + } + fragColor = vec4(color, 1.0); + } + } + } + """ + try: + tokens = tokenize_code(code) + ast = parse_code(tokens) + code = generate_code(ast) + print(code) + except SyntaxError: + pytest.fail("Struct parsing not implemented.") + + def test_function_call(): code = """ shader PerlinNoise { diff --git a/tests/test_translator/test_codegen/test_metal_codegen.py b/tests/test_translator/test_codegen/test_metal_codegen.py index 1d1c226d..75d1475d 100644 --- a/tests/test_translator/test_codegen/test_metal_codegen.py +++ b/tests/test_translator/test_codegen/test_metal_codegen.py @@ -189,6 +189,60 @@ def test_else_statement(): pytest.fail("Struct parsing not implemented.") +def test_else_if_statement(): + code = """ + shader PerlinNoise { + vertex { + input vec3 position; + output vec2 vUV; + + void main() { + vUV = position.xy * 10.0; + if (vUV.x < 0.5) { + vUV.x = 0.25; + } + if (vUV.x < 0.25) { + vUV.x = 0.0; + } else if (vUV.x < 0.75) { + vUV.x = 0.5; + } else if (vUV.x < 1.0) { + vUV.x = 0.75; + } else { + vUV.x = 0.0; + } + gl_Position = vec4(position, 1.0); + } + } + + // Fragment Shader + fragment { + input vec2 vUV; + output vec4 fragColor; + + void main() { + if (vUV.x > 0.75) { + fragColor = vec4(1.0, 1.0, 1.0, 1.0); + } else if (vUV.x > 0.5) { + fragColor = vec4(0.5, 0.5, 0.5, 1.0); + } else if (vUV.x > 0.25) { + fragColor = vec4(0.25, 0.25, 0.25, 1.0); + } else { + fragColor = vec4(0.0, 0.0, 0.0, 1.0); + } + fragColor = vec4(color, 1.0); + } + } + } + """ + try: + tokens = tokenize_code(code) + ast = parse_code(tokens) + code = generate_code(ast) + print(code) + except SyntaxError: + pytest.fail("Struct parsing not implemented.") + + def test_function_call(): code = """ shader PerlinNoise { diff --git a/tests/test_translator/test_codegen/test_opengl_codegen.py b/tests/test_translator/test_codegen/test_opengl_codegen.py index 13ab84d2..a789e6ee 100644 --- a/tests/test_translator/test_codegen/test_opengl_codegen.py +++ b/tests/test_translator/test_codegen/test_opengl_codegen.py @@ -189,6 +189,60 @@ def test_else_statement(): pytest.fail("Struct parsing not implemented.") +def test_else_if_statement(): + code = """ + shader PerlinNoise { + vertex { + input vec3 position; + output vec2 vUV; + + void main() { + vUV = position.xy * 10.0; + if (vUV.x < 0.5) { + vUV.x = 0.25; + } + if (vUV.x < 0.25) { + vUV.x = 0.0; + } else if (vUV.x < 0.75) { + vUV.x = 0.5; + } else if (vUV.x < 1.0) { + vUV.x = 0.75; + } else { + vUV.x = 0.0; + } + gl_Position = vec4(position, 1.0); + } + } + + // Fragment Shader + fragment { + input vec2 vUV; + output vec4 fragColor; + + void main() { + if (vUV.x > 0.75) { + fragColor = vec4(1.0, 1.0, 1.0, 1.0); + } else if (vUV.x > 0.5) { + fragColor = vec4(0.5, 0.5, 0.5, 1.0); + } else if (vUV.x > 0.25) { + fragColor = vec4(0.25, 0.25, 0.25, 1.0); + } else { + fragColor = vec4(0.0, 0.0, 0.0, 1.0); + } + fragColor = vec4(color, 1.0); + } + } + } + """ + try: + tokens = tokenize_code(code) + ast = parse_code(tokens) + code = generate_code(ast) + print(code) + except SyntaxError: + pytest.fail("Struct parsing not implemented.") + + def test_function_call(): code = """ shader PerlinNoise { diff --git a/tests/test_translator/test_lexer.py b/tests/test_translator/test_lexer.py index 754ba6a0..b77c40da 100644 --- a/tests/test_translator/test_lexer.py +++ b/tests/test_translator/test_lexer.py @@ -62,6 +62,27 @@ def test_else_statement_tokenization(): pytest.fail("Struct parsing not implemented.") +def test_else_if_statement_tokenization(): + code = """ + if (!a) { + return b; + } + if (!b) { + return a; + } else if (a < b) { + return b; + } else if (a > b) { + return a; + } else { + return 0; + } + """ + try: + tokenize_code(code) + except SyntaxError: + pytest.fail("Struct parsing not implemented.") + + def test_function_call_tokenization(): code = """ shader PerlinNoise { diff --git a/tests/test_translator/test_parser.py b/tests/test_translator/test_parser.py index dd65d7f2..672f075e 100644 --- a/tests/test_translator/test_parser.py +++ b/tests/test_translator/test_parser.py @@ -169,6 +169,58 @@ def test_else_statement(): pytest.fail("Struct parsing not implemented.") +def test_else_if_statement(): + code = """ + shader PerlinNoise { + vertex { + input vec3 position; + output vec2 vUV; + + void main() { + vUV = position.xy * 10.0; + if (vUV.x < 0.5) { + vUV.x = 0.25; + } + if (vUV.x < 0.25) { + vUV.x = 0.0; + } else if (vUV.x < 0.75) { + vUV.x = 0.5; + } else if (vUV.x < 1.0) { + vUV.x = 0.75; + } else { + vUV.x = 0.0; + } + gl_Position = vec4(position, 1.0); + } + } + + // Fragment Shader + fragment { + input vec2 vUV; + output vec4 fragColor; + + void main() { + if (vUV.x > 0.75) { + fragColor = vec4(1.0, 1.0, 1.0, 1.0); + } else if (vUV.x > 0.5) { + fragColor = vec4(0.5, 0.5, 0.5, 1.0); + } else if (vUV.x > 0.25) { + fragColor = vec4(0.25, 0.25, 0.25, 1.0); + } else { + fragColor = vec4(0.0, 0.0, 0.0, 1.0); + } + fragColor = vec4(color, 1.0); + } + } + } + """ + try: + tokens = tokenize_code(code) + parse_code(tokens) + except SyntaxError: + pytest.fail("Struct parsing not implemented.") + + def test_function_call(): code = """ shader PerlinNoise {