Skip to content

Commit

Permalink
Added else if Conditional Statements to the Metal backend #34 (#39)
Browse files Browse the repository at this point in the history
* Added else if Conditional Statements to the Metal backend #34

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix: Fixed the code for the mentioned issues #34

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor Fixes

* Fixed and Ran tests for issue #34

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor Fixes for else_if in MetalCrossGLCodeGen.

* Added If test condition

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor Fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixing else if output for test cases

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor Fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: Fixed Parser

* Fixed MetalParser.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix: Made requested changes for the MetalParser.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix : Made the changes in Metalparser.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Minor Fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: fixed multiple else if statement generation

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: samthakur587 <samunder268@gmail.com>
  • Loading branch information
3 people authored Aug 29, 2024
1 parent 421e229 commit c5bc4c5
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 18 deletions.
Binary file added .DS_Store
Binary file not shown.
8 changes: 4 additions & 4 deletions crosstl/src/backend/Metal/MetalAst.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ def __repr__(self):


class IfNode(ASTNode):
def __init__(self, condition, if_body, else_body=None):
self.condition = condition
self.if_body = if_body
def __init__(self, if_chain=None, else_if_chain=None, else_body=None):
self.if_chain = if_chain or []
self.else_if_chain = else_if_chain or []
self.else_body = else_body

def __repr__(self):
return f"IfNode(condition={self.condition}, if_body={self.if_body}, else_body={self.else_body})"
return f"IfNode(if_chain={self.if_chain}, else_if_chain={self.else_if_chain}, else_body={self.else_body})"


class ForNode(ASTNode):
Expand Down
21 changes: 16 additions & 5 deletions crosstl/src/backend/Metal/MetalCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,23 @@ def generate_for_loop(self, node, indent, is_main):
return code

def generate_if_statement(self, node, indent, is_main):
condition = self.generate_expression(node.condition, is_main)

code = f"if ({condition}) {{\n"
code += self.generate_function_body(node.if_body, indent + 1, is_main)
code += " " * indent + "}"
code = ""
if node.if_chain:
# Handle the if chain
for condition, body in node.if_chain:
code += f"if ({self.generate_expression(condition, is_main)}) {{\n"
code += self.generate_function_body(body, indent + 1, is_main)
code += " " * indent + "}"
# Handling the else if chain
if node.else_if_chain:
for condition, body in node.else_if_chain:
code += (
f" else if ({self.generate_expression(condition, is_main)}) {{\n"
)
code += self.generate_function_body(body, indent + 1, is_main)
code += " " * indent + "}"

# Handling the else condition
if node.else_body:
code += " else {\n"
code += self.generate_function_body(node.else_body, indent + 1, is_main)
Expand Down
2 changes: 2 additions & 0 deletions crosstl/src/backend/Metal/MetalLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
("BOOL", r"\bbool\b"),
("VOID", r"\bvoid\b"),
("RETURN", r"\breturn\b"),
("ELSE_IF", r"\belse\s+if\b"),
("IF", r"\bif\b"),
("ELSE", r"\belse\b"),
("FOR", r"\bfor\b"),
Expand Down Expand Up @@ -75,6 +76,7 @@
"bool": "BOOL",
"void": "VOID",
"return": "RETURN",
"else if": "ELSE_IF",
"if": "IF",
"else": "ELSE",
"for": "FOR",
Expand Down
25 changes: 19 additions & 6 deletions crosstl/src/backend/Metal/MetalParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,16 +305,29 @@ def parse_variable_declaration_or_assignment(self):
return expr

