diff --git a/.github/workflows/tests-and-docs.yml b/.github/workflows/tests-and-docs.yml
index ba08f5970a3..a338b47ddd8 100644
--- a/.github/workflows/tests-and-docs.yml
+++ b/.github/workflows/tests-and-docs.yml
@@ -86,7 +86,7 @@ jobs:
# Install tree-sitter parser (for Python component unit tests)
- name: Install tree-sitter parsers
working-directory: .
- run: python skema/program_analysis/tree_sitter_parsers/build_parsers.py --all
+ run: python skema/program_analysis/tree_sitter_parsers/build_parsers.py --ci --all
# docs (API)
diff --git a/docs/dev/cast_frontend.md b/docs/dev/cast_frontend.md
new file mode 100644
index 00000000000..b49e01d4364
--- /dev/null
+++ b/docs/dev/cast_frontend.md
@@ -0,0 +1,9 @@
+## CAST FrontEnd Generation Notes
+### Using Var vs Name nodes
+Currently in the CAST generation we have a convention on when to use Var and Name nodes.
+The GroMEt generation depends on these being conistent, otherwise there will be errors in the generation.
+In the future this convention might change, or be eliminated altogether, but for now this is the current set of rules.
+
+- If the variable in question is being stored into (i.e. as the result of an assignment), then we use Var. Even if it's a variable that has already been defined.
+- If the variable in question is being read from (i.e. being used in an expression), then we use Name.
+- Whenever we're creating a function call Call() node, the name of the function is specified using the Name node.
diff --git a/mkdocs.yml b/mkdocs.yml
index 84d88cd8272..d338cf679bf 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -45,6 +45,7 @@ nav:
- Generating code2fn model coverage reports: "dev/generating_code2fn_model_coverage.md"
- Using code ingestion frontends: "dev/using_code_ingestion_frontends.md"
- Using tree-sitter preprocessor: "dev/using_tree_sitter_preprocessor.md"
+ - CAST Front-end generation: "dev/cast_frontend.md"
- Coverage:
- Code2fn coverage reports: "coverage/code2fn_coverage/report.html"
- TA1 Integration Dashboard: "https://integration-dashboard.terarium.ai/TA1"
diff --git a/pyproject.toml b/pyproject.toml
index af3f401680b..20028ab88d8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,6 +23,7 @@ dependencies=[
"typing_extensions", # see https://github.com/pydantic/pydantic/issues/5821#issuecomment-1559196859
"fastapi~=0.100.0",
"starlette",
+ "httpx",
"pydantic>=2.0.0",
"uvicorn",
"python-multipart",
@@ -42,7 +43,7 @@ dynamic = ["readme"]
# Pygraphviz is often tricky to install, so we reserve it for the dev extras
# list.
# - six: Required by auto-generated Swagger models
-dev = ["pytest", "pytest-cov", "pytest-xdist", "httpx", "black", "mypy", "coverage", "pygraphviz", "six"]
+dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "black", "mypy", "coverage", "pygraphviz", "six"]
demo = ["notebook"]
diff --git a/skema/gromet/fn/gromet_fn_module.py b/skema/gromet/fn/gromet_fn_module.py
index ec1a94f11e3..2d36d86e703 100644
--- a/skema/gromet/fn/gromet_fn_module.py
+++ b/skema/gromet/fn/gromet_fn_module.py
@@ -191,7 +191,7 @@ def fn_array(self, fn_array):
def metadata_collection(self):
"""Gets the metadata_collection of this GrometFNModule. # noqa: E501
- Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GroMEt object. # noqa: E501
+ Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GrometFNModule object. # noqa: E501
:return: The metadata_collection of this GrometFNModule. # noqa: E501
:rtype: list[list[Metadata]]
@@ -202,7 +202,7 @@ def metadata_collection(self):
def metadata_collection(self, metadata_collection):
"""Sets the metadata_collection of this GrometFNModule.
- Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GroMEt object. # noqa: E501
+ Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GrometFNModule object. # noqa: E501
:param metadata_collection: The metadata_collection of this GrometFNModule. # noqa: E501
:type: list[list[Metadata]]
diff --git a/skema/img2mml/eqn2mml.py b/skema/img2mml/eqn2mml.py
index 9c7e3260d0b..e523f4521e5 100644
--- a/skema/img2mml/eqn2mml.py
+++ b/skema/img2mml/eqn2mml.py
@@ -7,7 +7,7 @@
from typing import Text
from typing_extensions import Annotated
-from fastapi import APIRouter, FastAPI, Response, Request, Query, UploadFile
+from fastapi import APIRouter, FastAPI, status, Response, Request, Query, UploadFile
from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS
from skema.img2mml.api import (
get_mathml_from_bytes,
@@ -86,23 +86,23 @@ def process_latex_equation(eqn: Text) -> Response:
"/img2mml/healthcheck",
summary="Check health of eqn2mml service",
response_model=int,
- status_code=200,
+ status_code=status.HTTP_200_OK,
)
def img2mml_healthcheck() -> int:
- return 200
+ return status.HTTP_200_OK
@router.get(
"/latex2mml/healthcheck",
summary="Check health of mathjax service",
response_model=int,
- status_code=200,
+ status_code=status.HTTP_200_OK,
)
def latex2mml_healthcheck() -> int:
try:
return int(requests.get(f"{SKEMA_MATHJAX_ADDRESS}/healthcheck").status_code)
except:
- return 500
+ return status.HTTP_500_INTERNAL_SERVER_ERROR
@router.post("/image/mml", summary="Get MathML representation of an equation image")
diff --git a/skema/isa/data.py b/skema/isa/data.py
new file mode 100644
index 00000000000..8e5ecbc1334
--- /dev/null
+++ b/skema/isa/data.py
@@ -0,0 +1,44 @@
+# -*- coding: utf-8 -*-
+
+mml = """
+ """
+
+expected = 'digraph G {\n0 [color=blue, label="Div(Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H))"];\n1 [color=blue, label="D(1, t)(H)"];\n2 [color=blue, label="Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H)"];\n3 [color=blue, label="Γ"];\n4 [color=blue, label="H^(n+2)"];\n5 [color=blue, label="H"];\n6 [color=blue, label="n+2"];\n7 [color=blue, label="n"];\n8 [color=blue, label="2"];\n9 [color=blue, label="Abs(Grad(H))^(n-1)"];\n10 [color=blue, label="Abs(Grad(H))"];\n11 [color=blue, label="Grad(H)"];\n12 [color=blue, label="n-1"];\n13 [color=blue, label="1"];\n1 -> 0 [color=blue, label="="];\n2 -> 0 [color=blue, label="Div"];\n3 -> 2 [color=blue, label="*"];\n4 -> 2 [color=blue, label="*"];\n5 -> 4 [color=blue, label="^"];\n6 -> 4 [color=blue, label="^"];\n7 -> 6 [color=blue, label="+"];\n8 -> 6 [color=blue, label="+"];\n9 -> 2 [color=blue, label="*"];\n10 -> 9 [color=blue, label="^"];\n11 -> 10 [color=blue, label="Abs"];\n5 -> 11 [color=blue, label="Grad"];\n12 -> 9 [color=blue, label="^"];\n7 -> 12 [color=blue, label="+"];\n13 -> 12 [color=blue, label="-"];\n11 -> 2 [color=blue, label="*"];\n}\n'
\ No newline at end of file
diff --git a/skema/isa/isa_service.py b/skema/isa/isa_service.py
index 011bac98f96..bc0d2b4de16 100644
--- a/skema/isa/isa_service.py
+++ b/skema/isa/isa_service.py
@@ -1,12 +1,15 @@
# -*- coding: utf-8 -*-
-from fastapi import FastAPI, File
+from fastapi import Depends, FastAPI, APIRouter, status
from skema.isa.lib import align_mathml_eqs
+import skema.isa.data as isa_data
+from skema.rest import utils
from pydantic import BaseModel
+import httpx
-# Create a web app using FastAPI
+from skema.rest.proxies import SKEMA_RS_ADDESS
-app = FastAPI()
+router = APIRouter()
# Model for ISA_Result
@@ -15,17 +18,39 @@ class ISA_Result(BaseModel):
union_graph: str = None
-@app.get("/ping", summary="Ping endpoint to test health of service")
-def ping():
- return "The ISA service is running."
+@router.get(
+ "/healthcheck",
+ summary="Status of ISA service",
+ response_model=int,
+ status_code=status.HTTP_200_OK
+)
+async def healthcheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> int:
+ res = await client.get(f"{SKEMA_RS_ADDESS}/ping")
+ return res.status_code
-@app.put("/align-eqns", summary="Align two MathML equations")
+@router.post(
+ "/align-eqns",
+ summary="Align two MathML equations"
+)
async def align_eqns(
- file1: str, file2: str, mention_json1: str = "", mention_json2: str = ""
+ mml1: str, mml2: str, mention_json1: str = "", mention_json2: str = ""
) -> ISA_Result:
- """
+ f"""
Endpoint for align two MathML equations.
+
+ ### Python example
+
+ ```
+ import requests
+
+ request = {{
+ "mml1": {isa_data.mml},
+ "mml2": {isa_data.mml}
+ }}
+
+ response=requests.post("/isa/align-eqns", json=request)
+ res = response.json()
"""
(
matching_ratio,
@@ -36,8 +61,16 @@ async def align_eqns(
aligned_indices2,
union_graph,
perfectly_matched_indices1,
- ) = align_mathml_eqs(file1, file2, mention_json1, mention_json2)
- ir = ISA_Result()
- ir.matching_ratio = matching_ratio
- ir.union_graph = union_graph.to_string()
- return ir
+ ) = align_mathml_eqs(mml1, mml2, mention_json1, mention_json2)
+ return ISA_Result(
+ matching_ratio = matching_ratio,
+ union_graph = union_graph.to_string()
+ )
+
+
+app = FastAPI()
+app.include_router(
+ router,
+ prefix="/isa",
+ tags=["isa"],
+)
diff --git a/skema/isa/lib.py b/skema/isa/lib.py
index 12cce55c017..a6590b8dc92 100644
--- a/skema/isa/lib.py
+++ b/skema/isa/lib.py
@@ -2,13 +2,14 @@
"""
All the functions required by performing incremental structure alignment (ISA)
Author: Liang Zhang (liangzh@arizona.edu)
-Updated date: August 24, 2023
+Updated date: December 18, 2023
"""
import json
import warnings
from typing import List, Any, Union, Dict
from numpy import ndarray
from pydot import Dot
+from skema.rest.proxies import SKEMA_RS_ADDESS
warnings.filterwarnings("ignore")
import requests
@@ -173,8 +174,9 @@ def generate_graph(file: str = "", render: bool = False) -> pydot.Dot:
content = f.read()
digraph = requests.put(
- "http://localhost:8080/mathml/math-exp-graph", data=content.encode("utf-8")
+ f"{SKEMA_RS_ADDESS}/mathml/math-exp-graph", data=content.encode("utf-8")
)
+
if render:
src = Source(digraph.text)
src.render("doctest-output/mathml_exp_tree", view=True)
@@ -671,8 +673,8 @@ def check_square_array(arr: np.ndarray) -> List[int]:
def align_mathml_eqs(
- file1: str = "",
- file2: str = "",
+ mml1: str = "",
+ mml2: str = "",
mention_json1: str = "",
mention_json2: str = "",
mode: int = 2,
@@ -685,7 +687,7 @@ def align_mathml_eqs(
[1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019).
Seeded graph matching. Pattern recognition, 87, 203-215.
- Input: the paths of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2;
+ Input: mml1 & mml2: the file path or contents of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2;
mode 0: without considering any priors; mode 1: having a heuristic prior
with the similarity of node labels; mode 2: using the variable definitions
Output:
@@ -698,8 +700,8 @@ def align_mathml_eqs(
union_graph: the visualization of the alignment result
perfectly_matched_indices1: strictly matched node indices in Graph 1
"""
- graph1 = generate_graph(file1)
- graph2 = generate_graph(file2)
+ graph1 = generate_graph(mml1)
+ graph2 = generate_graph(mml2)
amatrix1, node_labels1 = generate_amatrix(graph1)
amatrix2, node_labels2 = generate_amatrix(graph2)
diff --git a/skema/program_analysis/CAST/fortran/node_helper.py b/skema/program_analysis/CAST/fortran/node_helper.py
index f3d83853dd7..51f614a3586 100644
--- a/skema/program_analysis/CAST/fortran/node_helper.py
+++ b/skema/program_analysis/CAST/fortran/node_helper.py
@@ -1,8 +1,10 @@
+import itertools
from typing import List, Dict
-from skema.program_analysis.CAST2FN.model.cast import SourceRef
from tree_sitter import Node
+from skema.program_analysis.CAST2FN.model.cast import SourceRef
+
CONTROL_CHARACTERS = [
",",
"=",
@@ -41,7 +43,7 @@ def __init__(self, source: str, source_file_name: str):
# get_identifier optimization variables
self.source_lines = source.splitlines(keepends=True)
self.line_lengths = [len(line) for line in self.source_lines]
- self.line_length_sums = [sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))]
+ self.line_length_sums = list(itertools.accumulate(self.line_lengths))#[sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))]
def get_source_ref(self, node: Node) -> SourceRef:
"""Given a node and file name, return a CAST SourceRef object."""
@@ -96,6 +98,9 @@ def get_children_by_types(node: Node, types: List):
"""Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
return [child for child in node.children if child.type in types]
+def get_children_except_types(node: Node, types: List):
+ """Takes in a node and a list of types as inputs and returns all children not matching those types. Otherwise, return an empty list"""
+ return [child for child in node.children if child.type not in types]
def get_first_child_index(node, type: str):
"""Get the index of the first child of node with type type."""
diff --git a/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py b/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py
index 5bd6942017f..2e6a608a4a3 100644
--- a/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py
+++ b/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py
@@ -34,7 +34,7 @@ def preprocess(
"""
# NOTE: The order of preprocessing steps does matter. We have to run the GCC preprocessor before correcting the continuation lines or there could be issues
- # TODO: Create single location for generating include base path
+ # TODO: Create single location for generating include base path
source = source_path.read_text()
# Get paths for intermediate products
@@ -67,7 +67,7 @@ def preprocess(
# Step 2: Correct include directives to remove system references
source = fix_include_directives(source)
-
+
# Step 3: Process with gcc c-preprocessor
include_base_directory = Path(source_path.parent, f"include_{source_path.stem}")
if not include_base_directory.exists():
@@ -75,13 +75,13 @@ def preprocess(
source = run_c_preprocessor(source, include_base_directory)
if out_gcc:
gcc_path.write_text(source)
-
+
# Step 4: Prepare for tree-sitter
# This step removes any additional preprocessor directives added or not removed by GCC
source = "\n".join(
["!" + line if line.startswith("#") else line for line in source.splitlines()]
)
-
+
# Step 5: Check for unsupported idioms
if out_unsupported:
unsupported_path.write_text(
@@ -173,7 +173,7 @@ def fix_include_directives(source: str) -> str:
def run_c_preprocessor(source: str, include_base_path: Path) -> str:
"""Run the gcc c-preprocessor. Its run from the context of the include_base_path, so that it can find all included files"""
result = run(
- ["gcc", "-cpp", "-E", "-"],
+ ["gcc", "-cpp", "-E", "-x", "f95", "-"],
input=source,
text=True,
capture_output=True,
@@ -183,8 +183,14 @@ def run_c_preprocessor(source: str, include_base_path: Path) -> str:
return result.stdout
+def convert_assigned(source: str) -> str:
+ """Convered ASSIGNED GO TO to COMPUTED GO TO"""
+ pass
+
+
def convert_to_free_form(source: str) -> str:
"""If fixed-form Fortran source, convert to free-form"""
+
def validate_parse_tree(source: str) -> bool:
"""Parse source with tree-sitter and check if an error is returned."""
language = Language(INSTALLED_LANGUAGES_FILEPATH, "fortran")
@@ -204,7 +210,7 @@ def validate_parse_tree(source: str) -> bool:
)
if validate_parse_tree(free_source):
return free_source
-
+
return source
diff --git a/skema/program_analysis/CAST/fortran/ts2cast.py b/skema/program_analysis/CAST/fortran/ts2cast.py
index 542a7ece5be..8f26983ed32 100644
--- a/skema/program_analysis/CAST/fortran/ts2cast.py
+++ b/skema/program_analysis/CAST/fortran/ts2cast.py
@@ -10,6 +10,7 @@
from skema.program_analysis.CAST2FN.model.cast import (
Module,
SourceRef,
+ ModelBreak,
Assignment,
LiteralValue,
Var,
@@ -33,6 +34,7 @@
NodeHelper,
remove_comments,
get_children_by_types,
+ get_children_except_types,
get_first_child_by_type,
get_control_children,
get_non_control_children,
@@ -42,7 +44,20 @@
from skema.program_analysis.CAST.fortran.util import generate_dummy_source_refs
from skema.program_analysis.CAST.fortran.preprocessor.preprocess import preprocess
-from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH
+from skema.program_analysis.tree_sitter_parsers.build_parsers import (
+ INSTALLED_LANGUAGES_FILEPATH,
+)
+
+builtin_statements = set(
+ [
+ "read_statement",
+ "write_statement",
+ "rewind_statement",
+ "open_statement",
+ "print_statement",
+ ]
+)
+
class TS2CAST(object):
def __init__(self, source_file_path: str):
@@ -50,33 +65,30 @@ def __init__(self, source_file_path: str):
self.path = Path(source_file_path)
self.source_file_name = self.path.name
self.source = preprocess(self.path)
-
+
# Run tree-sitter on preprocessor output to generate parse tree
parser = Parser()
- parser.set_language(
- Language(
- INSTALLED_LANGUAGES_FILEPATH,
- "fortran"
- )
- )
+ parser.set_language(Language(INSTALLED_LANGUAGES_FILEPATH, "fortran"))
self.tree = parser.parse(bytes(self.source, "utf8"))
self.root_node = remove_comments(self.tree.root_node)
-
+
# Walking data
self.variable_context = VariableContext()
self.node_helper = NodeHelper(self.source, self.source_file_name)
# Start visiting
self.out_cast = self.generate_cast()
- #print(self.out_cast[0].to_json_str())
-
+ # print(self.out_cast[0].to_json_str())
+
def generate_cast(self) -> List[CAST]:
- '''Interface for generating CAST.'''
+ """Interface for generating CAST."""
modules = self.run(self.root_node)
- return [CAST([generate_dummy_source_refs(module)], "Fortran") for module in modules]
-
+ return [
+ CAST([generate_dummy_source_refs(module)], "Fortran") for module in modules
+ ]
+
def run(self, root) -> List[Module]:
- '''Top level visitor function. Will return between 1-3 Module objects.'''
+ """Top level visitor function. Will return between 1-3 Module objects."""
# A program can have between 1-3 modules
# 1. A module body
# 2. A program body
@@ -98,17 +110,18 @@ def run(self, root) -> List[Module]:
body.extend(child_cast)
elif isinstance(child_cast, AstNode):
body.append(child_cast)
- modules.append(Module(
- name=None,
- body=body,
- source_refs=[self.node_helper.get_source_ref(root)]
- ))
-
+ modules.append(
+ Module(
+ name=None,
+ body=body,
+ source_refs=[self.node_helper.get_source_ref(root)],
+ )
+ )
return modules
def visit(self, node: Node):
- if node.type in ["program", "module"] :
+ if node.type in ["program", "module"]:
return self.visit_module(node)
elif node.type == "internal_procedures":
return self.visit_internal_procedures(node)
@@ -126,7 +139,11 @@ def visit(self, node: Node):
return self.visit_identifier(node)
elif node.type == "name":
return self.visit_name(node)
- elif node.type in ["math_expression", "relational_expression"]:
+ elif node.type in [
+ "unary_expression",
+ "math_expression",
+ "relational_expression",
+ ]:
return self.visit_math_expression(node)
elif node.type in [
"number_literal",
@@ -137,11 +154,13 @@ def visit(self, node: Node):
return self.visit_literal(node)
elif node.type == "keyword_statement":
return self.visit_keyword_statement(node)
+ elif node.type in builtin_statements:
+ return self.visit_fortran_builtin_statement(node)
elif node.type == "extent_specifier":
return self.visit_extent_specifier(node)
- elif node.type == "do_loop_statement":
+ elif node.type in ["do_loop_statement"]:
return self.visit_do_loop_statement(node)
- elif node.type == "if_statement":
+ elif node.type in ["if_statement", "else_if_clause", "else_clause"]:
return self.visit_if_statement(node)
elif node.type == "logical_expression":
return self.visit_logical_expression(node)
@@ -153,9 +172,9 @@ def visit(self, node: Node):
return self._visit_passthrough(node)
def visit_module(self, node: Node) -> Module:
- '''Visitor for program and module statement. Returns a Module object'''
+ """Visitor for program and module statement. Returns a Module object"""
self.variable_context.push_context()
-
+
program_body = []
for child in node.children[1:-1]: # Ignore the start and end program statement
child_cast = self.visit(child)
@@ -163,17 +182,17 @@ def visit_module(self, node: Node) -> Module:
program_body.extend(child_cast)
elif isinstance(child_cast, AstNode):
program_body.append(child_cast)
-
+
self.variable_context.pop_context()
-
+
return Module(
- name=None, #TODO: Fill out name field
+ name=None, # TODO: Fill out name field
body=program_body,
- source_refs = [self.node_helper.get_source_ref(node)]
+ source_refs=[self.node_helper.get_source_ref(node)],
)
def visit_internal_procedures(self, node: Node) -> List[FunctionDef]:
- '''Visitor for internal procedures. Returns list of FunctionDef'''
+ """Visitor for internal procedures. Returns list of FunctionDef"""
internal_procedures = get_children_by_types(node, ["function", "subroutine"])
return [self.visit(procedure) for procedure in internal_procedures]
@@ -213,16 +232,20 @@ def visit_function_def(self, node):
self.variable_context.push_context()
# Top level statement node
- statement_node = get_children_by_types(node, ["subroutine_statement", "function_statement"])[0]
+
+ statement_node = get_children_by_types(
+ node, ["subroutine_statement", "function_statement"]
+ )[0]
+
name_node = get_first_child_by_type(statement_node, "name")
name = self.visit(
name_node
) # Visit the name node to add it to the variable context
# If this is a function, check for return type and return value
- intrinsic_type = None
- return_value = None
if node.type == "function":
+ intrinsic_type = None
+ return_value = None
signature_qualifiers = get_children_by_types(
statement_node, ["intrinsic_type", "function_result"]
)
@@ -235,20 +258,20 @@ def visit_function_def(self, node):
elif qualifier.type == "function_result":
return_value = self.visit(
get_first_child_by_type(qualifier, "identifier")
- ) # TODO: UPDATE NODES
- self.variable_context.add_return_value(return_value.val.name)
-
- # #TODO: What happens if function doesn't return anything?
- # If this is a function, and there is no explicit results variable, then we will assume the return value is the name of the function
- if not return_value:
- self.variable_context.add_return_value(
- self.node_helper.get_identifier(name_node)
- )
+ ).val
+ self.variable_context.add_return_value(return_value.name)
+
+ # NOTE: In the case of a function specifically, if there is no explicit return value, the return value will be the name of the function
+ # TODO: Should this be a node instead
+ if not return_value:
+ self.variable_context.add_return_value(
+ self.node_helper.get_identifier(name_node)
+ )
+ return_value = self.visit(name_node)
- # If funciton has both, then we also need to update the type of the return value in the variable context
- # It does not explicity have to be declared
- if return_value and intrinsic_type:
- self.variable_context.update_type(return_value.val.name, intrinsic_type)
+ # If funciton has both an explicit intrinsic type, then we also need to update the type of the return value in the variable context
+ if intrinsic_type:
+ self.variable_context.update_type(return_value.name, intrinsic_type)
# Generating the function arguments by walking the parameters node
func_args = []
@@ -301,17 +324,20 @@ def visit_function_call(self, node):
# A subroutine and function won't neccessarily have an arguments node.
# So we should be careful about trying to access it.
- function_node = get_children_by_types(node, ["unary_expression", "subroutine", "identifier", "derived_type_member_expression"])[0]
-
+ function_node = get_children_by_types(
+ node,
+ [
+ "unary_expression",
+ "subroutine",
+ "identifier",
+ "derived_type_member_expression",
+ ],
+ )[0]
if function_node.type == "derived_type_member_expression":
- func = Attribute(
- value=None,
- attr=None
- )
- return None
-
+ return self.visit_derived_type_member_expression(function_node)
+
arguments_node = get_first_child_by_type(node, "argument_list")
-
+
# If this is a unary expression (+foo()) the identifier will be nested.
# TODO: If this is a non '+' unary expression, how do we add it to the CAST?
if function_node.type == "unary_expression":
@@ -357,7 +383,9 @@ def visit_keyword_statement(self, node):
if node.type == "keyword_statement":
if "continue" in identifier or "go to" in identifier:
return self._visit_no_op(node)
-
+ if "exit" in identifier:
+ return ModelBreak(source_refs=[self.node_helper.get_source_ref(node)])
+
# In Fortran the return statement doesn't return a value (there is the obsolete "alternative return")
# We keep track of values that need to be returned in the variable context
return_values = self.variable_context.context_return_values[
@@ -365,30 +393,49 @@ def visit_keyword_statement(self, node):
] # TODO: Make function for this
if len(return_values) == 1:
- # TODO: Fix this case
value = self.variable_context.get_node(list(return_values)[0])
elif len(return_values) > 1:
value = LiteralValue(
value_type="Tuple",
- value=[
- Var(
- val=self.variable_context.get_node(ret),
- type=self.variable_context.get_type(ret),
- default_value=None,
- source_refs=None,
- )
- for ret in return_values
- ],
- source_code_data_type=None, # TODO: REFACTOR
+ value=[self.variable_context.get_node(ret) for ret in return_values],
+ source_code_data_type=None,
source_refs=None,
)
else:
- value = LiteralValue(val=None, type=None, source_refs=None)
+ value = LiteralValue(value=None, value_type=None, source_refs=None)
return ModelReturn(
value=value, source_refs=[self.node_helper.get_source_ref(node)]
)
+ def visit_fortran_builtin_statement(self, node):
+ """Visitor for Fortran keywords that are not classified as keyword_statement by tree-sitter"""
+ # All of the node types that fall into this category end with _statment.
+ # So the function name will be the node type with _statement removed (write, read, open, ...)
+ func = self.get_gromet_function_node(node.type.replace("_statement", ""))
+
+ arguments = []
+
+ return Call(
+ func=func,
+ arguments=arguments,
+ source_language="Fortran",
+ source_language_version=None,
+ source_refs=[self.node_helper.get_source_ref(node)],
+ )
+
+ def visit_print_statement(self, node):
+ func = self.get_gromet_function_node("print")
+
+ arguments = []
+
+ return Call(
+ func=func,
+ arguments=arguments,
+ source_language=None,
+ source_language_version=None,
+ )
+
def visit_use_statement(self, node):
# (use)
# (use)
@@ -410,7 +457,7 @@ def visit_use_statement(self, node):
alias=import_alias,
all=import_all,
symbol=None,
- source_refs=None,
+ source_refs=[self.node_helper.get_source_ref(node)],
)
else:
imports = []
@@ -448,24 +495,15 @@ def visit_do_loop_statement(self, node) -> Loop:
(body) ...
"""
- # First check for
- # TODO: Add do until Loop support
- while_statement_node = get_first_child_by_type(node, "while_statement")
- if while_statement_node:
+ loop_control_node = get_first_child_by_type(node, "loop_contrel_expression")
+ if not loop_control_node:
return self._visit_while(node)
# If there is a loop control expression, the first body node will be the node after the loop_control_expression
# It is valid Fortran to have a single itteration do loop as well.
- # TODO: Add support for single itteration do-loop
# NOTE: This code is for the creation of the main body. The do loop will still add some additional nodes at the end of this body.
- body = []
body_start_index = 1 + get_first_child_index(node, "loop_control_expression")
- for body_node in node.children[body_start_index:]:
- child_cast = self.visit(body_node)
- if isinstance(child_cast, List):
- body.extend(child_cast)
- elif isinstance(child_cast, AstNode):
- body.append(child_cast)
+ body = self.generate_cast_body(node.children[body_start_index:])
# For the init and expression fields, we first need to determine if we are in a regular "do" or a "do while" loop
# PRE:
@@ -580,101 +618,66 @@ def visit_if_statement(self, node):
# (else_clause)
# (end_if_statement)
- if_condition = self.visit(get_first_child_by_type(node, "parenthesized_expression"))
-
- child_types = [child.type for child in node.children]
-
- try:
- elseif_index = child_types.index("elseif_clause")
- except ValueError:
- elseif_index = -1
+ # TODO: Can you have a parenthesized expression as a body node
+ body_nodes = get_children_except_types(
+ node,
+ [
+ "if",
+ "elseif",
+ "else",
+ "then",
+ "parenthesized_expression",
+ "elseif_clause",
+ "else_clause",
+ "end_if_statement",
+ ],
+ )
+ body = self.generate_cast_body(body_nodes)
- try:
- else_index = child_types.index("else_clause")
- except ValueError:
- else_index = -1
+ expr_node = get_first_child_by_type(node, "parenthesized_expression")
+ expr = None
+ if expr_node:
+ expr = self.visit(expr_node)
- if elseif_index != -1:
- body_stop_index = elseif_index
- else:
- body_stop_index = else_index
+ elseif_nodes = get_children_by_types(node, ["elseif_clause"])
+ elseif_cast = [self.visit(elseif_clause) for elseif_clause in elseif_nodes]
+ for i in range(len(elseif_cast) - 1):
+ elseif_cast[i].orelse = [elseif_cast[i + 1]]
- # Single line if conditions don't have a 'then' or 'end if' clause.
- # So the starting index for the body can either be 2 or 3.
- then_index = get_first_child_index(node, "then")
- if then_index:
- body_start_index = then_index+1
- else:
- body_start_index = 2
- body_stop_index = len(node.children)
-
- prev = None
- orelse = None
- # If there are else_if statements, they need
- if elseif_index != -1:
- orelse = ModelIf()
- prev = orelse
- for condition in node.children[elseif_index:else_index]:
- if condition.type == "comment":
- continue
- elseif_expr = self.visit(condition.children[2])
- elseif_body = [self.visit(child) for child in condition.children[4:]]
-
- prev.orelse = ModelIf(elseif_expr, elseif_body, [])
- prev = prev.orelse
-
- if else_index != -1:
- else_body = [
- self.visit(child) for child in node.children[else_index].children[1:]
- ]
- if prev:
- prev.orelse = else_body
- else:
- orelse = else_body
+ else_node = get_first_child_by_type(node, "else_clause")
+ else_cast = None
+ if else_node:
+ else_cast = self.visit(else_node)
- # TODO: This orelse logic has gotten a little complex, we might want to refactor this.
- if isinstance(orelse, ModelIf):
- orelse = orelse.orelse
- if orelse:
- if isinstance(orelse, ModelIf):
- orelse = [orelse]
+ orelse = []
+ if len(elseif_cast) > 0:
+ orelse = [elseif_cast[0]]
+ elif else_cast:
+ orelse = else_cast.body
- body = []
- for child in node.children[body_start_index:body_stop_index]:
- child_cast = self.visit(child)
- if isinstance(child_cast, AstNode):
- body.append(child_cast)
- elif isinstance(child_cast, List):
- body.extend(child_cast)
-
- return ModelIf(
- expr=self.visit(node.children[1]),
- body=body,
- orelse=orelse if orelse else [],
- )
+ return ModelIf(expr=expr, body=body, orelse=orelse)
def visit_logical_expression(self, node):
"""Visitior for logical expression (i.e. true and false) which is used in compound conditional"""
# If this is a .not. operator, we need to pass it on to the math_expression visitor
if len(node.children) < 3:
return self.visit_math_expression(node)
-
+
literal_value_false = LiteralValue("Boolean", False)
literal_value_true = LiteralValue("Boolean", True)
-
+
# AND: Right side goes in body if, left side in condition
- # OR: Right side goes in body else, left side in condition
+ # OR: Right side goes in body else, left side in condition
left, operator, right = node.children
-
+
# First we need to check if this is logical and or a logical or
# The tehcnical types for these are \.or\. and \.and\. so to simplify things we can use the in keyword
- is_or = "or" in operator.type
-
- top_if = ModelIf()
+ is_or = "or" in operator.type
+ top_if = ModelIf()
top_if_expr = self.visit(left)
top_if.expr = top_if_expr
-
+
bottom_if_expr = self.visit(right)
if is_or:
top_if.orelse = [bottom_if_expr]
@@ -786,6 +789,10 @@ def visit_math_expression(self, node):
for operand in get_non_control_children(node):
operands.append(self.visit(operand))
+ # For operators, we will only need the name node since we are not allocating space
+ if operand.type == "identifier":
+ operands[-1] = operands[-1].val
+
return Operator(
source_language="Fortran",
interpreter=None,
@@ -822,13 +829,15 @@ def visit_variable_declaration(self, node) -> List:
type_map = {
"integer": "Integer",
"real": "AbstractFloat",
- "double precision": None,
- "complex": None,
+ "double precision": "AbstractFloat",
+ "complex": "Tuple", # Complex is a Tuple (rational,irrational),
"logical": "Boolean",
"character": "String",
}
# NOTE: Identifiers are case sensitive, so we always need to make sure we are comparing to the lower() version
- variable_type = type_map[self.node_helper.get_identifier(intrinsic_type_node).lower()]
+ variable_type = type_map[
+ self.node_helper.get_identifier(intrinsic_type_node).lower()
+ ]
elif derived_type_node:
variable_type = self.node_helper.get_identifier(
get_first_child_by_type(derived_type_node, "type_name", recurse=True),
@@ -866,15 +875,11 @@ def visit_variable_declaration(self, node) -> List:
)
),
right=self.visit(variable.children[2]),
- source_refs=[
- self.node_helper.get_source_ref(variable)
- ],
+ source_refs=[self.node_helper.get_source_ref(variable)],
)
)
- vars[-1].left.type = "dimension"
- self.variable_context.update_type(
- vars[-1].left.val.name, "dimension"
- )
+ vars[-1].left.type = "List"
+ self.variable_context.update_type(vars[-1].left.val.name, "List")
else:
# If its a regular assignment, we can update the type normally
vars.append(self.visit(variable))
@@ -892,8 +897,8 @@ def visit_variable_declaration(self, node) -> List:
# Declaring a dimension variable using the x(1:5) format. It will look like a call expression in tree-sitter.
# We treat it like an identifier by visiting its identifier node. Then the type gets overridden by "dimension"
vars.append(self.visit(get_first_child_by_type(variable, "identifier")))
- vars[-1].type = "dimension"
- self.variable_context.update_type(vars[-1].val.name, "dimension")
+ vars[-1].type = "List"
+ self.variable_context.update_type(vars[-1].val.name, "List")
# By default, all variables are added to a function's list of return values
# If the intent is actually in, then we need to remove them from the list
@@ -964,7 +969,7 @@ def visit_derived_type(self, node: Node) -> RecordDef:
# If we tell the variable context we are in a record definition, it will append the type name as a prefix to all defined variables.
self.variable_context.enter_record_definition(record_name)
- # TODO: Full support for this requires handling the contains statement generally
+ # Note:
funcs = []
derived_type_procedures_node = get_first_child_by_type(
node, "derived_type_procedures"
@@ -1034,15 +1039,17 @@ def visit_derived_type_member_expression(self, node) -> Attribute:
else:
# We shouldn't be accessing get_node directly, since it may not exist in the case of an import.
# Instead, we should visit the identifier node which will add it to the variable context automatically if it doesn't exist.
- value = self.visit(get_first_child_by_type(node, "identifier", recurse=True))
+ value = self.visit(
+ get_first_child_by_type(node, "identifier", recurse=True)
+ )
# NOTE: Attribue should be a Name node, NOT a string or Var node
- #attr = self.node_helper.get_identifier(
+ # attr = self.node_helper.get_identifier(
# get_first_child_by_type(node, "type_member", recurse=True)
- #)
- #print(self.node_helper.get_identifier(get_first_child_by_type(node, "type_member", recurse=True)))
+ # )
+ # print(self.node_helper.get_identifier(get_first_child_by_type(node, "type_member", recurse=True)))
attr = self.visit_name(get_first_child_by_type(node, "type_member"))
-
+
return Attribute(
value=value,
attr=attr,
@@ -1118,23 +1125,27 @@ def _visit_while(self, node) -> Loop:
"""
while_statement_node = get_first_child_by_type(node, "while_statement")
- # The first body node will be the node after the while_statement
- body = []
- body_start_index = 1 + get_first_child_index(node, "while_statement")
- for body_node in node.children[body_start_index:]:
- child_cast = self.visit(body_node)
- if isinstance(child_cast, List):
- body.extend(child_cast)
- elif isinstance(child_cast, AstNode):
- body.append(child_cast)
+ # Fortran has certain while(True) constructs that won't contain a while_statement node
+ if not while_statement_node:
+ body_start_index = 0
+ expr = LiteralValue(
+ value_type="Boolean",
+ value="True",
+ )
+ else:
+ body_start_index = 1 + get_first_child_index(node, "while_statement")
+ # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly.
+ expr = self.visit(
+ get_first_child_by_type(
+ while_statement_node, "parenthesized_expression"
+ )
+ )
- # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly.
- expr = self.visit(
- get_first_child_by_type(while_statement_node, "parenthesized_expression")
- )
+ # The first body node will be the node after the while_statement
+ body = self.generate_cast_body(node.children[body_start_index:])
return Loop(
- pre=[], # TODO: Should pre and post contain anything?
+ pre=[],
expr=expr,
body=body,
post=[],
@@ -1186,9 +1197,9 @@ def _visit_no_op(self, node):
func=self.get_gromet_function_node("no_op"),
source_language=None,
source_language_version=None,
- arguments=[]
+ arguments=[],
)
-
+
def get_gromet_function_node(self, func_name: str) -> Name:
# Idealy, we would be able to create a dummy node and just call the name visitor.
# However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here.
@@ -1196,3 +1207,19 @@ def get_gromet_function_node(self, func_name: str) -> Name:
return self.variable_context.get_node(func_name)
return self.variable_context.add_variable(func_name, "function", None)
+
+ def generate_cast_body(self, body_nodes: List):
+ body = []
+ for node in body_nodes:
+ cast = self.visit(node)
+ if isinstance(cast, AstNode):
+ body.append(cast)
+ elif isinstance(cast, List):
+ body.extend(cast)
+
+ # Gromet doesn't support empty bodies, so we should create a no_op instead
+ if len(body) == 0:
+ body.append(self._visit_no_op(None))
+
+ # TODO: How to add more support for source references
+ return body
diff --git a/skema/program_analysis/CAST/fortran/util.py b/skema/program_analysis/CAST/fortran/util.py
index 07c12ee809c..b4a9bc72d39 100644
--- a/skema/program_analysis/CAST/fortran/util.py
+++ b/skema/program_analysis/CAST/fortran/util.py
@@ -1,13 +1,15 @@
from typing import List
from skema.program_analysis.CAST2FN.model.cast import AstNode, LiteralValue, SourceRef
+DUMMY_SOURCE_REF = [SourceRef("", -1, -1, -1, -1)]
+DUMMY_SOURCE_CODE_DATA_TYPE = ["Fortran", "Fotran95", "None"]
def generate_dummy_source_refs(node: AstNode) -> AstNode:
"""Walks a tree of AstNodes replacing any null SourceRefs with a dummy value"""
if isinstance(node, LiteralValue) and not node.source_code_data_type:
- node.source_code_data_type = ["Fortran", "Fotran95", "None"]
+ node.source_code_data_type = DUMMY_SOURCE_CODE_DATA_TYPE
if not node.source_refs:
- node.source_refs = [SourceRef("", -1, -1, -1, -1)]
+ node.source_refs = DUMMY_SOURCE_REF
for attribute_str in node.attribute_map:
attribute = getattr(node, attribute_str)
diff --git a/skema/program_analysis/CAST/fortran/variable_context.py b/skema/program_analysis/CAST/fortran/variable_context.py
index eca184bee97..cfdd8a69b69 100644
--- a/skema/program_analysis/CAST/fortran/variable_context.py
+++ b/skema/program_analysis/CAST/fortran/variable_context.py
@@ -8,14 +8,17 @@ class VariableContext(object):
def __init__(self):
self.context = [{}] # Stack of context dictionaries
self.context_return_values = [set()] # Stack of context return values
+
+ # All symbols will use a seperate naming convention to prevent two scopes using the same symbol name
+ # The name will be a dot notation list of scopes i.e. scope1.scope2.symbol
+ self.all_symbols_scopes = []
self.all_symbols = {}
- self.record_definitions = {}
-
+
# The prefix is used to handle adding Record types to the variable context.
# This gives each symbol a unqique name. For example "a" would become "type_name.a"
# For nested type definitions (derived type in a module), multiple prefixes can be added.
self.prefix = []
-
+
# Flag neccessary to declare if a function is internal or external
self.internal = False
@@ -36,7 +39,7 @@ def push_context(self):
def pop_context(self):
"""Pop the current variable context off of the stack and remove any references to those symbols."""
-
+
# If the internal flag is set, then all new scopes will use the top-level context
if self.internal:
return None
@@ -45,7 +48,10 @@ def pop_context(self):
# Remove symbols from all_symbols variable
for symbol in context:
- self.all_symbols.pop(symbol)
+ if isinstance(self.all_symbols[symbol], List):
+ self.all_symbols[symbol].pop()
+ else:
+ self.all_symbols.pop(symbol)
self.context_return_values.pop()
@@ -68,8 +74,13 @@ def add_variable(self, symbol: str, type: str, source_refs: List) -> Name:
}
# Add reference to all_symbols
- self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name]
-
+ if full_symbol_name in self.all_symbols:
+ if isinstance(self.all_symbols[full_symbol_name], List):
+ self.all_symbols[full_symbol_name].append(self.context[-1][full_symbol_name])
+ else:
+ self.all_symbols[full_symbol_name] = [self.all_symbols[full_symbol_name], self.context[-1][full_symbol_name]]
+ else:
+ self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name]
return cast_name
def is_variable(self, symbol: str) -> bool:
@@ -77,16 +88,25 @@ def is_variable(self, symbol: str) -> bool:
return symbol in self.all_symbols
def get_node(self, symbol: str) -> Dict:
+ if isinstance(self.all_symbols[symbol], List):
+ return self.all_symbols[symbol][-1]["node"]
+
return self.all_symbols[symbol]["node"]
def get_type(self, symbol: str) -> str:
+ if isinstance(self.all_symbols[symbol], List):
+ return self.all_symbols[symbol][-1]["type"]
+
return self.all_symbols[symbol]["type"]
def update_type(self, symbol: str, type: str):
"""Update the type associated with a given symbol"""
# Generate the full symbol name using the prefix
full_symbol_name = ".".join(self.prefix + [symbol])
- self.all_symbols[full_symbol_name]["type"] = type
+ if isinstance(self.all_symbols[full_symbol_name], List):
+ self.all_symbols[full_symbol_name][-1]["type"] = type
+ else:
+ self.all_symbols[full_symbol_name]["type"] = type
def add_return_value(self, symbol):
self.context_return_values[-1].add(symbol)
diff --git a/skema/program_analysis/CAST/matlab/__init__.py b/skema/program_analysis/CAST/matlab/__init__.py
index 3c4351b21c3..e69de29bb2d 100644
--- a/skema/program_analysis/CAST/matlab/__init__.py
+++ b/skema/program_analysis/CAST/matlab/__init__.py
@@ -1,3 +0,0 @@
-__pdoc__ = {
- 'tests': False
-}
diff --git a/skema/program_analysis/CAST/matlab/matlab_to_cast.py b/skema/program_analysis/CAST/matlab/matlab_to_cast.py
index abb7408ff19..da171083daf 100644
--- a/skema/program_analysis/CAST/matlab/matlab_to_cast.py
+++ b/skema/program_analysis/CAST/matlab/matlab_to_cast.py
@@ -110,11 +110,10 @@ def visit(self, node):
]:return self.visit_identifier(node)
elif node.type == "if_statement":
return self.visit_if_statement(node)
-# elif node.type in [
-# "for_statement",
-# "iterator",
-# "while_statement"
-# ]: return self.visit_loop(node)
+ elif node.type == "iterator":
+ return self.visit_iterator(node)
+ elif node.type == "for_statement":
+ return self.visit_for_statement(node)
elif node.type in [
"cell",
"matrix"
@@ -139,6 +138,8 @@ def visit(self, node):
]: return self.visit_operator(node)
elif node.type == "string":
return self.visit_string(node)
+ elif node.type == "range":
+ return self.visit_range(node)
elif node.type == "switch_statement":
return self.visit_switch_statement(node)
else:
@@ -155,15 +156,16 @@ def visit_assignment(self, node):
def visit_boolean(self, node):
""" Translate Tree-sitter boolean node """
+ value_type = "Boolean"
for child in node.children:
# set the first letter to upper case for python
value = child.type
value = value[0].upper() + value[1:].lower()
# store as string, use Python Boolean capitalization.
return LiteralValue(
- value_type="Boolean",
+ value_type=value_type,
value = value,
- source_code_data_type=["matlab", MATLAB_VERSION, "boolean"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(node)],
)
@@ -210,7 +212,6 @@ def visit_identifier(self, node):
val = self.visit_name(node),
type = self.variable_context.get_type(identifier) if
self.variable_context.is_variable(identifier) else "Unknown",
- default_value = "LiteralValue",
source_refs = [self.node_helper.get_source_ref(node)],
)
@@ -248,12 +249,135 @@ def get_conditional(conditional_node):
return first
- # General loop translator for all MATLAB loop types
- # def visit_loop(self, node) -> Loop:
- # """ Translate Tree-sitter for_loop node into CAST Loop node """
- # return Loop (
- # source_refs = [self.node_helper.get_source_ref(node)]
- # )
+ # CAST has no Iterator node, so we return a partially
+ # completed Loop object
+ # MATLAB iterators are either matrices or ranges.
+ def visit_iterator(self, node) -> Loop:
+
+ itr_var = self.visit(get_first_child_by_type(node, "identifier"))
+ source_ref = self.node_helper.get_source_ref(node)
+
+ # process matrix iterator
+ matrix_node = get_first_child_by_type(node, "matrix")
+ if matrix_node is not None:
+ row_node = get_first_child_by_type(matrix_node, "row")
+ if row_node is not None:
+ mat = [self.visit(child) for child in
+ get_keyword_children(row_node)]
+ mat_idx = 0
+ mat_len = len(mat)
+
+
+ return Loop(
+ pre = [
+ Assignment(
+ left = "_mat",
+ right = mat,
+ source_refs = [source_ref]
+ ),
+ Assignment(
+ left = "_mat_len",
+ right = mat_len,
+ source_refs = [source_ref]
+ ),
+ Assignment(
+ left = "_mat_idx",
+ right = mat_idx,
+ source_refs = [source_ref]
+ ),
+ Assignment(
+ left = itr_var,
+ right = mat[mat_idx],
+ source_refs = [source_ref]
+ )
+ ],
+ expr = self.get_operator(
+ op = "<",
+ operands = ["_mat_idx", "_mat_len"],
+ source_refs = [source_ref]
+ ),
+ body = [
+ Assignment(
+ left = "_mat_idx",
+ right = self.get_operator(
+ op = "+",
+ operands = ["_mat_idx", 1],
+ source_refs = [source_ref]
+ ),
+ source_refs = [source_ref]
+ ),
+ Assignment(
+ left = itr_var,
+ right = "_mat[_mat_idx]",
+ source_refs = [source_ref]
+ )
+ ],
+ post = []
+ )
+
+
+
+ # process range iterator
+ range_node = get_first_child_by_type(node, "range")
+ if range_node is not None:
+ numbers = [self.visit(child) for child in
+ get_children_by_types(range_node, ["number"])]
+ start = numbers[0]
+ step = 1
+ stop = 0
+ if len(numbers) == 2:
+ stop = numbers[1]
+
+ elif len(numbers) == 3:
+ step = numbers[1]
+ stop = numbers[2]
+
+ range_name_node = self.variable_context.get_gromet_function_node("range")
+ iter_name_node = self.variable_context.get_gromet_function_node("iter")
+ next_name_node = self.variable_context.get_gromet_function_node("next")
+ generated_iter_name_node = self.variable_context.generate_iterator()
+ stop_condition_name_node = self.variable_context.generate_stop_condition()
+
+ return Loop(
+ pre = [
+ Assignment(
+ left = itr_var,
+ right = start,
+ source_refs = [source_ref]
+ )
+ ],
+ expr = self.get_operator(
+ op = "<=",
+ operands = [itr_var, stop],
+ source_refs = [source_ref]
+ ),
+ body = [
+ Assignment(
+ left = itr_var,
+ right = self.get_operator(
+ op = "+",
+ operands = [itr_var, step],
+ source_refs = [source_ref]
+ ),
+ source_refs = [source_ref]
+ )
+ ],
+ post = []
+ )
+
+
+ def visit_range(self, node):
+ return None
+
+ def visit_for_statement(self, node) -> Loop:
+ """ Translate Tree-sitter for loop node into CAST Loop node """
+
+ loop = self.visit(get_first_child_by_type(node, "iterator"))
+ loop.source_refs=[self.node_helper.get_source_ref(node)]
+ loop.body = self.get_block(node) + loop.body
+
+ return loop
+
def visit_matrix(self, node):
""" Translate the Tree-sitter cell node into a List """
@@ -271,10 +395,11 @@ def get_values(element, ret)-> List:
if len(values) > 0:
value = values[0]
+ value_type="List",
return LiteralValue(
- value_type="List",
+ value_type=value_type,
value = value,
- source_code_data_type=["matlab", MATLAB_VERSION, "matrix"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(node)],
)
@@ -314,41 +439,41 @@ def visit_number(self, node) -> LiteralValue:
literal_value = self.node_helper.get_identifier(node)
# Check if this is a real value, or an Integer
if "e" in literal_value.lower() or "." in literal_value:
+ value_type = "AbstractFloat"
return LiteralValue(
- value_type="AbstractFloat",
+ value_type=value_type,
value=float(literal_value),
- source_code_data_type=["matlab", MATLAB_VERSION, "real"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(node)]
)
+ value_type = "Integer"
return LiteralValue(
- value_type="Integer",
+ value_type=value_type,
value=int(literal_value),
- source_code_data_type=["matlab", MATLAB_VERSION, "integer"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(node)]
)
def visit_operator(self, node):
- """return an Operator based on the Tree-sitter node """
+ """return an operator based on the Tree-sitter node """
# The operator will be the first control character
op = self.node_helper.get_identifier(
get_control_children(node)[0]
)
# the operands will be the keyword children
operands=[self.visit(child) for child in get_keyword_children(node)]
- return Operator(
- source_language="matlab",
- interpreter=INTERPRETER,
- version=MATLAB_VERSION,
+ return self.get_operator(
op = op,
operands = operands,
source_refs=[self.node_helper.get_source_ref(node)],
)
def visit_string(self, node):
+ value_type = "Character"
return LiteralValue(
- value_type="Character",
+ value_type=value_type,
value=self.node_helper.get_identifier(node),
- source_code_data_type=["matlab", MATLAB_VERSION, "character"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(node)]
)
@@ -363,51 +488,55 @@ def visit_switch_statement(self, node):
"string",
"unary_operator"
]
-
- def get_operator(op, operands, source_refs):
- """ return an Operator representing the case test """
- return Operator(
- source_language = "matlab",
- interpreter = INTERPRETER,
- version = MATLAB_VERSION,
- op = op,
- operands = operands,
- source_refs = source_refs
- )
- def get_case_expression(case_node, identifier):
- """ return an Operator representing the case test """
+ def get_case_expression(case_node, switch_var):
+ """ return an operator representing the case test """
source_refs=[self.node_helper.get_source_ref(case_node)]
cell_node = get_first_child_by_type(case_node, "cell")
# multiple case arguments
if (cell_node):
+ value_type="List",
operand = LiteralValue(
- value_type="List",
+ value_type=value_type,
value = self.visit(cell_node),
- source_code_data_type=["matlab", MATLAB_VERSION, "unknown"],
+ source_code_data_type=["matlab", MATLAB_VERSION, value_type],
source_refs=[self.node_helper.get_source_ref(cell_node)]
)
- return get_operator("in", [identifier, operand], source_refs)
+ return self.get_operator(
+ op = "in",
+ operands = [switch_var, operand],
+ source_refs = source_refs
+ )
# single case argument
operand = [self.visit(node) for node in
get_children_by_types(case_node, case_node_types)][0]
- return get_operator("==", [identifier, operand], source_refs)
+ return self.get_operator(
+ op = "==",
+ operands = [switch_var, operand],
+ source_refs = source_refs
+ )
- def get_model_if(case_node, identifier):
+ def get_model_if(case_node, switch_var):
""" return conditional logic representing the case """
return ModelIf(
- expr = get_case_expression(case_node, identifier),
+ expr = get_case_expression(case_node, switch_var),
body = self.get_block(case_node),
orelse = [],
source_refs=[self.node_helper.get_source_ref(case_node)]
)
- # switch statement identifier
- identifier = self.visit(get_first_child_by_type(node, "identifier"))
-
+ # switch variable is usually an identifier
+ switch_var = get_first_child_by_type(node, "identifier")
+ if switch_var is not None:
+ switch_var = self.visit(switch_var)
+
+ # however it can be a function call
+ else:
+ switch_var = self.visit(get_first_child_by_type(node, "function_call"))
+
# n case clauses as 'if then' nodes
case_nodes = get_children_by_types(node, ["case_clause"])
- model_ifs = [get_model_if(node, identifier) for node in case_nodes]
+ model_ifs = [get_model_if(node, switch_var) for node in case_nodes]
for i, model_if in enumerate(model_ifs[1:]):
model_ifs[i].orelse = [model_if]
@@ -426,6 +555,21 @@ def get_block(self, node) -> List[AstNode]:
return [self.visit(child) for child in
get_keyword_children(block)]
+ def get_operator(self, op, operands, source_refs):
+ """ return an operator representing the arguments """
+ return Operator(
+ source_language = "matlab",
+ interpreter = INTERPRETER,
+ version = MATLAB_VERSION,
+ op = op,
+ operands = operands,
+ source_refs = source_refs
+ )
+
+ def get_gromet_function_node(self, func_name: str) -> Name:
+ if self.variable_context.is_variable(func_name):
+ return self.variable_context.get_node(func_name)
+
# skip control nodes and other junk
def _visit_passthrough(self, node):
if len(node.children) == 0:
diff --git a/skema/program_analysis/CAST/matlab/tests/__init__.py b/skema/program_analysis/CAST/matlab/tests/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/skema/program_analysis/CAST/matlab/tests/test_assignment.py b/skema/program_analysis/CAST/matlab/tests/test_assignment.py
index 90393c927bd..633e4dce64d 100644
--- a/skema/program_analysis/CAST/matlab/tests/test_assignment.py
+++ b/skema/program_analysis/CAST/matlab/tests/test_assignment.py
@@ -7,9 +7,8 @@
def test_boolean():
""" Test assignment of literal boolean types. """
# we translate these MATLAB keywords into capitalized strings for Python
- nodes = cast("x = true; y = false")
- check(nodes[0], Assignment(left = "x", right = "True"))
- check(nodes[1], Assignment(left = "y", right = "False"))
+ check(cast("x = true")[0], Assignment(left = "x", right = "True"))
+ check(cast("y = false")[0], Assignment(left = "y", right = "False"))
def test_number_zero_integer():
""" Test assignment of integer and real numbers."""
@@ -42,9 +41,20 @@ def test_identifier():
def test_operator():
""" Test assignment of operator"""
check(
- cast("Vtot = V1PF+V1AZ;")[0],
+ cast("x = x + 1")[0],
Assignment(
- left = "Vtot",
- right = Operator(op = "+",operands = ["V1PF", "V1AZ"])
+ left = "x",
+ right = Operator(op = "+",operands = ["x", 1])
)
)
+
+def test_matrix():
+ """ Test assignment of matrix"""
+ check(
+ cast("x = [1 cat 'dog' ]")[0],
+ Assignment(
+ left = "x",
+ right = [1, 'cat', "'dog'"]
+ )
+ )
+
diff --git a/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py b/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py
index 4d32d563ff6..0d988111795 100644
--- a/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py
+++ b/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py
@@ -1,11 +1,18 @@
+import os.path
from skema.program_analysis.CAST.matlab.matlab_to_cast import MatlabToCast
from skema.program_analysis.CAST.matlab.tests.utils import (check, cast)
from skema.program_analysis.CAST2FN.model.cast import Assignment
def test_file_ingest():
""" Test the ability of the CAST translator to read from a file"""
- filename = "skema/program_analysis/CAST/matlab/tests/data/matlab.m"
- cast = MatlabToCast(source_path = filename).out_cast
+
+ filepath = "skema/program_analysis/CAST/matlab/tests/data/matlab.m"
+ if not os.path.exists(filepath):
+ filepath = "data/matlab.m"
+
+
+
+ cast = MatlabToCast(source_path = filepath).out_cast
module = cast.nodes[0]
nodes = module.body
check(nodes[0], Assignment(left = "y", right = "b"))
diff --git a/skema/program_analysis/CAST/matlab/tests/test_loop.py b/skema/program_analysis/CAST/matlab/tests/test_loop.py
index 354ce0582e8..dcd409691e8 100644
--- a/skema/program_analysis/CAST/matlab/tests/test_loop.py
+++ b/skema/program_analysis/CAST/matlab/tests/test_loop.py
@@ -1,13 +1,123 @@
from skema.program_analysis.CAST.matlab.tests.utils import (check, cast)
-from skema.program_analysis.CAST2FN.model.cast import Loop
+from skema.program_analysis.CAST2FN.model.cast import (
+ Assignment,
+ Call,
+ Loop,
+ Operator
+)
-# Test the for loop and others
-def no_test_for_loop():
+# Test the for loop incrementing by 1
+def test_implicit_step():
""" Test the MATLAB for loop syntax elements"""
source = """
- for n = 1:10
- x = do_something(n)
+ for n = 0:10
+ disp(n)
end
"""
nodes = cast(source)
- check(nodes[0], Loop())
+ check(nodes[0],
+ Loop(
+ pre = [Assignment(left = "n", right = 0)],
+ expr = Operator(op = "<=", operands = ["n", 10]),
+ body = [
+ Call(
+ func = "disp",
+ arguments = ["n"]
+ ),
+ Assignment(
+ left = "n",
+ right = Operator(
+ op = "+",
+ operands = ["n", 1]
+ )
+ )
+ ],
+ post = []
+ )
+ )
+
+# Test the for loop incrementing by n
+def test_explicit_step():
+ """ Test the MATLAB for loop syntax elements"""
+ source = """
+ for n = 0:2:10
+ disp(n)
+ end
+ """
+ nodes = cast(source)
+ check(nodes[0],
+ Loop(
+ pre = [Assignment(left = "n", right = 0)],
+ expr = Operator(op = "<=", operands = ["n", 10]),
+ body = [
+ Call(
+ func = "disp",
+ arguments = ["n"]
+ ),
+ Assignment(
+ left = "n",
+ right = Operator(
+ op = "+",
+ operands = ["n", 2]
+ )
+ )
+ ],
+ post = []
+ )
+ )
+
+
+
+
+# Test the for loop using matrix steps
+def test_matrix():
+ """ Test the MATLAB for loop syntax elements"""
+ source = """
+ for k = [10 3 5 6]
+ disp(k)
+ end
+ """
+ nodes = cast(source)
+ check(nodes[0],
+ Loop(
+ pre = [
+ Assignment(
+ left = "_mat",
+ right = [10, 3, 5, 6]
+ ),
+ Assignment(
+ left = "_mat_len",
+ right = 4
+ ),
+ Assignment(
+ left = "_mat_idx",
+ right = 0
+ ),
+ Assignment(
+ left = "k",
+ right = 10
+ )
+ ],
+ expr = Operator(op = "<", operands = ["_mat_idx", "_mat_len"]),
+ body = [
+ Call(
+ func = "disp",
+ arguments = ["k"]
+ ),
+ Assignment(
+ left = "_mat_idx",
+ right = Operator(
+ op = "+",
+ operands = ["_mat_idx", 1]
+ )
+ ),
+ Assignment(
+ left = "k",
+ right = "_mat[_mat_idx]"
+ )
+ ],
+ post = []
+
+ )
+ )
+
diff --git a/skema/program_analysis/CAST/matlab/tests/test_switch.py b/skema/program_analysis/CAST/matlab/tests/test_switch.py
index 07582d91f39..e01454cad95 100644
--- a/skema/program_analysis/CAST/matlab/tests/test_switch.py
+++ b/skema/program_analysis/CAST/matlab/tests/test_switch.py
@@ -1,11 +1,12 @@
from skema.program_analysis.CAST.matlab.tests.utils import (check, cast)
from skema.program_analysis.CAST2FN.model.cast import (
Assignment,
+ Call,
ModelIf,
Operator
)
-def test_case_clause_1_argument():
+def test_1_argument():
""" Test CAST from single argument case clause."""
source = """
switch s
@@ -37,7 +38,7 @@ def test_case_clause_1_argument():
)
)
-def test_case_clause_n_arguments():
+def test_n_arguments():
""" Test CAST from multipe argument case clause."""
source = """
@@ -60,3 +61,32 @@ def test_case_clause_n_arguments():
orelse = [Assignment(left="n", right = 0)]
)
)
+
+def test_call_argument():
+ """ Test CAST using the value of a function call """
+
+ source = """
+ switch fd(i,j)
+ case 0
+ x = 5
+ end
+
+ """
+ # switch statement translated into conditional
+ check(
+ cast(source)[0],
+ ModelIf(
+ expr = Operator(
+ op = "==",
+ operands = [
+ Call (
+ func = "fd",
+ arguments = ["i","j"]
+ ),
+ 0
+ ]
+ ),
+ body = [Assignment(left="x", right = 5)],
+ orelse = []
+ )
+ )
diff --git a/skema/program_analysis/CAST/matlab/tests/utils.py b/skema/program_analysis/CAST/matlab/tests/utils.py
index 8345715cee8..9375b95df56 100644
--- a/skema/program_analysis/CAST/matlab/tests/utils.py
+++ b/skema/program_analysis/CAST/matlab/tests/utils.py
@@ -13,6 +13,11 @@
Name,
Var
)
+from skema.program_analysis.CAST2FN.visitors.cast_to_agraph_visitor import (
+ CASTToAGraphVisitor,
+)
+from skema.program_analysis.CAST2FN.cast import CAST
+
def check(result, expected = None):
""" Test for match with the same datatypes. """
@@ -37,6 +42,11 @@ def check(result, expected = None):
check(result.expr, expected.expr)
check(result.body, expected.body)
check(result.orelse, expected.orelse)
+ elif isinstance(result, Loop):
+ check(result.pre, expected.pre)
+ check(result.expr, expected.expr)
+ check(result.body, expected.body)
+ check(result.post, expected.post)
elif isinstance(result, LiteralValue):
check(result.value, expected)
elif isinstance(result, Var):
@@ -48,16 +58,28 @@ def check(result, expected = None):
# every CAST node has a source_refs element
if isinstance(result, AstNode):
- assert not result.source_refs == None
+ assert result.source_refs is not None
# we curently produce a CAST object with a single Module in the nodes list.
def cast(source):
""" Return the MatlabToCast output """
# there should only be one CAST object in the cast output list
cast = MatlabToCast(source = source).out_cast
+ # the cast should be parsable
+ # assert validate(cast) == True
# there should be one module in the CAST object
assert len(cast.nodes) == 1
module = cast.nodes[0]
assert isinstance(module, Module)
# return the module body node list
return module.body
+
+def validate(cast):
+ """ Test that the cast can be parsed """
+ try:
+ foo = CASTToAGraphVisitor(cast)
+ foo.to_pdf("/dev/null")
+ return True
+ except:
+ return False
+
diff --git a/skema/program_analysis/CAST/matlab/tokens.py b/skema/program_analysis/CAST/matlab/tokens.py
index 7dea817f1cb..25b060f09e1 100644
--- a/skema/program_analysis/CAST/matlab/tokens.py
+++ b/skema/program_analysis/CAST/matlab/tokens.py
@@ -26,6 +26,7 @@
'function_arguments',
'function_call',
'function_definition',
+ 'function_output',
'identifier',
'if',
'if_statement',
@@ -46,20 +47,19 @@
'switch_statement',
'unary_operator',
- # keywords to be supported
+ # keywords currently being added
'break_statement',
'continue_statement',
- 'field_expression',
'for',
'for_statement',
- 'function_output',
'iterator',
+ 'range',
+
+ # keywords to be supported
+ 'field_expression',
'lambda',
'line_continuation',
'multioutput_variable',
- 'range',
- 'while',
- 'while_statement'
]
""" anything not a keyword """
diff --git a/skema/program_analysis/CAST/matlab/variable_context.py b/skema/program_analysis/CAST/matlab/variable_context.py
index 4bc8486b9f2..3d6db267e6a 100644
--- a/skema/program_analysis/CAST/matlab/variable_context.py
+++ b/skema/program_analysis/CAST/matlab/variable_context.py
@@ -72,3 +72,18 @@ def get_node(self, symbol: str) -> Dict:
def get_type(self, symbol: str) -> str:
return self.all_symbols[symbol]["type"]
+
+ def get_gromet_function_node(self, func_name: str) -> Name:
+ if self.is_variable(func_name):
+ return self.get_node(func_name)
+
+ def generate_iterator(self):
+ symbol = f"generated_iter_{self.iterator_id}"
+ self.iterator_id += 1
+ return self.add_variable(symbol, "iterator", None)
+
+ def generate_stop_condition(self):
+ symbol = f"sc_{self.stop_condition_id}"
+ self.stop_condition_id += 1
+ return self.add_variable(symbol, "boolean", None)
+
diff --git a/skema/program_analysis/CAST/python/node_helper.py b/skema/program_analysis/CAST/python/node_helper.py
index abef0b50bcc..5e66f0fb567 100644
--- a/skema/program_analysis/CAST/python/node_helper.py
+++ b/skema/program_analysis/CAST/python/node_helper.py
@@ -1,3 +1,4 @@
+import itertools
from typing import List, Dict
from skema.program_analysis.CAST2FN.model.cast import SourceRef
@@ -24,11 +25,58 @@
"not"
]
+# Whatever constructs we see in the left
+# part of the for loop construct
+# for LEFT in RIGHT:
+FOR_LOOP_LEFT_TYPES = [
+ "identifier",
+ "tuple_pattern",
+ "pattern_list",
+ "list_pattern"
+]
+
+# Whatever constructs we see in the right
+# part of the for loop construct
+# for LEFT in RIGHT:
+FOR_LOOP_RIGHT_TYPES = [
+ "call",
+ "identifier",
+ "list",
+ "tuple"
+]
+
+# Whatever constructs we see in the conditional
+# part of the while loop
+WHILE_COND_TYPES = [
+ "boolean_operator",
+ "call",
+ "comparison_operator"
+]
+
class NodeHelper():
def __init__(self, source: str, source_file_name: str):
self.source = source
self.source_file_name = source_file_name
+ # get_identifier optimization variables
+ self.source_lines = source.splitlines(keepends=True)
+ self.line_lengths = [len(line) for line in self.source_lines]
+ self.line_length_sums = [0] + list(itertools.accumulate(self.line_lengths))
+
+ def get_identifier(self, node: Node) -> str:
+ """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
+ start_line, start_column = node.start_point
+ end_line, end_column = node.end_point
+
+ # Edge case for when an identifier is on the very first line of the code
+ # We can't index into the line_length_sums
+ start_index = self.line_length_sums[start_line] + start_column
+ if start_line == end_line:
+ end_index = start_index + (end_column-start_column)
+ else:
+ end_index = self.line_length_sums[end_line] + end_column
+
+ return self.source[start_index:end_index]
def get_source_ref(self, node: Node) -> SourceRef:
"""Given a node and file name, return a CAST SourceRef object."""
@@ -36,30 +84,6 @@ def get_source_ref(self, node: Node) -> SourceRef:
row_end, col_end = node.end_point
return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end)
-
- def get_identifier(self, node: Node) -> str:
- """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
- line_num = 0
- column_num = 0
- in_identifier = False
- identifier = ""
- for i, char in enumerate(self.source):
- if line_num == node.start_point[0] and column_num == node.start_point[1]:
- in_identifier = True
- elif line_num == node.end_point[0] and column_num == node.end_point[1]:
- break
-
- if char == "\n":
- line_num += 1
- column_num = 0
- else:
- column_num += 1
-
- if in_identifier:
- identifier += char
-
- return identifier
-
def get_operator(self, node: Node) -> str:
"""Given a unary/binary operator node, return the operator it contains"""
return node.type
diff --git a/skema/program_analysis/CAST/python/ts2cast.py b/skema/program_analysis/CAST/python/ts2cast.py
index 720c1f40569..cf51f6b6e55 100644
--- a/skema/program_analysis/CAST/python/ts2cast.py
+++ b/skema/program_analysis/CAST/python/ts2cast.py
@@ -25,7 +25,8 @@
ModelIf,
RecordDef,
Attribute,
- ScalarType
+ ScalarType,
+ StructureType
)
from skema.program_analysis.CAST.python.node_helper import (
@@ -35,7 +36,10 @@
get_first_child_index,
get_last_child_index,
get_control_children,
- get_non_control_children
+ get_non_control_children,
+ FOR_LOOP_LEFT_TYPES,
+ FOR_LOOP_RIGHT_TYPES,
+ WHILE_COND_TYPES
)
from skema.program_analysis.CAST.python.util import (
generate_dummy_source_refs,
@@ -71,6 +75,9 @@ def __init__(self, source_file_path: str, from_file = True):
)
)
+ # Additional variables used in generation
+ self.var_count = 0
+
# Tree walking structures
self.variable_context = VariableContext()
self.node_helper = NodeHelper(self.source, self.source_file_name)
@@ -82,6 +89,7 @@ def __init__(self, source_file_path: str, from_file = True):
def generate_cast(self) -> List[CAST]:
'''Interface for generating CAST.'''
module = self.run(self.tree.root_node)
+ module.name = self.source_file_name
return CAST([generate_dummy_source_refs(module)], "Python")
def run(self, root) -> List[Module]:
@@ -107,16 +115,26 @@ def visit(self, node: Node):
return self.visit_return(node)
elif node.type == "call":
return self.visit_call(node)
+ elif node.type == "if_statement":
+ return self.visit_if_statement(node)
+ elif node.type == "comparison_operator":
+ return self.visit_comparison_op(node)
elif node.type == "assignment":
return self.visit_assignment(node)
elif node.type == "identifier":
return self.visit_identifier(node)
- elif node.type =="unary_operator":
+ elif node.type == "unary_operator":
return self.visit_unary_op(node)
- elif node.type =="binary_operator":
+ elif node.type == "binary_operator":
return self.visit_binary_op(node)
- elif node.type in ["integer"]:
+ elif node.type in ["integer", "list"]:
return self.visit_literal(node)
+ elif node.type in ["list_pattern", "pattern_list", "tuple_pattern"]:
+ return self.visit_pattern(node)
+ elif node.type == "while_statement":
+ return self.visit_while(node)
+ elif node.type == "for_statement":
+ return self.visit_for(node)
else:
return self._visit_passthrough(node)
@@ -220,6 +238,21 @@ def visit_call(self, node: Node) -> Call:
elif isinstance(cast, AstNode):
func_args.append(cast)
+ if func_name.val.name == "range":
+ start_step_value = LiteralValue(
+ ScalarType.INTEGER,
+ value="1",
+ source_code_data_type=["Python", PYTHON_VERSION, str(type(1))],
+ source_refs=[ref]
+ )
+ # Add a step value
+ if len(func_args) == 2:
+ func_args.append(start_step_value)
+ # Add a start and step value
+ elif len(func_args) == 1:
+ func_args.insert(0, start_step_value)
+ func_args.append(start_step_value)
+
# Function calls only want the 'Name' part of the 'Var' that the visit returns
return Call(
func=func_name.val,
@@ -227,6 +260,84 @@ def visit_call(self, node: Node) -> Call:
source_refs=[ref]
)
+ def visit_comparison_op(self, node: Node):
+ ref = self.node_helper.get_source_ref(node)
+ op = get_op(self.node_helper.get_operator(node.children[1]))
+ left, _, right = node.children
+
+ left_cast = get_name_node(self.visit(left))
+ right_cast = get_name_node(self.visit(right))
+
+ return Operator(
+ op=op,
+ operands=[left_cast, right_cast],
+ source_refs=[ref]
+ )
+
+ def visit_if_statement(self, node: Node) -> ModelIf:
+ if_condition = self.visit(get_first_child_by_type(node, "comparison_operator"))
+
+ # Get the body of the if true part
+ if_true = get_children_by_types(node, "block")[0].children
+
+ # Because in tree-sitter the else if, and else aren't nested, but they're
+ # in a flat level order, we need to do some arranging of the pieces
+ # in order to get the correct CAST nested structure that we use
+ # Visit all the alternatives, generate CAST for each one
+ # and then join them all together
+ alternatives = get_children_by_types(node, ["elif_clause","else_clause"])
+
+ if_true_cast = []
+ for node in if_true:
+ cast = self.visit(node)
+ if isinstance(cast, List):
+ if_true_cast.extend(cast)
+ elif isinstance(cast, AstNode):
+ if_true_cast.append(cast)
+
+ # If we have ts nodes in alternatives, then we're guaranteed
+ # at least an else at the end of the if-statement construct
+ # We generate the cast for the final else statement, and then
+ # reverse the rest of the if-elses that we have, so we can
+ # create the CAST correctly
+ final_else_cast = []
+ if len(alternatives) > 0:
+ final_else = alternatives.pop()
+ alternatives.reverse()
+ final_else_body = get_children_by_types(final_else, "block")[0].children
+ for node in final_else_body:
+ cast = self.visit(node)
+ if isinstance(cast, List):
+ final_else_cast.extend(cast)
+ elif isinstance(cast, AstNode):
+ final_else_cast.append(cast)
+
+ # We go through any additional if-else nodes that we may have,
+ # generating their ModelIf CAST and appending the tail of the
+ # overall if-else construct, starting with the else at the very end
+ # We do this tail appending so that when we finish generating CAST the
+ # resulting ModelIf CAST is in the correct order
+ alternatives_cast = None
+ for ts_node in alternatives:
+ assert ts_node.type == "elif_clause"
+ temp_cast = self.visit_if_statement(ts_node)
+ if alternatives_cast == None:
+ temp_cast.orelse = final_else_cast
+ else:
+ temp_cast.orelse = [alternatives_cast]
+ alternatives_cast = temp_cast
+
+ if alternatives_cast == None:
+ if_false_cast = final_else_cast
+ else:
+ if_false_cast = [alternatives_cast]
+
+ return ModelIf(
+ expr=if_condition,
+ body=if_true_cast,
+ orelse=if_false_cast,
+ source_refs=[self.node_helper.get_source_ref(node)]
+ )
def visit_assignment(self, node: Node) -> Assignment:
left, _, right = node.children
@@ -275,7 +386,6 @@ def visit_binary_op(self, node: Node) -> Operator:
Binary Ops
left OP right
where left and right can either be operators or literals
-
"""
ref = self.node_helper.get_source_ref(node)
op = get_op(self.node_helper.get_operator(node.children[1]))
@@ -290,6 +400,17 @@ def visit_binary_op(self, node: Node) -> Operator:
source_refs=[ref]
)
+ def visit_pattern(self, node: Node):
+ pattern_cast = []
+ for elem in node.children:
+ cast = self.visit(elem)
+ if isinstance(cast, List):
+ pattern_cast.extend(cast)
+ elif isinstance(cast, AstNode):
+ pattern_cast.append(cast)
+
+ return LiteralValue(value_type=StructureType.TUPLE, value=pattern_cast)
+
def visit_identifier(self, node: Node) -> Var:
identifier = self.node_helper.get_identifier(node)
@@ -336,6 +457,173 @@ def visit_literal(self, node: Node) -> Any:
source_code_data_type=["Python", PYTHON_VERSION, str(type(True))],
source_refs=[literal_source_ref]
)
+ elif literal_type == "list":
+ list_items = []
+ for elem in node.children:
+ cast = self.visit(elem)
+ if isinstance(cast, List):
+ list_items.extend(cast)
+ elif isinstance(cast, AstNode):
+ list_items.append(cast)
+
+ return LiteralValue(
+ value_type=StructureType.LIST,
+ value = list_items,
+ source_code_data_type=["Python", PYTHON_VERSION, str(type([0]))],
+ source_refs=[literal_source_ref]
+ )
+ elif literal_type == "tuple":
+ tuple_items = []
+ for elem in node.children:
+ cast = self.visit(cast)
+ if isinstance(cast, List):
+ tuple_items.extend(cast)
+ elif isinstance(cast, AstNode):
+ tuple_items.append(cast)
+
+ return LiteralValue(
+ value_type=StructureType.LIST,
+ value = tuple_items,
+ source_code_data_type=["Python", PYTHON_VERSION, str(type((0)))],
+ source_refs=[literal_source_ref]
+ )
+
+
+
+ def visit_while(self, node: Node) -> Loop:
+ ref = self.node_helper.get_source_ref(node)
+
+ # Push a variable context since a loop
+ # can create variables that only it can see
+ self.variable_context.push_context()
+
+ loop_cond_node = get_children_by_types(node, WHILE_COND_TYPES)[0]
+ loop_body_node = get_children_by_types(node, "block")[0].children
+
+ loop_cond = self.visit(loop_cond_node)
+
+ loop_body = []
+ for node in loop_body_node:
+ cast = self.visit(node)
+ if isinstance(cast, List):
+ loop_body.extend(cast)
+ elif isinstance(cast, AstNode):
+ loop_body.append(cast)
+
+ self.variable_context.pop_context()
+
+ return Loop(
+ pre=[],
+ expr=loop_cond,
+ body=loop_body,
+ post=[],
+ source_refs = ref
+ )
+
+ def visit_for(self, node: Node) -> Loop:
+ ref = self.node_helper.get_source_ref(node)
+
+ # Pre: left, right
+ loop_cond_left = get_children_by_types(node, FOR_LOOP_LEFT_TYPES)[0]
+ loop_cond_right = get_children_by_types(node, FOR_LOOP_RIGHT_TYPES)[-1]
+
+ # Construct pre and expr value using left and right as needed
+ # need calls to "_Iterator"
+
+ self.variable_context.push_context()
+ iterator_name = self.variable_context.generate_iterator()
+ stop_cond_name = self.variable_context.generate_stop_condition()
+ iter_func = self.get_gromet_function_node("iter")
+ next_func = self.get_gromet_function_node("next")
+
+ loop_cond_left_cast = self.visit(loop_cond_left)
+ loop_cond_right_cast = self.visit(loop_cond_right)
+
+ loop_pre = []
+ loop_pre.append(
+ Assignment(
+ left = Var(iterator_name, "Iterator"),
+ right = Call(
+ iter_func,
+ arguments=[loop_cond_right_cast]
+ )
+ )
+ )
+
+ loop_pre.append(
+ Assignment(
+ left=LiteralValue(
+ "Tuple",
+ [
+ loop_cond_left_cast,
+ Var(iterator_name, "Iterator"),
+ Var(stop_cond_name, "Boolean"),
+ ],
+ source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"],
+ source_refs=ref
+ ),
+ right=Call(
+ next_func,
+ arguments=[Var(iterator_name, "Iterator")],
+ ),
+ )
+
+ )
+
+ loop_expr = Operator(
+ source_language="Python",
+ interpreter="Python",
+ version=PYTHON_VERSION,
+ op="ast.Eq",
+ operands=[
+ stop_cond_name,
+ LiteralValue(
+ ScalarType.BOOLEAN,
+ False,
+ ["Python", PYTHON_VERSION, "boolean"],
+ source_refs=ref,
+ )
+ ],
+ source_refs=ref
+ )
+
+ loop_body_node = get_children_by_types(node, "block")[0].children
+ loop_body = []
+ for node in loop_body_node:
+ cast = self.visit(node)
+ if isinstance(cast, List):
+ loop_body.extend(cast)
+ elif isinstance(cast, AstNode):
+ loop_body.append(cast)
+
+ # Insert an additional call to 'next' at the end of the loop body,
+ # to facilitate looping in GroMEt
+ loop_body.append(
+ Assignment(
+ left=LiteralValue(
+ "Tuple",
+ [
+ loop_cond_left_cast,
+ Var(iterator_name, "Iterator"),
+ Var(stop_cond_name, "Boolean"),
+ ],
+ ),
+ right=Call(
+ next_func,
+ arguments=[Var(iterator_name, "Iterator")],
+ ),
+ )
+ )
+
+ self.variable_context.pop_context()
+ return Loop(
+ pre=loop_pre,
+ expr=loop_expr,
+ body=loop_body,
+ post=[],
+ source_refs = ref
+ )
+
def visit_name(self, node):
# First, we will check if this name is already defined, and if it is return the name node generated previously
@@ -355,6 +643,14 @@ def _visit_passthrough(self, node):
child_cast = self.visit(child)
if child_cast:
return child_cast
+
+ def get_gromet_function_node(self, func_name: str) -> Name:
+ # Idealy, we would be able to create a dummy node and just call the name visitor.
+ # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here.
+ if self.variable_context.is_variable(func_name):
+ return self.variable_context.get_node(func_name)
+
+ return self.variable_context.add_variable(func_name, "function", None)
def get_name_node(node):
# Given a CAST node, if it's type Var, then we extract the name node out of it
diff --git a/skema/program_analysis/CAST/python/util.py b/skema/program_analysis/CAST/python/util.py
index f315c44f2a4..ceb12c60a5e 100644
--- a/skema/program_analysis/CAST/python/util.py
+++ b/skema/program_analysis/CAST/python/util.py
@@ -26,6 +26,12 @@ def get_op(operator):
'-': 'ast.Sub',
'*': 'ast.Mult',
'/': 'ast.Div',
+ '==' : 'ast.Eq',
+ '!=' : 'ast.NotEq',
+ '<' : 'ast.Lt',
+ '<=' : 'ast.LtE',
+ '>' : 'ast.Gt',
+ '>=' : 'ast.GtE',
# ast.UAdd: 'ast.UAdd',
# ast.USub: 'ast.USub',
# ast.FloorDiv: 'ast.FloorDiv',
@@ -38,12 +44,6 @@ def get_op(operator):
# ast.BitXor: 'ast.BitXor',
# ast.And: 'ast.And',
# ast.Or: 'ast.Or',
- # ast.Eq: 'ast.Eq',
- # ast.NotEq: 'ast.NotEq',
- # ast.Lt: 'ast.Lt',
- # ast.LtE: 'ast.LtE',
- # ast.Gt: 'ast.Gt',
- # ast.GtE: 'ast.GtE',
# ast.In: 'ast.In',
# ast.NotIn: 'ast.NotIn',
# ast.Not: 'ast.Not',
diff --git a/skema/program_analysis/CAST/pythonAST/modules_list.py b/skema/program_analysis/CAST/pythonAST/modules_list.py
index a26cbfcdaf2..04f504a98c9 100644
--- a/skema/program_analysis/CAST/pythonAST/modules_list.py
+++ b/skema/program_analysis/CAST/pythonAST/modules_list.py
@@ -321,7 +321,12 @@ def find_func_in_module(module_name, func_name):
import sys
sys.path.append(os.getcwd())
- module_import = importlib.import_module(module_name)
+ # TODO: Support find_func_in_module for Fortran source code as well
+ try:
+ module_import = importlib.import_module(module_name)
+ except:
+ return False
+
funcs = list(dir(module_import))
return func_name in funcs
diff --git a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py
index 4e81daf28d7..b9247be87fb 100644
--- a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py
+++ b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py
@@ -84,20 +84,19 @@ def insert_gromet_object(t: list, obj):
If the table we're trying to insert into doesn't already exist, then we
first create it, and then insert the value.
"""
+
+ if t == None:
+ t = []
# Logic for generating port ids
if isinstance(obj, GrometPort):
- if t == None:
- obj.id = 1
- else:
- current_box = obj.box
- current_box_ports = [port for port in t if port.box == current_box]
- obj.id = len(current_box_ports) + 1
+ obj.id = 1
+ for port in reversed(t):
+ if port.box == obj.box:
+ obj.id = port.id + 1
+ break
- if t == None:
- t = []
t.append(obj)
-
return t
@@ -2521,6 +2520,7 @@ def visit_call(
from_assignment = False
from_call = False
from_operator = False
+ from_loop = False
func_name, qual_func_name = get_func_name(node)
if isinstance(parent_cast_node, AnnCastAssignment):
@@ -2529,6 +2529,8 @@ def visit_call(
from_call = True
elif isinstance(parent_cast_node, AnnCastOperator):
from_operator = True
+ elif isinstance(parent_cast_node, AnnCastLoop):
+ from_loop = True
if isinstance(node.func, AnnCastAttribute):
self.visit(node.func, parent_gromet_fn, parent_cast_node)
@@ -2733,7 +2735,7 @@ def visit_call(
)
# if isinstance(arg.right)
- if from_call or from_operator or from_assignment:
+ if from_call or from_operator or from_assignment or from_loop:
# Operator and calls need a pof appended here because they dont
# do it themselves
# At some point we would like the call handler to always append a POF
@@ -2913,6 +2915,9 @@ def wire_return_node(self, node, gromet_fn):
if isinstance(node, AnnCastLiteralValue):
if is_tuple(node):
self.pack_return_tuple(node, gromet_fn)
+ else:
+ gromet_fn.opo = insert_gromet_object(gromet_fn.opo, GrometPort(box=len(gromet_fn.b)))
+ gromet_fn.wfopo = insert_gromet_object(gromet_fn.wfopo, GrometWire(src=len(gromet_fn.opo),tgt=len(gromet_fn.pof)))
return
elif isinstance(node, AnnCastVar):
var_name = node.val.name
@@ -3014,7 +3019,6 @@ def handle_function_def(
# can clear the local variable environment
var_environment["local"] = deepcopy(prev_local_env)
-
@_visit.register
def visit_function_def(
self, node: AnnCastFunctionDef, parent_gromet_fn, parent_cast_node
@@ -3241,7 +3245,6 @@ def visit_literal_value(
)
code_data_metadata = SourceCodeDataType(
- gromet_type="source_code_data_type",
provenance=generate_provenance(),
source_language=ref[0],
source_language_version=ref[1],
diff --git a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py
index 1c4fe73d0ed..e32ebabf32a 100644
--- a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py
+++ b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py
@@ -582,6 +582,10 @@ def _(self, node: LiteralValue):
node_uid = uuid.uuid4()
self.G.add_node(node_uid, label=f"Boolean: {str(node.value)}")
return node_uid
+ elif node.value_type == ScalarType.CHARACTER:
+ node_uid = uuid.uuid4()
+ self.G.add_node(node_uid, label=f"Character: {str(node.value)}")
+ return node_uid
elif node.value_type == ScalarType.ABSTRACTFLOAT:
node_uid = uuid.uuid4()
self.G.add_node(node_uid, label=f"abstractFloat: {node.value}")
@@ -592,7 +596,10 @@ def _(self, node: LiteralValue):
return node_uid
elif node.value_type == StructureType.TUPLE:
node_uid = uuid.uuid4()
- self.G.add_node(node_uid, label=f"Tuple (...)")
+ self.G.add_node(node_uid, label=f"Tuple")
+ tuple_elems = self.visit_list(node.value)
+ for elem_uid in tuple_elems:
+ self.G.add_edge(node_uid, elem_uid)
return node_uid
elif node.value_type == None:
node_uid = uuid.uuid4()
diff --git a/skema/program_analysis/JSON2GroMEt/json2gromet.py b/skema/program_analysis/JSON2GroMEt/json2gromet.py
index 2c174e72120..59f8df020c2 100644
--- a/skema/program_analysis/JSON2GroMEt/json2gromet.py
+++ b/skema/program_analysis/JSON2GroMEt/json2gromet.py
@@ -32,8 +32,10 @@ def json_to_gromet(path: str) -> GrometFNModuleCollection:
sys.modules["skema.gromet.metadata"], inspect.isclass
):
instance = metadata_object()
- if "metadata_type" in instance.attribute_map:
- gromet_metadata_map[instance.metadata_type] = metadata_object
+ if "is_metadatum" in instance.attribute_map and instance.is_metadatum:
+ gromet_metadata_map[metadata_name] = metadata_object
+ else:
+ gromet_fn_map[metadata_name] = metadata_object
def get_obj_type(obj: Dict) -> Any:
"""Given a dictionary representing a Gromet object (i.e. BoxFunction), return an instance of that object.
@@ -42,10 +44,10 @@ def get_obj_type(obj: Dict) -> Any:
# First check if we already have a mapping to a data-class memeber. All Gromet FN and most Gromet Metadata classes will fall into this category.
# There are a few Gromet Metadata fields such as Provenance that do not have a "metadata_type" field
- if "gromet_type" in obj:
+ if "gromet_type" in obj and ("is_metadatum" not in obj or obj["is_metadatum"] != True):
return gromet_fn_map[obj["gromet_type"]]()
- elif "metadata_type" in obj:
- return gromet_metadata_map[obj["metadata_type"]]()
+ elif obj["is_metadatum"]:
+ return gromet_metadata_map[obj["gromet_type"]]()
# If there is not a mapping to an object, we will check the fields to see if they match an existing class in the data-model.
# For example: (id, box, metadata) would map to GrometPort
diff --git a/skema/program_analysis/comment_extractor/comment_extractor.py b/skema/program_analysis/comment_extractor/comment_extractor.py
index f43023ddec7..9a8bb4e1ac3 100644
--- a/skema/program_analysis/comment_extractor/comment_extractor.py
+++ b/skema/program_analysis/comment_extractor/comment_extractor.py
@@ -245,7 +245,7 @@ def extract_comments_multi(
request: MultiFileCommentRequest,
) -> MultiFileCommentResponse:
"""Wrapper for processing multiple source files at a time."""
- return MultiFileCommentResponse.parse_obj(
+ return MultiFileCommentResponse(**
{
"files": {
file_name: extract_comments_single(file_request)
diff --git a/skema/program_analysis/comment_extractor/server.py b/skema/program_analysis/comment_extractor/server.py
index f570ff67314..3b8ff908e3f 100644
--- a/skema/program_analysis/comment_extractor/server.py
+++ b/skema/program_analysis/comment_extractor/server.py
@@ -83,7 +83,7 @@ async def comments_extract_zip(
}
return comment_service.extract_comments_multi(
- MultiFileCommentRequest.parse_obj(request)
+ MultiFileCommentRequest(**request)
)
app = FastAPI()
diff --git a/skema/program_analysis/comment_extractor/tests/test_comment_server.py b/skema/program_analysis/comment_extractor/tests/test_comment_server.py
index 9b5691cd219..bc9457677b6 100644
--- a/skema/program_analysis/comment_extractor/tests/test_comment_server.py
+++ b/skema/program_analysis/comment_extractor/tests/test_comment_server.py
@@ -14,7 +14,7 @@ def test_comments_get_supported_languages():
response = client.get("/comment_service/comments-get-supported-languages")
assert response.status_code == 200
- languages = comment_service.SupportedLanguageResponse.parse_obj(response.json())
+ languages = comment_service.SupportedLanguageResponse(**response.json())
assert isinstance(languages, comment_service.SupportedLanguageResponse)
assert len(languages.languages) > 0
@@ -37,7 +37,7 @@ def test_comments_extract():
response = client.post("/comment_service/comments-extract", json=request)
assert response.status_code == 200
- comments = comment_service.SingleFileCommentResponse.parse_obj(response.json())
+ comments = comment_service.SingleFileCommentResponse(**response.json())
assert isinstance(comments, comment_service.SingleFileCommentResponse)
@@ -72,5 +72,5 @@ def test_comments_extract_zip():
)
assert response.status_code == 200
- comments = comment_service.MultiFileCommentResponse.parse_obj(response.json())
+ comments = comment_service.MultiFileCommentResponse(**response.json())
assert isinstance(comments, comment_service.MultiFileCommentResponse)
\ No newline at end of file
diff --git a/skema/program_analysis/gromet_wire_diagnosis.py b/skema/program_analysis/gromet_wire_diagnosis.py
new file mode 100644
index 00000000000..b4ae814551e
--- /dev/null
+++ b/skema/program_analysis/gromet_wire_diagnosis.py
@@ -0,0 +1,208 @@
+import argparse
+from skema.program_analysis.JSON2GroMEt import json2gromet
+from skema.gromet.metadata import SourceCodeReference
+
+# Ways to expand
+# Check loop, condition FN indices
+# Check bf call FN indices
+# Boxes associated with ports
+
+def disp_wire(wire):
+ return f"src:{wire.src}<-->tgt:{wire.tgt}"
+
+def get_length(gromet_item):
+ # For any gromet object we can generically retrieve the length, since they all exist
+ # in lists
+ return len(gromet_item) if gromet_item != None else 0
+
+def check_wire(gromet_wire, src_port_count, tgt_port_count, wire_type = "", metadata=None):
+ # The current wiring checks are
+ # Checking if the ports on both ends of the wire are below or over the bounds
+ error_detected = False
+ if gromet_wire.src < 0:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has negative src port.")
+ if gromet_wire.src == 0:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has zero src port.")
+ if gromet_wire.src > src_port_count:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has a src port that goes over the boundary of {src_port_count} src ports.")
+
+ if gromet_wire.tgt < 0:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has negative tgt port.")
+ if gromet_wire.tgt == 0:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has zero tgt port.")
+ if gromet_wire.tgt > tgt_port_count:
+ error_detected = True
+ print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has a tgt port that goes over the boundary of {tgt_port_count} tgt ports.")
+
+
+ if error_detected:
+ if metadata == None:
+ print("No line number information exists for this particular wire!")
+ else:
+ print(f"Wire is associated with source code lines start:{metadata.line_begin} end:{metadata.line_end}")
+ print()
+
+ return error_detected
+
+def find_metadata_idx(gromet_fn):
+ """
+ Attempts to find a metadata associated with this fn
+ If it finds something, return it, otherwise return None
+ """
+ if gromet_fn.b != None:
+ for b in gromet_fn.b:
+ if b.metadata != None:
+ return b.metadata
+
+ if gromet_fn.bf != None:
+ for bf in gromet_fn.bf:
+ if bf.metadata != None:
+ return bf.metadata
+
+ return None
+
+def analyze_fn_wiring(gromet_fn, metadata_collection):
+ # Acquire information for all the ports, if they exist
+ pif_length = get_length(gromet_fn.pif)
+ pof_length = get_length(gromet_fn.pof)
+ opi_length = get_length(gromet_fn.opi)
+ opo_length = get_length(gromet_fn.opo)
+ pil_length = get_length(gromet_fn.pil)
+ pol_length = get_length(gromet_fn.pol)
+ pic_length = get_length(gromet_fn.pic)
+ poc_length = get_length(gromet_fn.poc)
+
+ # Find a SourceCodeReference metadata that we can extract line number information for
+ # so we can display some line number information about potential errors in the wiring
+ # NOTE: Can we make this extraction more accurate?
+ metadata_idx = find_metadata_idx(gromet_fn)
+ metadata = None
+ if metadata_idx != None:
+ for md in metadata_collection[metadata_idx - 1]:
+ if isinstance(md, SourceCodeReference):
+ metadata = md
+
+ wopio_length = get_length(gromet_fn.wopio)
+ if wopio_length > 0:
+ for wire in gromet_fn.wff:
+ check_wire(wire, opo_length, opi_length, "wff", metadata)
+
+ ######################## loop (bl) wiring
+
+ wlopi_length = get_length(gromet_fn.wlopi)
+ if wlopi_length > 0:
+ for wire in gromet_fn.wlopi:
+ check_wire(wire, pil_length, opi_length, "wlopi", metadata)
+
+ wll_length = get_length(gromet_fn.wll)
+ if wll_length > 0:
+ for wire in gromet_fn.wll:
+ check_wire(wire, pil_length, pol_length, "wll", metadata)
+
+ wlf_length = get_length(gromet_fn.wlf)
+ if wlf_length > 0:
+ for wire in gromet_fn.wlf:
+ check_wire(wire, pif_length, pol_length, "wlf", metadata)
+
+ wlc_length = get_length(gromet_fn.wlc)
+ if wlc_length > 0:
+ for wire in gromet_fn.wlc:
+ check_wire(wire, pic_length, pol_length, "wlc", metadata)
+
+ wlopo_length = get_length(gromet_fn.wlopo)
+ if wlopo_length > 0:
+ for wire in gromet_fn.wlopo:
+ check_wire(wire, opo_length, pol_length, "wlopo", metadata)
+
+ ######################## function (bf) wiring
+ wfopi_length = get_length(gromet_fn.wfopi)
+ if wfopi_length > 0:
+ for wire in gromet_fn.wfopi:
+ check_wire(wire, pif_length, opi_length, "wfopi", metadata)
+
+ wfl_length = get_length(gromet_fn.wfl)
+ if wfl_length > 0:
+ for wire in gromet_fn.wfl:
+ check_wire(wire, pil_length, pof_length, "wfl", metadata)
+
+ wff_length = get_length(gromet_fn.wff)
+ if wff_length > 0:
+ for wire in gromet_fn.wff:
+ check_wire(wire, pif_length, pof_length, "wff", metadata)
+
+ wfc_length = get_length(gromet_fn.wfc)
+ if wfc_length > 0:
+ for wire in gromet_fn.wfc:
+ check_wire(wire, pic_length, pof_length, "wfc", metadata)
+
+ wfopo_length = get_length(gromet_fn.wfopo)
+ if wfopo_length > 0:
+ for wire in gromet_fn.wfopo:
+ check_wire(wire, opo_length, pof_length, "wfopo", metadata)
+
+ ######################## condition (bc) wiring
+ wcopi_length = get_length(gromet_fn.wcopi)
+ if wcopi_length > 0:
+ for wire in gromet_fn.wcopi:
+ check_wire(wire, pic_length, opi_length, "wcopi", metadata)
+
+ wcl_length = get_length(gromet_fn.wcl)
+ if wcl_length > 0:
+ for wire in gromet_fn.wcl:
+ check_wire(wire, pil_length, poc_length, "wcl", metadata)
+
+ wcf_length = get_length(gromet_fn.wcf)
+ if wcf_length > 0:
+ for wire in gromet_fn.wcf:
+ check_wire(wire, pif_length, poc_length, "wcf", metadata)
+
+ wcc_length = get_length(gromet_fn.wcc)
+ if wcc_length > 0:
+ for wire in gromet_fn.wcc:
+ check_wire(wire, pic_length, poc_length, "wcc", metadata)
+
+ wcopo_length = get_length(gromet_fn.wcopo)
+ if wcopo_length > 0:
+ for wire in gromet_fn.wcopo:
+ check_wire(wire, opo_length, poc_length, "wcopo", metadata)
+
+
+def wiring_analyzer(gromet_obj):
+ # TODO: Multifiles
+
+ for module in gromet_obj.modules:
+ # first_module = gromet_obj.modules[0]
+ metadata = []
+ # Analyze base FN
+ print(f"Analyzing {module.name}")
+ analyze_fn_wiring(module.fn, module.metadata_collection)
+
+ # Analyze the rest of the FN_array
+ for fn in module.fn_array:
+ analyze_fn_wiring(fn, module.metadata_collection)
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ "Attempts to analyize GroMEt JSON for issues"
+ )
+ parser.add_argument(
+ "gromet_file_path",
+ help="input GroMEt JSON file"
+ )
+
+ options = parser.parse_args()
+ return options
+
+if __name__ == "__main__":
+ args = get_args()
+ gromet_obj = json2gromet.json_to_gromet(args.gromet_file_path)
+
+ wiring_analyzer(gromet_obj)
+
+
+
diff --git a/skema/program_analysis/tests/test_conditional_cast.py b/skema/program_analysis/tests/test_conditional_cast.py
new file mode 100644
index 00000000000..ad0ba999976
--- /dev/null
+++ b/skema/program_analysis/tests/test_conditional_cast.py
@@ -0,0 +1,201 @@
+# import json NOTE: json and Path aren't used right now,
+# from pathlib import Path but will be used in the future
+from skema.program_analysis.CAST.python.ts2cast import TS2CAST
+from skema.program_analysis.CAST2FN.model.cast import (
+ Assignment,
+ Var,
+ Name,
+ LiteralValue,
+ ModelIf,
+ Operator
+)
+
+def cond1():
+ return """
+x = 2
+
+if x < 5:
+ x = x + 1
+else:
+ x = x - 3
+ """
+
+def cond2():
+ return """
+x = 2
+y = 3
+
+if x < 5:
+ x = 1
+ y = 2
+ x = x * y
+else:
+ x = x - 3
+ """
+
+def cond3():
+ return """
+x = 2
+y = 4
+
+if x < 5:
+ x = x + y
+ y = 1
+elif x > 10:
+ y = x + 2
+ x = 1
+elif x == 30:
+ x = 1
+ y = 2
+ z = x * y
+else:
+ x = 0
+ y = x - 2
+ """
+
+def generate_cast(test_file_string):
+ # use Python to CAST
+ out_cast = TS2CAST(test_file_string, from_file=False).out_cast
+
+ return out_cast
+
+def test_cond1():
+ exp_cast = generate_cast(cond1())
+
+ # Test basic conditional
+ asg_node = exp_cast.nodes[0].body[0]
+ cond_node = exp_cast.nodes[0].body[1]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '2'
+
+ assert isinstance(cond_node, ModelIf)
+ cond_expr = cond_node.expr
+ cond_body = cond_node.body
+ cond_else = cond_node.orelse
+
+ assert isinstance(cond_expr, Operator)
+ assert cond_expr.op == "ast.Lt"
+ assert isinstance(cond_expr.operands[0], Name)
+ assert isinstance(cond_expr.operands[1], LiteralValue)
+
+ assert len(cond_body) == 1
+ assert isinstance(cond_body[0], Assignment)
+ assert isinstance(cond_body[0].left, Var)
+ assert isinstance(cond_body[0].right, Operator)
+ assert cond_body[0].right.op == "ast.Add"
+
+ assert len(cond_else) == 1
+ assert isinstance(cond_else[0], Assignment)
+ assert isinstance(cond_else[0].left, Var)
+ assert isinstance(cond_else[0].right, Operator)
+ assert cond_else[0].right.op == "ast.Sub"
+
+
+def test_cond2():
+ exp_cast = generate_cast(cond2())
+
+ # Test multiple variable conditional
+ asg_node = exp_cast.nodes[0].body[0]
+ cond_node = exp_cast.nodes[0].body[2]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+ assert asg_node.left.val.id == 0
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '2'
+
+ asg_node = exp_cast.nodes[0].body[1]
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "y"
+ assert asg_node.left.val.id == 1
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '3'
+
+ assert isinstance(cond_node, ModelIf)
+ cond_expr = cond_node.expr
+ cond_body = cond_node.body
+ cond_else = cond_node.orelse
+
+ assert isinstance(cond_expr, Operator)
+ assert cond_expr.op == "ast.Lt"
+ assert isinstance(cond_expr.operands[0], Name)
+ assert cond_expr.operands[0].name == "x"
+ assert isinstance(cond_expr.operands[1], LiteralValue)
+ assert cond_expr.operands[1].value_type == "Integer"
+ assert cond_expr.operands[1].value == "5"
+
+ assert len(cond_body) == 3
+ assert isinstance(cond_body[0], Assignment)
+ assert isinstance(cond_body[0].left, Var)
+ assert cond_body[0].left.val.name == "x"
+ assert isinstance(cond_body[0].right, LiteralValue)
+ assert cond_body[0].right.value == "1"
+
+ assert isinstance(cond_body[1], Assignment)
+ assert isinstance(cond_body[1].left, Var)
+ assert cond_body[1].left.val.name == "y"
+ assert isinstance(cond_body[1].right, LiteralValue)
+ assert cond_body[1].right.value == "2"
+
+ assert isinstance(cond_body[2], Assignment)
+ assert isinstance(cond_body[2].left, Var)
+ assert isinstance(cond_body[2].right, Operator)
+
+ assert cond_body[2].right.op == "ast.Mult"
+
+ assert isinstance(cond_body[2].right.operands[0], Name)
+ assert cond_body[2].right.operands[0].name == "x"
+ assert cond_body[2].right.operands[0].id == 0
+ assert isinstance(cond_body[2].right.operands[1], Name)
+ assert cond_body[2].right.operands[1].name == "y"
+ assert cond_body[2].right.operands[1].id == 1
+
+ assert len(cond_else) == 1
+ assert isinstance(cond_else[0], Assignment)
+ assert isinstance(cond_else[0].left, Var)
+ assert isinstance(cond_else[0].right, Operator)
+ assert cond_else[0].right.op == "ast.Sub"
+
+def test_cond3():
+ exp_cast = generate_cast(cond3())
+
+ # Test nested ifs
+ cond_node = exp_cast.nodes[0].body[2]
+
+ assert isinstance(cond_node, ModelIf)
+ cond_body = cond_node.body
+ cond_else = cond_node.orelse
+
+ assert len(cond_body) == 2
+ assert len(cond_else) == 1
+ assert isinstance(cond_else[0], ModelIf)
+ nested_if = cond_else[0]
+ cond_body = nested_if.body
+ cond_else = nested_if.orelse
+
+ assert len(cond_body) == 2
+ assert len(cond_else) == 1
+ assert isinstance(cond_else[0], ModelIf)
+ nested_if = cond_else[0]
+ cond_body = nested_if.body
+ cond_else = nested_if.orelse
+
+ assert len(cond_body) == 3
+ assert len(cond_else) == 2
+ assert isinstance(cond_else[0], Assignment)
+ assert isinstance(cond_else[1], Assignment)
diff --git a/skema/program_analysis/tests/test_expression_cast.py b/skema/program_analysis/tests/test_expression_cast.py
index 5d1d1e72b1a..c0f7459caa9 100644
--- a/skema/program_analysis/tests/test_expression_cast.py
+++ b/skema/program_analysis/tests/test_expression_cast.py
@@ -66,3 +66,6 @@ def test_exp1():
assert isinstance(asg_node.right, LiteralValue)
assert asg_node.right.value_type == "Integer"
assert asg_node.right.value == '3'
+
+if __name__ == "__main__":
+ cast = generate_cast(exp0())
\ No newline at end of file
diff --git a/skema/program_analysis/tests/test_for_cast.py b/skema/program_analysis/tests/test_for_cast.py
new file mode 100644
index 00000000000..3693b8bbe75
--- /dev/null
+++ b/skema/program_analysis/tests/test_for_cast.py
@@ -0,0 +1,288 @@
+# import json NOTE: json and Path aren't used right now,
+# from pathlib import Path but will be used in the future
+from skema.program_analysis.CAST.python.ts2cast import TS2CAST
+from skema.program_analysis.CAST2FN.model.cast import (
+ Assignment,
+ Var,
+ Call,
+ Name,
+ LiteralValue,
+ ModelIf,
+ Loop,
+ Operator
+)
+
+def for1():
+ return """
+x = 7
+for i in range(10):
+ x = x + i
+ """
+
+def for2():
+ return """
+x = 1
+for a,b in range(10):
+ x = x + a + b
+ """
+
+def for3():
+ return """
+x = 1
+L = [1,2,3]
+
+for i in L:
+ x = x + i
+ """
+
+
+def generate_cast(test_file_string):
+ # use Python to CAST
+ out_cast = TS2CAST(test_file_string, from_file=False).out_cast
+
+ return out_cast
+
+def test_for1():
+ cast = generate_cast(for1())
+
+ asg_node = cast.nodes[0].body[0]
+ loop_node = cast.nodes[0].body[1]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '7'
+
+ assert isinstance(loop_node, Loop)
+ assert len(loop_node.pre) == 2
+
+ # Loop Pre
+ loop_pre = loop_node.pre
+ assert isinstance(loop_pre[0], Assignment)
+ assert isinstance(loop_pre[0].left, Var)
+ assert loop_pre[0].left.val.name == "generated_iter_0"
+
+ assert isinstance(loop_pre[0].right, Call)
+ assert loop_pre[0].right.func.name == "iter"
+ iter_args = loop_pre[0].right.arguments
+
+ assert len(iter_args) == 1
+ assert isinstance(iter_args[0], Call)
+ assert iter_args[0].func.name == "range"
+ assert len(iter_args[0].arguments) == 3
+
+ assert isinstance(iter_args[0].arguments[0], LiteralValue)
+ assert iter_args[0].arguments[0].value == "1"
+ assert isinstance(iter_args[0].arguments[1], LiteralValue)
+ assert iter_args[0].arguments[1].value == "10"
+ assert isinstance(iter_args[0].arguments[2], LiteralValue)
+ assert iter_args[0].arguments[2].value == "1"
+
+ assert isinstance(loop_pre[1], Assignment)
+ assert isinstance(loop_pre[1].left, LiteralValue)
+ assert loop_pre[1].left.value_type == "Tuple"
+
+ assert isinstance(loop_pre[1].left.value[0], Var)
+ assert loop_pre[1].left.value[0].val.name == "i"
+ assert isinstance(loop_pre[1].left.value[1], Var)
+ assert loop_pre[1].left.value[1].val.name == "generated_iter_0"
+ assert isinstance(loop_pre[1].left.value[2], Var)
+ assert loop_pre[1].left.value[2].val.name == "sc_0"
+
+ assert isinstance(loop_pre[1].right, Call)
+ assert loop_pre[1].right.func.name == "next"
+ assert len(loop_pre[1].right.arguments) == 1
+ assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0"
+
+ # Loop Test
+ loop_test = loop_node.expr
+ assert isinstance(loop_test, Operator)
+ assert loop_test.op == "ast.Eq"
+ assert isinstance(loop_test.operands[0], Name)
+ assert loop_test.operands[0].name == "sc_0"
+
+ assert isinstance(loop_test.operands[1], LiteralValue)
+ assert loop_test.operands[1].value_type == "Boolean"
+
+ # Loop Body
+ loop_body = loop_node.body
+ next_call = loop_body[-1]
+ assert isinstance(next_call, Assignment)
+ assert isinstance(next_call.right, Call)
+ assert next_call.right.func.name == "next"
+ assert next_call.right.arguments[0].val.name == "generated_iter_0"
+
+
+def test_for2():
+ cast = generate_cast(for2())
+
+ asg_node = cast.nodes[0].body[0]
+ loop_node = cast.nodes[0].body[1]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '1'
+
+ assert isinstance(loop_node, Loop)
+ assert len(loop_node.pre) == 2
+
+ # Loop Pre
+ loop_pre = loop_node.pre
+ assert isinstance(loop_pre[0], Assignment)
+ assert isinstance(loop_pre[0].left, Var)
+ assert loop_pre[0].left.val.name == "generated_iter_0"
+
+ assert isinstance(loop_pre[0].right, Call)
+ assert loop_pre[0].right.func.name == "iter"
+ iter_args = loop_pre[0].right.arguments
+
+ assert len(iter_args) == 1
+ assert isinstance(iter_args[0], Call)
+ assert iter_args[0].func.name == "range"
+ assert len(iter_args[0].arguments) == 3
+
+ assert isinstance(iter_args[0].arguments[0], LiteralValue)
+ assert iter_args[0].arguments[0].value == "1"
+ assert isinstance(iter_args[0].arguments[1], LiteralValue)
+ assert iter_args[0].arguments[1].value == "10"
+ assert isinstance(iter_args[0].arguments[2], LiteralValue)
+ assert iter_args[0].arguments[2].value == "1"
+
+ assert isinstance(loop_pre[1], Assignment)
+ assert isinstance(loop_pre[1].left, LiteralValue)
+ assert loop_pre[1].left.value_type == "Tuple"
+
+ assert isinstance(loop_pre[1].left.value[0], LiteralValue)
+ assert loop_pre[1].left.value[0].value_type == "Tuple"
+
+ assert isinstance(loop_pre[1].left.value[0].value[0], Var)
+ assert loop_pre[1].left.value[0].value[0].val.name == "a"
+ assert isinstance(loop_pre[1].left.value[0].value[1], Var)
+ assert loop_pre[1].left.value[0].value[1].val.name == "b"
+
+ assert isinstance(loop_pre[1].left.value[1], Var)
+ assert loop_pre[1].left.value[1].val.name == "generated_iter_0"
+ assert isinstance(loop_pre[1].left.value[2], Var)
+ assert loop_pre[1].left.value[2].val.name == "sc_0"
+
+ assert isinstance(loop_pre[1].right, Call)
+ assert loop_pre[1].right.func.name == "next"
+ assert len(loop_pre[1].right.arguments) == 1
+ assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0"
+
+ # Loop Test
+ loop_test = loop_node.expr
+ assert isinstance(loop_test, Operator)
+ assert loop_test.op == "ast.Eq"
+ assert isinstance(loop_test.operands[0], Name)
+ assert loop_test.operands[0].name == "sc_0"
+
+ assert isinstance(loop_test.operands[1], LiteralValue)
+ assert loop_test.operands[1].value_type == "Boolean"
+
+ # Loop Body
+ loop_body = loop_node.body
+ body_asg = loop_body[0]
+ assert isinstance(body_asg, Assignment)
+
+ assert isinstance(body_asg.right, Operator)
+ assert isinstance(body_asg.right.operands[0], Operator)
+ assert isinstance(body_asg.right.operands[0].operands[0], Name)
+ assert body_asg.right.operands[0].operands[0].name == "x"
+
+ assert isinstance(body_asg.right.operands[0].operands[1], Name)
+ assert body_asg.right.operands[0].operands[1].name == "a"
+
+ assert isinstance(body_asg.right.operands[1], Name)
+ assert body_asg.right.operands[1].name == "b"
+
+ next_call = loop_body[-1]
+ assert isinstance(next_call, Assignment)
+ assert isinstance(next_call.right, Call)
+ assert next_call.right.func.name == "next"
+ assert next_call.right.arguments[0].val.name == "generated_iter_0"
+
+
+def test_for3():
+ cast = generate_cast(for3())
+
+ asg_node = cast.nodes[0].body[0]
+ list_node = cast.nodes[0].body[1]
+ loop_node = cast.nodes[0].body[2]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '1'
+
+ assert isinstance(loop_node, Loop)
+ assert len(loop_node.pre) == 2
+
+ assert isinstance(list_node, Assignment)
+ assert isinstance(list_node.left, Var)
+ assert list_node.left.val.name == "L"
+
+ assert isinstance(list_node.right, LiteralValue)
+ assert list_node.right.value_type == "List"
+
+ # Loop Pre
+ loop_pre = loop_node.pre
+ assert isinstance(loop_pre[0], Assignment)
+ assert isinstance(loop_pre[0].left, Var)
+ assert loop_pre[0].left.val.name == "generated_iter_0"
+
+ assert isinstance(loop_pre[0].right, Call)
+ assert loop_pre[0].right.func.name == "iter"
+ iter_args = loop_pre[0].right.arguments
+
+ assert len(iter_args) == 1
+ assert isinstance(iter_args[0], Var)
+ assert iter_args[0].val.name == "L"
+
+ assert isinstance(loop_pre[1], Assignment)
+ assert isinstance(loop_pre[1].left, LiteralValue)
+ assert loop_pre[1].left.value_type == "Tuple"
+
+ assert isinstance(loop_pre[1].left.value[0], Var)
+ assert loop_pre[1].left.value[0].val.name == "i"
+ assert isinstance(loop_pre[1].left.value[1], Var)
+ assert loop_pre[1].left.value[1].val.name == "generated_iter_0"
+ assert isinstance(loop_pre[1].left.value[2], Var)
+ assert loop_pre[1].left.value[2].val.name == "sc_0"
+
+ assert isinstance(loop_pre[1].right, Call)
+ assert loop_pre[1].right.func.name == "next"
+ assert len(loop_pre[1].right.arguments) == 1
+ assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0"
+
+ # Loop Test
+ loop_test = loop_node.expr
+ assert isinstance(loop_test, Operator)
+ assert loop_test.op == "ast.Eq"
+ assert isinstance(loop_test.operands[0], Name)
+ assert loop_test.operands[0].name == "sc_0"
+
+ assert isinstance(loop_test.operands[1], LiteralValue)
+ assert loop_test.operands[1].value_type == "Boolean"
+
+ # Loop Body
+ loop_body = loop_node.body
+ next_call = loop_body[-1]
+
+ assert isinstance(next_call, Assignment)
+ assert isinstance(next_call.right, Call)
+ assert next_call.right.func.name == "next"
+ assert next_call.right.arguments[0].val.name == "generated_iter_0"
diff --git a/skema/program_analysis/tests/test_identifier.py b/skema/program_analysis/tests/test_identifier.py
new file mode 100644
index 00000000000..46e976172b9
--- /dev/null
+++ b/skema/program_analysis/tests/test_identifier.py
@@ -0,0 +1,39 @@
+# import json NOTE: json and Path aren't used right now,
+# from pathlib import Path but will be used in the future
+from skema.program_analysis.CAST.python.ts2cast import TS2CAST
+from skema.program_analysis.CAST2FN.model.cast import (
+ Assignment,
+ Var,
+ Call,
+ Name,
+ LiteralValue,
+ ModelIf,
+ Loop,
+ Operator
+)
+
+def identifier1():
+ return """x = 2"""
+
+def generate_cast(test_file_string):
+ # use Python to CAST
+ out_cast = TS2CAST(test_file_string, from_file=False).out_cast
+
+ return out_cast
+
+
+# Tests to make sure that identifiers are correctly being generated
+def test_identifier1():
+ cast = generate_cast(identifier1())
+
+ asg_node = cast.nodes[0].body[0]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '2'
+
diff --git a/skema/program_analysis/tests/test_literal_returns.py b/skema/program_analysis/tests/test_literal_returns.py
new file mode 100644
index 00000000000..f59bb0d4a8e
--- /dev/null
+++ b/skema/program_analysis/tests/test_literal_returns.py
@@ -0,0 +1,141 @@
+# import json NOTE: json and Path aren't used right now,
+# from pathlib import Path but will be used in the future
+from skema.program_analysis.multi_file_ingester import process_file_system
+from skema.gromet.fn import (
+ GrometFNModuleCollection,
+ FunctionType,
+ TypedValue,
+)
+import ast
+
+from skema.program_analysis.CAST.pythonAST import py_ast_to_cast
+from skema.program_analysis.CAST2FN.model.cast import SourceRef
+from skema.program_analysis.CAST2FN import cast
+from skema.program_analysis.CAST2FN.cast import CAST
+from skema.program_analysis.run_ann_cast_pipeline import ann_cast_pipeline
+
+
+def return1():
+ return """
+def return_true():
+ return True
+ """
+
+def return2():
+ return """
+def return_true():
+ return True
+
+while (return_true()):
+ print("Test")
+ """
+
+
+def generate_gromet(test_file_string):
+ # use ast.Parse to get Python AST
+ contents = ast.parse(test_file_string)
+
+ # use Python to CAST
+ line_count = len(test_file_string.split("\n"))
+ convert = py_ast_to_cast.PyASTToCAST("temp")
+ C = convert.visit(contents, {}, {})
+ C.source_refs = [SourceRef("temp", None, None, 1, line_count)]
+ out_cast = cast.CAST([C], "python")
+
+ # use AnnCastPipeline to create GroMEt
+ gromet = ann_cast_pipeline(out_cast, gromet=True, to_file=False, from_obj=True)
+
+ return gromet
+
+def test_return1():
+ gromet = generate_gromet(return1())
+
+ base_fn = gromet.fn
+
+ assert len(base_fn.b) == 1
+
+ func_fn = gromet.fn_array[0]
+ assert len(func_fn.b) == 1
+
+ assert len(func_fn.opo) == 1
+ assert func_fn.opo[0].box == 1
+
+ assert len(func_fn.bf) == 1
+ assert func_fn.bf[0].function_type == FunctionType.LITERAL
+ assert func_fn.bf[0].value.value_type == "Boolean"
+ assert func_fn.bf[0].value.value == "True"
+
+ assert len(func_fn.pof) == 1
+ assert func_fn.pof[0].box == 1
+
+ assert len(func_fn.wfopo) == 1
+ assert func_fn.wfopo[0].src == 1 and func_fn.wfopo[0].tgt == 1
+
+
+def test_return2():
+ exp_gromet = generate_gromet(return2())
+
+ base_fn = exp_gromet.fn
+ assert len(base_fn.bl) == 1
+ assert base_fn.bl[0].condition == 2
+ assert base_fn.bl[0].body == 3
+
+ func_fn = exp_gromet.fn_array[0]
+ assert len(func_fn.b) == 1
+
+ assert len(func_fn.opo) == 1
+ assert func_fn.opo[0].box == 1
+
+ assert len(func_fn.bf) == 1
+ assert func_fn.bf[0].function_type == FunctionType.LITERAL
+ assert func_fn.bf[0].value.value_type == "Boolean"
+ assert func_fn.bf[0].value.value == "True"
+
+ assert len(func_fn.pof) == 1
+ assert func_fn.pof[0].box == 1
+
+ assert len(func_fn.wfopo) == 1
+ assert func_fn.wfopo[0].src == 1 and func_fn.wfopo[0].tgt == 1
+
+ predicate_fn = exp_gromet.fn_array[1]
+ assert len(predicate_fn.b) == 1
+ assert len(predicate_fn.opo) == 1
+ assert predicate_fn.opo[0].box == 1
+
+ assert len(predicate_fn.bf) == 1
+ assert predicate_fn.bf[0].body == 1
+
+ assert len(predicate_fn.pof) == 1
+ assert predicate_fn.pof[0].box == 1
+
+ assert len(predicate_fn.wfopo) == 1
+ assert predicate_fn.wfopo[0].src == 1
+ assert predicate_fn.wfopo[0].tgt == 1
+
+ loop_fn = exp_gromet.fn_array[2]
+ assert len(loop_fn.bf) == 1
+ assert loop_fn.bf[0].body == 4
+
+ loop_body_fn = exp_gromet.fn_array[3]
+ assert len(loop_body_fn.opo) == 1
+ assert loop_body_fn.opo[0].box == 1
+
+ assert len(loop_body_fn.bf) == 2
+ assert loop_body_fn.bf[1].function_type == FunctionType.LITERAL
+ assert loop_body_fn.bf[1].value.value_type == "List"
+
+ assert len(loop_body_fn.pif) == 1
+ assert loop_body_fn.pif[0].box == 1
+
+ assert len(loop_body_fn.pof) == 2
+ assert loop_body_fn.pof[0].box == 1
+ assert loop_body_fn.pof[1].box == 2
+
+ assert len(loop_body_fn.wff) == 1
+ assert loop_body_fn.wff[0].src == 1
+ assert loop_body_fn.wff[0].tgt == 2
+
+ assert len(loop_body_fn.wfopo) == 1
+ assert loop_body_fn.wfopo[0].src == 1
+ assert loop_body_fn.wfopo[0].tgt == 1
+
\ No newline at end of file
diff --git a/skema/program_analysis/tests/test_while_cast.py b/skema/program_analysis/tests/test_while_cast.py
new file mode 100644
index 00000000000..5d1f2613175
--- /dev/null
+++ b/skema/program_analysis/tests/test_while_cast.py
@@ -0,0 +1,146 @@
+# import json NOTE: json and Path aren't used right now,
+# from pathlib import Path but will be used in the future
+from skema.program_analysis.CAST.python.ts2cast import TS2CAST
+from skema.program_analysis.CAST2FN.model.cast import (
+ Assignment,
+ Var,
+ Call,
+ Name,
+ LiteralValue,
+ ModelIf,
+ Loop,
+ Operator
+)
+
+def while1():
+ return """
+x = 2
+while x < 5:
+ x = x + 1
+ """
+
+def while2():
+ return """
+x = 2
+y = 3
+
+while x < 5:
+ x = x + 1
+ x = x + y
+ """
+
+def generate_cast(test_file_string):
+ # use Python to CAST
+ out_cast = TS2CAST(test_file_string, from_file=False).out_cast
+
+ return out_cast
+
+def test_while1():
+ cast = generate_cast(while1())
+
+ asg_node = cast.nodes[0].body[0]
+ loop_node = cast.nodes[0].body[1]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '2'
+
+ assert isinstance(loop_node, Loop)
+ assert len(loop_node.pre) == 0
+
+ # Loop Test
+ loop_test = loop_node.expr
+ assert isinstance(loop_test, Operator)
+ assert loop_test.op == "ast.Lt"
+ assert isinstance(loop_test.operands[0], Name)
+ assert loop_test.operands[0].name == "x"
+
+ assert isinstance(loop_test.operands[1], LiteralValue)
+ assert loop_test.operands[1].value_type == "Integer"
+ assert loop_test.operands[1].value == "5"
+
+ # Loop Body
+ loop_body = loop_node.body
+ asg = loop_body[0]
+ assert isinstance(asg, Assignment)
+ assert isinstance(asg.left, Var)
+ assert asg.left.val.name == "x"
+
+ assert isinstance(asg.right, Operator)
+ assert asg.right.op == "ast.Add"
+ assert isinstance(asg.right.operands[0], Name)
+ assert isinstance(asg.right.operands[1], LiteralValue)
+ assert asg.right.operands[1].value == "1"
+
+def test_while2():
+ cast = generate_cast(while2())
+
+ asg_node = cast.nodes[0].body[0]
+ asg_node_2 = cast.nodes[0].body[1]
+ loop_node = cast.nodes[0].body[2]
+
+ assert isinstance(asg_node, Assignment)
+ assert isinstance(asg_node.left, Var)
+ assert isinstance(asg_node.left.val, Name)
+ assert asg_node.left.val.name == "x"
+
+ assert isinstance(asg_node.right, LiteralValue)
+ assert asg_node.right.value_type == "Integer"
+ assert asg_node.right.value == '2'
+
+ assert isinstance(asg_node_2, Assignment)
+ assert isinstance(asg_node_2.left, Var)
+ assert isinstance(asg_node_2.left.val, Name)
+ assert asg_node_2.left.val.name == "y"
+
+ assert isinstance(asg_node_2.right, LiteralValue)
+ assert asg_node_2.right.value_type == "Integer"
+ assert asg_node_2.right.value == '3'
+
+ assert isinstance(loop_node, Loop)
+ assert len(loop_node.pre) == 0
+
+ # Loop Test
+ loop_test = loop_node.expr
+ assert isinstance(loop_test, Operator)
+ assert loop_test.op == "ast.Lt"
+ assert isinstance(loop_test.operands[0], Name)
+ assert loop_test.operands[0].name == "x"
+
+ assert isinstance(loop_test.operands[1], LiteralValue)
+ assert loop_test.operands[1].value_type == "Integer"
+ assert loop_test.operands[1].value == "5"
+
+ # Loop Body
+ loop_body = loop_node.body
+ asg = loop_body[0]
+ assert isinstance(asg, Assignment)
+ assert isinstance(asg.left, Var)
+ assert asg.left.val.name == "x"
+
+ assert isinstance(asg.right, Operator)
+ assert asg.right.op == "ast.Add"
+ assert isinstance(asg.right.operands[0], Name)
+ assert asg.right.operands[0].name == "x"
+
+ assert isinstance(asg.right.operands[1], LiteralValue)
+ assert asg.right.operands[1].value == "1"
+
+ asg = loop_body[1]
+ assert isinstance(asg, Assignment)
+ assert isinstance(asg.left, Var)
+ assert asg.left.val.name == "x"
+
+ assert isinstance(asg.right, Operator)
+ assert asg.right.op == "ast.Add"
+ assert isinstance(asg.right.operands[0], Name)
+ assert asg.right.operands[0].name == "x"
+
+ assert isinstance(asg.right.operands[1], Name)
+ assert asg.right.operands[1].name == "y"
+
diff --git a/skema/program_analysis/tests/test_wiring_diagnosis.py b/skema/program_analysis/tests/test_wiring_diagnosis.py
new file mode 100644
index 00000000000..0f23252ed06
--- /dev/null
+++ b/skema/program_analysis/tests/test_wiring_diagnosis.py
@@ -0,0 +1,28 @@
+from skema.program_analysis.gromet_wire_diagnosis import check_wire
+from skema.gromet.fn import GrometWire
+
+
+def test_correct_wire():
+ correct_wire = GrometWire(src=1, tgt=1)
+ result = check_wire(correct_wire, 1, 1, "wff")
+ assert not result
+
+ correct_wire = GrometWire(src=3, tgt=4)
+ result = check_wire(correct_wire, 4, 5, "wlc")
+ assert not result
+
+ correct_wire = GrometWire(src=2, tgt=1)
+ result = check_wire(correct_wire, 2, 1, "wff")
+
+def test_wrong_wire():
+ wrong_wire = GrometWire(src=0, tgt=-1)
+ result = check_wire(wrong_wire, 1, 1, "wff")
+ assert result
+
+ wrong_wire = GrometWire(src=20, tgt=2)
+ result = check_wire(wrong_wire, 19, 2, "wff")
+ assert result
+
+ wrong_wire = GrometWire(src=-1, tgt=2)
+ result = check_wire(wrong_wire, 1, 1, "wlc")
+ assert result
diff --git a/skema/program_analysis/tree_sitter_parsers/build_parsers.py b/skema/program_analysis/tree_sitter_parsers/build_parsers.py
index 2c73cbabd95..cf87c8d9bcf 100644
--- a/skema/program_analysis/tree_sitter_parsers/build_parsers.py
+++ b/skema/program_analysis/tree_sitter_parsers/build_parsers.py
@@ -74,6 +74,7 @@ def copy_to_site_packages():
flag = f"--{language}"
help_text = f"Include {language} language"
parser.add_argument(flag, action="store_true", help=help_text)
+ parser.add_argument("--ci", action="store_true", help="Copy to site packages if running on ci")
args = parser.parse_args()
if args.all:
@@ -82,4 +83,6 @@ def copy_to_site_packages():
selected_languages = [language for language, value in vars(args).items() if value]
build_parsers(selected_languages)
- copy_to_site_packages()
+
+ if args.ci:
+ copy_to_site_packages()
diff --git a/skema/rest/api.py b/skema/rest/api.py
index 03fe324ae4b..862fa58b577 100644
--- a/skema/rest/api.py
+++ b/skema/rest/api.py
@@ -1,10 +1,11 @@
import os
from typing import Dict
-from fastapi import FastAPI, Response, status
+from fastapi import Depends, FastAPI, Response, status
from fastapi.responses import PlainTextResponse
from skema.rest import (
+ config,
schema,
workflows,
proxies,
@@ -12,11 +13,14 @@
morae_proxy,
metal_proxy,
llm_proxy,
+ utils
)
+from skema.isa import isa_service
from skema.img2mml import eqn2mml
from skema.skema_py import server as code2fn
from skema.gromet.execution_engine import server as execution_engine
from skema.program_analysis.comment_extractor import server as comment_service
+import httpx
VERSION: str = os.environ.get("APP_VERSION", "????")
@@ -62,15 +66,27 @@
},
{
"name": "morae",
- "description": "",
+ "description": "Operations to MORAE.",
"externalDocs": {
"description": "Issues",
"url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AMORAE",
},
},
+ {
+ "name": "isa",
+ "description": "Operations to ISA",
+ "externalDocs": {
+ "description": "Issues",
+ "url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AISA",
+ },
+ },
{
"name": "text reading",
"description": "Unified proxy and integration code for MIT and SKEMA TR pipelines",
+ "externalDocs": {
+ "description": "Issues",
+ "url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AText%20Reading",
+ },
},
{
"name": "metal",
@@ -139,9 +155,25 @@
tags=["metal"]
)
+app.include_router(
+ isa_service.router,
+ prefix="/isa",
+ tags=["isa"]
+)
-@app.get("/version", tags=["core"], summary="API version")
-async def version() -> str:
+@app.head(
+ "/version",
+ tags=["core"],
+ summary="API version",
+ status_code=status.HTTP_200_OK
+)
+@app.get(
+ "/version",
+ tags=["core"],
+ summary="API version",
+ status_code=status.HTTP_200_OK
+)
+def version() -> str:
return PlainTextResponse(VERSION)
@@ -161,11 +193,11 @@ async def version() -> str:
},
},
)
-async def healthcheck(response: Response) -> schema.HealthStatus:
- morae_status = await morae_proxy.healthcheck()
+async def healthcheck(response: Response, client: httpx.AsyncClient = Depends(utils.get_client)) -> schema.HealthStatus:
+ morae_status = await morae_proxy.healthcheck(client)
mathjax_status = eqn2mml.latex2mml_healthcheck()
eqn2mml_status = eqn2mml.img2mml_healthcheck()
- code2fn_status = code2fn.ping()
+ code2fn_status = code2fn.healthcheck()
text_reading_status = integrated_text_reading_proxy.healthcheck()
metal_status = metal_proxy.healthcheck()
# check if any services failing and alter response status code accordingly
@@ -201,6 +233,7 @@ async def environment_variables() -> Dict:
"SKEMA_GRAPH_DB_HOST": proxies.SKEMA_GRAPH_DB_HOST,
"SKEMA_GRAPH_DB_PORT": proxies.SKEMA_GRAPH_DB_PORT,
"SKEMA_RS_ADDRESS": proxies.SKEMA_RS_ADDESS,
+ "SKEMA_RS_DEFAULT_TIMEOUT": config.SKEMA_RS_DEFAULT_TIMEOUT,
"SKEMA_MATHJAX_PROTOCOL": proxies.SKEMA_MATHJAX_PROTOCOL,
"SKEMA_MATHJAX_HOST": proxies.SKEMA_MATHJAX_HOST,
diff --git a/skema/rest/config.py b/skema/rest/config.py
new file mode 100644
index 00000000000..92038fbb368
--- /dev/null
+++ b/skema/rest/config.py
@@ -0,0 +1,9 @@
+# -*- coding: utf-8 -*-
+"""
+ENV-based config
+"""
+
+import os
+
+
+SKEMA_RS_DEFAULT_TIMEOUT = float(os.environ.get("SKEMA_RS_DEFAULT_TIMEOUT", "60.0"))
\ No newline at end of file
diff --git a/skema/rest/integrated_text_reading_proxy.py b/skema/rest/integrated_text_reading_proxy.py
index d167455d057..6cb64420242 100644
--- a/skema/rest/integrated_text_reading_proxy.py
+++ b/skema/rest/integrated_text_reading_proxy.py
@@ -11,9 +11,10 @@
import pandas as pd
import requests
+import httpx
from askem_extractions.data_model import AttributeCollection
from askem_extractions.importers import import_arizona
-from fastapi import APIRouter, FastAPI, UploadFile, Response, status
+from fastapi import APIRouter, Depends, FastAPI, UploadFile, Response, status
from skema.rest.proxies import SKEMA_TR_ADDRESS, MIT_TR_ADDRESS, OPENAI_KEY, COSMOS_ADDRESS
from skema.rest.schema import (
@@ -22,7 +23,7 @@
TextReadingDocumentResults,
TextReadingError, MiraGroundingInputs, MiraGroundingOutputItem, TextReadingEvaluationResults,
)
-from skema.rest.utils import compute_text_reading_evaluation
+from skema.rest import utils
router = APIRouter()
@@ -676,7 +677,7 @@ def quantitative_eval() -> TextReadingEvaluationResults:
# Read the SKEMA extractions
extractions = AttributeCollection.from_json(Path(__file__).parents[0] / "data" / "extractions_sidarthe_skema.json")
- return compute_text_reading_evaluation(gt_data, extractions)
+ return utils.compute_text_reading_evaluation(gt_data, extractions)
@router.post("/eval", response_model=TextReadingEvaluationResults, status_code=200)
@@ -716,7 +717,7 @@ def quantitative_eval(extractions_file: UploadFile,
extractions = AttributeCollection(
attributes=list(it.chain.from_iterable(c.attributes for c in collections)))
- return compute_text_reading_evaluation(gt_data, extractions, json_contents)
+ return utils.compute_text_reading_evaluation(gt_data, extractions, json_contents)
app = FastAPI()
diff --git a/skema/rest/llm_proxy.py b/skema/rest/llm_proxy.py
index 35fea8f90fb..a7454e1d453 100644
--- a/skema/rest/llm_proxy.py
+++ b/skema/rest/llm_proxy.py
@@ -11,10 +11,10 @@
from fastapi import APIRouter, FastAPI, File, UploadFile
from io import BytesIO
from zipfile import ZipFile
-import requests
from pathlib import Path
from pydantic import BaseModel, Field
from typing import List, Optional
+from skema.skema_py import server as code2fn
from skema.rest.proxies import SKEMA_OPENAI_KEY
import time
@@ -121,13 +121,14 @@ async def get_lines_of_model(zip_file: UploadFile = File()) -> List[Dynamics]:
function_name = parsed_output['model_function']
- # Get the FN from it
- url = "https://api.askem.lum.ai/code2fn/fn-given-filepaths"
- time.sleep(0.5)
- response_zip = requests.post(url, json=single_snippet_payload)
-
+ # FIXME: we should rewrite things to avoid this need
+ #time.sleep(0.5)
+ system = code2fn.System(**single_snippet_payload)
+ print(f"System:\t{system}")
+ response_zip = await code2fn.fn_given_filepaths(system)
+ #print(f"response_zip:\t{response_zip}")
# get metadata entry for function
- for entry in response_zip.json()['modules'][0]['fn_array']:
+ for entry in response_zip['modules'][0]['fn_array']:
try:
if entry['b'][0]['name'][0:len(function_name)] == function_name:
metadata_idx = entry['b'][0]['metadata']
@@ -135,17 +136,30 @@ async def get_lines_of_model(zip_file: UploadFile = File()) -> List[Dynamics]:
continue
# get line span using metadata
- for (i,metadata) in enumerate(response_zip.json()['modules'][0]['metadata_collection']):
+ for (i,metadata) in enumerate(response_zip['modules'][0]['metadata_collection']):
if i == (metadata_idx - 1):
line_begin = metadata[0]['line_begin']
line_end = metadata[0]['line_end']
- except:
+ # if the line_begin of meta entry 2 (base 0) and meta entry 3 (base 0) are we add a slice from [meta2.line_begin, meta3.line_begin)
+ # to capture all the imports, return a Dynamics.block with 2 entries, both of which need to be concatenated to pass forward
+ file_line_begin = response_zip['modules'][0]['metadata_collection'][2][0]['line_begin']
+
+ code_line_begin = response_zip['modules'][0]['metadata_collection'][3][0]['line_begin'] - 1
+
+ if (file_line_begin != code_line_begin) and (code_line_begin > file_line_begin):
+ block.append(f"L{file_line_begin}-L{code_line_begin}")
+
+ block.append(f"L{line_begin}-L{line_end}")
+ except Exception as e:
print("Failed to parse dynamics")
+ print(f"e:\t{e}")
description = "Failed to parse dynamics"
line_begin = 0
line_end = 0
+ block.append(f"L{line_begin}-L{line_end}")
+
+
- block.append(f"L{line_begin}-L{line_end}")
output = Dynamics(name=file, description=description, block=block)
outputs.append(output)
diff --git a/skema/rest/morae_proxy.py b/skema/rest/morae_proxy.py
index ce1bc1e2651..c88d40cb9e9 100644
--- a/skema/rest/morae_proxy.py
+++ b/skema/rest/morae_proxy.py
@@ -6,8 +6,10 @@
from typing import Any, Dict, List, Text
from skema.rest.proxies import SKEMA_RS_ADDESS
-from fastapi import APIRouter
-import requests
+from fastapi import APIRouter, Depends
+from skema.rest import utils
+# TODO: replace use of requests with httpx
+import httpx
router = APIRouter()
@@ -15,26 +17,30 @@
# FIXME: make GrometFunctionModuleCollection a pydantic model via code gen
@router.post("/model", summary="Pushes gromet (function network) to the graph database", include_in_schema=False)
-async def post_model(gromet: Dict[Text, Any]):
- return requests.post(f"{SKEMA_RS_ADDESS}/models", json=gromet).json()
+async def post_model(gromet: Dict[Text, Any], client: httpx.AsyncClient = Depends(utils.get_client)):
+ res = await client.post(f"{SKEMA_RS_ADDESS}/models", json=gromet)
+ return res.json()
@router.get("/models", summary="Gets function network IDs from the graph database")
-async def get_models() -> List[int]:
- request = requests.get(f"{SKEMA_RS_ADDESS}/models")
- print(f"request: {request}")
- return request.json()
+async def get_models(client: httpx.AsyncClient = Depends(utils.get_client)) -> List[int]:
+ res = await client.get(f"{SKEMA_RS_ADDESS}/models")
+ print(f"request: {res}")
+ return res.json()
@router.get("/ping", summary="Status of MORAE service")
-async def healthcheck() -> int:
- return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code
+async def healthcheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> int:
+ res = await client.get(f"{SKEMA_RS_ADDESS}/ping")
+ return res.status_code
@router.get("/version", summary="Status of MORAE service")
-async def versioncheck() -> str:
- return requests.get(f"{SKEMA_RS_ADDESS}/version").text
+async def versioncheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> str:
+ res = await client.get(f"{SKEMA_RS_ADDESS}/version")
+ return res.text
@router.post("/mathml/decapodes", summary="Gets Decapodes from a list of MathML strings")
-async def get_decapodes(mathml: List[str]) -> Dict[Text, Any]:
- return requests.put(f"{SKEMA_RS_ADDESS}/mathml/decapodes", json=mathml).json()
\ No newline at end of file
+async def get_decapodes(mathml: List[str], client: httpx.AsyncClient = Depends(utils.get_client)) -> Dict[Text, Any]:
+ res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/decapodes", json=mathml)
+ return res.json()
\ No newline at end of file
diff --git a/skema/rest/tests/test_eqn_to_latex.py b/skema/rest/tests/test_eqn_to_latex.py
index b0f5de39a5e..89cfa81e7b9 100644
--- a/skema/rest/tests/test_eqn_to_latex.py
+++ b/skema/rest/tests/test_eqn_to_latex.py
@@ -1,14 +1,14 @@
+import base64
from pathlib import Path
-from fastapi.testclient import TestClient
+from httpx import AsyncClient
from skema.rest.workflows import app
import pytest
import json
-client = TestClient(app)
-
@pytest.mark.ci_only
-def test_post_image_to_latex():
+@pytest.mark.asyncio
+async def test_post_image_to_latex():
"""Test case for /images/equations-to-latex endpoint."""
cwd = Path(__file__).parents[0]
@@ -18,7 +18,9 @@ def test_post_image_to_latex():
}
endpoint = "/images/equations-to-latex"
- response = client.post(endpoint, files=files)
+ # see https://fastapi.tiangolo.com/advanced/async-tests/#async-tests
+ async with AsyncClient(app=app, base_url="http://eqn-to-latex-test") as ac:
+ response = await ac.post(endpoint, files=files)
expected = "\\frac{d H}{dt}=\\nabla \\cdot {(\\Gamma*H^{n+2}*\\left|\\nabla{H}\\right|^{n-1}*\\nabla{H})}"
# check for route's existence
assert (
@@ -32,3 +34,32 @@ def test_post_image_to_latex():
assert (
json.loads(response.text) == expected
), f"Response should be {expected}, but instead received {response.text}"
+
+
+@pytest.mark.ci_only
+@pytest.mark.asyncio
+async def test_post_image_to_latex_base64():
+ """Test case for /images/base64/equations-to-latex endpoint."""
+ cwd = Path(__file__).parents[0]
+ image_path = cwd / "data" / "img2latex" / "halfar.png"
+ with Path(image_path).open("rb") as infile:
+ img_bytes = infile.read()
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
+
+ endpoint = "/images/base64/equations-to-latex"
+ # see https://fastapi.tiangolo.com/advanced/async-tests/#async-tests
+ async with AsyncClient(app=app, base_url="http://eqn-to-latex-base64-test") as ac:
+ response = await ac.post(endpoint, data=img_b64)
+ expected = "\\frac{d H}{dt}=\\nabla \\cdot {(\\Gamma*H^{n+2}*\\left|\\nabla{H}\\right|^{n-1}*\\nabla{H})}"
+ # check for route's existence
+ assert (
+ any(route.path == endpoint for route in app.routes) == True
+ ), "{endpoint} does not exist for app"
+ # check status code
+ assert (
+ response.status_code == 200
+ ), f"Request was unsuccessful (status code was {response.status_code} instead of 200)"
+ # check response
+ assert (
+ json.loads(response.text) == expected
+ ), f"Response should be {expected}, but instead received {response.text}"
\ No newline at end of file
diff --git a/skema/rest/tests/test_integrated_text_reading_proxy.py b/skema/rest/tests/test_integrated_text_reading_proxy.py
index 46c0bff8822..d4e726e2440 100644
--- a/skema/rest/tests/test_integrated_text_reading_proxy.py
+++ b/skema/rest/tests/test_integrated_text_reading_proxy.py
@@ -114,9 +114,9 @@ def test_extraction_evaluation():
results = response.json()
assert results['num_manual_annotations'] == 220, "There should be 220 gt manual annotations"
- assert results['precision'] == approx(0.7230769230768118), "Precision drastically different from the expected value"
- assert results['recall'] == approx(0.21363636363636362), "Recall drastically different from the expected value"
- assert results['f1'] == approx(0.32982456136828636), "F1 drastically different from the expected value"
+ assert results['precision'] == approx(0.5230769230768426), "Precision drastically different from the expected value"
+ assert results['recall'] == approx(0.154545454545454542), "Recall drastically different from the expected value"
+ assert results['f1'] == approx(0.23859649119285095), "F1 drastically different from the expected value"
def test_healthcheck():
diff --git a/skema/rest/tests/test_isa.py b/skema/rest/tests/test_isa.py
new file mode 100644
index 00000000000..96ace972810
--- /dev/null
+++ b/skema/rest/tests/test_isa.py
@@ -0,0 +1,40 @@
+import json
+
+from fastapi.testclient import TestClient
+from skema.isa.isa_service import app
+import skema.isa.data as isa_data
+import pytest
+
+client = TestClient(app)
+
+
+@pytest.mark.ci_only
+def test_align_eqns():
+ """Test case for /align-eqns endpoint."""
+
+ halfar_dome_eqn = isa_data.mml
+ mention_json1_content = ""
+ mention_json2_content = ""
+ data = {
+ "mml1": halfar_dome_eqn,
+ "mml2": halfar_dome_eqn,
+ "mention_json1": mention_json1_content,
+ "mention_json2": mention_json2_content,
+ }
+
+ endpoint = "/isa/align-eqns"
+ response = client.post(endpoint, params=data)
+ expected = isa_data.expected
+
+ # check status code
+ assert (
+ response.status_code == 200
+ ), f"Request was unsuccessful (status code was {response.status_code} instead of 200)"
+ # check response of matching_ratio
+ assert (
+ json.loads(response.text)["matching_ratio"] == 1.0
+ ), f"Response should be 1.0, but instead received {response.text}"
+ # check response of union_graph
+ assert (
+ json.loads(response.text)["union_graph"] == expected
+ ), f"Response should be {expected}, but instead received {response.text}"
diff --git a/skema/rest/tests/test_model_to_amr.py b/skema/rest/tests/test_model_to_amr.py
index 51c240818ed..b59d23a4140 100644
--- a/skema/rest/tests/test_model_to_amr.py
+++ b/skema/rest/tests/test_model_to_amr.py
@@ -12,14 +12,21 @@
)
from skema.rest.llm_proxy import Dynamics
from skema.rest.proxies import SKEMA_RS_ADDESS
-from skema.skema_py.server import System
-import time
+from skema.skema_py import server as code2fn
+import json
+import httpx
+import pytest
CHIME_SIR_URL = (
"https://artifacts.askem.lum.ai/askem/data/models/zip-archives/CHIME-SIR-model.zip"
)
-def test_any_amr_chime_sir():
+SIDARTHE_URL = (
+ "https://artifacts.askem.lum.ai/askem/data/models/zip-archives/SIDARTHE.zip"
+)
+
+@pytest.mark.asyncio
+async def test_any_amr_chime_sir():
"""
Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR.
This will test if just the core dynamics works, the whole script, and also rewritten scripts work.
@@ -36,16 +43,26 @@ def test_any_amr_chime_sir():
llm_mock_output = [dyn1, dyn2, dyn3, dyn4]
line_begin = []
+ import_begin = []
line_end = []
+ import_end = []
files = []
blobs = []
amrs = []
+
for linespan in llm_mock_output:
- lines = linespan.block[0].split("-")
+ blocks = len(linespan.block)
+ lines = linespan.block[blocks-1].split("-")
line_begin.append(
max(int(lines[0][1:]) - 1, 0)
) # Normalizing the 1-index response from llm_proxy
line_end.append(int(lines[1][1:]))
+ if blocks == 2:
+ lines = linespan.block[0].split("-")
+ import_begin.append(
+ max(int(lines[0][1:]) - 1, 0)
+ ) # Normalizing the 1-index response from llm_proxy
+ import_end.append(int(lines[1][1:]))
# So we are required to do the same when slicing the source code using its output.
with ZipFile(zip_bytes, "r") as zip:
@@ -62,39 +79,149 @@ def test_any_amr_chime_sir():
if line_begin[i] == line_end[i]:
print("failed linespan")
else:
- blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ if blocks == 2:
+ temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]])
+ blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ else:
+ blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
try:
- time.sleep(0.5)
- code_snippet_response = asyncio.run(
- code_snippets_to_pn_amr(
- System(
- files=[files[i]],
- blobs=[blobs[i]],
- )
- )
- )
+ async with httpx.AsyncClient() as client:
+ code_snippet_response = await code_snippets_to_pn_amr(
+ system=code2fn.System(
+ files=[files[i]],
+ blobs=[blobs[i]],
+ ),
+ client=client
+ )
+ # code_snippet_response = json.loads(code_snippet_response.body)
+ # print(f"code_snippet_response for test_any_amr_chime_sir: {code_snippet_response}")
if "model" in code_snippet_response:
+ code_snippet_response["header"]["name"] = "LLM-assisted code to amr model"
+ code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}"
+ code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}"
amrs.append(code_snippet_response)
else:
print("snippets failure")
logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
- except:
- print("except hit")
+ except Exception as e:
+ print("Hit except to snippets failure")
+ print(f"Exception for test_any_amr_chime_sir:\t{e}")
logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
# we will return the amr with most states, in assumption it is the most "correct"
# by default it returns the first entry
- print(f"amrs: {amrs}\n")
- amr = amrs[0]
- print(f"initial amr: {amr}\n")
- for temp_amr in amrs:
- try:
- temp_len = len(temp_amr["model"]["states"])
- amr_len = len(amr["model"]["states"])
- if temp_len > amr_len:
- amr = temp_amr
- except:
- continue
+ print(f"{amrs}")
+ try:
+ amr = amrs[0]
+ for temp_amr in amrs:
+ try:
+ temp_len = len(temp_amr["model"]["states"])
+ amr_len = len(amr["model"]["states"])
+ if temp_len > amr_len:
+ amr = temp_amr
+ except:
+ continue
+ except Exception as e:
+ print(f"Exception for test_any_amr_chime_sir:\t{e}")
+ amr = logging
print(f"final amr: {amr}\n")
# For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy.
assert "model" in amr, f"'model' should be in AMR response, but got {amr}"
+@pytest.mark.asyncio
+async def test_any_amr_sidarthe():
+ """
+ Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR.
+ This will test if just the core dynamics works, the whole script, and also rewritten scripts work.
+ """
+ response = requests.get(SIDARTHE_URL)
+ zip_bytes = BytesIO(response.content)
+
+ # NOTE: For CI we are unable to use the LLM assisted functions due to API keys
+ # So, we will instead mock the output for those functions instead
+ dyn1 = Dynamics(name="commented_Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L59"])
+ dyn2 = Dynamics(name="Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L18"])
+ llm_mock_output = [dyn1, dyn2]
+
+ line_begin = []
+ import_begin = []
+ line_end = []
+ import_end = []
+ files = []
+ blobs = []
+ amrs = []
+
+
+ for linespan in llm_mock_output:
+ blocks = len(linespan.block)
+ lines = linespan.block[blocks-1].split("-")
+ line_begin.append(
+ max(int(lines[0][1:]) - 1, 0)
+ ) # Normalizing the 1-index response from llm_proxy
+ line_end.append(int(lines[1][1:]))
+ if blocks == 2:
+ lines = linespan.block[0].split("-")
+ import_begin.append(
+ max(int(lines[0][1:]) - 1, 0)
+ ) # Normalizing the 1-index response from llm_proxy
+ import_end.append(int(lines[1][1:]))
+
+ # So we are required to do the same when slicing the source code using its output.
+ with ZipFile(zip_bytes, "r") as zip:
+ for file in zip.namelist():
+ file_obj = Path(file)
+ if file_obj.suffix in [".py"]:
+ files.append(file)
+ blobs.append(zip.open(file).read().decode("utf-8"))
+
+ # The source code is a string, so to slice using the line spans, we must first convert it to a list.
+ # Then we can convert it back to a string using .join
+ logging = []
+ for i in range(len(blobs)):
+ if line_begin[i] == line_end[i]:
+ print("failed linespan")
+ else:
+ if blocks == 2:
+ temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]])
+ blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ else:
+ blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ try:
+ async with httpx.AsyncClient() as client:
+ code_snippet_response = await code_snippets_to_pn_amr(
+ system=code2fn.System(
+ files=[files[i]],
+ blobs=[blobs[i]],
+ ),
+ client=client
+ )
+ if "model" in code_snippet_response:
+ code_snippet_response["header"]["name"] = "LLM-assisted code to amr model"
+ code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}"
+ code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}"
+ amrs.append(code_snippet_response)
+ else:
+ print("snippets failure")
+ logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
+ except Exception as e:
+ print("Hit except to snippets failure")
+ print(f"Exception for test_any_amr_sidarthe:\t{e}")
+ logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
+ # we will return the amr with most states, in assumption it is the most "correct"
+ # by default it returns the first entry
+ print(f"{amrs}")
+ try:
+ amr = amrs[0]
+ for temp_amr in amrs:
+ try:
+ temp_len = len(temp_amr["model"]["states"])
+ amr_len = len(amr["model"]["states"])
+ if temp_len > amr_len:
+ amr = temp_amr
+ except:
+ continue
+ except Exception as e:
+ print(f"Exception for final amr of test_any_amr_sidarthe:\t{e}")
+ amr = logging
+ print(f"final amr: {amr}\n")
+ # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy.
+ assert "model" in amr, f"'model' should be in AMR response, but got {amr}"
\ No newline at end of file
diff --git a/skema/rest/utils.py b/skema/rest/utils.py
index 51a4d151bc4..765add351a6 100644
--- a/skema/rest/utils.py
+++ b/skema/rest/utils.py
@@ -1,13 +1,24 @@
import itertools as it
+import httpx
from collections import defaultdict
from typing import Any, Dict
from askem_extractions.data_model import AttributeCollection, AttributeType, AnchoredEntity
from bs4 import BeautifulSoup, Comment
+from skema.rest import config
from skema.rest.schema import TextReadingEvaluationResults, AMRLinkingEvaluationResults
+# see https://stackoverflow.com/a/74401249
+async def get_client():
+ # create a new client for each request
+ async with httpx.AsyncClient(timeout=config.SKEMA_RS_DEFAULT_TIMEOUT, follow_redirects=True) as client:
+ # yield the client to the endpoint function
+ yield client
+ # close the client when the request is done
+
+
def fn_preprocessor(function_network: Dict[str, Any]):
fn_data = function_network.copy()
@@ -170,23 +181,32 @@ def compute_text_reading_evaluation(gt_data: list, attributes: AttributeCollecti
page = a["page"]
annotations_by_page[page].append(a)
+ def annotation_key(a: Dict):
+ return a['page'], tuple(a['start_xy']), a['text']
+
# Count the matches
tp, tn, fp, fn = 0, 0, 0, 0
+ matched_annotations = set()
for e in extractions:
+ matched = False
for m in e.mentions:
- if m.extraction_source is not None:
- te = m.extraction_source
- if te.page is not None:
- e_page = te.page
- page_annotations = annotations_by_page[e_page]
- matched = False
- for a in page_annotations:
- if extraction_matches_annotation(m, a, json_contents):
- matched = True
- tp += 1
- break
- if not matched:
- fp += 1
+ if not matched:
+ if m.extraction_source is not None:
+ te = m.extraction_source
+ if te.page is not None:
+ e_page = te.page
+ page_annotations = annotations_by_page[e_page]
+
+ for a in page_annotations:
+ key = annotation_key(a)
+ if key not in matched_annotations:
+ if extraction_matches_annotation(m, a, json_contents):
+ matched_annotations.add(key)
+ matched = True
+ tp += 1
+ break
+ if not matched:
+ fp += 1
recall = tp / len(gt_data)
precision = tp / (tp + fp + 0.00000000001)
diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py
index 7ef76c6eaf8..5f0e53745db 100644
--- a/skema/rest/workflows.py
+++ b/skema/rest/workflows.py
@@ -3,20 +3,22 @@
End-to-end skema workflows
"""
import copy
-import requests
import time
from zipfile import ZipFile
from io import BytesIO
from typing import List
from pathlib import Path
+import httpx
+import json
+import requests
-from fastapi import APIRouter, File, UploadFile, FastAPI
+from fastapi import APIRouter, Depends, File, UploadFile, FastAPI, Request
from starlette.responses import JSONResponse
from skema.img2mml import eqn2mml
-from skema.img2mml.eqn2mml import image2mathml_db
+from skema.img2mml.eqn2mml import image2mathml_db, b64_image_to_mml
from skema.img2mml.api import get_mathml_from_bytes
-from skema.rest import schema, utils, llm_proxy
+from skema.rest import config, schema, utils, llm_proxy
from skema.rest.proxies import SKEMA_RS_ADDESS
from skema.skema_py import server as code2fn
@@ -27,7 +29,7 @@
@router.post(
"/images/base64/equations-to-amr", summary="Equations (base64 images) → MML → AMR"
)
-async def equations_to_amr(data: schema.EquationImagesToAMR):
+async def equations_to_amr(data: schema.EquationImagesToAMR, client: httpx.AsyncClient = Depends(utils.get_client)):
"""
Converts images of equations to AMR.
@@ -57,7 +59,7 @@ async def equations_to_amr(data: schema.EquationImagesToAMR):
]
payload = {"mathml": mml, "model": data.model}
# FIXME: why is this a PUT?
- res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
+ res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
if res.status_code != 200:
return JSONResponse(
status_code=400,
@@ -71,7 +73,7 @@ async def equations_to_amr(data: schema.EquationImagesToAMR):
# equation images -> mml -> latex
@router.post("/images/equations-to-latex", summary="Equations (images) → MML → LaTeX")
-async def equations_to_latex(data: UploadFile):
+async def equations_to_latex(data: UploadFile, client: httpx.AsyncClient = Depends(utils.get_client)):
"""
Converts images of equations to LaTeX.
@@ -96,8 +98,9 @@ async def equations_to_latex(data: UploadFile):
# pass image bytes to get_mathml_from_bytes function
mml_res = get_mathml_from_bytes(image_bytes, image2mathml_db)
proxy_url = f"{SKEMA_RS_ADDESS}/mathml/latex"
+ print(f"MMML:\t{mml_res}")
print(f"Proxying request to {proxy_url}")
- response = requests.post(proxy_url, data=mml_res)
+ response = await client.post(proxy_url, data=mml_res)
# Check the response
if response.status_code == 200:
# The request was successful
@@ -109,9 +112,51 @@ async def equations_to_latex(data: UploadFile):
return f"Error: {response.status_code} {response.text}"
+# equation images -> base64 -> mml -> latex
+@router.post("/images/base64/equations-to-latex", summary="Equations (images) → MML → LaTeX")
+async def equations_to_latex(request: Request, client: httpx.AsyncClient = Depends(utils.get_client)):
+ """
+ Converts images of equations to LaTeX.
+
+ ### Python example
+
+ Endpoint for generating LaTeX from an input image.
+
+ ```
+ from pathlib import Path
+ import base64
+ import requests
+
+ url = "http://127.0.0.1:8000/workflows/images/base64/equations-to-latex"
+ with Path("test.png").open("rb") as infile:
+ img_bytes = infile.read()
+ img_b64 = base64.b64encode(img_bytes).decode("utf-8")
+ r = requests.post(url, data=img_b64)
+ print(r.text)
+ ```
+ """
+ # Read image data
+ img_b64 = await request.body()
+ mml_res = b64_image_to_mml(img_b64)
+
+ # pass image bytes to get_mathml_from_bytes function
+ proxy_url = f"{SKEMA_RS_ADDESS}/mathml/latex"
+ print(f"MML:\t{mml_res}")
+ print(f"Proxying request to {proxy_url}")
+ response = await client.post(proxy_url, data=mml_res)
+ # Check the response
+ if response.status_code == 200:
+ # The request was successful
+ return response.text
+ else:
+ # The request failed
+ print(f"Error: {response.status_code}")
+ print(response.text)
+ return f"Error: {response.status_code} {response.text}"
+
# tex equations -> pmml -> amr
@router.post("/latex/equations-to-amr", summary="Equations (LaTeX) → pMML → AMR")
-async def equations_to_amr(data: schema.EquationLatexToAMR):
+async def equations_to_amr(data: schema.EquationLatexToAMR, client: httpx.AsyncClient = Depends(utils.get_client)):
"""
Converts equations (in LaTeX) to AMR.
@@ -131,7 +176,7 @@ async def equations_to_amr(data: schema.EquationLatexToAMR):
utils.clean_mml(eqn2mml.get_mathml_from_latex(tex)) for tex in data.equations
]
payload = {"mathml": mml, "model": data.model}
- res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
+ res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
if res.status_code != 200:
return JSONResponse(
status_code=400,
@@ -145,9 +190,9 @@ async def equations_to_amr(data: schema.EquationLatexToAMR):
# pmml -> amr
@router.post("/pmml/equations-to-amr", summary="Equations pMML → AMR")
-async def equations_to_amr(data: schema.MmlToAMR):
+async def equations_to_amr(data: schema.MmlToAMR, client: httpx.AsyncClient = Depends(utils.get_client)):
payload = {"mathml": data.equations, "model": data.model}
- res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
+ res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload)
if res.status_code != 200:
return JSONResponse(
status_code=400,
@@ -161,15 +206,18 @@ async def equations_to_amr(data: schema.MmlToAMR):
# code snippets -> fn -> petrinet amr
@router.post("/code/snippets-to-pn-amr", summary="Code snippets → PetriNet AMR")
-async def code_snippets_to_pn_amr(system: code2fn.System):
+async def code_snippets_to_pn_amr(system: code2fn.System, client: httpx.AsyncClient = Depends(utils.get_client)):
gromet = await code2fn.fn_given_filepaths(system)
- gromet, logs = utils.fn_preprocessor(gromet)
- res = requests.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet)
+ gromet, _ = utils.fn_preprocessor(gromet)
+ # print(f"gromet:{gromet}")
+ # print(f"client.follow_redirects:\t{client.follow_redirects}")
+ # print(f"client.timeout:\t{client.timeout}")
+ res = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet)
if res.status_code != 200:
return JSONResponse(
status_code=400,
content={
- "error": f"MORAE PUT /models/PN failed to process payload",
+ "error": f"MORAE PUT /models/PN failed to process payload ({res.text})",
"payload": gromet,
},
)
@@ -199,10 +247,10 @@ async def code_snippets_to_rn_amr(system: code2fn.System):
@router.post(
"/code/codebase-to-pn-amr", summary="Code repo (zip archive) → PetriNet AMR"
)
-async def repo_to_pn_amr(zip_file: UploadFile = File()):
+async def repo_to_pn_amr(zip_file: UploadFile = File(), client: httpx.AsyncClient = Depends(utils.get_client)):
gromet = await code2fn.fn_given_filepaths_zip(zip_file)
- gromet, logs = utils.fn_preprocessor(gromet)
- res = requests.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet)
+ gromet, _ = utils.fn_preprocessor(gromet)
+ res = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet)
if res.status_code != 200:
return JSONResponse(
status_code=400,
@@ -219,7 +267,7 @@ async def repo_to_pn_amr(zip_file: UploadFile = File()):
"/code/llm-assisted-codebase-to-pn-amr",
summary="Code repo (zip archive) → PetriNet AMR",
)
-async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()):
+async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File(), client: httpx.AsyncClient = Depends(utils.get_client)):
"""Codebase->AMR workflow using an llm to extract the dynamics line span.
### Python example
```
@@ -238,16 +286,27 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()):
print(f"Time response linespan: {time.time()}")
line_begin = []
+ import_begin = []
line_end = []
+ import_end = []
files = []
blobs = []
amrs = []
+
+ # There could now be multiple blocks that we need to handle and adjoin together
for linespan in linespans:
- lines = linespan.block[0].split("-")
+ blocks = len(linespan.block)
+ lines = linespan.block[blocks-1].split("-")
line_begin.append(
max(int(lines[0][1:]) - 1, 0)
) # Normalizing the 1-index response from llm_proxy
line_end.append(int(lines[1][1:]))
+ if blocks == 2:
+ lines = linespan.block[0].split("-")
+ import_begin.append(
+ max(int(lines[0][1:]) - 1, 0)
+ ) # Normalizing the 1-index response from llm_proxy
+ import_end.append(int(lines[1][1:]))
# So we are required to do the same when slicing the source code using its output.
with ZipFile(BytesIO(zip_file.file.read()), "r") as zip:
@@ -260,28 +319,38 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()):
# The source code is a string, so to slice using the line spans, we must first convert it to a list.
# Then we can convert it back to a string using .join
logging = []
+ import_counter = 0
for i in range(len(blobs)):
if line_begin[i] == line_end[i]:
print("failed linespan")
else:
- blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ if len(linespans[i].block) == 2:
+ temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[import_counter]:import_end[import_counter]])
+ blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
+ import_counter += 1
+ else:
+ blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]])
try:
- time.sleep(0.5)
print(f"Time call code-snippets: {time.time()}")
- code_snippet_response = await code_snippets_to_pn_amr(
- code2fn.System(
- files=[files[i]],
- blobs=[blobs[i]],
- )
- )
+ gromet = await code2fn.fn_given_filepaths(code2fn.System(
+ files=[files[i]],
+ blobs=[blobs[i]],
+ ))
+ gromet, _ = utils.fn_preprocessor(gromet)
+ code_snippet_response = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet)
+ code_snippet_response = code_snippet_response.json()
print(f"Time response code-snippets: {time.time()}")
if "model" in code_snippet_response:
+ code_snippet_response["header"]["name"] = "LLM-assisted code to amr model"
+ code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}"
+ code_snippet_response["header"]["linespan"] = f"{linespans[i]}"
amrs.append(code_snippet_response)
else:
print("snippets failure")
logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
- except:
+ except Exception as e:
print("Hit except to snippets failure")
+ print(f"Exception:\t{e}")
logging.append(f"{files[i]} failed to parse an AMR from the dynamics")
# we will return the amr with most states, in assumption it is the most "correct"
# by default it returns the first entry
@@ -318,6 +387,34 @@ async def repo_to_rn_amr(zip_file: UploadFile = File()):
)
return res.json()
"""
+"""
+# code snippets -> fn -> Vec -> ????
+@router.post("/isa/code-align", summary="ISA aided inference")
+async def code_snippets_to_isa_align(system: code2fn.System, client: httpx.AsyncClient = Depends(utils.get_client)):
+ gromet = await code2fn.fn_given_filepaths(system)
+ gromet, _ = utils.fn_preprocessor(gromet)
+ # print(f"gromet:{gromet}")
+ # print(f"client.follow_redirects:\t{client.follow_redirects}")
+ # print(f"client.timeout:\t{client.timeout}")
+ res = await client.put(f"{SKEMA_RS_ADDESS}/models/MET", json=gromet)
+ # res is a vector of MET's from the code (assuming it could extract correctly)
+ if res.status_code != 200:
+ return JSONResponse(
+ status_code=400,
+ content={
+ "error": f"MORAE PUT /models/PN failed to process payload ({res.text})",
+ "payload": gromet,
+ },
+ )
+
+ # Liang, if you want to put your ISA portion here?
+ # ISA:
+ #
+ #
+ #
+ #
+ return res.json()
+"""
app = FastAPI()
app.include_router(router)
\ No newline at end of file
diff --git a/skema/skema-rs/mathml/src/acset.rs b/skema/skema-rs/mathml/src/acset.rs
index 37e95470879..82f78d62fcd 100644
--- a/skema/skema-rs/mathml/src/acset.rs
+++ b/skema/skema-rs/mathml/src/acset.rs
@@ -367,7 +367,6 @@ impl From> for PetriNet {
terms.push(term.clone());
}
}
-
for term in terms.iter() {
for param in &term.parameters {
let parameters = Parameter {
@@ -425,7 +424,6 @@ impl From> for PetriNet {
for i in paired_term_indices.iter().rev() {
terms.remove(*i);
}
-
// Now we replace unpaired terms with subterms, by their subterms and repeat the process
// but first we need to inherit the dynamic state to each sub term
diff --git a/skema/skema-rs/mathml/src/ast.rs b/skema/skema-rs/mathml/src/ast.rs
index ee058f7e231..99a8479c85f 100644
--- a/skema/skema-rs/mathml/src/ast.rs
+++ b/skema/skema-rs/mathml/src/ast.rs
@@ -2,16 +2,17 @@ use derive_new::new;
use std::fmt;
pub mod operator;
-
+use serde::{Deserialize, Serialize};
use operator::Operator;
+//use crate::ast::MathExpression::SummationOp;
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct Mi(pub String);
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct Mrow(pub Vec);
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub enum Type {
Integer,
Rational,
@@ -27,22 +28,35 @@ pub enum Type {
Matrix,
}
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct Ci {
pub r#type: Option,
pub content: Box,
pub func_of: Option>,
}
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct Differential {
pub diff: Box,
pub func: Box,
}
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
+pub struct SummationMath {
+ pub op: Box,
+ pub func: Box,
+}
+
+/// Hat operation
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
+pub struct HatComp {
+ pub op: Box,
+ pub comp: Box,
+}
+
/// The MathExpression enum is not faithful to the corresponding element type in MathML 3
/// (https://www.w3.org/TR/MathML3/appendixa.html#parsing_MathExpression)
-#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Hash, Default, new)]
+#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Hash, Default, new, Deserialize, Serialize)]
pub enum MathExpression {
Mi(Mi),
Mo(Operator),
@@ -66,8 +80,10 @@ pub enum MathExpression {
//GroupTuple(Vec),
Ci(Ci),
Differential(Differential),
+ SummationMath(SummationMath),
AbsoluteSup(Box, Box),
Absolute(Box, Box),
+ HatComp(HatComp),
//Differential(Box, Box),
#[default]
None,
@@ -110,6 +126,14 @@ impl fmt::Display for MathExpression {
write!(f, "{superscript:?}")
}
MathExpression::Mtext(text) => write!(f, "{}", text),
+ MathExpression::SummationMath(SummationMath { op, func }) => {
+ write!(f, "{op}")?;
+ write!(f, "{func}")
+ }
+ MathExpression::HatComp(HatComp { op, comp }) => {
+ write!(f, "{op}")?;
+ write!(f, "{comp}")
+ }
expression => write!(f, "{expression:?}"),
}
}
diff --git a/skema/skema-rs/mathml/src/ast/operator.rs b/skema/skema-rs/mathml/src/ast/operator.rs
index f1e8dcb8eba..04a542ce87e 100644
--- a/skema/skema-rs/mathml/src/ast/operator.rs
+++ b/skema/skema-rs/mathml/src/ast/operator.rs
@@ -1,23 +1,46 @@
use crate::ast::Ci;
+use crate::ast::MathExpression;
use derive_new::new;
use std::fmt;
+use serde::{Deserialize, Serialize};
/// Derivative operator, in line with Spivak notation: http://ceres-solver.org/spivak_notation.html
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct Derivative {
pub order: u8,
pub var_index: u8,
pub bound_var: Ci,
}
+
/// Partial derivative operator
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub struct PartialDerivative {
pub order: u8,
pub var_index: u8,
pub bound_var: Ci,
}
-#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)]
+/// Summation operator with under and over components
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
+pub struct SumUnderOver {
+ pub op: Box,
+ pub under: Box,
+ pub over: Box,
+}
+
+/// Hat operation
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
+pub struct HatOp {
+ pub comp: Box,
+}
+
+/// Handles grad operations with subscript. E.g. ∇_{x}
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
+pub struct GradSub {
+ pub sub: Box,
+}
+
+#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)]
pub enum Operator {
Add,
Multiply,
@@ -33,6 +56,7 @@ pub enum Operator {
Power,
Comma,
Grad,
+ GradSub(GradSub),
Dot,
Period,
Div,
@@ -52,6 +76,11 @@ pub enum Operator {
Arccsc,
Arccot,
Mean,
+ Sum,
+ SumUnderOver(SumUnderOver),
+ Cross,
+ Hat,
+ HatOp(HatOp),
// Catchall for operators we haven't explicitly defined as enum variants yet.
Other(String),
}
@@ -68,7 +97,7 @@ impl fmt::Display for Operator {
Operator::Lparen => write!(f, "("),
Operator::Rparen => write!(f, ")"),
Operator::Compose => write!(f, "."),
- Operator::Comma => write!(f, ""),
+ Operator::Comma => write!(f, ","),
Operator::Factorial => write!(f, "!"),
Operator::Derivative(Derivative {
order,
@@ -101,10 +130,20 @@ impl fmt::Display for Operator {
Operator::Arccot => write!(f, "Arccot"),
Operator::Mean => write!(f, "Mean"),
Operator::Grad => write!(f, "Grad"),
- Operator::Dot => write!(f, "Dot"),
+ Operator::GradSub(GradSub {sub}) =>{
+ write!(f, "Grad_{sub})")
+ }
+ Operator::Dot => write!(f, "⋅"),
Operator::Period => write!(f, ""),
Operator::Div => write!(f, "Div"),
Operator::Abs => write!(f, "Abs"),
+ Operator::Sum => write!(f, "∑"),
+ Operator::SumUnderOver(SumUnderOver { op, under, over }) => {
+ write!(f, "{op}_{{{under}}}^{{{over}}}")
+ }
+ Operator::Cross => write!(f, "×"),
+ Operator::Hat => write!(f, "Hat"),
+ Operator::HatOp(HatOp { comp }) => write!(f, "Hat({comp})"),
}
}
}
diff --git a/skema/skema-rs/mathml/src/expression.rs b/skema/skema-rs/mathml/src/expression.rs
index 410aa0308db..7a438e0f279 100644
--- a/skema/skema-rs/mathml/src/expression.rs
+++ b/skema/skema-rs/mathml/src/expression.rs
@@ -1,18 +1,12 @@
-use crate::{
- ast::{
- operator::Operator,
- Math, MathExpression,
- MathExpression::{Mfrac, Mn, Mo, Mover, Msqrt, Msubsup, Msup},
- Mi, Mrow,
- },
- petri_net::recognizers::recognize_leibniz_differential_operator,
-};
+use crate::ast::{operator::Operator, MathExpression, Mi};
+use crate::parsers::math_expression_tree::MathExpressionTree;
use petgraph::{graph::NodeIndex, Graph};
use std::{clone::Clone, collections::VecDeque};
/// Struct for representing mathematical expressions in order to align with source code.
pub type MathExpressionGraph<'a> = Graph;
+use petgraph::dot::Dot;
use std::string::ToString;
#[derive(Debug, PartialEq, Eq, Clone)]
@@ -23,13 +17,6 @@ pub enum Atom {
}
/// Intermediate data structure to support the generation of graphs of mathematical expressions
-#[derive(Debug, Default, PartialEq, Clone)]
-pub struct Expression {
- pub ops: Vec,
- pub args: Vec,
- pub name: String,
-}
-
#[derive(Debug, PartialEq, Clone)]
pub enum Expr {
Atom(Atom),
@@ -40,177 +27,325 @@ pub enum Expr {
},
}
-/// Check if the fraction is a derivative expressed in Leibniz notation. If yes, mutate it to
-/// remove the 'd' prefixes.
-pub fn is_derivative(
- numerator: &mut Box,
- denominator: &mut Box,
-) -> bool {
- if recognize_leibniz_differential_operator(numerator, denominator).is_ok() {
- if let MathExpression::Mrow(Mrow(x)) = &mut **numerator {
- x.remove(0);
- }
-
- if let MathExpression::Mrow(Mrow(x)) = &mut **denominator {
- x.remove(0);
- }
- return true;
- }
- false
-}
-
-/// Identify if there is an implicit multiplication operator, and if so, add an
-/// explicit multiplication operator.
-fn insert_explicit_multiplication_operator(pre: &mut Expression) {
- if pre.args.len() >= pre.ops.len() {
- pre.ops.push(Operator::Multiply);
+fn is_unary_operator(op: &Operator) -> bool {
+ match op {
+ Operator::Sqrt
+ | Operator::Factorial
+ | Operator::Exp
+ | Operator::Grad
+ | Operator::Div
+ | Operator::Abs
+ | Operator::Derivative(_)
+ | Operator::Sin
+ | Operator::Cos
+ | Operator::Tan
+ | Operator::Sec
+ | Operator::Csc
+ | Operator::Cot
+ | Operator::Arcsin
+ | Operator::Arccos
+ | Operator::Arctan
+ | Operator::Arcsec
+ | Operator::Arccsc
+ | Operator::Arccot
+ | Operator::Mean => true,
+ _ => false,
}
}
-impl MathExpression {
- /// Convert a MathExpression struct to a Expression struct.
- pub fn to_expr(self, pre: &mut Expression) {
- match self {
- MathExpression::Mi(Mi(x)) => {
- // Process unary minus operation.
- if !pre.args.is_empty() {
- // Check the last arg
- let args_last_idx = pre.args.len() - 1;
- if let Expr::Atom(Atom::Operator(Operator::Subtract)) = &pre.args[args_last_idx]
- {
- let neg_identifier = format!("-{x}");
- pre.args[args_last_idx] = Expr::Atom(Atom::Identifier(neg_identifier));
- return;
- }
- }
- // deal with the invisible multiply operator
- if pre.args.len() >= pre.ops.len() {
- pre.ops.push(Operator::Multiply);
- }
- pre.args
- .push(Expr::Atom(Atom::Identifier(x.replace(' ', ""))));
+/// Processes a MathExpression under the type of MathExpressionTree::Atom and appends
+/// the corresponding LaTeX representation to the provided String.
+fn process_atom_expression(expr: &MathExpression, expression: &mut Expr) {
+ match expr {
+ // If it's a Ci variant, recursively process its content
+ MathExpression::Ci(x) => {
+ process_atom_expression(&x.content, expression);
+ }
+ MathExpression::Mi(Mi(id)) => {
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(Expr::Atom(Atom::Identifier(id.replace(' ', ""))));
}
- Mn(x) => {
- insert_explicit_multiplication_operator(pre);
- // Remove redundant whitespace
- pre.args.push(Expr::Atom(Atom::Number(x.replace(' ', ""))));
+ }
+ MathExpression::Mn(number) => {
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(Expr::Atom(Atom::Number(number.replace(' ', ""))));
}
- Mo(x) => {
- // Insert a temporary placeholder identifier to deal with unary minus operation.
- // The placeholder will be removed later.
- if x == Operator::Subtract && pre.ops.len() > pre.args.len() {
- pre.ops.push(x);
- pre.args
- .push(Expr::Atom(Atom::Identifier("place_holder".to_string())));
- } else {
- pre.ops.push(x);
- }
+ }
+ MathExpression::Msqrt(x) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Sqrt);
+ process_atom_expression(x, &mut new_expr);
}
- MathExpression::Mrow(Mrow(xs)) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- pre_exp.ops.push(Operator::Other("".to_string()));
- for x in xs {
- x.to_expr(&mut pre_exp);
- }
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
}
- Msubsup(xs1, xs2, xs3) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- pre_exp.ops.push(Operator::Other("".to_string()));
- pre_exp.ops.push(Operator::Other("_".to_string()));
- xs1.to_expr(&mut pre_exp);
- pre_exp.ops.push(Operator::Other("^".to_string()));
- xs2.to_expr(&mut pre_exp);
- xs3.to_expr(&mut pre_exp);
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ }
+ MathExpression::Mfrac(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
}
- Msqrt(xs) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- pre_exp.ops.push(Operator::Sqrt);
- xs.to_expr(&mut pre_exp);
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Divide);
+ process_atom_expression(x2, &mut new_expr);
}
- Mfrac(mut xs1, mut xs2) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- if is_derivative(&mut xs1, &mut xs2) {
- pre_exp.ops.push(Operator::Other("derivative".to_string()));
- } else {
- pre_exp.ops.push(Operator::Other("".to_string()));
- }
- xs1.to_expr(&mut pre_exp);
- pre_exp.ops.push(Operator::Divide);
- xs2.to_expr(&mut pre_exp);
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
}
- Msup(xs1, xs2) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- pre_exp.ops.push(Operator::Other("".to_string()));
- xs1.to_expr(&mut pre_exp);
- pre_exp.ops.push(Operator::Other("^".to_string()));
- xs2.to_expr(&mut pre_exp);
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ }
+ MathExpression::Msup(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
}
- Mover(xs1, xs2) => {
- insert_explicit_multiplication_operator(pre);
- let mut pre_exp = Expression::default();
- pre_exp.ops.push(Operator::Other("".to_string()));
- xs1.to_expr(&mut pre_exp);
- xs2.to_expr(&mut pre_exp);
- pre_exp.ops.remove(0);
- pre.args.push(Expr::Expression {
- ops: pre_exp.ops,
- args: pre_exp.args,
- name: "".to_string(),
- });
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("^".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Msub(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("_".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Msubsup(x1, x2, x3) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
}
- _ => {
- panic!("Unhandled type!");
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("_".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("^".to_string()));
+ process_atom_expression(x3, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Munder(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("under".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Mover(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ process_atom_expression(x1, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("over".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Mtext(x) => {
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(Expr::Atom(Atom::Identifier(x.replace(' ', ""))));
}
}
+ MathExpression::Mspace(x) => {
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(Expr::Atom(Atom::Identifier(x.to_string())));
+ }
+ }
+ MathExpression::AbsoluteSup(x1, x2) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("|.|".to_string()));
+ process_atom_expression(x1, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("_".to_string()));
+ process_atom_expression(x2, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ MathExpression::Mrow(vec_me) => {
+ for me in vec_me.0.iter() {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ process_atom_expression(me, &mut new_expr);
+ }
+ if let Expr::Expression { ops, args, name } = expression {
+ args.push(new_expr.clone());
+ }
+ }
+ }
+ t => panic!("Unhandled MathExpression: {:?}", t),
}
+}
+impl MathExpressionTree {
+ /// Convert a MathExpressionTree struct to a Expression struct.
+ pub fn to_expr(self, expr: &mut Expr) -> &mut Expr {
+ match self {
+ MathExpressionTree::Atom(a) => {
+ process_atom_expression(&a, expr);
+ }
+ MathExpressionTree::Cons(head, rest) => {
+ let mut new_expr = Expr::Expression {
+ ops: Vec::::new(),
+ args: Vec::::new(),
+ name: String::new(),
+ };
+ if is_unary_operator(&head) || (head == Operator::Subtract && rest.len() == 1) {
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(head);
+ rest[0].clone().to_expr(&mut new_expr);
+ }
+ } else {
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ ops.push(Operator::Other("".to_string()));
+ for (index, r) in rest.iter().enumerate() {
+ if index < rest.len() - 1 {
+ ops.push(head.clone());
+ }
+ }
+ }
+ if let Expr::Expression { ops, args, name } = &mut new_expr {
+ for r in &rest {
+ r.clone().to_expr(&mut new_expr);
+ }
+ }
+ }
+ if let Expr::Expression { ops, args, name } = expr {
+ args.push(new_expr.clone());
+ }
+ }
+ }
+ expr
+ }
pub fn to_graph(self) -> MathExpressionGraph<'static> {
- let mut pre_exp = Expression {
- ops: Vec::::new(),
+ let mut expr = self.clone();
+ let mut pre_exp = Expr::Expression {
+ ops: vec![Operator::Other("root".to_string())],
args: Vec::::new(),
name: "root".to_string(),
};
- pre_exp.ops.push(Operator::Other("root".to_string()));
- self.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.collapse_expr();
- // if need to convert to canonical form, please uncomment the following
- // pre_exp.distribute_expr();
- // pre_exp.group_expr();
- // pre_exp.collapse_expr();
- pre_exp.set_name();
- pre_exp.to_graph()
+ expr.to_expr(&mut pre_exp);
+
+ if let Expr::Expression { ops, args, name } = &mut pre_exp {
+ for mut arg in args {
+ if let Expr::Expression { .. } = arg {
+ arg.group_expr();
+ }
+ }
+ }
+ if let Expr::Expression { ops, args, name } = &mut pre_exp {
+ for mut arg in args {
+ if let Expr::Expression { .. } = arg {
+ arg.collapse_expr();
+ }
+ }
+ }
+ /// if need to convert to canonical form, please uncomment the following
+ // if let Expr::Expression {ops, args, name} = &mut pre_exp {
+ // for mut arg in args {
+ // if let Expr::Expression { .. } = arg {
+ // arg.distribute_expr();
+ // }
+ // }
+ // }
+ // if let Expr::Expression {ops, args, name} = &mut pre_exp {
+ // for mut arg in args {
+ // if let Expr::Expression { .. } = arg {
+ // arg.group_expr();
+ // }
+ // }
+ // }
+ // if let Expr::Expression {ops, args, name} = &mut pre_exp {
+ // for mut arg in args {
+ // if let Expr::Expression { .. } = arg {
+ // arg.collapse_expr();
+ // }
+ // }
+ // }
+ if let Expr::Expression { ops, args, name } = &mut pre_exp {
+ for mut arg in args {
+ if let Expr::Expression { .. } = arg {
+ arg.set_name();
+ }
+ }
+ }
+ let mut g = MathExpressionGraph::new();
+ if let Expr::Expression { ops, args, name } = &mut pre_exp {
+ for mut arg in args {
+ if let Expr::Expression { .. } = arg {
+ arg.to_graph(&mut g);
+ }
+ }
+ }
+ g
}
}
@@ -519,7 +654,7 @@ impl Expr {
Atom::Operator(_) => {}
},
Expr::Expression { ops, .. } => {
- let mut string;
+ let mut string = "".to_string();
if ops[0] != Operator::Other("".to_string()) {
string = ops[0].to_string();
string.push('(');
@@ -1008,51 +1143,6 @@ pub fn need_to_distribute(ops: Vec) -> bool {
false
}
-impl Expression {
- pub fn group_expr(&mut self) {
- for arg in &mut self.args {
- if let Expr::Expression { .. } = arg {
- arg.group_expr();
- }
- }
- }
-
- pub fn collapse_expr(&mut self) {
- for arg in &mut self.args {
- if let Expr::Expression { .. } = arg {
- arg.collapse_expr();
- }
- }
- }
-
- #[allow(dead_code)] // used in tests I believe
- fn distribute_expr(&mut self) {
- for arg in &mut self.args {
- if let Expr::Expression { .. } = arg {
- arg.distribute_expr();
- }
- }
- }
-
- pub fn set_name(&mut self) {
- for arg in &mut self.args {
- if let Expr::Expression { .. } = arg {
- arg.set_name();
- }
- }
- }
-
- pub fn to_graph(&mut self) -> MathExpressionGraph {
- let mut g = MathExpressionGraph::new();
- for arg in &mut self.args {
- if let Expr::Expression { .. } = arg {
- arg.to_graph(&mut g);
- }
- }
- g
- }
-}
-
/// Remove redundant parentheses.
pub fn remove_redundant_parens(string: &mut String) -> &mut String {
while contains_redundant_parens(string) {
@@ -1084,1638 +1174,134 @@ pub fn get_node_idx(graph: &mut MathExpressionGraph, name: &mut String) -> NodeI
graph.add_node(name.to_string())
}
-/// Remove redundant mrow next to specific MathML elements. This function will likely be removed
-/// once the img2mml pipeline is fixed.
-pub fn remove_redundant_mrow(mml: String, key_word: String) -> String {
- let mut content = mml;
- let key_words_left = "".to_string() + &*key_word.clone();
- let mut key_word_right = key_word.clone();
- key_word_right.insert(1, '/');
- let key_words_right = key_word_right.clone() + "";
- let locs: Vec<_> = content
- .match_indices(&key_words_left)
- .map(|(i, _)| i)
- .collect();
- for loc in locs.iter().rev() {
- if content[loc + 1..].contains(&key_words_right) {
- let l = content[*loc..].find(&key_word_right).map(|i| i + *loc);
- if let Some(x) = l {
- if content.len() > (x + key_words_right.len())
- && content[x..x + key_words_right.len()] == key_words_right
- {
- content.replace_range(x..x + key_words_right.len(), key_word_right.as_str());
- content.replace_range(*loc..*loc + key_words_left.len(), key_word.as_str());
- }
- }
- }
- }
- content
-}
-
-/// Remove redundant mrows in mathml because some mathml elements don't need mrow to wrap. This
-/// function will likely be removed
-/// once the img2mml pipeline is fixed.
-pub fn remove_redundant_mrows(mathml_content: String) -> String {
- let mut content = mathml_content;
- content = content.replace("", "(");
- content = content.replace("", ")");
- let f = |b: &[u8]| -> Vec {
- let v = (0..)
- .zip(b)
- .scan(vec![], |a, (b, c)| {
- Some(match c {
- 40 => {
- a.push(b);
- None
- }
- 41 => Some((a.pop()?, b)),
- _ => None,
- })
- })
- .flatten()
- .collect::>();
- for k in &v {
- if k.0 == 0 && k.1 == b.len() - 1 {
- return b[1..b.len() - 1].to_vec();
- }
- for l in &v {
- if l.0 == k.0 + 1 && l.1 == k.1 - 1 {
- return [&b[..k.0], &b[l.0..k.1], &b[k.1 + 1..]].concat();
- }
- }
- }
- b.to_vec()
- };
- let g = |mut b: Vec| {
- while f(&b) != b {
- b = f(&b)
- }
- b
- };
- content = std::str::from_utf8(&g(content.bytes().collect()))
- .unwrap()
- .to_string();
- content = content.replace('(', "");
- content = content.replace(')', "");
- content = remove_redundant_mrow(content, "".to_string());
- content = remove_redundant_mrow(content, "".to_string());
- content = remove_redundant_mrow(content, "".to_string());
- content = remove_redundant_mrow(content, "".to_string());
- content
-}
-
-/// Preprocess the content prior to parsing.
-pub fn preprocess_content(content_str: String) -> String {
- let mut pre_string = content_str;
- pre_string = pre_string.replace(' ', "");
- pre_string = pre_string.replace('\n', "");
- pre_string = pre_string.replace('\t', "");
- pre_string = pre_string.replace("(t)", "");
- pre_string = pre_string.replace(",", "");
- pre_string = pre_string.replace("(", "");
- pre_string = pre_string.replace(")", "");
-
- // Unicode to Symbol
- let unicode_locs: Vec<_> = pre_string.match_indices("").map(|(i, _)| i).collect();
- for ul in unicode_locs.iter().rev() {
- let loc = pre_string[*ul..].find('<').map(|i| i + ul);
- match loc {
- None => {}
- Some(_x) => {}
- }
- }
- pre_string = html_escape::decode_html_entities(&pre_string).to_string();
- pre_string = pre_string.replace(
- &html_escape::decode_html_entities("−").to_string(),
- "-",
- );
- pre_string = remove_redundant_mrows(pre_string);
- pre_string
-}
-
-/// Wrap mathml vectors by mrow as a single expression to process
-pub fn wrap_math(math: Math) -> MathExpression {
- let mut math_vec = vec![];
- for con in math.content {
- math_vec.push(con);
- }
-
- MathExpression::Mrow(Mrow(math_vec))
-}
-
-#[test]
-fn test_to_expr() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
-
- if let Expr::Expression { ops, args, .. } = &pre_exp.args[0] {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string())));
- }
-}
-
-#[test]
-fn test_to_expr2() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mrow(Mrow(vec![
- Mn("4".to_string()),
- MathExpression::Mi(Mi("c".to_string())),
- MathExpression::Mi(Mi("d".to_string())),
- ])),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
-
- math_expression.to_expr(&mut pre_exp);
- pre_exp.ops.push(Operator::Other("root".to_string()));
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(ops[2], Operator::Subtract);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string())));
- match &args[2] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(ops[2], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Number("4".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string())));
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr3() {
- let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- ]))));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Sqrt);
- match &args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string())));
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr4() {
- let math_expression = Mfrac(
- Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- ]))),
- Box::from(MathExpression::Mi(Mi("c".to_string()))),
- );
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Divide);
- match &args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string())));
- }
- }
- match &args[1] {
- Expr::Atom(_x) => {
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- }
- Expr::Expression { .. } => {}
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr5() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr6() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("d".to_string())),
- Mo(Operator::Divide),
- MathExpression::Mi(Mi("e".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("f".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("g".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("h".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(ops[2], Operator::Subtract);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[3], Expr::Atom(Atom::Identifier("h".to_string())));
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(ops[2], Operator::Multiply);
- assert_eq!(ops[3], Operator::Divide);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string())));
- assert_eq!(args[3], Expr::Atom(Atom::Identifier("e".to_string())));
- }
- }
- match &args[2] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, .. } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("f".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("g".to_string())));
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr7() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "root".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(name, "(a+b*c)");
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- assert_eq!(name, "b*c");
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr8() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("d".to_string())),
- Mo(Operator::Divide),
- MathExpression::Mi(Mi("e".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("f".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("g".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("h".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(ops[2], Operator::Subtract);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(args[3], Expr::Atom(Atom::Identifier("h".to_string())));
- assert_eq!(name, "(a+b*c*d/e-f*g-h)");
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(ops[2], Operator::Multiply);
- assert_eq!(ops[3], Operator::Divide);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string())));
- assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string())));
- assert_eq!(args[3], Expr::Atom(Atom::Identifier("e".to_string())));
- assert_eq!(name, "b*c*d/e");
- }
- }
- match &args[2] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("f".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("g".to_string())));
- assert_eq!(name, "f*g");
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr9() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("d".to_string())),
- ])),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "root".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
-
- match &pre_exp.args[0] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Add);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string())));
- assert_eq!(name, "(a+b*(c-d))");
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Multiply);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string())));
- assert_eq!(name, "b*(c-d)");
- match &args[1] {
- Expr::Atom(_) => {}
- Expr::Expression { ops, args, name } => {
- assert_eq!(ops[0], Operator::Other("".to_string()));
- assert_eq!(ops[1], Operator::Subtract);
- assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string())));
- assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string())));
- assert_eq!(name, "(c-d)");
- }
- }
- }
- }
- }
- }
-}
-
-#[test]
-fn test_to_expr10() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("a".to_string())),
- ])),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "root".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr11() {
- let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ])),
- ]))));
-
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "root".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr12() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("d".to_string())),
- Mo(Operator::Divide),
- MathExpression::Mi(Mi("e".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("f".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("g".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("h".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr13() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Divide),
- MathExpression::Mi(Mi("d".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr14() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ])),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr15() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("c".to_string())),
- Mo(Operator::Subtract),
- Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Add),
- MathExpression::Mi(Mi("b".to_string())),
- ])))),
- ]));
- let mut pre_exp = Expression {
- ops: Vec::::new(),
- args: Vec::::new(),
- name: "".to_string(),
- };
- pre_exp.ops.push(Operator::Other("root".to_string()));
- math_expression.to_expr(&mut pre_exp);
- pre_exp.group_expr();
- pre_exp.set_name();
- let _g = pre_exp.to_graph();
-}
-
-#[test]
-fn test_to_expr16() {
- let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ])),
- ]))));
- let _g = math_expression.to_graph();
-}
-
-#[test]
-fn test_to_expr17() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("s".to_string())),
- Mo(Operator::Equals),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ])),
- ]));
- let _g = math_expression.to_graph();
-}
-
-#[test]
-fn test_to_expr18() {
- let math_expression = MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("s".to_string())),
- Mo(Operator::Equals),
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Subtract),
- Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- Mo(Operator::Multiply),
- MathExpression::Mrow(Mrow(vec![
- MathExpression::Mi(Mi("a".to_string())),
- Mo(Operator::Subtract),
- MathExpression::Mi(Mi("b".to_string())),
- ])),
- ])))),
- ]));
- let _g = math_expression.to_graph();
-}
-
-#[test]
-fn test_to_expr19() {
- let input = "tests/sir.xml";
- let contents = std::fs::read_to_string(input)
- .unwrap_or_else(|_| panic!("{}", "Unable to read file {input}!"));
- let mut math = contents
- .parse::
";
let exp = input.parse::().unwrap();
+ let cmml = exp.to_cmml();
+ assert_eq!(cmml, "H");
let s_exp = exp.to_string();
assert_eq!(s_exp, "(Grad H)");
}
@@ -1717,6 +1759,8 @@ fn test_divergence() {
";
let exp = input.parse::().unwrap();
+ let cmml = exp.to_cmml();
+ assert_eq!(cmml, "H");
let s_exp = exp.to_string();
assert_eq!(s_exp, "(Div H)");
}
@@ -2125,12 +2169,12 @@ fn test_equation_with_mtext() {
#[test]
fn new_msqrt_test_function() {
let input = "
-
- 4
- a
- c
-
-";
+
+ 4
+ a
+ c
+
+ ";
let exp = input.parse::().unwrap();
let s_exp = exp.to_string();
assert_eq!(s_exp, "(√ (* (* 4 a) c))");
@@ -2172,3 +2216,478 @@ fn new_quadratic_equation() {
);
assert_eq!(exp.to_latex(), "x=\\frac{(-b)-\\sqrt{b^{2}-(4*a*c)}}{2*a}");
}
+
+#[test]
+fn test_dot_in_derivative() {
+ let input = "
+
+
+ S
+ ˙
+
+
+ (
+ t
+ )
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(D(1, t) S)");
+}
+
+#[test]
+fn test_sidarthe_equation() {
+ let input = "
+
+
+ S
+ ˙
+
+
+ (
+ t
+ )
+ =
+ −
+ S
+ (
+ t
+ )
+ (
+ α
+ I
+ (
+ t
+ )
+ +
+ β
+ D
+ (
+ t
+ )
+ +
+ γ
+ A
+ (
+ t
+ )
+ +
+ δ
+ R
+ (
+ t
+ )
+ )
+";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(
+ s_exp,
+ "(= (D(1, t) S) (* (- S) (+ (+ (+ (* α I) (* β D)) (* γ A)) (* δ R))))"
+ );
+}
+
+#[test]
+fn test_heating_rate() {
+ let input = "
+
+ Q
+ i
+
+ =
+
+ (
+
+ T
+ i
+
+ −
+
+ T
+
+ i
+ −
+ 1
+
+
+ )
+
+ ∕
+
+ (
+
+ C
+ p
+
+ Δ
+ t
+ )
+
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(= Q_{i} (/ (- T_{i} T_{i-1}) (* C_{p} Δt)))");
+}
+
+#[test]
+fn test_sum_munderover() {
+ let input = "
+
+ ∑
+
+ l
+ =
+ k
+
+ K
+
+ S
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(∑_{l=k}^{K} S)");
+}
+
+#[test]
+fn test_hydrostatic() {
+ let input = "
+
+ Φ
+ k
+
+ =
+
+ Φ
+ s
+
+ +
+ R
+
+ ∑
+
+ l
+ =
+ k
+
+ K
+
+
+ H
+
+ k
+ l
+
+
+
+ T
+
+ v
+ l
+
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ //println!("s_exp={:?}", s_exp);
+ assert_eq!(
+ s_exp,
+ "(= Φ_{k} (+ Φ_{s} (* R (∑_{l=k}^{K} (* H_{kl} T_{vl})))))"
+ );
+}
+
+#[test]
+fn test_temperature_evolution() {
+ let input = "
+
+
+ Δ
+
+ s
+ i
+
+
+
+ Δ
+ t
+
+
+ ∕
+
+ C
+ p
+
+ =
+
+
+ (
+
+ s
+ i
+
+ −
+
+ s
+
+ i
+ −
+ 1
+
+
+ )
+
+
+ Δ
+ t
+
+
+ ∕
+
+ C
+ p
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(
+ s_exp,
+ "(= (/ (/ Δs_{i} Δt) C_{p}) (/ (/ (- s_{i} s_{i-1}) Δt) C_{p}))"
+ );
+}
+
+#[test]
+fn test_cross_product() {
+ let input = "
+ f
+ ×
+ u
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(× f u)");
+}
+#[test]
+fn test_dot_product() {
+ let input = "
+ f
+ ⋅
+ u
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(⋅ f u)");
+}
+
+#[test]
+fn test_partial_with_msub_t() {
+ let input = "
+
+ ∂
+ t
+
+ S
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(PD(1, t) S)");
+}
+
+#[test]
+fn test_dry_static_energy() {
+ let input = "
+
+ s
+ i
+
+ =
+
+ s
+
+ i
+ −
+ 1
+
+
+ +
+ (
+ Δ
+ t
+ )
+
+ Q
+ i
+
+
+ (
+
+ s
+
+ i
+ −
+ 1
+
+
+ ,
+
+ T
+
+ i
+ −
+ 1
+
+
+ ,
+
+ Φ
+
+ i
+ −
+ 1
+
+
+ ,
+
+ q
+
+ i
+ −
+ 1
+
+
+ ,
+ …
+ )
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(= s_{i} (+ s_{i-1} (* Δt Q_{i})))");
+}
+
+#[test]
+fn test_hat_operator() {
+ let input = "
+ ζ
+
+
+ z
+ ^
+
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(Hat(z) ζ)");
+}
+
+#[test]
+fn test_vector_invariant_form() {
+ let input = "
+
+ ∂
+ t
+
+ u
+ +
+ (
+ ζ
+
+
+ z
+ ^
+
+
+ +
+ f
+ )
+ ×
+ u
+ =
+ −
+ ∇
+
+ [
+ g
+ (
+ h
+ +
+ b
+ )
+ +
+
+ 1
+ 2
+
+ u
+ ⋅
+ u
+ ]
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(= (+ (PD(1, t) u) (× (+ (Hat(z) ζ) f) u)) (- (Grad (+ (* g (+ h b)) (* (/ 1 2) (⋅ u u))))))");
+}
+
+#[test]
+fn test_mi_dot_gradient() {
+ let input = "
+ (
+ v
+ ⋅
+ ∇
+ )
+ u
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(* (⋅ v Grad) u)");
+}
+
+#[test]
+fn test_momentum_conservation() {
+ let input = "
+
+ ∂
+ t
+
+ u
+ =
+ −
+ (
+ v
+ ⋅
+ ∇
+ )
+ u
+ −
+ f
+ ×
+ u
+ −
+
+ ∇
+ h
+
+ (
+ p
+ +
+ g
+ η
+ )
+ −
+ ∇
+ ⋅
+ τ
+ +
+
+ F
+
+ u
+
+
+ ";
+ let exp = input.parse::().unwrap();
+ let s_exp = exp.to_string();
+ assert_eq!(s_exp, "(= (PD(1, t) u) (+ (- (- (- (* (- (⋅ v Grad)) u) (× f u)) (Grad_h) (+ p (* g η)))) (Div τ)) F_{u}))");
+}
diff --git a/skema/skema-rs/skema/Cargo.toml b/skema/skema-rs/skema/Cargo.toml
index 4c0feeb7a8b..fa404e3bc79 100644
--- a/skema/skema-rs/skema/Cargo.toml
+++ b/skema/skema-rs/skema/Cargo.toml
@@ -12,7 +12,7 @@ path = "src/lib.rs"
serde_json = { version = "1.0.85", features = ["preserve_order"] }
serde = { version = "1.0.1", features = ["derive"] }
strum_macros = "0.24"
-neo4rs = { version = "0.6.2" }
+neo4rs = { version = "0.7.0-rc.1" }
actix-web = "4.2.1"
mathml = { path = "../mathml" }
utoipa = { version = "3.0.3", features = ["actix_extras", "yaml", "debug"] }
@@ -21,4 +21,4 @@ clap = { version = "4.0.26", features = ["derive"] }
utoipa-swagger-ui = { version = "3.0.2", features = ["actix-web"] }
schemars = { version = "0.8.12" }
pretty_env_logger = "0.5.0"
-tokio = {version = "1.34.0", features = ["full", "rt"]}
\ No newline at end of file
+tokio = { version = "1.34.0", features = ["full", "rt"] }
diff --git a/skema/skema-rs/skema/src/bin/morae.rs b/skema/skema-rs/skema/src/bin/morae.rs
index d0f71ad18d9..88c3f72e9d6 100644
--- a/skema/skema-rs/skema/src/bin/morae.rs
+++ b/skema/skema-rs/skema/src/bin/morae.rs
@@ -86,8 +86,9 @@ async fn main() {
}
println!("{:?}", ids.clone());
let math_content = module_id2mathml_MET_ast(ids[ids.len() - 1], config.clone()).await;
- println!("{:?}", math_content.clone());
- println!("\nAMR from code: {:?}", PetriNet::from(math_content));
+ let pn_amr = PetriNet::from(math_content);
+ //println!("{:?}", math_content.clone());
+ //println!("\nAMR from code: {:?}", PetriNet::from(math_content));
//let input_src = "../../data/mml2pn_inputs/testing_eqns/sidarthe_mml.txt";
diff --git a/skema/skema-rs/skema/src/bin/skema_service.rs b/skema/skema-rs/skema/src/bin/skema_service.rs
index f3a7e03e68b..6a218d0575f 100644
--- a/skema/skema-rs/skema/src/bin/skema_service.rs
+++ b/skema/skema-rs/skema/src/bin/skema_service.rs
@@ -58,6 +58,7 @@ async fn main() -> std::io::Result<()> {
gromet::get_model_RN,
gromet::model2PN,
gromet::model2RN,
+ gromet::model2MET,
ping,
version
),
@@ -141,6 +142,7 @@ async fn main() -> std::io::Result<()> {
.service(gromet::get_model_RN)
.service(gromet::model2PN)
.service(gromet::model2RN)
+ .service(gromet::model2MET)
.service(ping)
.service(version)
.service(SwaggerUi::new("/docs/{_:.*}").url("/api-doc/openapi.json", openapi.clone()))
diff --git a/skema/skema-rs/skema/src/config.rs b/skema/skema-rs/skema/src/config.rs
index 43f57a1ef94..95f3f04c309 100644
--- a/skema/skema-rs/skema/src/config.rs
+++ b/skema/skema-rs/skema/src/config.rs
@@ -40,7 +40,7 @@ impl Config {
}
pub async fn graphdb_connection(&self) -> Graph {
let uri = self.create_graphdb_uri();
- println!("skema-rs:memgraph uri:\t{addr}", addr = uri);
+ //println!("skema-rs:memgraph uri:\t{addr}", addr = uri);
let graph_config = ConfigBuilder::new()
.uri(uri)
.user("".to_string())
diff --git a/skema/skema-rs/skema/src/database.rs b/skema/skema-rs/skema/src/database.rs
index b3ee80bf521..075b0b72404 100644
--- a/skema/skema-rs/skema/src/database.rs
+++ b/skema/skema-rs/skema/src/database.rs
@@ -68,12 +68,13 @@ pub struct Node {
pub box_counter: usize, // this indexes the box call for the node one scope up, matches nbox if higher scope is top level
}
-#[derive(Debug, Clone, PartialEq, Ord, Eq, PartialOrd)]
+#[derive(Debug, Clone, PartialEq, Ord, Eq, PartialOrd, Default)]
pub struct Edge {
pub src: String,
pub tgt: String,
pub e_type: String,
pub prop: Option, // option because of opo's and opi's
+ pub refer: Option,
}
#[derive(Debug, Clone)]
@@ -378,7 +379,7 @@ fn create_module(gromet: &ModuleCollection) -> Vec {
src: String::from("mod"),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
let edge_query = format!(
"{} ({})-[e{}{}:{}]->({})",
@@ -391,12 +392,25 @@ fn create_module(gromet: &ModuleCollection) -> Vec {
fn create_graph_queries(gromet: &ModuleCollection, start: u32) -> Vec {
let mut queries: Vec = vec![];
+ let mut only_imports = true;
// if a library module need to walk through gromet differently
if gromet.modules[0].r#fn.bf.is_none() {
queries.append(&mut create_function_net_lib(gromet, start));
} else {
// if executable code
- queries.append(&mut create_function_net(gromet, start));
+ for bf in gromet.modules[0].r#fn.bf.as_ref().unwrap().iter() {
+ if bf.function_type != FunctionType::Imported
+ && bf.function_type != FunctionType::ImportedMethod
+ {
+ only_imports = false;
+ }
+ }
+ println!("{}", only_imports);
+ if only_imports {
+ queries.append(&mut create_function_net_lib(gromet, start));
+ } else {
+ queries.append(&mut create_function_net(gromet, start));
+ }
}
queries
}
@@ -453,6 +467,7 @@ fn create_function_net_lib(gromet: &ModuleCollection, mut start: u32) -> Vec Vec Vec Vec Vec Vec Vec Vec
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(boxf.contents.unwrap() as usize),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -991,7 +1017,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -1012,7 +1038,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -1052,7 +1078,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
src: n1.node_id.clone(),
tgt: node.node_id.clone(),
e_type: String::from("Contains"),
- prop: None,
+ ..Default::default()
};
edges.push(e5);
}
@@ -1110,7 +1136,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
src: wfopi_src_tgt[0].clone(),
tgt: wfopi_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e6);
}
@@ -1163,7 +1189,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
src: wfopo_src_tgt[0].clone(),
tgt: wfopo_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e7);
}
@@ -1182,7 +1208,15 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
}
}
FunctionType::Imported => {
- create_import(
+ create_att_primitive(
+ gromet, // gromet for metadata
+ &mut nodes, // nodes
+ &mut edges,
+ &mut meta_nodes,
+ &mut start,
+ c_args.clone(),
+ );
+ /*create_import(
gromet,
&mut nodes,
&mut edges,
@@ -1199,11 +1233,19 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
c_args.att_idx,
c_args.bf_counter,
c_args.parent_node.clone(),
- );
+ );*/
}
FunctionType::ImportedMethod => {
+ create_att_primitive(
+ gromet, // gromet for metadata
+ &mut nodes, // nodes
+ &mut edges,
+ &mut meta_nodes,
+ &mut start,
+ c_args.clone(),
+ );
// basically seems like these are just functions to me.
- c_args.att_idx = boxf.contents.unwrap() as usize;
+ /*c_args.att_idx = boxf.contents.unwrap() as usize;
c_args.att_box = gromet.modules[0].attributes[c_args.att_idx - 1].clone();
create_function(
gromet, // gromet for metadata
@@ -1212,7 +1254,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
&mut meta_nodes,
&mut start,
c_args.clone(),
- );
+ );*/
}
_ => {}
}
@@ -1288,7 +1330,13 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
}
}
// convert every node object into a node query
- let create = String::from("CREATE");
+ queries.append(&mut construct_memgraph_queries(
+ &mut nodes,
+ &mut edges,
+ &mut meta_nodes,
+ &mut queries.clone(),
+ ));
+ /*let create = String::from("CREATE");
for node in nodes.iter() {
let mut name = String::from("a");
if node.name.is_none() {
@@ -1344,7 +1392,11 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec
let set_query = format!("set e{}{}.index={}", edge.src, edge.tgt, edge.prop.unwrap());
queries.push(set_query);
}
- }
+ if edge.refer.is_some() {
+ let set_query = format!("set e{}{}.refer={}", edge.src, edge.tgt, edge.refer.unwrap());
+ queries.push(set_query);
+ }
+ }*/
queries
}
// this method creates an import type function
@@ -1406,7 +1458,7 @@ pub fn create_import(
src: c_args.parent_node.node_id,
tgt: n3.node_id.clone(),
e_type: String::from("Contains"),
- prop: None,
+ ..Default::default()
};
edges.push(e4);
if eboxf.metadata.is_some() {
@@ -1423,7 +1475,7 @@ pub fn create_import(
src: n3.node_id,
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -1484,6 +1536,7 @@ pub fn create_function(
tgt: n1.node_id.clone(),
e_type: String::from("Contains"),
prop: Some(c_args.att_idx),
+ ..Default::default()
};
parent_node = n1.clone();
nodes.push(n1.clone());
@@ -1504,7 +1557,7 @@ pub fn create_function(
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -1544,6 +1597,7 @@ pub fn create_function(
new_c_args.box_counter = box_counter;
new_c_args.cur_box = att_sub_box.clone();
new_c_args.att_idx = c_args.att_idx;
+ new_c_args.att_bf_idx = c_args.att_bf_idx;
match att_sub_box.function_type {
FunctionType::Function => {
new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
@@ -1570,6 +1624,7 @@ pub fn create_function(
}
FunctionType::Expression => {
new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
+ new_c_args.att_bf_idx = c_args.att_idx;
create_att_expression(
gromet, // gromet for metadata
nodes, // nodes
@@ -1611,8 +1666,7 @@ pub fn create_function(
}
FunctionType::ImportedMethod => {
// this is a function call, but for some reason is not called a function
- new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
- create_function(
+ create_att_primitive(
gromet, // gromet for metadata
nodes, // nodes
edges,
@@ -1620,9 +1674,26 @@ pub fn create_function(
start,
new_c_args.clone(),
);
+ /*new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
+ create_function(
+ gromet, // gromet for metadata
+ nodes, // nodes
+ edges,
+ meta_nodes,
+ start,
+ new_c_args.clone(),
+ );*/
}
FunctionType::Imported => {
- create_import(gromet, nodes, edges, meta_nodes, start, c_args.clone());
+ create_att_primitive(
+ gromet, // gromet for metadata
+ nodes, // nodes
+ edges,
+ meta_nodes,
+ start,
+ new_c_args.clone(),
+ );
+ /*create_import(gromet, nodes, edges, meta_nodes, start, c_args.clone());
*start += 1;
// now to implement wiring
import_wiring(
@@ -1632,7 +1703,7 @@ pub fn create_function(
c_args.att_idx,
c_args.bf_counter,
c_args.parent_node.clone(),
- );
+ );*/
}
_ => {
println!(
@@ -1767,6 +1838,7 @@ pub fn create_conditional(
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(cond_counter as usize),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -1800,7 +1872,7 @@ pub fn create_conditional(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n2.clone());
edges.push(e3);
@@ -1818,7 +1890,7 @@ pub fn create_conditional(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -1853,7 +1925,7 @@ pub fn create_conditional(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n3.clone());
edges.push(e5);
@@ -1871,7 +1943,7 @@ pub fn create_conditional(
src: n3.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -2019,7 +2091,7 @@ pub fn create_conditional(
src: wfc_src_tgt[0].clone(),
tgt: wfc_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -2072,7 +2144,7 @@ pub fn create_conditional(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e9);
}
@@ -2119,7 +2191,7 @@ pub fn create_conditional(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e10);
}
@@ -2167,7 +2239,7 @@ pub fn create_conditional(
src: else_src_tgt[0].clone(),
tgt: else_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e11);
}
@@ -2219,7 +2291,7 @@ pub fn create_conditional(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e12);
}
@@ -2267,7 +2339,7 @@ pub fn create_conditional(
src: else_src_tgt[0].clone(),
tgt: else_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e13);
}
@@ -2317,7 +2389,7 @@ pub fn create_conditional(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e14);
}
@@ -2357,6 +2429,7 @@ pub fn create_for_loop(
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(cond_counter as usize),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -2380,7 +2453,7 @@ pub fn create_for_loop(
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -2492,7 +2565,7 @@ pub fn create_for_loop(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n2.clone());
edges.push(e3);
@@ -2510,7 +2583,7 @@ pub fn create_for_loop(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -2544,7 +2617,7 @@ pub fn create_for_loop(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n3.clone());
edges.push(e5);
@@ -2562,7 +2635,7 @@ pub fn create_for_loop(
src: n3.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -2636,7 +2709,7 @@ pub fn create_for_loop(
src: wfl_src_tgt[0].clone(),
tgt: wfl_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -2685,7 +2758,7 @@ pub fn create_for_loop(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e9);
}
@@ -2732,7 +2805,7 @@ pub fn create_for_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e10);
}
@@ -2782,7 +2855,7 @@ pub fn create_for_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e12);
}
@@ -2829,7 +2902,7 @@ pub fn create_for_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e15);
}
@@ -2879,7 +2952,7 @@ pub fn create_for_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e16);
}
@@ -2928,7 +3001,7 @@ pub fn create_for_loop(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e14);
}
@@ -2969,6 +3042,7 @@ pub fn create_while_loop(
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(cond_counter as usize),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -2992,7 +3066,7 @@ pub fn create_while_loop(
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3081,7 +3155,7 @@ pub fn create_while_loop(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n2.clone());
edges.push(e3);
@@ -3099,7 +3173,7 @@ pub fn create_while_loop(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3133,7 +3207,7 @@ pub fn create_while_loop(
src: n1.node_id.clone(),
tgt: format!("n{}", start),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
nodes.push(n3.clone());
edges.push(e5);
@@ -3151,7 +3225,7 @@ pub fn create_while_loop(
src: n3.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3225,7 +3299,7 @@ pub fn create_while_loop(
src: wfl_src_tgt[0].clone(),
tgt: wfl_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -3274,7 +3348,7 @@ pub fn create_while_loop(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e9);
}
@@ -3321,7 +3395,7 @@ pub fn create_while_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e10);
}
@@ -3371,7 +3445,7 @@ pub fn create_while_loop(
src: if_src_tgt[0].clone(),
tgt: if_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e12);
}
@@ -3420,7 +3494,7 @@ pub fn create_while_loop(
src: cond_src_tgt[0].clone(),
tgt: cond_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e14);
}
@@ -3456,6 +3530,7 @@ pub fn create_att_expression(
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(c_args.att_idx),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -3476,7 +3551,7 @@ pub fn create_att_expression(
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3524,6 +3599,13 @@ pub fn create_att_expression(
}
}
}
+ if opo_name.is_empty() {
+ println!(
+ "Missed Opo at att_idx: {:?} and box_counter: {:?}",
+ c_args.att_idx, c_args.box_counter
+ );
+ println!("parent att box: {:?}", c_args.att_bf_idx);
+ }
if !opo_name.clone().is_empty() {
let mut oport: u32 = 0;
for _op in att_box.opo.as_ref().unwrap().iter() {
@@ -3545,7 +3627,7 @@ pub fn create_att_expression(
src: n1.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
if att_box.opo.clone().as_ref().unwrap()[oport as usize]
@@ -3569,7 +3651,7 @@ pub fn create_att_expression(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3644,7 +3726,7 @@ pub fn create_att_expression(
src: n1.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
if att_box.opi.clone().as_ref().unwrap()[iport as usize]
@@ -3668,7 +3750,7 @@ pub fn create_att_expression(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3685,9 +3767,8 @@ pub fn create_att_expression(
for att_sub_box in att_box.bf.as_ref().unwrap().iter() {
new_c_args.box_counter = box_counter;
new_c_args.cur_box = att_sub_box.clone();
- if att_sub_box.contents.is_some() {
- new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
- }
+ new_c_args.att_idx = c_args.att_idx;
+ new_c_args.att_bf_idx = c_args.att_bf_idx;
match att_sub_box.function_type {
FunctionType::Literal => {
create_att_literal(
@@ -3709,6 +3790,18 @@ pub fn create_att_expression(
new_c_args.clone(),
);
}
+ FunctionType::Expression => {
+ new_c_args.att_idx = att_sub_box.contents.unwrap() as usize;
+ new_c_args.att_bf_idx = c_args.att_idx;
+ create_att_expression(
+ gromet, // gromet for metadata
+ nodes, // nodes
+ edges,
+ meta_nodes,
+ start,
+ new_c_args.clone(),
+ );
+ }
_ => {}
}
box_counter += 1;
@@ -3724,6 +3817,14 @@ pub fn create_att_expression(
c_args.bf_counter,
);
+ cross_att_wiring(
+ att_box.clone(),
+ nodes,
+ edges,
+ c_args.att_idx,
+ c_args.bf_counter,
+ );
+
// Now we also perform wopio wiring in case there is an empty expression
if att_box.wopio.is_some() {
wopio_wiring(att_box, nodes, edges, c_args.att_idx - 1, c_args.bf_counter);
@@ -3759,6 +3860,7 @@ pub fn create_att_predicate(
tgt: format!("n{}", start),
e_type: String::from("Contains"),
prop: Some(c_args.att_idx),
+ ..Default::default()
};
nodes.push(n1.clone());
edges.push(e1);
@@ -3779,7 +3881,7 @@ pub fn create_att_predicate(
src: n1.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3822,7 +3924,7 @@ pub fn create_att_predicate(
src: n1.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
if att_box.opo.clone().as_ref().unwrap()[oport as usize]
@@ -3846,7 +3948,7 @@ pub fn create_att_predicate(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -3883,7 +3985,7 @@ pub fn create_att_predicate(
src: n1.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
if att_box.opi.clone().as_ref().unwrap()[iport as usize]
@@ -3907,7 +4009,7 @@ pub fn create_att_predicate(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4009,7 +4111,7 @@ pub fn create_att_literal(
src: c_args.parent_node.node_id,
tgt: n3.node_id.clone(),
e_type: String::from("Contains"),
- prop: None,
+ ..Default::default()
};
edges.push(e4);
if lit_box.metadata.is_some() {
@@ -4026,7 +4128,7 @@ pub fn create_att_literal(
src: n3.node_id,
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4085,7 +4187,7 @@ pub fn create_att_primitive(
src: c_args.parent_node.node_id,
tgt: n3.node_id.clone(),
e_type: String::from("Contains"),
- prop: None,
+ ..Default::default()
};
edges.push(e4);
if c_args.cur_box.metadata.is_some() {
@@ -4102,7 +4204,7 @@ pub fn create_att_primitive(
src: n3.node_id,
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4123,11 +4225,15 @@ pub fn create_att_abstract(
) {
// first find the pof's for box
let mut pof: Vec = vec![];
+ let mut pof_names: Vec = vec![];
if c_args.att_box.pof.is_some() {
let mut po_idx: u32 = 1;
for port in c_args.att_box.pof.clone().unwrap().iter() {
if port.r#box == c_args.box_counter as u8 {
pof.push(po_idx);
+ if port.name.is_some() {
+ pof_names.push(port.name.clone().unwrap());
+ }
}
po_idx += 1;
}
@@ -4143,11 +4249,26 @@ pub fn create_att_abstract(
pi_idx += 1;
}
}
+ // now to construct an entry of ValueL for abstract port references
+ let mut value_vec = Vec::::new();
+ for name in pof_names.iter() {
+ let val = ValueL {
+ value_type: "String".to_string(),
+ value: format!("{:?}", name.clone()),
+ gromet_type: Some("Name".to_string()),
+ };
+ value_vec.push(val.clone());
+ }
+ let val = ValueL {
+ value_type: "List".to_string(),
+ value: format!("{:?}", value_vec.clone()),
+ gromet_type: Some("Abstract".to_string()),
+ };
// now make the node with the port information
let mut metadata_idx = 0;
let n3 = Node {
- n_type: String::from("Primitive"),
- value: None,
+ n_type: String::from("Abstract"),
+ value: Some(val),
name: c_args.cur_box.name.clone(),
node_id: format!("n{}", start),
out_idx: Some(pof),
@@ -4163,7 +4284,7 @@ pub fn create_att_abstract(
src: c_args.parent_node.node_id,
tgt: n3.node_id.clone(),
e_type: String::from("Contains"),
- prop: None,
+ ..Default::default()
};
edges.push(e4);
if c_args.cur_box.metadata.is_some() {
@@ -4180,7 +4301,7 @@ pub fn create_att_abstract(
src: n3.node_id,
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4235,7 +4356,7 @@ pub fn create_opo(
src: c_args.parent_node.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
@@ -4261,7 +4382,7 @@ pub fn create_opo(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4320,7 +4441,7 @@ pub fn create_opi(
src: c_args.parent_node.node_id.clone(),
tgt: n2.node_id.clone(),
e_type: String::from("Port_Of"),
- prop: None,
+ ..Default::default()
};
edges.push(e3);
@@ -4346,7 +4467,7 @@ pub fn create_opi(
src: n2.node_id.clone(),
tgt: format!("m{}", metadata_idx),
e_type: String::from("Metadata"),
- prop: None,
+ ..Default::default()
};
edges.push(me1);
}
@@ -4418,6 +4539,7 @@ pub fn wfopi_wiring(
tgt: wfopi_src_tgt[1].clone(),
e_type: String::from("Wire"),
prop: Some(prop.unwrap() as usize),
+ ..Default::default()
};
edges.push(e6);
}
@@ -4480,7 +4602,7 @@ pub fn wfopo_wiring(
src: wfopo_src_tgt[0].clone(),
tgt: wfopo_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e7);
}
@@ -4500,6 +4622,7 @@ pub fn wff_wiring(
for wire in eboxf.wff.unwrap().iter() {
let mut wff_src_tgt: Vec = vec![];
let mut prop = None;
+ let mut refer = None;
let src_idx = wire.src; // port index
@@ -4533,7 +4656,7 @@ pub fn wff_wiring(
// push the tgt
if (wire.src as u32) == *p {
wff_src_tgt.push(node.node_id.clone());
- prop = Some(i as u32);
+ prop = Some(i);
}
}
}
@@ -4555,10 +4678,13 @@ pub fn wff_wiring(
// exclude opo's
if node.n_type != "Opo" {
// iterate through port to check for tgt
- for p in node.out_idx.as_ref().unwrap().iter() {
+ for (i, p) in node.out_idx.as_ref().unwrap().iter().enumerate() {
// push the tgt
if (wire.tgt as u32) == *p {
wff_src_tgt.push(node.node_id.clone());
+ if node.n_type == "Abstract" {
+ refer = Some(i);
+ }
}
}
}
@@ -4572,7 +4698,8 @@ pub fn wff_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: Some(prop.unwrap() as usize),
+ prop: Some(prop.unwrap()),
+ refer,
};
edges.push(e8);
}
@@ -4635,7 +4762,7 @@ pub fn wopio_wiring(
src: wopio_src_tgt[0].clone(),
tgt: wopio_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e7);
}
@@ -4785,7 +4912,7 @@ pub fn import_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -4873,7 +5000,7 @@ pub fn import_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -5007,7 +5134,7 @@ pub fn wfopi_cross_att_wiring(
src: wfopi_src_tgt[0].clone(),
tgt: wfopi_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -5091,7 +5218,7 @@ pub fn wfopo_cross_att_wiring(
src: wfopo_src_tgt[0].clone(),
tgt: wfopo_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -5099,7 +5226,7 @@ pub fn wfopo_cross_att_wiring(
}
}
// this will construct connections from the sub function modules opi's to another sub module opo's, tracing data inside the function
-// opi(sub)->opo(sub)
+// opi(sub)->opo(sub) or pif(current) -> opo(sub) or opi(sub) -> pof(current)
#[allow(unused_assignments)]
pub fn wff_cross_att_wiring(
eboxf: FunctionNet, // This is the current attribute, should be the function if in a function
@@ -5109,6 +5236,7 @@ pub fn wff_cross_att_wiring(
bf_counter: u8, // this is the current box
) {
for wire in eboxf.wff.as_ref().unwrap().iter() {
+ let mut prop = None;
// collect info to identify the opi src node
let src_idx = wire.src; // port index
let src_pif = eboxf.pif.as_ref().unwrap()[(src_idx - 1) as usize].clone(); // src port
@@ -5195,7 +5323,7 @@ pub fn wff_cross_att_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
@@ -5232,7 +5360,10 @@ pub fn wff_cross_att_wiring(
&& (tgt_box as u32) == node.box_counter as u32
{
// only opo's
- if node.n_type == "Primitive" || node.n_type == "Literal" {
+ if node.n_type == "Primitive"
+ || node.n_type == "Literal"
+ || node.n_type == "Abstract"
+ {
// iterate through port to check for tgt
for p in node.out_idx.as_ref().unwrap().iter() {
// push the src first, being pif
@@ -5263,12 +5394,13 @@ pub fn wff_cross_att_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e8);
}
}
} else {
+ // This should be pif -> opo
let src_nbox = bf_counter; // nbox value of src opi
// collect info to identify the opo tgt node
let tgt_idx = wire.tgt; // port index
@@ -5298,12 +5430,13 @@ pub fn wff_cross_att_wiring(
&& (src_box as u32) == node.box_counter as u32
{
// only opo's
- if node.n_type == "Primitive" {
+ if node.n_type == "Primitive" || node.n_type == "Abstract" {
// iterate through port to check for tgt
- for p in node.in_indx.as_ref().unwrap().iter() {
+ for (i, p) in node.in_indx.as_ref().unwrap().iter().enumerate() {
// push the src first, being pif
- if (src_opi_idx as u32) == *p {
+ if (src_idx as u32) == *p {
wff_src_tgt.push(node.node_id.clone());
+ prop = Some(i);
}
}
}
@@ -5335,7 +5468,8 @@ pub fn wff_cross_att_wiring(
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ prop: Some(prop.unwrap()),
+ ..Default::default()
};
edges.push(e8);
}
@@ -5449,7 +5583,7 @@ pub fn external_wiring(gromet: &ModuleCollection, nodes: &mut [Node], edges: &mu
src: wff_src_tgt[0].clone(),
tgt: wff_src_tgt[1].clone(),
e_type: String::from("Wire"),
- prop: None,
+ ..Default::default()
};
edges.push(e9);
}
@@ -5467,3 +5601,120 @@ pub fn parse_gromet_queries(gromet: ModuleCollection) -> Vec {
queries
}
+
+// convert every node object into a node query
+pub fn construct_memgraph_queries(
+ nodes: &mut Vec,
+ edges: &mut Vec,
+ meta_nodes: &mut Vec,
+ queries: &mut Vec,
+) -> Vec {
+ // convert every node object into a node query
+ let create = String::from("CREATE");
+ for node in nodes.iter() {
+ let mut name = String::from("a");
+ if node.name.is_none() {
+ name = node.n_type.clone();
+ } else {
+ name = node.name.as_ref().unwrap().to_string();
+ }
+ // better parsing of values for inference later on.
+ // handles case of parsing a list as a proper list object, only depth one though
+ // would need recursive function for aritrary depth. To be done at somepoint.
+ let value = match &node.value {
+ Some(val) => {
+ if val.value_type == *"List" && &val.value[0..1] == "[" && &val.value[1..2] != "]" {
+ let val_type = val.value_type.clone();
+ let val_grom_type = val.gromet_type.as_ref().unwrap();
+ let val_len = val.value[..].len();
+ let val_val: Vec = val.value[1..val_len]
+ .split("}, ")
+ .map(|x| x.to_string())
+ .collect();
+
+ let mut val_vec = Vec::::new();
+ for (i, val) in val_val.iter().enumerate() {
+ if i == val_val.len() - 1 {
+ let val_string = format!("{}}}", &val[7..(val.len() - 2)]);
+ val_vec.push(val_string.clone());
+ } else {
+ let val_string = format!("{}}}", &val[7..]);
+ val_vec.push(val_string.clone());
+ }
+ }
+ let mut final_val_vec = Vec::::new();
+ for val_str in val_vec.iter() {
+ let val_fields: Vec =
+ val_str.split(", ").map(|x| x.to_string()).collect();
+ let cor_val: Vec =
+ val_fields[1].split(": ").map(|x| x.to_string()).collect();
+ let final_val = cor_val[1].replace("\\\"", "");
+ final_val_vec.push(final_val.clone());
+ }
+ format!(
+ "{{ value_type:{:?}, value:{:?}, gromet_type:{:?} }}",
+ val_type, final_val_vec, val_grom_type
+ )
+ .replace("\\\"", "")
+ } else {
+ format!(
+ "{{ value_type:{:?}, value:{:?}, gromet_type:{:?} }}",
+ val.value_type,
+ val.value,
+ val.gromet_type.as_ref().unwrap()
+ )
+ }
+ }
+ None => String::from("\"\""),
+ };
+
+ // NOTE: The format of value has changed to represent a literal Cypher map {field:value}.
+ // We no longer need to format value with the debug :? parameter
+ let node_query = format!(
+ "{} ({}:{} {{name:{:?},value:{},order_box:{:?},order_att:{:?}}})",
+ create, node.node_id, node.n_type, name, value, node.nbox, node.contents
+ );
+ queries.push(node_query);
+ }
+ for node in meta_nodes.iter() {
+ queries.append(&mut create_metadata_node_query(node.clone()));
+ }
+
+ // convert every edge object into an edge query
+ let init_edges = edges.len();
+ edges.sort();
+ edges.dedup();
+ let edges_clone = edges.clone();
+ // also dedup if edge prop is different
+ for (i, edge) in edges_clone.iter().enumerate().rev() {
+ if i != 0 && edge.src == edges_clone[i - 1].src && edge.tgt == edges_clone[i - 1].tgt {
+ edges.remove(i);
+ }
+ }
+ let fin_edges = edges.len();
+ if init_edges != fin_edges {
+ println!("Duplicated Edges Removed, check for bugs");
+ }
+ for edge in edges.iter() {
+ let edge_query = format!(
+ "{} ({})-[e{}{}:{}]->({})",
+ create, edge.src, edge.src, edge.tgt, edge.e_type, edge.tgt
+ );
+ queries.push(edge_query);
+
+ if edge.prop.is_some() {
+ let set_query = format!("set e{}{}.index={}", edge.src, edge.tgt, edge.prop.unwrap());
+ queries.push(set_query);
+ }
+ if edge.refer.is_some() {
+ let set_query = format!(
+ "set e{}{}.refer={}",
+ edge.src,
+ edge.tgt,
+ edge.refer.unwrap()
+ );
+ queries.push(set_query);
+ }
+ }
+ queries.to_vec()
+}
diff --git a/skema/skema-rs/skema/src/model_extraction.rs b/skema/skema-rs/skema/src/model_extraction.rs
index 9f50b7336f1..44b824e20e3 100644
--- a/skema/skema-rs/skema/src/model_extraction.rs
+++ b/skema/skema-rs/skema/src/model_extraction.rs
@@ -1,6 +1,9 @@
use crate::config::Config;
+use crate::ValueL;
+
use mathml::ast::operator::Operator;
pub use mathml::mml2pn::{ACSet, Term};
+
use petgraph::prelude::*;
use std::string::ToString;
@@ -17,11 +20,51 @@ use neo4rs;
use neo4rs::{query, Error};
use std::sync::Arc;
+/// This struct is the node struct for the constructed petgraph
+#[derive(Clone, Debug)]
+pub struct ModelNode {
+ id: i64,
+ label: String,
+ name: Option,
+ value: Option,
+}
+
+/// This struct is the edge struct for the constructed petgraph
+#[derive(Clone, Debug)]
+pub struct ModelEdge {
+ id: i64,
+ src_id: i64,
+ tgt_id: i64,
+ index: Option,
+ refer: Option,
+}
+
+/**
+ * This is the main function call for model extraction.
+ *
+ * Parameters:
+ * - module_id: i64 -> This is the top level id of the gromet module in memgraph.
+ * - config: Config -> This is a config struct for connecting to memgraph
+ *
+ * Returns:
+ * - Vector of FirstOrderODE -> This vector of structs is used to construct a PetriNet or RegNet further down the pipeline
+ *
+ * Assumptions:
+ * - As of right now, we can always assume the code has been sliced to only one relevant function which contains the
+ * core dynamics in it somewhere
+ *
+ * Notes:
+ * - FirstOrderODE is primarily composed of a LHS and a RHS,
+ * - LHS is just a Mi object of the state being differentiated. There are additional fields for the LHS but only the
+ * content field is used in downstream inference for now.
+ * - RHS is where the bulk of the inference happens, it produces an expression tree, hence the MET -> Math Expression Tree.
+ * Every operator has a vector of arguments. (order matters)
+ */
#[allow(non_snake_case)]
pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec {
let mut core_dynamics_ast = Vec::::new();
- let core_id = find_pn_dynamics(module_id, config.clone()).await; // gives back list of function nodes that might contain the dynamics
+ let core_id = find_pn_dynamics(module_id, config.clone()).await;
if core_id.is_empty() {
let deriv = Ci {
@@ -48,24 +91,31 @@ pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec Vec {
let graph = subgraph2petgraph(module_id, config.clone()).await;
// 1. find each function node
let mut function_nodes = Vec::::new();
for node in graph.node_indices() {
- if graph[node].labels()[0] == *"Function" {
+ if graph[node].label == *"Function" {
function_nodes.push(node);
}
}
// 2. check and make sure only expressions in function
// 3. check number of expressions and decide off that
- let mut functions = Vec::>::new();
+ let mut functions = Vec::>::new();
for i in 0..function_nodes.len() {
// grab the subgraph of the given expression
- functions.push(subgraph2petgraph(graph[function_nodes[i]].id(), config.clone()).await);
+ functions.push(subgraph2petgraph(graph[function_nodes[i]].id, config.clone()).await);
}
// get a sense of the number of expressions in each function
let mut func_counter = 0;
@@ -74,17 +124,17 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec {
let mut expression_counter = 0;
let mut primitive_counter = 0;
for node in func.node_indices() {
- if func[node].labels()[0] == *"Expression" {
+ if func[node].label == *"Expression" {
expression_counter += 1;
}
- if func[node].labels()[0] == *"Primitive" {
- if func[node].get::("name").unwrap() == *"ast.Mult" {
+ if func[node].label == *"Primitive" {
+ if *func[node].name.as_ref().unwrap() == "ast.Mult".to_string() {
primitive_counter += 1;
- } else if func[node].get::("name").unwrap() == *"ast.Add" {
+ } else if *func[node].name.as_ref().unwrap() == "ast.Add".to_string() {
primitive_counter += 1;
- } else if func[node].get::("name").unwrap() == *"ast.Sub" {
+ } else if *func[node].name.as_ref().unwrap() == "ast.Sub".to_string() {
primitive_counter += 1;
- } else if func[node].get::("name").unwrap() == *"ast.USub" {
+ } else if *func[node].name.as_ref().unwrap() == "ast.USub".to_string() {
primitive_counter += 1;
}
}
@@ -98,8 +148,8 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec {
let mut core_id = Vec::::new();
for c_func in core_func.iter() {
for node in functions[*c_func].node_indices() {
- if functions[*c_func][node].labels()[0] == *"Function" {
- core_id.push(functions[*c_func][node].id());
+ if functions[*c_func][node].label == *"Function" {
+ core_id.push(functions[*c_func][node].id);
}
}
}
@@ -107,6 +157,11 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec {
core_id
}
+/**
+ * Once the function node has been identified, this function takes it from there to extract the vector of FirstOrderODE's
+ *
+ * This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case.
+ */
#[allow(non_snake_case)]
pub async fn subgrapg2_core_dyn_MET_ast(
root_node_id: i64,
@@ -118,7 +173,7 @@ pub async fn subgrapg2_core_dyn_MET_ast(
// find all the expressions
let mut expression_nodes = Vec::::new();
for node in graph.node_indices() {
- if graph[node].labels()[0] == *"Expression" {
+ if graph[node].label == *"Expression" {
expression_nodes.push(node);
}
}
@@ -128,19 +183,35 @@ pub async fn subgrapg2_core_dyn_MET_ast(
// initialize vector to collect all expression wiring graphs
for i in 0..expression_nodes.len() {
// grab the wiring subgraph of the given expression
- let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id(), config.clone())
+ let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id, config.clone())
.await
.unwrap();
- if sub_w.node_count() > 3 {
- let expr = trim_un_named(&mut sub_w, config.clone()).await;
+ let mut prim_counter = 0;
+ let mut has_call = false;
+ for node_index in sub_w.node_indices() {
+ if sub_w[node_index].label == *"Primitive" {
+ prim_counter += 1;
+ if *sub_w[node_index].name.as_ref().unwrap() == "_call" {
+ has_call = true;
+ }
+ }
+ }
+ if sub_w.node_count() > 3 && !(prim_counter == 1 && has_call) && prim_counter != 0 {
+ println!("--------------------");
+ println!("expression: {}", graph[expression_nodes[i]].id);
+ // the call expressions get referenced by multiple top level expressions, so deleting the nodes in it breaks the other graphs. Need to pass clone of expression subgraph so references to original has all the nodes.
+ if has_call {
+ sub_w = trim_calls(sub_w.clone())
+ }
+ let expr = trim_un_named(&mut sub_w);
let mut root_node = Vec::::new();
for node_index in expr.node_indices() {
- if expr[node_index].labels()[0].clone() == *"Opo" {
+ if expr[node_index].label.clone() == *"Opo" {
root_node.push(node_index);
}
}
if root_node.len() >= 2 {
- // println!("More than one Opo! Skipping Expression!");
+ println!("More than one Opo! Skipping Expression!");
} else {
core_dynamics.push(tree_2_MET_ast(expr, root_node[0]).unwrap());
}
@@ -150,17 +221,20 @@ pub async fn subgrapg2_core_dyn_MET_ast(
Ok(core_dynamics)
}
+/**
+ * This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it.
+ */
#[allow(non_snake_case)]
fn tree_2_MET_ast(
- graph: &mut petgraph::Graph,
+ graph: &mut petgraph::Graph,
root_node: NodeIndex,
) -> Result {
let mut fo_eq_vec = Vec::::new();
let _math_vec = Vec::::new();
let mut lhs = Vec::::new();
- if graph[root_node].labels()[0] == *"Opo" {
+ if graph[root_node].label == *"Opo" {
// we first construct the derivative of the first node
- let deriv_name: &str = &graph[root_node].get::("name").unwrap();
+ let deriv_name: &str = graph[root_node].name.as_ref().unwrap();
// this will let us know if additional trimming is needed to handle the code implementation of the equations
// let mut step_impl = false; this will be used for step implementaion for later
// This is very bespoke right now
@@ -183,7 +257,7 @@ fn tree_2_MET_ast(
lhs.push(deriv);
}
for node in graph.neighbors_directed(root_node, Outgoing) {
- if graph[node].labels()[0].clone() == *"Primitive" {
+ if graph[node].label.clone() == *"Primitive" {
let operate = get_operator_MET(graph, node); // output -> Operator
let rhs_arg = get_args_MET(graph, node); // output -> Vec
let rhs = MathExpressionTree::Cons(operate, rhs_arg); // MathExpressionTree
@@ -204,9 +278,10 @@ fn tree_2_MET_ast(
Ok(fo_eq_vec[0].clone())
}
+/// This is a recursive function that walks along the wired subgraph of an expression to construct the expression tree
#[allow(non_snake_case)]
pub fn get_args_MET(
- graph: &petgraph::Graph,
+ graph: &petgraph::Graph,
root_node: NodeIndex,
) -> Vec {
let mut args = Vec::::new();
@@ -219,30 +294,34 @@ pub fn get_args_MET(
// construct vecs
for node in graph.neighbors_directed(root_node, Outgoing) {
// first need to check for operator
- if graph[node].labels()[0].clone() == *"Primitive" {
+ if graph[node].label.clone() == *"Primitive" {
let operate = get_operator_MET(graph, node); // output -> Operator
let rhs_arg = get_args_MET(graph, node); // output -> Vec
let rhs = MathExpressionTree::Cons(operate, rhs_arg); // MathExpressionTree
args.push(rhs.clone());
} else {
// asummption it is atomic
- let temp_string = graph[node].get::("name").unwrap().clone();
- let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone())));
- args.push(arg2.clone());
+ if graph[node].label.clone() == *"Literal" {
+ let temp_string = graph[node].value.clone().unwrap().value.replace('\"', "");
+ let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone())));
+ args.push(arg2.clone());
+ } else {
+ let temp_string = graph[node].name.as_ref().unwrap().clone();
+ let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone())));
+ args.push(arg2.clone());
+ }
}
// construct order of args
let x = graph
.edge_weight(graph.find_edge(root_node, node).unwrap())
.unwrap()
- .get::("index")
+ .index
.unwrap();
arg_order.push(x);
}
-
// fix order of args
let mut ordered_args = args.clone();
-
for (i, ind) in arg_order.iter().enumerate() {
// the ind'th element of order_args is the ith element of the unordered args
if ordered_args.len() > *ind as usize {
@@ -253,102 +332,92 @@ pub fn get_args_MET(
ordered_args
}
-// this gets the operator from the node name
+/// This gets the operator from the node name
#[allow(non_snake_case)]
#[allow(clippy::if_same_then_else)]
pub fn get_operator_MET(
- graph: &petgraph::Graph,
+ graph: &petgraph::Graph,
root_node: NodeIndex,
) -> Operator {
let mut op = Vec::::new();
- if graph[root_node].get::("name").unwrap() == *"ast.Mult" {
+ if *graph[root_node].name.as_ref().unwrap() == "ast.Mult".to_string() {
op.push(Operator::Multiply);
- } else if graph[root_node].get::("name").unwrap() == *"ast.Add" {
+ } else if *graph[root_node].name.as_ref().unwrap() == "ast.Add" {
op.push(Operator::Add);
- } else if graph[root_node].get::("name").unwrap() == *"ast.Sub" {
+ } else if *graph[root_node].name.as_ref().unwrap() == "ast.Sub" {
op.push(Operator::Subtract);
- } else if graph[root_node].get::("name").unwrap() == *"ast.USub" {
+ } else if *graph[root_node].name.as_ref().unwrap() == "ast.USub" {
op.push(Operator::Subtract);
- } else if graph[root_node].get::("name").unwrap() == *"ast.Div" {
+ } else if *graph[root_node].name.as_ref().unwrap() == "ast.Div" {
op.push(Operator::Divide);
} else {
- op.push(Operator::Other(
- graph[root_node].get::("name").unwrap(),
- ));
+ op.push(Operator::Other(graph[root_node].name.clone().unwrap()));
}
op[0].clone()
}
-// this currently only works for un-named nodes that are not chained or have multiple incoming/outgoing edges
-async fn trim_un_named(
- graph: &mut petgraph::Graph,
- config: Config,
-) -> &mut petgraph::Graph {
+/**
+ * This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes.
+ *
+ * This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph.
+ * For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely.
+ *
+ * Concerns:
+ * - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice,
+ * but I think it's possible. So something to keep in mind.
+ */
+fn trim_un_named(
+ graph: &mut petgraph::Graph,
+) -> &mut petgraph::Graph {
// first create a cloned version of the graph we can modify while iterating over it.
- let graph_call = Arc::new(config.graphdb_connection().await);
-
// iterate over the graph and add a new edge to bypass the un-named nodes
for node_index in graph.node_indices() {
- if graph[node_index].get::("name").unwrap().clone() == *"un-named" {
+ if graph[node_index].clone().name.unwrap().clone() == *"un-named" {
let mut bypass = Vec::::new();
+ let mut outgoing_bypass = Vec::::new();
for node1 in graph.neighbors_directed(node_index, Incoming) {
bypass.push(node1);
}
for node2 in graph.neighbors_directed(node_index, Outgoing) {
- bypass.push(node2);
+ outgoing_bypass.push(node2);
}
// one incoming one outgoing
- if bypass.len() == 2 {
+ if bypass.len() == 1 && outgoing_bypass.len() == 1 {
// annoyingly have to pull the edge/Relation to insert into graph
- let mut edge_list = Vec::::new();
- let query_string = format!(
- "MATCH (n)-[r:Wire]->(m) WHERE id(n) = {} AND id(m) = {} RETURN r",
- graph[bypass[0]].id(),
- graph[node_index].id()
+ graph.add_edge(
+ bypass[0],
+ outgoing_bypass[0],
+ graph
+ .edge_weight(graph.find_edge(bypass[0], node_index).unwrap())
+ .unwrap()
+ .clone(),
);
- let mut result = graph_call.execute(query(&query_string[..])).await.unwrap();
- while let Ok(Some(row)) = result.next().await {
- let edge: neo4rs::Relation = row.get("r").unwrap();
- edge_list.push(edge);
- }
- // add the bypass edge
- for edge in edge_list {
- graph.add_edge(bypass[0], bypass[1], edge);
- }
- } else if bypass.len() > 2 {
+ } else if bypass.len() >= 2 && outgoing_bypass.len() == 1 {
// this operates on the assumption that there maybe multiple references to the port
// (incoming arrows) but only one outgoing arrow, this seems to be the case based on
// data too.
- let end_node_idx = bypass.len() - 1;
- for (i, _ent) in bypass[0..end_node_idx].iter().enumerate() {
+ for (i, _ent) in bypass.iter().enumerate() {
// this iterates over all but the last entry in the bypass vec
- let mut edge_list = Vec::::new();
- let query_string = format!(
- "MATCH (n)-[r:Wire]->(m) WHERE id(n) = {} AND id(m) = {} RETURN r",
- graph[bypass[i]].id(),
- graph[node_index].id()
+ graph.add_edge(
+ bypass[i],
+ outgoing_bypass[0],
+ graph
+ .edge_weight(graph.find_edge(bypass[i], node_index).unwrap())
+ .unwrap()
+ .clone(),
);
- let mut result = graph_call.execute(query(&query_string[..])).await.unwrap();
- while let Ok(Some(row)) = result.next().await {
- let edge: neo4rs::Relation = row.get("r").unwrap();
- edge_list.push(edge);
- }
-
- for edge in edge_list {
- graph.add_edge(bypass[i], bypass[end_node_idx], edge);
- }
}
}
}
}
- // now we perform a filter_map to remove the un-named nodes and only the bypass edge will remain to connect the nodes
+ // now we remove the un-named nodes and only the bypass edge will remain to connect the nodes
// we also remove the unpack node if it is present here as well
for node_index in graph.node_indices().rev() {
- if graph[node_index].get::("name").unwrap().clone() == *"un-named"
- || graph[node_index].get::("name").unwrap().clone() == *"unpack"
+ if graph[node_index].name.clone().unwrap() == *"un-named"
+ || graph[node_index].name.clone().unwrap() == *"unpack"
{
graph.remove_node(node_index);
}
@@ -357,12 +426,14 @@ async fn trim_un_named(
graph
}
+/// This function takes in a node id (typically that of an expression subgraph) and returns a
+/// petgraph subgraph of only the wire type edges
async fn subgraph_wiring(
module_id: i64,
config: Config,
-) -> Result, Error> {
- let mut node_list = Vec::::new();
- let mut edge_list = Vec::::new();
+) -> Result, Error> {
+ let mut node_list = Vec::::new();
+ let mut edge_list = Vec::::new();
// Connect to Memgraph.
let graph = Arc::new(config.graphdb_connection().await);
@@ -382,7 +453,13 @@ async fn subgraph_wiring(
.await?;
while let Ok(Some(row)) = result1.next().await {
let node: neo4rs::Node = row.get("nodes2").unwrap();
- node_list.push(node);
+ let modelnode = ModelNode {
+ id: node.id(),
+ label: node.labels()[0].to_string(),
+ name: node.get::("name").ok(),
+ value: node.get::("value").ok(),
+ };
+ node_list.push(modelnode.clone());
}
// edge query
let mut result2 = graph
@@ -400,10 +477,17 @@ async fn subgraph_wiring(
.await?;
while let Ok(Some(row)) = result2.next().await {
let edge: neo4rs::Relation = row.get("edges2").unwrap();
- edge_list.push(edge);
+ let modeledge = ModelEdge {
+ id: edge.id(),
+ src_id: edge.start_node_id(),
+ tgt_id: edge.end_node_id(),
+ index: edge.get::("index").ok(),
+ refer: edge.get::("refer").ok(),
+ };
+ edge_list.push(modeledge);
}
- let mut graph: petgraph::Graph = Graph::new();
+ let mut graph: petgraph::Graph = Graph::new();
// Add nodes to the petgraph graph and collect their indexes
let mut nodes = Vec::::new();
@@ -417,10 +501,10 @@ async fn subgraph_wiring(
let mut src = Vec::::new();
let mut tgt = Vec::::new();
for node_idx in &nodes {
- if graph[*node_idx].id() == edge.start_node_id() {
+ if graph[*node_idx].id == edge.src_id {
src.push(*node_idx);
}
- if graph[*node_idx].id() == edge.end_node_id() {
+ if graph[*node_idx].id == edge.tgt_id {
tgt.push(*node_idx);
}
}
@@ -431,14 +515,15 @@ async fn subgraph_wiring(
Ok(graph)
}
+/// This function takes in a node id and returns a petgraph represention of the memgraph graph
async fn subgraph2petgraph(
module_id: i64,
config: Config,
-) -> petgraph::Graph {
+) -> petgraph::Graph {
let (x, y) = get_subgraph(module_id, config.clone()).await.unwrap();
// Create a petgraph graph
- let mut graph: petgraph::Graph = Graph::new();
+ let mut graph: petgraph::Graph = Graph::new();
// Add nodes to the petgraph graph and collect their indexes
let mut nodes = Vec::::new();
@@ -452,10 +537,10 @@ async fn subgraph2petgraph(
let mut src = Vec::::new();
let mut tgt = Vec::::new();
for node_idx in &nodes {
- if graph[*node_idx].id() == edge.start_node_id() {
+ if graph[*node_idx].id == edge.src_id {
src.push(*node_idx);
}
- if graph[*node_idx].id() == edge.end_node_id() {
+ if graph[*node_idx].id == edge.tgt_id {
tgt.push(*node_idx);
}
}
@@ -466,14 +551,13 @@ async fn subgraph2petgraph(
graph
}
+/// This function takes in a node id and returns the nodes and edges in it
pub async fn get_subgraph(
module_id: i64,
config: Config,
-) -> Result<(Vec, Vec), Error> {
- // construct the query that will delete the module with a given unique identifier
-
- let mut node_list = Vec::::new();
- let mut edge_list = Vec::::new();
+) -> Result<(Vec, Vec), Error> {
+ let mut node_list = Vec::::new();
+ let mut edge_list = Vec::::new();
// Connect to Memgraph.
let graph = Arc::new(config.graphdb_connection().await);
@@ -492,7 +576,13 @@ pub async fn get_subgraph(
.await?;
while let Ok(Some(row)) = result1.next().await {
let node: neo4rs::Node = row.get("nodes2").unwrap();
- node_list.push(node);
+ let modelnode = ModelNode {
+ id: node.id(),
+ label: node.labels()[0].to_string(),
+ name: node.get::("name").ok(),
+ value: node.get::("value").ok(),
+ };
+ node_list.push(modelnode);
}
// edge query
let mut result2 = graph
@@ -509,8 +599,125 @@ pub async fn get_subgraph(
.await?;
while let Ok(Some(row)) = result2.next().await {
let edge: neo4rs::Relation = row.get("edges2").unwrap();
- edge_list.push(edge);
+ let modeledge = ModelEdge {
+ id: edge.id(),
+ src_id: edge.start_node_id(),
+ tgt_id: edge.end_node_id(),
+ index: edge.get::("index").ok(),
+ refer: edge.get::("refer").ok(),
+ };
+ edge_list.push(modeledge);
}
Ok((node_list, edge_list))
}
+
+// this does special trimming to handle function calls
+pub fn trim_calls(
+ graph: petgraph::Graph,
+) -> petgraph::Graph {
+ let mut graph_clone = graph.clone();
+
+ // This will be all the nodes to be deleted
+ let mut inner_nodes = Vec::::new();
+ // find the call nodes
+ for node_index in graph.node_indices() {
+ if graph[node_index].clone().name.unwrap().clone() == *"_call" {
+ // we now trace up the incoming path until we hit a primitive,
+ // this will be the start node for the new edge.
+
+ // initialize trackers
+ let mut node_start = node_index;
+ let mut node_end = node_index;
+ let mut i_inner_nodes = Vec::::new();
+
+ // find end node and track path
+ for node in graph.neighbors_directed(node_index, Outgoing) {
+ if graph
+ .edge_weight(graph.find_edge(node_index, node).unwrap())
+ .unwrap()
+ .index
+ .unwrap()
+ == 0
+ {
+ let mut temp = to_terminal(graph.clone(), node);
+ node_end = temp.0;
+ i_inner_nodes.append(&mut temp.1);
+ }
+ }
+
+ // find start primtive node and track path
+ for node in graph.neighbors_directed(node_index, Incoming) {
+ let mut temp = to_primitive(graph.clone(), node);
+ node_start = temp.0;
+ i_inner_nodes.append(&mut temp.1);
+ }
+
+ // add edge from start to end node, with weight from start node a matching outgoing node form it
+ for node in graph.clone().neighbors_directed(node_start, Outgoing) {
+ for node_p in i_inner_nodes.iter() {
+ if node == *node_p {
+ graph_clone.add_edge(
+ node_start,
+ node_end,
+ graph
+ .clone()
+ .edge_weight(graph.clone().find_edge(node_start, node).unwrap())
+ .unwrap()
+ .clone(),
+ );
+ }
+ }
+ }
+ // we keep track all the node indexes we found while tracing the path and delete all
+ // intermediate nodes.
+ i_inner_nodes.push(node_index);
+ inner_nodes.append(&mut i_inner_nodes.clone());
+ }
+ }
+ inner_nodes.sort();
+ for node in inner_nodes.iter().rev() {
+ graph_clone.remove_node(*node);
+ }
+ graph_clone
+}
+
+pub fn to_terminal(
+ graph: petgraph::Graph,
+ node_index: NodeIndex,
+) -> (NodeIndex, Vec) {
+ let mut node_vec = Vec::::new();
+ let mut end_node = node_index;
+ // if there another node deeper
+ // else pass original input node out and an empty path vector
+ if graph.neighbors_directed(node_index, Outgoing).count() != 0 {
+ node_vec.push(node_index); // add current node to path list
+ for node in graph.neighbors_directed(node_index, Outgoing) {
+ // pass next node forward
+ let mut temp = to_terminal(graph.clone(), node);
+ end_node = temp.0; // make end_node
+ node_vec.append(&mut temp.1); // append previous path nodes
+ }
+ }
+ (end_node, node_vec)
+}
+
+// incoming walker to first primitive (NOTE: assumes input is not a primitive)
+pub fn to_primitive(
+ graph: petgraph::Graph,
+ node_index: NodeIndex,
+) -> (NodeIndex, Vec) {
+ let mut node_vec = Vec::::new();
+ let mut end_node = node_index;
+ node_vec.push(node_index);
+ for node in graph.neighbors_directed(node_index, Incoming) {
+ if graph[node].label.clone() != *"Primitive" {
+ let mut temp = to_primitive(graph.clone(), node);
+ end_node = temp.0;
+ node_vec.append(&mut temp.1);
+ } else {
+ end_node = node;
+ }
+ }
+ (end_node, node_vec)
+}
diff --git a/skema/skema-rs/skema/src/services/gromet.rs b/skema/skema-rs/skema/src/services/gromet.rs
index 79479c00a31..747f2cdf717 100644
--- a/skema/skema-rs/skema/src/services/gromet.rs
+++ b/skema/skema-rs/skema/src/services/gromet.rs
@@ -7,6 +7,8 @@ use actix_web::web::ServiceConfig;
use actix_web::{delete, get, post, put, web, HttpResponse};
use mathml::acset::{PetriNet, RegNet};
+use mathml::ast::MathExpression;
+use mathml::parsers::math_expression_tree::MathExpressionTree;
use neo4rs;
use neo4rs::{query, Error, Node};
use std::collections::HashMap;
@@ -327,3 +329,39 @@ pub async fn model2RN(
model_to_RN(payload.into_inner(), config1).await.unwrap(),
))
}
+
+/// This returns a MET vector from a gromet.
+#[allow(non_snake_case)]
+#[utoipa::path(
+ request_body = ModuleCollection,
+ responses(
+ (
+ status = 200, description = "Successfully retrieved MET"
+ )
+ )
+)]
+#[put("/models/MET")]
+pub async fn model2MET(
+ payload: web::Json,
+ config: web::Data,
+) -> HttpResponse {
+ let config1 = Config {
+ db_host: config.db_host.clone(),
+ db_port: config.db_port,
+ db_protocol: config.db_protocol.clone(),
+ };
+ let module_id = push_model_to_db(payload.into_inner(), config1.clone()).await; // pushes model to db and gets id
+ let ref_module_id1 = module_id.as_ref();
+ let ref_module_id2 = module_id.as_ref();
+ let mathml_ast = module_id2mathml_MET_ast(*ref_module_id1.unwrap(), config1.clone()).await; // turns model into mathml ast equations
+ let _del_response = delete_module(*ref_module_id2.unwrap(), config1.clone()).await; // deletes model from db
+ let mut mets = Vec::::new();
+ for equation in mathml_ast.iter() {
+ let mut equal_args = Vec::::new();
+ equal_args.push(MathExpressionTree::Atom(MathExpression::Ci(equation.lhs_var.clone())));
+ equal_args.push(equation.rhs.clone());
+ let met = MathExpressionTree::Cons(mathml::ast::operator::Operator::Equals, equal_args.clone());
+ mets.push(met.clone());
+ }
+ HttpResponse::Ok().json(web::Json(mets))
+}
diff --git a/skema/skema-rs/skema/src/services/mathml.rs b/skema/skema-rs/skema/src/services/mathml.rs
index 468bcfd2482..c3bf7ff693c 100644
--- a/skema/skema-rs/skema/src/services/mathml.rs
+++ b/skema/skema-rs/skema/src/services/mathml.rs
@@ -11,8 +11,6 @@ use mathml::parsers::math_expression_tree::{
use mathml::{
acset::{AMRmathml, PetriNet, RegNet},
- ast::Math,
- expression::{preprocess_content, wrap_math},
parsers::first_order_ode::{first_order_ode, FirstOrderODE},
};
use petgraph::dot::{Config, Dot};
@@ -57,11 +55,8 @@ pub async fn get_ast_graph(payload: String) -> String {
#[put("/mathml/math-exp-graph")]
pub async fn get_math_exp_graph(payload: String) -> String {
let mut contents = payload;
- contents = preprocess_content(contents);
- let mut math = contents.parse::().unwrap();
- math.normalize();
- let new_math = wrap_math(math);
- let g = new_math.to_graph();
+ let exp = contents.parse::().unwrap();
+ let g = exp.to_graph();
let dot_representation = Dot::new(&g);
dot_representation.to_string()
}
diff --git a/skema/skema_py/server.py b/skema/skema_py/server.py
index 46683eb5851..960331015db 100644
--- a/skema/skema_py/server.py
+++ b/skema/skema_py/server.py
@@ -8,8 +8,7 @@
from typing import List, Dict, Optional
from io import BytesIO
from zipfile import ZipFile
-from urllib.request import urlopen
-from fastapi import APIRouter, FastAPI, Body, File, UploadFile
+from fastapi import APIRouter, FastAPI, status, Body, File, UploadFile
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
@@ -140,11 +139,11 @@ async def system_to_enriched_system(system: System) -> System:
comments = {"files": {}}
for file_path, result in zip(file_paths, results):
comments["files"][str(file_path)] = result
- system.comments = MultiFileCommentResponse.parse_obj(comments)
+ system.comments = MultiFileCommentResponse(**comments)
return system
-
+# returns an abbreviated Dict representing a GrometFNModuleCollection
async def system_to_gromet(system: System):
"""Convert a System to Gromet JSON"""
@@ -235,9 +234,14 @@ async def system_to_gromet(system: System):
router = APIRouter()
-@router.get("/ping", summary="Ping endpoint to test health of service")
-def ping() -> int:
- return 200
+@router.get(
+ "/healthcheck",
+ summary="Ping endpoint to test health of service",
+ status_code=status.HTTP_200_OK,
+ response_model=int
+)
+def healthcheck() -> int:
+ return status.HTTP_200_OK
@router.get(
diff --git a/skema/skema_py/tests/test_server.py b/skema/skema_py/tests/test_server.py
index d0602baf90f..2d62b8b179f 100644
--- a/skema/skema_py/tests/test_server.py
+++ b/skema/skema_py/tests/test_server.py
@@ -11,9 +11,9 @@
client = TestClient(app)
-def test_ping():
+def test_healthcheck():
"""Test case for /code2fn/ping endpoint."""
- response = client.get("/code2fn/ping")
+ response = client.get("/code2fn/healthcheck")
assert response.status_code == 200
diff --git a/skema/utils/script_functions.py b/skema/utils/script_functions.py
index 2f58051cbf8..afbfaa57b4b 100644
--- a/skema/utils/script_functions.py
+++ b/skema/utils/script_functions.py
@@ -254,6 +254,7 @@ def ann_cast_pipeline(
pdf_file_name = f"{f_name}-AnnCast.pdf"
agraph.to_pdf(pdf_file_name)
+
print("\nCalling GrfnVarCreationPass-------------------")
GrfnVarCreationPass(pipeline_state)