def parse_if_statement(self):
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
if_body = self.parse_block()
if_chain = []
else_if_chain = []
else_body = None
while self.current_token[0] == "IF":
self.eat("IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
if_chain.append((condition, body))
while self.current_token[0] == "ELSE_IF":
self.eat("ELSE_IF")
self.eat("LPAREN")
condition = self.parse_expression()
self.eat("RPAREN")
body = self.parse_block()
else_if_chain.append((condition, body))

if self.current_token[0] == "ELSE":
self.eat("ELSE")
else_body = self.parse_block()
return IfNode(condition, if_body, else_body)

return IfNode(if_chain, else_if_chain, else_body)

def parse_for_statement(self):
self.eat("FOR")
Expand Down
58 changes: 58 additions & 0 deletions tests/test_backend/test_metal/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,63 @@ def test_function_call():
pytest.fail("Struct parsing not implemented.")


def test_else_if():
code = """
#include <metal_stdlib>
using namespace metal;
struct Vertex_INPUT {
float3 position [[attribute(0)]];
};
struct Vertex_OUTPUT {
float4 position [[position]];
float2 vUV;
};
vertex Vertex_OUTPUT vertex_main(Vertex_INPUT input [[stage_in]]) {
Vertex_OUTPUT output;
output.position = float4(input.position, 1.0);
if (input.position.x == input.position.y) {
output.vUV = float2(0.0, 0.0);}
else if (input.position.x > input.position.y) {
output.vUV = float2(1.0, 1.0);
}
else if (input.position.x < input.position.y) {
output.vUV = float2(-1.0, -1.0);
}
else {
output.vUV = float2(0.0, 0.0);
}
return output;
}
struct Fragment_INPUT {
float2 vUV [[stage_in]];
};
struct Fragment_OUTPUT {
float4 fragColor [[color(0)]];
};
fragment Fragment_OUTPUT fragment_main(Fragment_INPUT input [[stage_in]]) {
Fragment_OUTPUT output;
if (input.vUV.x == input.vUV.y) {
output.fragColor = float4(0.0, 1.0, 0.0, 1.0);
} else if (input.vUV.x > input.vUV.y) {
output.fragColor = float4(1.0, 0.0, 0.0, 1.0);
} else {
output.fragColor = float4(0.0, 0.0, 1.0, 1.0);
}
return output;
}
"""

tokens = tokenize_code(code)
ast = parse_code(tokens)
generated_code = generate_code(ast)
print(generated_code)


if __name__ == "__main__":
pytest.main()
29 changes: 26 additions & 3 deletions tests/test_backend/test_metal/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
def tokenize_code(code: str) -> List:
"""Helper function to tokenize code."""
lexer = MetalLexer(code)
return lexer.tokenize()
return lexer.tokens


def test_struct_tokenization():
Expand All @@ -20,8 +20,10 @@ def test_struct_tokenization():
float2 vUV;
};
"""
tokens = tokenize_code(code)
print(tokens)
try:
tokenize_code(code)
except SyntaxError:
pytest.fail("Struct tokenization not implemented.")


def test_if_tokenization():
Expand Down Expand Up @@ -103,5 +105,26 @@ def test_function_call_tokenization():
pytest.fail("Function call tokenization not implemented.")


def test_if_else_tokenization():
code = """
float perlinNoise(float2 p) {
if (p.x == p.y) {
return 0.0;
}
else if (p.x == 0.0) {
return 1.0;
}
else {
return fract(sin(dot(p, float2(12.9898, 78.233))) * 43758.5453);
}
}
"""
try:
tokenize_code(code)
except SyntaxError:
pytest.fail("If-else statement tokenization not implemented.")


if __name__ == "__main__":
pytest.main()
29 changes: 29 additions & 0 deletions tests/test_backend/test_metal/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,34 @@ def test_function_call():
pytest.fail("Struct parsing not implemented.")


def test_if_else():
code = """
float perlinNoise(float2 p) {
if (p.x == p.y) {
return 0.0;
}
if (p.x > p.y) {
return 1.0;
}
else if (p.x > p.y) {
return 1.0;
}
else if (p.x < p.y) {
return -1.0;
}
else {
return fract(sin(dot(p, float2(12.9898, 78.233))) * 43758.5453);
}
}
"""
try:
tokens = tokenize_code(code)
parse_code(tokens)
except SyntaxError:
pytest.fail("If-else statement parsing not implemented.")


if __name__ == "__main__":
pytest.main()

0 comments on commit c5bc4c5

Please sign in to comment.