From f31262dd9344893ee78903c233a0e9eec840b3c7 Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Thu, 7 Dec 2023 11:51:35 -0800 Subject: [PATCH 01/22] Scenario-3-Prep (#716) ## Summary of Changes This fixed some wiring bugs for outputting to memgraph and cleaned up / commented the model extraction code some more. Prep work for better support of scenario 3. ### Related issues Resolves ??? --------- Co-authored-by: Justin --- skema/skema-rs/skema/src/database.rs | 403 ++++++++++++++----- skema/skema-rs/skema/src/model_extraction.rs | 261 +++++++----- 2 files changed, 474 insertions(+), 190 deletions(-) diff --git a/skema/skema-rs/skema/src/database.rs b/skema/skema-rs/skema/src/database.rs index b3ee80bf521..4a6f3c24bfc 100644 --- a/skema/skema-rs/skema/src/database.rs +++ b/skema/skema-rs/skema/src/database.rs @@ -68,12 +68,13 @@ pub struct Node { pub box_counter: usize, // this indexes the box call for the node one scope up, matches nbox if higher scope is top level } -#[derive(Debug, Clone, PartialEq, Ord, Eq, PartialOrd)] +#[derive(Debug, Clone, PartialEq, Ord, Eq, PartialOrd, Default)] pub struct Edge { pub src: String, pub tgt: String, pub e_type: String, pub prop: Option, // option because of opo's and opi's + pub refer: Option, } #[derive(Debug, Clone)] @@ -378,7 +379,7 @@ fn create_module(gromet: &ModuleCollection) -> Vec { src: String::from("mod"), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; let edge_query = format!( "{} ({})-[e{}{}:{}]->({})", @@ -453,6 +454,7 @@ fn create_function_net_lib(gromet: &ModuleCollection, mut start: u32) -> Vec Vec Vec Vec Vec Vec Vec Vec tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(boxf.contents.unwrap() as usize), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -991,7 +1004,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -1012,7 +1025,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -1052,7 +1065,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec src: n1.node_id.clone(), tgt: node.node_id.clone(), e_type: String::from("Contains"), - prop: None, + ..Default::default() }; edges.push(e5); } @@ -1110,7 +1123,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec src: wfopi_src_tgt[0].clone(), tgt: wfopi_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e6); } @@ -1163,7 +1176,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec src: wfopo_src_tgt[0].clone(), tgt: wfopo_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e7); } @@ -1182,7 +1195,15 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec } } FunctionType::Imported => { - create_import( + create_att_primitive( + gromet, // gromet for metadata + &mut nodes, // nodes + &mut edges, + &mut meta_nodes, + &mut start, + c_args.clone(), + ); + /*create_import( gromet, &mut nodes, &mut edges, @@ -1199,11 +1220,19 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec c_args.att_idx, c_args.bf_counter, c_args.parent_node.clone(), - ); + );*/ } FunctionType::ImportedMethod => { + create_att_primitive( + gromet, // gromet for metadata + &mut nodes, // nodes + &mut edges, + &mut meta_nodes, + &mut start, + c_args.clone(), + ); // basically seems like these are just functions to me. - c_args.att_idx = boxf.contents.unwrap() as usize; + /*c_args.att_idx = boxf.contents.unwrap() as usize; c_args.att_box = gromet.modules[0].attributes[c_args.att_idx - 1].clone(); create_function( gromet, // gromet for metadata @@ -1212,7 +1241,7 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec &mut meta_nodes, &mut start, c_args.clone(), - ); + );*/ } _ => {} } @@ -1288,7 +1317,13 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec } } // convert every node object into a node query - let create = String::from("CREATE"); + queries.append(&mut construct_memgraph_queries( + &mut nodes, + &mut edges, + &mut meta_nodes, + &mut queries.clone(), + )); + /*let create = String::from("CREATE"); for node in nodes.iter() { let mut name = String::from("a"); if node.name.is_none() { @@ -1344,7 +1379,11 @@ fn create_function_net(gromet: &ModuleCollection, mut start: u32) -> Vec let set_query = format!("set e{}{}.index={}", edge.src, edge.tgt, edge.prop.unwrap()); queries.push(set_query); } - } + if edge.refer.is_some() { + let set_query = format!("set e{}{}.refer={}", edge.src, edge.tgt, edge.refer.unwrap()); + queries.push(set_query); + } + }*/ queries } // this method creates an import type function @@ -1406,7 +1445,7 @@ pub fn create_import( src: c_args.parent_node.node_id, tgt: n3.node_id.clone(), e_type: String::from("Contains"), - prop: None, + ..Default::default() }; edges.push(e4); if eboxf.metadata.is_some() { @@ -1423,7 +1462,7 @@ pub fn create_import( src: n3.node_id, tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -1484,6 +1523,7 @@ pub fn create_function( tgt: n1.node_id.clone(), e_type: String::from("Contains"), prop: Some(c_args.att_idx), + ..Default::default() }; parent_node = n1.clone(); nodes.push(n1.clone()); @@ -1504,7 +1544,7 @@ pub fn create_function( src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -1611,8 +1651,7 @@ pub fn create_function( } FunctionType::ImportedMethod => { // this is a function call, but for some reason is not called a function - new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; - create_function( + create_att_primitive( gromet, // gromet for metadata nodes, // nodes edges, @@ -1620,9 +1659,26 @@ pub fn create_function( start, new_c_args.clone(), ); + /*new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; + create_function( + gromet, // gromet for metadata + nodes, // nodes + edges, + meta_nodes, + start, + new_c_args.clone(), + );*/ } FunctionType::Imported => { - create_import(gromet, nodes, edges, meta_nodes, start, c_args.clone()); + create_att_primitive( + gromet, // gromet for metadata + nodes, // nodes + edges, + meta_nodes, + start, + new_c_args.clone(), + ); + /*create_import(gromet, nodes, edges, meta_nodes, start, c_args.clone()); *start += 1; // now to implement wiring import_wiring( @@ -1632,7 +1688,7 @@ pub fn create_function( c_args.att_idx, c_args.bf_counter, c_args.parent_node.clone(), - ); + );*/ } _ => { println!( @@ -1767,6 +1823,7 @@ pub fn create_conditional( tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(cond_counter as usize), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -1800,7 +1857,7 @@ pub fn create_conditional( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n2.clone()); edges.push(e3); @@ -1818,7 +1875,7 @@ pub fn create_conditional( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -1853,7 +1910,7 @@ pub fn create_conditional( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n3.clone()); edges.push(e5); @@ -1871,7 +1928,7 @@ pub fn create_conditional( src: n3.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -2019,7 +2076,7 @@ pub fn create_conditional( src: wfc_src_tgt[0].clone(), tgt: wfc_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -2072,7 +2129,7 @@ pub fn create_conditional( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e9); } @@ -2119,7 +2176,7 @@ pub fn create_conditional( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e10); } @@ -2167,7 +2224,7 @@ pub fn create_conditional( src: else_src_tgt[0].clone(), tgt: else_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e11); } @@ -2219,7 +2276,7 @@ pub fn create_conditional( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e12); } @@ -2267,7 +2324,7 @@ pub fn create_conditional( src: else_src_tgt[0].clone(), tgt: else_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e13); } @@ -2317,7 +2374,7 @@ pub fn create_conditional( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e14); } @@ -2357,6 +2414,7 @@ pub fn create_for_loop( tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(cond_counter as usize), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -2380,7 +2438,7 @@ pub fn create_for_loop( src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -2492,7 +2550,7 @@ pub fn create_for_loop( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n2.clone()); edges.push(e3); @@ -2510,7 +2568,7 @@ pub fn create_for_loop( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -2544,7 +2602,7 @@ pub fn create_for_loop( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n3.clone()); edges.push(e5); @@ -2562,7 +2620,7 @@ pub fn create_for_loop( src: n3.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -2636,7 +2694,7 @@ pub fn create_for_loop( src: wfl_src_tgt[0].clone(), tgt: wfl_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -2685,7 +2743,7 @@ pub fn create_for_loop( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e9); } @@ -2732,7 +2790,7 @@ pub fn create_for_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e10); } @@ -2782,7 +2840,7 @@ pub fn create_for_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e12); } @@ -2829,7 +2887,7 @@ pub fn create_for_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e15); } @@ -2879,7 +2937,7 @@ pub fn create_for_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e16); } @@ -2928,7 +2986,7 @@ pub fn create_for_loop( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e14); } @@ -2969,6 +3027,7 @@ pub fn create_while_loop( tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(cond_counter as usize), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -2992,7 +3051,7 @@ pub fn create_while_loop( src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3081,7 +3140,7 @@ pub fn create_while_loop( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n2.clone()); edges.push(e3); @@ -3099,7 +3158,7 @@ pub fn create_while_loop( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3133,7 +3192,7 @@ pub fn create_while_loop( src: n1.node_id.clone(), tgt: format!("n{}", start), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; nodes.push(n3.clone()); edges.push(e5); @@ -3151,7 +3210,7 @@ pub fn create_while_loop( src: n3.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3225,7 +3284,7 @@ pub fn create_while_loop( src: wfl_src_tgt[0].clone(), tgt: wfl_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -3274,7 +3333,7 @@ pub fn create_while_loop( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e9); } @@ -3321,7 +3380,7 @@ pub fn create_while_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e10); } @@ -3371,7 +3430,7 @@ pub fn create_while_loop( src: if_src_tgt[0].clone(), tgt: if_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e12); } @@ -3420,7 +3479,7 @@ pub fn create_while_loop( src: cond_src_tgt[0].clone(), tgt: cond_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e14); } @@ -3456,6 +3515,7 @@ pub fn create_att_expression( tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(c_args.att_idx), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -3476,7 +3536,7 @@ pub fn create_att_expression( src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3545,7 +3605,7 @@ pub fn create_att_expression( src: n1.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); if att_box.opo.clone().as_ref().unwrap()[oport as usize] @@ -3569,7 +3629,7 @@ pub fn create_att_expression( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3644,7 +3704,7 @@ pub fn create_att_expression( src: n1.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); if att_box.opi.clone().as_ref().unwrap()[iport as usize] @@ -3668,7 +3728,7 @@ pub fn create_att_expression( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3759,6 +3819,7 @@ pub fn create_att_predicate( tgt: format!("n{}", start), e_type: String::from("Contains"), prop: Some(c_args.att_idx), + ..Default::default() }; nodes.push(n1.clone()); edges.push(e1); @@ -3779,7 +3840,7 @@ pub fn create_att_predicate( src: n1.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3822,7 +3883,7 @@ pub fn create_att_predicate( src: n1.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); if att_box.opo.clone().as_ref().unwrap()[oport as usize] @@ -3846,7 +3907,7 @@ pub fn create_att_predicate( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -3883,7 +3944,7 @@ pub fn create_att_predicate( src: n1.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); if att_box.opi.clone().as_ref().unwrap()[iport as usize] @@ -3907,7 +3968,7 @@ pub fn create_att_predicate( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4009,7 +4070,7 @@ pub fn create_att_literal( src: c_args.parent_node.node_id, tgt: n3.node_id.clone(), e_type: String::from("Contains"), - prop: None, + ..Default::default() }; edges.push(e4); if lit_box.metadata.is_some() { @@ -4026,7 +4087,7 @@ pub fn create_att_literal( src: n3.node_id, tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4085,7 +4146,7 @@ pub fn create_att_primitive( src: c_args.parent_node.node_id, tgt: n3.node_id.clone(), e_type: String::from("Contains"), - prop: None, + ..Default::default() }; edges.push(e4); if c_args.cur_box.metadata.is_some() { @@ -4102,7 +4163,7 @@ pub fn create_att_primitive( src: n3.node_id, tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4123,11 +4184,15 @@ pub fn create_att_abstract( ) { // first find the pof's for box let mut pof: Vec = vec![]; + let mut pof_names: Vec = vec![]; if c_args.att_box.pof.is_some() { let mut po_idx: u32 = 1; for port in c_args.att_box.pof.clone().unwrap().iter() { if port.r#box == c_args.box_counter as u8 { pof.push(po_idx); + if port.name.is_some() { + pof_names.push(port.name.clone().unwrap()); + } } po_idx += 1; } @@ -4143,11 +4208,26 @@ pub fn create_att_abstract( pi_idx += 1; } } + // now to construct an entry of ValueL for abstract port references + let mut value_vec = Vec::::new(); + for name in pof_names.iter() { + let val = ValueL { + value_type: "String".to_string(), + value: format!("{:?}", name.clone()), + gromet_type: Some("Name".to_string()), + }; + value_vec.push(val.clone()); + } + let val = ValueL { + value_type: "List".to_string(), + value: format!("{:?}", value_vec.clone()), + gromet_type: Some("Abstract".to_string()), + }; // now make the node with the port information let mut metadata_idx = 0; let n3 = Node { - n_type: String::from("Primitive"), - value: None, + n_type: String::from("Abstract"), + value: Some(val), name: c_args.cur_box.name.clone(), node_id: format!("n{}", start), out_idx: Some(pof), @@ -4163,7 +4243,7 @@ pub fn create_att_abstract( src: c_args.parent_node.node_id, tgt: n3.node_id.clone(), e_type: String::from("Contains"), - prop: None, + ..Default::default() }; edges.push(e4); if c_args.cur_box.metadata.is_some() { @@ -4180,7 +4260,7 @@ pub fn create_att_abstract( src: n3.node_id, tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4235,7 +4315,7 @@ pub fn create_opo( src: c_args.parent_node.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); @@ -4261,7 +4341,7 @@ pub fn create_opo( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4320,7 +4400,7 @@ pub fn create_opi( src: c_args.parent_node.node_id.clone(), tgt: n2.node_id.clone(), e_type: String::from("Port_Of"), - prop: None, + ..Default::default() }; edges.push(e3); @@ -4346,7 +4426,7 @@ pub fn create_opi( src: n2.node_id.clone(), tgt: format!("m{}", metadata_idx), e_type: String::from("Metadata"), - prop: None, + ..Default::default() }; edges.push(me1); } @@ -4418,6 +4498,7 @@ pub fn wfopi_wiring( tgt: wfopi_src_tgt[1].clone(), e_type: String::from("Wire"), prop: Some(prop.unwrap() as usize), + ..Default::default() }; edges.push(e6); } @@ -4480,7 +4561,7 @@ pub fn wfopo_wiring( src: wfopo_src_tgt[0].clone(), tgt: wfopo_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e7); } @@ -4500,6 +4581,7 @@ pub fn wff_wiring( for wire in eboxf.wff.unwrap().iter() { let mut wff_src_tgt: Vec = vec![]; let mut prop = None; + let mut refer = None; let src_idx = wire.src; // port index @@ -4533,7 +4615,7 @@ pub fn wff_wiring( // push the tgt if (wire.src as u32) == *p { wff_src_tgt.push(node.node_id.clone()); - prop = Some(i as u32); + prop = Some(i); } } } @@ -4555,10 +4637,13 @@ pub fn wff_wiring( // exclude opo's if node.n_type != "Opo" { // iterate through port to check for tgt - for p in node.out_idx.as_ref().unwrap().iter() { + for (i, p) in node.out_idx.as_ref().unwrap().iter().enumerate() { // push the tgt if (wire.tgt as u32) == *p { wff_src_tgt.push(node.node_id.clone()); + if node.n_type == "Abstract" { + refer = Some(i); + } } } } @@ -4572,7 +4657,8 @@ pub fn wff_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: Some(prop.unwrap() as usize), + prop: Some(prop.unwrap()), + refer, }; edges.push(e8); } @@ -4635,7 +4721,7 @@ pub fn wopio_wiring( src: wopio_src_tgt[0].clone(), tgt: wopio_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e7); } @@ -4785,7 +4871,7 @@ pub fn import_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -4873,7 +4959,7 @@ pub fn import_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -5007,7 +5093,7 @@ pub fn wfopi_cross_att_wiring( src: wfopi_src_tgt[0].clone(), tgt: wfopi_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -5091,7 +5177,7 @@ pub fn wfopo_cross_att_wiring( src: wfopo_src_tgt[0].clone(), tgt: wfopo_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -5099,7 +5185,7 @@ pub fn wfopo_cross_att_wiring( } } // this will construct connections from the sub function modules opi's to another sub module opo's, tracing data inside the function -// opi(sub)->opo(sub) +// opi(sub)->opo(sub) or pif(current) -> opo(sub) or opi(sub) -> pof(current) #[allow(unused_assignments)] pub fn wff_cross_att_wiring( eboxf: FunctionNet, // This is the current attribute, should be the function if in a function @@ -5195,7 +5281,7 @@ pub fn wff_cross_att_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -5232,7 +5318,10 @@ pub fn wff_cross_att_wiring( && (tgt_box as u32) == node.box_counter as u32 { // only opo's - if node.n_type == "Primitive" || node.n_type == "Literal" { + if node.n_type == "Primitive" + || node.n_type == "Literal" + || node.n_type == "Abstract" + { // iterate through port to check for tgt for p in node.out_idx.as_ref().unwrap().iter() { // push the src first, being pif @@ -5263,12 +5352,13 @@ pub fn wff_cross_att_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } } } else { + // This should be pif -> opo let src_nbox = bf_counter; // nbox value of src opi // collect info to identify the opo tgt node let tgt_idx = wire.tgt; // port index @@ -5298,11 +5388,11 @@ pub fn wff_cross_att_wiring( && (src_box as u32) == node.box_counter as u32 { // only opo's - if node.n_type == "Primitive" { + if node.n_type == "Primitive" || node.n_type == "Abstract" { // iterate through port to check for tgt for p in node.in_indx.as_ref().unwrap().iter() { // push the src first, being pif - if (src_opi_idx as u32) == *p { + if (src_idx as u32) == *p { wff_src_tgt.push(node.node_id.clone()); } } @@ -5335,7 +5425,7 @@ pub fn wff_cross_att_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e8); } @@ -5449,7 +5539,7 @@ pub fn external_wiring(gromet: &ModuleCollection, nodes: &mut [Node], edges: &mu src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), - prop: None, + ..Default::default() }; edges.push(e9); } @@ -5467,3 +5557,120 @@ pub fn parse_gromet_queries(gromet: ModuleCollection) -> Vec { queries } + +// convert every node object into a node query +pub fn construct_memgraph_queries( + nodes: &mut Vec, + edges: &mut Vec, + meta_nodes: &mut Vec, + queries: &mut Vec, +) -> Vec { + // convert every node object into a node query + let create = String::from("CREATE"); + for node in nodes.iter() { + let mut name = String::from("a"); + if node.name.is_none() { + name = node.n_type.clone(); + } else { + name = node.name.as_ref().unwrap().to_string(); + } + // better parsing of values for inference later on. + // handles case of parsing a list as a proper list object, only depth one though + // would need recursive function for aritrary depth. To be done at somepoint. + let value = match &node.value { + Some(val) => { + if val.value_type == *"List" && &val.value[0..1] == "[" && &val.value[1..2] != "]" { + let val_type = val.value_type.clone(); + let val_grom_type = val.gromet_type.as_ref().unwrap(); + let val_len = val.value[..].len(); + let val_val: Vec = val.value[1..val_len] + .split("}, ") + .map(|x| x.to_string()) + .collect(); + + let mut val_vec = Vec::::new(); + for (i, val) in val_val.iter().enumerate() { + if i == val_val.len() - 1 { + let val_string = format!("{}}}", &val[7..(val.len() - 2)]); + val_vec.push(val_string.clone()); + } else { + let val_string = format!("{}}}", &val[7..]); + val_vec.push(val_string.clone()); + } + } + let mut final_val_vec = Vec::::new(); + for val_str in val_vec.iter() { + let val_fields: Vec = + val_str.split(", ").map(|x| x.to_string()).collect(); + let cor_val: Vec = + val_fields[1].split(": ").map(|x| x.to_string()).collect(); + let final_val = cor_val[1].replace("\\\"", ""); + final_val_vec.push(final_val.clone()); + } + format!( + "{{ value_type:{:?}, value:{:?}, gromet_type:{:?} }}", + val_type, final_val_vec, val_grom_type + ) + .replace("\\\"", "") + } else { + format!( + "{{ value_type:{:?}, value:{:?}, gromet_type:{:?} }}", + val.value_type, + val.value, + val.gromet_type.as_ref().unwrap() + ) + } + } + None => String::from("\"\""), + }; + + // NOTE: The format of value has changed to represent a literal Cypher map {field:value}. + // We no longer need to format value with the debug :? parameter + let node_query = format!( + "{} ({}:{} {{name:{:?},value:{},order_box:{:?},order_att:{:?}}})", + create, node.node_id, node.n_type, name, value, node.nbox, node.contents + ); + queries.push(node_query); + } + for node in meta_nodes.iter() { + queries.append(&mut create_metadata_node_query(node.clone())); + } + + // convert every edge object into an edge query + let init_edges = edges.len(); + edges.sort(); + edges.dedup(); + let edges_clone = edges.clone(); + // also dedup if edge prop is different + for (i, edge) in edges_clone.iter().enumerate().rev() { + if i != 0 && edge.src == edges_clone[i - 1].src && edge.tgt == edges_clone[i - 1].tgt { + edges.remove(i); + } + } + let fin_edges = edges.len(); + if init_edges != fin_edges { + println!("Duplicated Edges Removed, check for bugs"); + } + for edge in edges.iter() { + let edge_query = format!( + "{} ({})-[e{}{}:{}]->({})", + create, edge.src, edge.src, edge.tgt, edge.e_type, edge.tgt + ); + queries.push(edge_query); + + if edge.prop.is_some() { + let set_query = format!("set e{}{}.index={}", edge.src, edge.tgt, edge.prop.unwrap()); + queries.push(set_query); + } + if edge.refer.is_some() { + let set_query = format!( + "set e{}{}.refer={}", + edge.src, + edge.tgt, + edge.refer.unwrap() + ); + queries.push(set_query); + } + } + queries.to_vec() +} diff --git a/skema/skema-rs/skema/src/model_extraction.rs b/skema/skema-rs/skema/src/model_extraction.rs index 9f50b7336f1..abd9dd674b0 100644 --- a/skema/skema-rs/skema/src/model_extraction.rs +++ b/skema/skema-rs/skema/src/model_extraction.rs @@ -17,11 +17,51 @@ use neo4rs; use neo4rs::{query, Error}; use std::sync::Arc; +/// This struct is the node struct for the constructed petgraph +#[derive(Clone, Debug)] +pub struct ModelNode { + id: i64, + label: String, + name: Option, + value: Option, +} + +/// This struct is the edge struct for the constructed petgraph +#[derive(Clone, Debug)] +pub struct ModelEdge { + id: i64, + src_id: i64, + tgt_id: i64, + index: Option, + refer: Option, +} + +/** + * This is the main function call for model extraction. + * + * Parameters: + * - module_id: i64 -> This is the top level id of the gromet module in memgraph. + * - config: Config -> This is a config struct for connecting to memgraph + * + * Returns: + * - Vector of FirstOrderODE -> This vector of structs is used to construct a PetriNet or RegNet further down the pipeline + * + * Assumptions: + * - As of right now, we can always assume the code has been sliced to only one relevant function which contains the + * core dynamics in it somewhere + * + * Notes: + * - FirstOrderODE is primarily composed of a LHS and a RHS, + * - LHS is just a Mi object of the state being differentiated. There are additional fields for the LHS but only the + * content field is used in downstream inference for now. + * - RHS is where the bulk of the inference happens, it produces an expression tree, hence the MET -> Math Expression Tree. + * Every operator has a vector of arguments. (order matters) + */ #[allow(non_snake_case)] pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec { let mut core_dynamics_ast = Vec::::new(); - let core_id = find_pn_dynamics(module_id, config.clone()).await; // gives back list of function nodes that might contain the dynamics + let core_id = find_pn_dynamics(module_id, config.clone()).await; if core_id.is_empty() { let deriv = Ci { @@ -48,24 +88,31 @@ pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec Vec { let graph = subgraph2petgraph(module_id, config.clone()).await; // 1. find each function node let mut function_nodes = Vec::::new(); for node in graph.node_indices() { - if graph[node].labels()[0] == *"Function" { + if graph[node].label == *"Function" { function_nodes.push(node); } } // 2. check and make sure only expressions in function // 3. check number of expressions and decide off that - let mut functions = Vec::>::new(); + let mut functions = Vec::>::new(); for i in 0..function_nodes.len() { // grab the subgraph of the given expression - functions.push(subgraph2petgraph(graph[function_nodes[i]].id(), config.clone()).await); + functions.push(subgraph2petgraph(graph[function_nodes[i]].id, config.clone()).await); } // get a sense of the number of expressions in each function let mut func_counter = 0; @@ -74,17 +121,17 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec { let mut expression_counter = 0; let mut primitive_counter = 0; for node in func.node_indices() { - if func[node].labels()[0] == *"Expression" { + if func[node].label == *"Expression" { expression_counter += 1; } - if func[node].labels()[0] == *"Primitive" { - if func[node].get::("name").unwrap() == *"ast.Mult" { + if func[node].label == *"Primitive" { + if *func[node].name.as_ref().unwrap() == "ast.Mult".to_string() { primitive_counter += 1; - } else if func[node].get::("name").unwrap() == *"ast.Add" { + } else if *func[node].name.as_ref().unwrap() == "ast.Add".to_string() { primitive_counter += 1; - } else if func[node].get::("name").unwrap() == *"ast.Sub" { + } else if *func[node].name.as_ref().unwrap() == "ast.Sub".to_string() { primitive_counter += 1; - } else if func[node].get::("name").unwrap() == *"ast.USub" { + } else if *func[node].name.as_ref().unwrap() == "ast.USub".to_string() { primitive_counter += 1; } } @@ -98,8 +145,8 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec { let mut core_id = Vec::::new(); for c_func in core_func.iter() { for node in functions[*c_func].node_indices() { - if functions[*c_func][node].labels()[0] == *"Function" { - core_id.push(functions[*c_func][node].id()); + if functions[*c_func][node].label == *"Function" { + core_id.push(functions[*c_func][node].id); } } } @@ -107,6 +154,12 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec { core_id } + +/** + * Once the function node has been identified, this function takes it from there to extract the vector of FirstOrderODE's + * + * This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case. + */ #[allow(non_snake_case)] pub async fn subgrapg2_core_dyn_MET_ast( root_node_id: i64, @@ -118,7 +171,7 @@ pub async fn subgrapg2_core_dyn_MET_ast( // find all the expressions let mut expression_nodes = Vec::::new(); for node in graph.node_indices() { - if graph[node].labels()[0] == *"Expression" { + if graph[node].label == *"Expression" { expression_nodes.push(node); } } @@ -128,14 +181,14 @@ pub async fn subgrapg2_core_dyn_MET_ast( // initialize vector to collect all expression wiring graphs for i in 0..expression_nodes.len() { // grab the wiring subgraph of the given expression - let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id(), config.clone()) + let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id, config.clone()) .await .unwrap(); if sub_w.node_count() > 3 { - let expr = trim_un_named(&mut sub_w, config.clone()).await; + let expr = trim_un_named(&mut sub_w).await; let mut root_node = Vec::::new(); for node_index in expr.node_indices() { - if expr[node_index].labels()[0].clone() == *"Opo" { + if expr[node_index].label.clone() == *"Opo" { root_node.push(node_index); } } @@ -150,17 +203,21 @@ pub async fn subgrapg2_core_dyn_MET_ast( Ok(core_dynamics) } + +/** + * This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it. + */ #[allow(non_snake_case)] fn tree_2_MET_ast( - graph: &mut petgraph::Graph, + graph: &mut petgraph::Graph, root_node: NodeIndex, ) -> Result { let mut fo_eq_vec = Vec::::new(); let _math_vec = Vec::::new(); let mut lhs = Vec::::new(); - if graph[root_node].labels()[0] == *"Opo" { + if graph[root_node].label == *"Opo" { // we first construct the derivative of the first node - let deriv_name: &str = &graph[root_node].get::("name").unwrap(); + let deriv_name: &str = graph[root_node].name.as_ref().unwrap(); // this will let us know if additional trimming is needed to handle the code implementation of the equations // let mut step_impl = false; this will be used for step implementaion for later // This is very bespoke right now @@ -183,7 +240,7 @@ fn tree_2_MET_ast( lhs.push(deriv); } for node in graph.neighbors_directed(root_node, Outgoing) { - if graph[node].labels()[0].clone() == *"Primitive" { + if graph[node].label.clone() == *"Primitive" { let operate = get_operator_MET(graph, node); // output -> Operator let rhs_arg = get_args_MET(graph, node); // output -> Vec let rhs = MathExpressionTree::Cons(operate, rhs_arg); // MathExpressionTree @@ -204,9 +261,10 @@ fn tree_2_MET_ast( Ok(fo_eq_vec[0].clone()) } +/// This is a recursive function that walks along the wired subgraph of an expression to construct the expression tree #[allow(non_snake_case)] pub fn get_args_MET( - graph: &petgraph::Graph, + graph: &petgraph::Graph, root_node: NodeIndex, ) -> Vec { let mut args = Vec::::new(); @@ -219,14 +277,14 @@ pub fn get_args_MET( // construct vecs for node in graph.neighbors_directed(root_node, Outgoing) { // first need to check for operator - if graph[node].labels()[0].clone() == *"Primitive" { + if graph[node].label.clone() == *"Primitive" { let operate = get_operator_MET(graph, node); // output -> Operator let rhs_arg = get_args_MET(graph, node); // output -> Vec let rhs = MathExpressionTree::Cons(operate, rhs_arg); // MathExpressionTree args.push(rhs.clone()); } else { // asummption it is atomic - let temp_string = graph[node].get::("name").unwrap().clone(); + let temp_string = graph[node].name.as_ref().unwrap().clone(); let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone()))); args.push(arg2.clone()); } @@ -235,7 +293,7 @@ pub fn get_args_MET( let x = graph .edge_weight(graph.find_edge(root_node, node).unwrap()) .unwrap() - .get::("index") + .index .unwrap(); arg_order.push(x); } @@ -253,44 +311,48 @@ pub fn get_args_MET( ordered_args } -// this gets the operator from the node name +/// This gets the operator from the node name #[allow(non_snake_case)] #[allow(clippy::if_same_then_else)] pub fn get_operator_MET( - graph: &petgraph::Graph, + graph: &petgraph::Graph, root_node: NodeIndex, ) -> Operator { let mut op = Vec::::new(); - if graph[root_node].get::("name").unwrap() == *"ast.Mult" { + if *graph[root_node].name.as_ref().unwrap() == "ast.Mult".to_string() { op.push(Operator::Multiply); - } else if graph[root_node].get::("name").unwrap() == *"ast.Add" { + } else if *graph[root_node].name.as_ref().unwrap() == "ast.Add" { op.push(Operator::Add); - } else if graph[root_node].get::("name").unwrap() == *"ast.Sub" { + } else if *graph[root_node].name.as_ref().unwrap() == "ast.Sub" { op.push(Operator::Subtract); - } else if graph[root_node].get::("name").unwrap() == *"ast.USub" { + } else if *graph[root_node].name.as_ref().unwrap() == "ast.USub" { op.push(Operator::Subtract); - } else if graph[root_node].get::("name").unwrap() == *"ast.Div" { + } else if *graph[root_node].name.as_ref().unwrap() == "ast.Div" { op.push(Operator::Divide); } else { - op.push(Operator::Other( - graph[root_node].get::("name").unwrap(), - )); + op.push(Operator::Other(graph[root_node].name.clone().unwrap())); } op[0].clone() } -// this currently only works for un-named nodes that are not chained or have multiple incoming/outgoing edges +/** + * This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes. + * + * This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph. + * For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely. + * + * Concerns: + * - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice, + * but I think it's possible. So something to keep in mind. + */ async fn trim_un_named( - graph: &mut petgraph::Graph, - config: Config, -) -> &mut petgraph::Graph { + graph: &mut petgraph::Graph, +) -> &mut petgraph::Graph { // first create a cloned version of the graph we can modify while iterating over it. - let graph_call = Arc::new(config.graphdb_connection().await); - // iterate over the graph and add a new edge to bypass the un-named nodes for node_index in graph.node_indices() { - if graph[node_index].get::("name").unwrap().clone() == *"un-named" { + if graph[node_index].clone().name.unwrap().clone() == *"un-named" { let mut bypass = Vec::::new(); for node1 in graph.neighbors_directed(node_index, Incoming) { bypass.push(node1); @@ -301,21 +363,14 @@ async fn trim_un_named( // one incoming one outgoing if bypass.len() == 2 { // annoyingly have to pull the edge/Relation to insert into graph - let mut edge_list = Vec::::new(); - let query_string = format!( - "MATCH (n)-[r:Wire]->(m) WHERE id(n) = {} AND id(m) = {} RETURN r", - graph[bypass[0]].id(), - graph[node_index].id() + graph.add_edge( + bypass[0], + bypass[1], + graph + .edge_weight(graph.find_edge(bypass[0], node_index).unwrap()) + .unwrap() + .clone(), ); - let mut result = graph_call.execute(query(&query_string[..])).await.unwrap(); - while let Ok(Some(row)) = result.next().await { - let edge: neo4rs::Relation = row.get("r").unwrap(); - edge_list.push(edge); - } - // add the bypass edge - for edge in edge_list { - graph.add_edge(bypass[0], bypass[1], edge); - } } else if bypass.len() > 2 { // this operates on the assumption that there maybe multiple references to the port // (incoming arrows) but only one outgoing arrow, this seems to be the case based on @@ -324,31 +379,24 @@ async fn trim_un_named( let end_node_idx = bypass.len() - 1; for (i, _ent) in bypass[0..end_node_idx].iter().enumerate() { // this iterates over all but the last entry in the bypass vec - let mut edge_list = Vec::::new(); - let query_string = format!( - "MATCH (n)-[r:Wire]->(m) WHERE id(n) = {} AND id(m) = {} RETURN r", - graph[bypass[i]].id(), - graph[node_index].id() + graph.add_edge( + bypass[i], + bypass[end_node_idx], + graph + .edge_weight(graph.find_edge(bypass[i], node_index).unwrap()) + .unwrap() + .clone(), ); - let mut result = graph_call.execute(query(&query_string[..])).await.unwrap(); - while let Ok(Some(row)) = result.next().await { - let edge: neo4rs::Relation = row.get("r").unwrap(); - edge_list.push(edge); - } - - for edge in edge_list { - graph.add_edge(bypass[i], bypass[end_node_idx], edge); - } } } } } - // now we perform a filter_map to remove the un-named nodes and only the bypass edge will remain to connect the nodes + // now we remove the un-named nodes and only the bypass edge will remain to connect the nodes // we also remove the unpack node if it is present here as well for node_index in graph.node_indices().rev() { - if graph[node_index].get::("name").unwrap().clone() == *"un-named" - || graph[node_index].get::("name").unwrap().clone() == *"unpack" + if graph[node_index].name.clone().unwrap() == *"un-named" + || graph[node_index].name.clone().unwrap() == *"unpack" { graph.remove_node(node_index); } @@ -357,12 +405,14 @@ async fn trim_un_named( graph } +/// This function takes in a node id (typically that of an expression subgraph) and returns a +/// petgraph subgraph of only the wire type edges async fn subgraph_wiring( module_id: i64, config: Config, -) -> Result, Error> { - let mut node_list = Vec::::new(); - let mut edge_list = Vec::::new(); +) -> Result, Error> { + let mut node_list = Vec::::new(); + let mut edge_list = Vec::::new(); // Connect to Memgraph. let graph = Arc::new(config.graphdb_connection().await); @@ -382,7 +432,13 @@ async fn subgraph_wiring( .await?; while let Ok(Some(row)) = result1.next().await { let node: neo4rs::Node = row.get("nodes2").unwrap(); - node_list.push(node); + let modelnode = ModelNode { + id: node.id(), + label: node.labels()[0].clone(), + name: node.get::("name"), + value: node.get::("value"), + }; + node_list.push(modelnode); } // edge query let mut result2 = graph @@ -400,10 +456,17 @@ async fn subgraph_wiring( .await?; while let Ok(Some(row)) = result2.next().await { let edge: neo4rs::Relation = row.get("edges2").unwrap(); - edge_list.push(edge); + let modeledge = ModelEdge { + id: edge.id(), + src_id: edge.start_node_id(), + tgt_id: edge.end_node_id(), + index: edge.get::("index"), + refer: edge.get::("refer"), + }; + edge_list.push(modeledge); } - let mut graph: petgraph::Graph = Graph::new(); + let mut graph: petgraph::Graph = Graph::new(); // Add nodes to the petgraph graph and collect their indexes let mut nodes = Vec::::new(); @@ -417,10 +480,10 @@ async fn subgraph_wiring( let mut src = Vec::::new(); let mut tgt = Vec::::new(); for node_idx in &nodes { - if graph[*node_idx].id() == edge.start_node_id() { + if graph[*node_idx].id == edge.src_id { src.push(*node_idx); } - if graph[*node_idx].id() == edge.end_node_id() { + if graph[*node_idx].id == edge.tgt_id { tgt.push(*node_idx); } } @@ -431,14 +494,15 @@ async fn subgraph_wiring( Ok(graph) } +/// This function takes in a node id and returns a petgraph represention of the memgraph graph async fn subgraph2petgraph( module_id: i64, config: Config, -) -> petgraph::Graph { +) -> petgraph::Graph { let (x, y) = get_subgraph(module_id, config.clone()).await.unwrap(); // Create a petgraph graph - let mut graph: petgraph::Graph = Graph::new(); + let mut graph: petgraph::Graph = Graph::new(); // Add nodes to the petgraph graph and collect their indexes let mut nodes = Vec::::new(); @@ -452,10 +516,10 @@ async fn subgraph2petgraph( let mut src = Vec::::new(); let mut tgt = Vec::::new(); for node_idx in &nodes { - if graph[*node_idx].id() == edge.start_node_id() { + if graph[*node_idx].id == edge.src_id { src.push(*node_idx); } - if graph[*node_idx].id() == edge.end_node_id() { + if graph[*node_idx].id == edge.tgt_id { tgt.push(*node_idx); } } @@ -466,14 +530,14 @@ async fn subgraph2petgraph( graph } +/// This function takes in a node id and returns the nodes and edges in it pub async fn get_subgraph( module_id: i64, config: Config, -) -> Result<(Vec, Vec), Error> { - // construct the query that will delete the module with a given unique identifier +) -> Result<(Vec, Vec), Error> { - let mut node_list = Vec::::new(); - let mut edge_list = Vec::::new(); + let mut node_list = Vec::::new(); + let mut edge_list = Vec::::new(); // Connect to Memgraph. let graph = Arc::new(config.graphdb_connection().await); @@ -492,7 +556,13 @@ pub async fn get_subgraph( .await?; while let Ok(Some(row)) = result1.next().await { let node: neo4rs::Node = row.get("nodes2").unwrap(); - node_list.push(node); + let modelnode = ModelNode { + id: node.id(), + label: node.labels()[0].clone(), + name: node.get::("name"), + value: node.get::("value"), + }; + node_list.push(modelnode); } // edge query let mut result2 = graph @@ -509,7 +579,14 @@ pub async fn get_subgraph( .await?; while let Ok(Some(row)) = result2.next().await { let edge: neo4rs::Relation = row.get("edges2").unwrap(); - edge_list.push(edge); + let modeledge = ModelEdge { + id: edge.id(), + src_id: edge.start_node_id(), + tgt_id: edge.end_node_id(), + index: edge.get::("index"), + refer: edge.get::("refer"), + }; + edge_list.push(modeledge); } Ok((node_list, edge_list)) From 094b7eced45a3a64a1fc2754fd1cc21fca4df852 Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Thu, 7 Dec 2023 12:54:48 -0800 Subject: [PATCH 02/22] Initial Function as Argument support for code2amr (#723) ## Summary of Changes This updates the database.rs to support function calls as arguments, this required adding support for sub-expressions inside expressions. These sub-expressions are related to the calling of the argument functions This also updates model_extraction.rs to add support for handling these sub-expressions to get the full AMR's out. This includes new trimming to delete the sub-expressions while connecting the primitives directly to the argument functions. ### Related issues Resolves #622 --------- Co-authored-by: Justin --- skema/skema-rs/skema/src/database.rs | 28 ++- skema/skema-rs/skema/src/model_extraction.rs | 206 +++++++++++++++---- 2 files changed, 189 insertions(+), 45 deletions(-) diff --git a/skema/skema-rs/skema/src/database.rs b/skema/skema-rs/skema/src/database.rs index 4a6f3c24bfc..9e3fbdd8746 100644 --- a/skema/skema-rs/skema/src/database.rs +++ b/skema/skema-rs/skema/src/database.rs @@ -3745,9 +3745,7 @@ pub fn create_att_expression( for att_sub_box in att_box.bf.as_ref().unwrap().iter() { new_c_args.box_counter = box_counter; new_c_args.cur_box = att_sub_box.clone(); - if att_sub_box.contents.is_some() { - new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; - } + new_c_args.att_idx = c_args.att_idx; match att_sub_box.function_type { FunctionType::Literal => { create_att_literal( @@ -3769,6 +3767,17 @@ pub fn create_att_expression( new_c_args.clone(), ); } + FunctionType::Expression => { + new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; + create_att_expression( + gromet, // gromet for metadata + nodes, // nodes + edges, + meta_nodes, + start, + new_c_args.clone(), + ); + } _ => {} } box_counter += 1; @@ -3784,6 +3793,14 @@ pub fn create_att_expression( c_args.bf_counter, ); + cross_att_wiring( + att_box.clone(), + nodes, + edges, + c_args.att_idx, + c_args.bf_counter, + ); + // Now we also perform wopio wiring in case there is an empty expression if att_box.wopio.is_some() { wopio_wiring(att_box, nodes, edges, c_args.att_idx - 1, c_args.bf_counter); @@ -5195,6 +5212,7 @@ pub fn wff_cross_att_wiring( bf_counter: u8, // this is the current box ) { for wire in eboxf.wff.as_ref().unwrap().iter() { + let mut prop = None; // collect info to identify the opi src node let src_idx = wire.src; // port index let src_pif = eboxf.pif.as_ref().unwrap()[(src_idx - 1) as usize].clone(); // src port @@ -5390,10 +5408,11 @@ pub fn wff_cross_att_wiring( // only opo's if node.n_type == "Primitive" || node.n_type == "Abstract" { // iterate through port to check for tgt - for p in node.in_indx.as_ref().unwrap().iter() { + for (i, p) in node.in_indx.as_ref().unwrap().iter().enumerate() { // push the src first, being pif if (src_idx as u32) == *p { wff_src_tgt.push(node.node_id.clone()); + prop = Some(i); } } } @@ -5425,6 +5444,7 @@ pub fn wff_cross_att_wiring( src: wff_src_tgt[0].clone(), tgt: wff_src_tgt[1].clone(), e_type: String::from("Wire"), + prop: Some(prop.unwrap()), ..Default::default() }; edges.push(e8); diff --git a/skema/skema-rs/skema/src/model_extraction.rs b/skema/skema-rs/skema/src/model_extraction.rs index abd9dd674b0..505227309a0 100644 --- a/skema/skema-rs/skema/src/model_extraction.rs +++ b/skema/skema-rs/skema/src/model_extraction.rs @@ -1,7 +1,10 @@ use crate::config::Config; + use mathml::ast::operator::Operator; pub use mathml::mml2pn::{ACSet, Term}; + use petgraph::prelude::*; +use petgraph::visit::IntoNeighborsDirected; use std::string::ToString; @@ -38,22 +41,22 @@ pub struct ModelEdge { /** * This is the main function call for model extraction. - * + * * Parameters: - * - module_id: i64 -> This is the top level id of the gromet module in memgraph. + * - module_id: i64 -> This is the top level id of the gromet module in memgraph. * - config: Config -> This is a config struct for connecting to memgraph - * + * * Returns: * - Vector of FirstOrderODE -> This vector of structs is used to construct a PetriNet or RegNet further down the pipeline - * + * * Assumptions: - * - As of right now, we can always assume the code has been sliced to only one relevant function which contains the + * - As of right now, we can always assume the code has been sliced to only one relevant function which contains the * core dynamics in it somewhere - * - * Notes: + * + * Notes: * - FirstOrderODE is primarily composed of a LHS and a RHS, * - LHS is just a Mi object of the state being differentiated. There are additional fields for the LHS but only the - * content field is used in downstream inference for now. + * content field is used in downstream inference for now. * - RHS is where the bulk of the inference happens, it produces an expression tree, hence the MET -> Math Expression Tree. * Every operator has a vector of arguments. (order matters) */ @@ -88,14 +91,14 @@ pub async fn module_id2mathml_MET_ast(module_id: i64, config: Config) -> Vec Vec { @@ -154,11 +157,10 @@ pub async fn find_pn_dynamics(module_id: i64, config: Config) -> Vec { core_id } - /** * Once the function node has been identified, this function takes it from there to extract the vector of FirstOrderODE's - * - * This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case. + * + * This is based heavily on the assumption that each equation is in a seperate expression which breaks for the vector case. */ #[allow(non_snake_case)] pub async fn subgrapg2_core_dyn_MET_ast( @@ -184,8 +186,23 @@ pub async fn subgrapg2_core_dyn_MET_ast( let mut sub_w = subgraph_wiring(graph[expression_nodes[i]].id, config.clone()) .await .unwrap(); - if sub_w.node_count() > 3 { - let expr = trim_un_named(&mut sub_w).await; + let mut prim_counter = 0; + let mut has_call = false; + for node_index in sub_w.node_indices() { + if sub_w[node_index].label == *"Primitive" { + prim_counter += 1; + if *sub_w[node_index].name.as_ref().unwrap() == "_call" { + has_call = true; + } + } + } + if sub_w.node_count() > 3 && !(prim_counter == 1 && has_call) && prim_counter != 0 { + println!("expression: {}", graph[expression_nodes[i]].id); + // the call expressions get referenced by multiple top level expressions, so deleting the nodes in it breaks the other graphs. Need to pass clone of expression subgraph so references to original has all the nodes. + if has_call { + sub_w = trim_calls(sub_w.clone()) + } + let expr = trim_un_named(&mut sub_w); let mut root_node = Vec::::new(); for node_index in expr.node_indices() { if expr[node_index].label.clone() == *"Opo" { @@ -193,7 +210,7 @@ pub async fn subgrapg2_core_dyn_MET_ast( } } if root_node.len() >= 2 { - // println!("More than one Opo! Skipping Expression!"); + println!("More than one Opo! Skipping Expression!"); } else { core_dynamics.push(tree_2_MET_ast(expr, root_node[0]).unwrap()); } @@ -203,9 +220,8 @@ pub async fn subgrapg2_core_dyn_MET_ast( Ok(core_dynamics) } - /** - * This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it. + * This function is designed to take in a petgraph instance of a wires only expression subgraph and output a FirstOrderODE equations representing it. */ #[allow(non_snake_case)] fn tree_2_MET_ast( @@ -336,16 +352,16 @@ pub fn get_operator_MET( } /** - * This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes. - * - * This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph. - * For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely. - * - * Concerns: - * - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice, - * but I think it's possible. So something to keep in mind. + * This function takes in a wiring only petgraph of an expression and trims off the un-named nodes and unpack nodes. + * + * This is done by creating new edges that bypass the un-named nodes and then deleting them from the graph. + * For deleting the unpacks, the assumption is they are always terminal in the subgraph and can be deleted freely. + * + * Concerns: + * - I don't think this will work if there are multiple un-named nodes changed together. I haven't seen this in practice, + * but I think it's possible. So something to keep in mind. */ -async fn trim_un_named( +fn trim_un_named( graph: &mut petgraph::Graph, ) -> &mut petgraph::Graph { // first create a cloned version of the graph we can modify while iterating over it. @@ -354,34 +370,34 @@ async fn trim_un_named( for node_index in graph.node_indices() { if graph[node_index].clone().name.unwrap().clone() == *"un-named" { let mut bypass = Vec::::new(); + let mut outgoing_bypass = Vec::::new(); for node1 in graph.neighbors_directed(node_index, Incoming) { bypass.push(node1); } for node2 in graph.neighbors_directed(node_index, Outgoing) { - bypass.push(node2); + outgoing_bypass.push(node2); } // one incoming one outgoing - if bypass.len() == 2 { + if bypass.len() == 1 && outgoing_bypass.len() == 1 { // annoyingly have to pull the edge/Relation to insert into graph graph.add_edge( bypass[0], - bypass[1], + outgoing_bypass[0], graph .edge_weight(graph.find_edge(bypass[0], node_index).unwrap()) .unwrap() .clone(), ); - } else if bypass.len() > 2 { + } else if bypass.len() >= 2 && outgoing_bypass.len() == 1 { // this operates on the assumption that there maybe multiple references to the port // (incoming arrows) but only one outgoing arrow, this seems to be the case based on // data too. - let end_node_idx = bypass.len() - 1; - for (i, _ent) in bypass[0..end_node_idx].iter().enumerate() { + for (i, _ent) in bypass.iter().enumerate() { // this iterates over all but the last entry in the bypass vec graph.add_edge( bypass[i], - bypass[end_node_idx], + outgoing_bypass[0], graph .edge_weight(graph.find_edge(bypass[i], node_index).unwrap()) .unwrap() @@ -405,7 +421,7 @@ async fn trim_un_named( graph } -/// This function takes in a node id (typically that of an expression subgraph) and returns a +/// This function takes in a node id (typically that of an expression subgraph) and returns a /// petgraph subgraph of only the wire type edges async fn subgraph_wiring( module_id: i64, @@ -535,7 +551,6 @@ pub async fn get_subgraph( module_id: i64, config: Config, ) -> Result<(Vec, Vec), Error> { - let mut node_list = Vec::::new(); let mut edge_list = Vec::::new(); @@ -591,3 +606,112 @@ pub async fn get_subgraph( Ok((node_list, edge_list)) } + +// this does special trimming to handle function calls +pub fn trim_calls( + graph: petgraph::Graph, +) -> petgraph::Graph { + let mut graph_clone = graph.clone(); + + // This will be all the nodes to be deleted + let mut inner_nodes = Vec::::new(); + // find the call nodes + for node_index in graph.node_indices() { + if graph[node_index].clone().name.unwrap().clone() == *"_call" { + // we now trace up the incoming path until we hit a primitive, + // this will be the start node for the new edge. + + // initialize trackers + let mut node_start = node_index; + let mut node_end = node_index; + + // find end node and track path + for node in graph.neighbors_directed(node_index, Outgoing) { + if graph + .edge_weight(graph.find_edge(node_index, node).unwrap()) + .unwrap() + .index + .unwrap() + == 0 + { + let mut temp = to_terminal(graph.clone(), node); + node_end = temp.0; + inner_nodes.append(&mut temp.1); + } + } + + // find start primtive node and track path + for node in graph.neighbors_directed(node_index, Incoming) { + let mut temp = to_primitive(graph.clone(), node); + node_start = temp.0; + inner_nodes.append(&mut temp.1); + } + + // add edge from start to end node, with weight from start node a matching outgoing node form it + for node in graph.clone().neighbors_directed(node_start, Outgoing) { + for node_p in inner_nodes.iter() { + if node == *node_p { + graph_clone.add_edge( + node_start, + node_end, + graph + .clone() + .edge_weight(graph.clone().find_edge(node_start, node).unwrap()) + .unwrap() + .clone(), + ); + } + } + } + // we keep track all the node indexes we found while tracing the path and delete all + // intermediate nodes. + inner_nodes.push(node_index); + } + } + inner_nodes.sort(); + for node in inner_nodes.iter().rev() { + graph_clone.remove_node(*node); + } + + graph_clone +} + +pub fn to_terminal( + graph: petgraph::Graph, + node_index: NodeIndex, +) -> (NodeIndex, Vec) { + let mut node_vec = Vec::::new(); + let mut end_node = node_index; + // if there another node deeper + // else pass original input node out and an empty path vector + if graph.neighbors_directed(node_index, Outgoing).count() != 0 { + node_vec.push(node_index); // add current node to path list + for node in graph.neighbors_directed(node_index, Outgoing) { + // pass next node forward + let mut temp = to_terminal(graph.clone(), node); + end_node = temp.0; // make end_node + node_vec.append(&mut temp.1); // append previous path nodes + } + } + (end_node, node_vec) +} + +// incoming walker to first primitive (NOTE: assumes input is not a primitive) +pub fn to_primitive( + graph: petgraph::Graph, + node_index: NodeIndex, +) -> (NodeIndex, Vec) { + let mut node_vec = Vec::::new(); + let mut end_node = node_index; + node_vec.push(node_index); + for node in graph.neighbors_directed(node_index, Incoming) { + if graph[node].label.clone() != *"Primitive" { + let mut temp = to_primitive(graph.clone(), node); + end_node = temp.0; + node_vec.append(&mut temp.1); + } else { + end_node = node; + } + } + (end_node, node_vec) +} From 44f3d23f8274358614012df85de0ae30c2d98ef9 Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:03:17 -0800 Subject: [PATCH 03/22] SIDARTHE Bug fixes (#726) ## Summary of Changes Fixed 2 bugs code2amr: - One was a bug related Opo creation in memgraph for expressions inside expressions - One was a bug related to pulling the value of literal from memgraph. This required updated to neo4rs version 0.7.0-rc.1 to support deserialization of non-basic types. ### Related issues Resolves ??? --- skema/skema-rs/mathml/src/acset.rs | 2 - skema/skema-rs/skema/Cargo.toml | 4 +- skema/skema-rs/skema/src/bin/morae.rs | 5 +- skema/skema-rs/skema/src/config.rs | 2 +- skema/skema-rs/skema/src/database.rs | 11 +++++ skema/skema-rs/skema/src/model_extraction.rs | 52 +++++++++++--------- 6 files changed, 46 insertions(+), 30 deletions(-) diff --git a/skema/skema-rs/mathml/src/acset.rs b/skema/skema-rs/mathml/src/acset.rs index 37e95470879..82f78d62fcd 100644 --- a/skema/skema-rs/mathml/src/acset.rs +++ b/skema/skema-rs/mathml/src/acset.rs @@ -367,7 +367,6 @@ impl From> for PetriNet { terms.push(term.clone()); } } - for term in terms.iter() { for param in &term.parameters { let parameters = Parameter { @@ -425,7 +424,6 @@ impl From> for PetriNet { for i in paired_term_indices.iter().rev() { terms.remove(*i); } - // Now we replace unpaired terms with subterms, by their subterms and repeat the process // but first we need to inherit the dynamic state to each sub term diff --git a/skema/skema-rs/skema/Cargo.toml b/skema/skema-rs/skema/Cargo.toml index 4c0feeb7a8b..fa404e3bc79 100644 --- a/skema/skema-rs/skema/Cargo.toml +++ b/skema/skema-rs/skema/Cargo.toml @@ -12,7 +12,7 @@ path = "src/lib.rs" serde_json = { version = "1.0.85", features = ["preserve_order"] } serde = { version = "1.0.1", features = ["derive"] } strum_macros = "0.24" -neo4rs = { version = "0.6.2" } +neo4rs = { version = "0.7.0-rc.1" } actix-web = "4.2.1" mathml = { path = "../mathml" } utoipa = { version = "3.0.3", features = ["actix_extras", "yaml", "debug"] } @@ -21,4 +21,4 @@ clap = { version = "4.0.26", features = ["derive"] } utoipa-swagger-ui = { version = "3.0.2", features = ["actix-web"] } schemars = { version = "0.8.12" } pretty_env_logger = "0.5.0" -tokio = {version = "1.34.0", features = ["full", "rt"]} \ No newline at end of file +tokio = { version = "1.34.0", features = ["full", "rt"] } diff --git a/skema/skema-rs/skema/src/bin/morae.rs b/skema/skema-rs/skema/src/bin/morae.rs index d0f71ad18d9..88c3f72e9d6 100644 --- a/skema/skema-rs/skema/src/bin/morae.rs +++ b/skema/skema-rs/skema/src/bin/morae.rs @@ -86,8 +86,9 @@ async fn main() { } println!("{:?}", ids.clone()); let math_content = module_id2mathml_MET_ast(ids[ids.len() - 1], config.clone()).await; - println!("{:?}", math_content.clone()); - println!("\nAMR from code: {:?}", PetriNet::from(math_content)); + let pn_amr = PetriNet::from(math_content); + //println!("{:?}", math_content.clone()); + //println!("\nAMR from code: {:?}", PetriNet::from(math_content)); //let input_src = "../../data/mml2pn_inputs/testing_eqns/sidarthe_mml.txt"; diff --git a/skema/skema-rs/skema/src/config.rs b/skema/skema-rs/skema/src/config.rs index 43f57a1ef94..95f3f04c309 100644 --- a/skema/skema-rs/skema/src/config.rs +++ b/skema/skema-rs/skema/src/config.rs @@ -40,7 +40,7 @@ impl Config { } pub async fn graphdb_connection(&self) -> Graph { let uri = self.create_graphdb_uri(); - println!("skema-rs:memgraph uri:\t{addr}", addr = uri); + //println!("skema-rs:memgraph uri:\t{addr}", addr = uri); let graph_config = ConfigBuilder::new() .uri(uri) .user("".to_string()) diff --git a/skema/skema-rs/skema/src/database.rs b/skema/skema-rs/skema/src/database.rs index 9e3fbdd8746..8ea7cec9f90 100644 --- a/skema/skema-rs/skema/src/database.rs +++ b/skema/skema-rs/skema/src/database.rs @@ -1584,6 +1584,7 @@ pub fn create_function( new_c_args.box_counter = box_counter; new_c_args.cur_box = att_sub_box.clone(); new_c_args.att_idx = c_args.att_idx; + new_c_args.att_bf_idx = c_args.att_bf_idx; match att_sub_box.function_type { FunctionType::Function => { new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; @@ -1610,6 +1611,7 @@ pub fn create_function( } FunctionType::Expression => { new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; + new_c_args.att_bf_idx = c_args.att_idx; create_att_expression( gromet, // gromet for metadata nodes, // nodes @@ -3584,6 +3586,13 @@ pub fn create_att_expression( } } } + if opo_name.is_empty() { + println!( + "Missed Opo at att_idx: {:?} and box_counter: {:?}", + c_args.att_idx, c_args.box_counter + ); + println!("parent att box: {:?}", c_args.att_bf_idx); + } if !opo_name.clone().is_empty() { let mut oport: u32 = 0; for _op in att_box.opo.as_ref().unwrap().iter() { @@ -3746,6 +3755,7 @@ pub fn create_att_expression( new_c_args.box_counter = box_counter; new_c_args.cur_box = att_sub_box.clone(); new_c_args.att_idx = c_args.att_idx; + new_c_args.att_bf_idx = c_args.att_bf_idx; match att_sub_box.function_type { FunctionType::Literal => { create_att_literal( @@ -3769,6 +3779,7 @@ pub fn create_att_expression( } FunctionType::Expression => { new_c_args.att_idx = att_sub_box.contents.unwrap() as usize; + new_c_args.att_bf_idx = c_args.att_idx; create_att_expression( gromet, // gromet for metadata nodes, // nodes diff --git a/skema/skema-rs/skema/src/model_extraction.rs b/skema/skema-rs/skema/src/model_extraction.rs index 505227309a0..44b824e20e3 100644 --- a/skema/skema-rs/skema/src/model_extraction.rs +++ b/skema/skema-rs/skema/src/model_extraction.rs @@ -1,10 +1,10 @@ use crate::config::Config; +use crate::ValueL; use mathml::ast::operator::Operator; pub use mathml::mml2pn::{ACSet, Term}; use petgraph::prelude::*; -use petgraph::visit::IntoNeighborsDirected; use std::string::ToString; @@ -26,7 +26,7 @@ pub struct ModelNode { id: i64, label: String, name: Option, - value: Option, + value: Option, } /// This struct is the edge struct for the constructed petgraph @@ -197,6 +197,7 @@ pub async fn subgrapg2_core_dyn_MET_ast( } } if sub_w.node_count() > 3 && !(prim_counter == 1 && has_call) && prim_counter != 0 { + println!("--------------------"); println!("expression: {}", graph[expression_nodes[i]].id); // the call expressions get referenced by multiple top level expressions, so deleting the nodes in it breaks the other graphs. Need to pass clone of expression subgraph so references to original has all the nodes. if has_call { @@ -300,9 +301,15 @@ pub fn get_args_MET( args.push(rhs.clone()); } else { // asummption it is atomic - let temp_string = graph[node].name.as_ref().unwrap().clone(); - let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone()))); - args.push(arg2.clone()); + if graph[node].label.clone() == *"Literal" { + let temp_string = graph[node].value.clone().unwrap().value.replace('\"', ""); + let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone()))); + args.push(arg2.clone()); + } else { + let temp_string = graph[node].name.as_ref().unwrap().clone(); + let arg2 = MathExpressionTree::Atom(MathExpression::Mi(Mi(temp_string.clone()))); + args.push(arg2.clone()); + } } // construct order of args @@ -313,10 +320,8 @@ pub fn get_args_MET( .unwrap(); arg_order.push(x); } - // fix order of args let mut ordered_args = args.clone(); - for (i, ind) in arg_order.iter().enumerate() { // the ind'th element of order_args is the ith element of the unordered args if ordered_args.len() > *ind as usize { @@ -450,11 +455,11 @@ async fn subgraph_wiring( let node: neo4rs::Node = row.get("nodes2").unwrap(); let modelnode = ModelNode { id: node.id(), - label: node.labels()[0].clone(), - name: node.get::("name"), - value: node.get::("value"), + label: node.labels()[0].to_string(), + name: node.get::("name").ok(), + value: node.get::("value").ok(), }; - node_list.push(modelnode); + node_list.push(modelnode.clone()); } // edge query let mut result2 = graph @@ -476,8 +481,8 @@ async fn subgraph_wiring( id: edge.id(), src_id: edge.start_node_id(), tgt_id: edge.end_node_id(), - index: edge.get::("index"), - refer: edge.get::("refer"), + index: edge.get::("index").ok(), + refer: edge.get::("refer").ok(), }; edge_list.push(modeledge); } @@ -573,9 +578,9 @@ pub async fn get_subgraph( let node: neo4rs::Node = row.get("nodes2").unwrap(); let modelnode = ModelNode { id: node.id(), - label: node.labels()[0].clone(), - name: node.get::("name"), - value: node.get::("value"), + label: node.labels()[0].to_string(), + name: node.get::("name").ok(), + value: node.get::("value").ok(), }; node_list.push(modelnode); } @@ -598,8 +603,8 @@ pub async fn get_subgraph( id: edge.id(), src_id: edge.start_node_id(), tgt_id: edge.end_node_id(), - index: edge.get::("index"), - refer: edge.get::("refer"), + index: edge.get::("index").ok(), + refer: edge.get::("refer").ok(), }; edge_list.push(modeledge); } @@ -624,6 +629,7 @@ pub fn trim_calls( // initialize trackers let mut node_start = node_index; let mut node_end = node_index; + let mut i_inner_nodes = Vec::::new(); // find end node and track path for node in graph.neighbors_directed(node_index, Outgoing) { @@ -636,7 +642,7 @@ pub fn trim_calls( { let mut temp = to_terminal(graph.clone(), node); node_end = temp.0; - inner_nodes.append(&mut temp.1); + i_inner_nodes.append(&mut temp.1); } } @@ -644,12 +650,12 @@ pub fn trim_calls( for node in graph.neighbors_directed(node_index, Incoming) { let mut temp = to_primitive(graph.clone(), node); node_start = temp.0; - inner_nodes.append(&mut temp.1); + i_inner_nodes.append(&mut temp.1); } // add edge from start to end node, with weight from start node a matching outgoing node form it for node in graph.clone().neighbors_directed(node_start, Outgoing) { - for node_p in inner_nodes.iter() { + for node_p in i_inner_nodes.iter() { if node == *node_p { graph_clone.add_edge( node_start, @@ -665,14 +671,14 @@ pub fn trim_calls( } // we keep track all the node indexes we found while tracing the path and delete all // intermediate nodes. - inner_nodes.push(node_index); + i_inner_nodes.push(node_index); + inner_nodes.append(&mut i_inner_nodes.clone()); } } inner_nodes.sort(); for node in inner_nodes.iter().rev() { graph_clone.remove_node(*node); } - graph_clone } From 6fc5b32cd85100d2cdef497a4146bd04924a7ad8 Mon Sep 17 00:00:00 2001 From: titomeister Date: Tue, 12 Dec 2023 09:26:13 -0700 Subject: [PATCH 04/22] Python tree-sitter to CAST porting: Conditionals (#711) This PR introduces support for generating CAST for Conditionals using tree-sitter, as part of the ongoing effort to port over the Python AST to CAST generation to using tree-sitter. ### Python Tree Sitter - Adds additional handlers to support conditional (if/elif/else) statements, along with basic comparison support. - Updates the NodeHelper() class to use a more optimized version of get_identifier() that was written by Vincent. ### Testing - Adds a pytest script to test the CAST structure of Conditional statements. Resolves #499 --------- Co-authored-by: Vincent Raymond --- .../CAST/python/node_helper.py | 41 ++-- skema/program_analysis/CAST/python/ts2cast.py | 83 +++++++- skema/program_analysis/CAST/python/util.py | 12 +- .../tests/test_conditional_cast.py | 201 ++++++++++++++++++ 4 files changed, 306 insertions(+), 31 deletions(-) create mode 100644 skema/program_analysis/tests/test_conditional_cast.py diff --git a/skema/program_analysis/CAST/python/node_helper.py b/skema/program_analysis/CAST/python/node_helper.py index abef0b50bcc..0c4b4304cd9 100644 --- a/skema/program_analysis/CAST/python/node_helper.py +++ b/skema/program_analysis/CAST/python/node_helper.py @@ -29,6 +29,23 @@ def __init__(self, source: str, source_file_name: str): self.source = source self.source_file_name = source_file_name + # get_identifier optimization variables + self.source_lines = source.splitlines(keepends=True) + self.line_lengths = [len(line) for line in self.source_lines] + self.line_length_sums = [sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))] + + def get_identifier(self, node: Node) -> str: + """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point""" + start_line, start_column = node.start_point + end_line, end_column = node.end_point + + start_index = self.line_length_sums[start_line-1] + start_column + if start_line == end_line: + end_index = start_index + (end_column-start_column) + else: + end_index = self.line_length_sums[end_line] + end_column + + return self.source[start_index:end_index] def get_source_ref(self, node: Node) -> SourceRef: """Given a node and file name, return a CAST SourceRef object.""" @@ -36,30 +53,6 @@ def get_source_ref(self, node: Node) -> SourceRef: row_end, col_end = node.end_point return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end) - - def get_identifier(self, node: Node) -> str: - """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point""" - line_num = 0 - column_num = 0 - in_identifier = False - identifier = "" - for i, char in enumerate(self.source): - if line_num == node.start_point[0] and column_num == node.start_point[1]: - in_identifier = True - elif line_num == node.end_point[0] and column_num == node.end_point[1]: - break - - if char == "\n": - line_num += 1 - column_num = 0 - else: - column_num += 1 - - if in_identifier: - identifier += char - - return identifier - def get_operator(self, node: Node) -> str: """Given a unary/binary operator node, return the operator it contains""" return node.type diff --git a/skema/program_analysis/CAST/python/ts2cast.py b/skema/program_analysis/CAST/python/ts2cast.py index 720c1f40569..b498166b217 100644 --- a/skema/program_analysis/CAST/python/ts2cast.py +++ b/skema/program_analysis/CAST/python/ts2cast.py @@ -107,6 +107,10 @@ def visit(self, node: Node): return self.visit_return(node) elif node.type == "call": return self.visit_call(node) + elif node.type == "if_statement": + return self.visit_if_statement(node) + elif node.type == "comparison_operator": + return self.visit_comparison_op(node) elif node.type == "assignment": return self.visit_assignment(node) elif node.type == "identifier": @@ -227,6 +231,84 @@ def visit_call(self, node: Node) -> Call: source_refs=[ref] ) + def visit_comparison_op(self, node: Node): + ref = self.node_helper.get_source_ref(node) + op = get_op(self.node_helper.get_operator(node.children[1])) + left, _, right = node.children + + left_cast = get_name_node(self.visit(left)) + right_cast = get_name_node(self.visit(right)) + + return Operator( + op=op, + operands=[left_cast, right_cast], + source_refs=[ref] + ) + + def visit_if_statement(self, node: Node) -> ModelIf: + if_condition = self.visit(get_first_child_by_type(node, "comparison_operator")) + + # Get the body of the if true part + if_true = get_children_by_types(node, "block")[0].children + + # Because in tree-sitter the else if, and else aren't nested, but they're + # in a flat level order, we need to do some arranging of the pieces + # in order to get the correct CAST nested structure that we use + # Visit all the alternatives, generate CAST for each one + # and then join them all together + alternatives = get_children_by_types(node, ["elif_clause","else_clause"]) + + if_true_cast = [] + for node in if_true: + cast = self.visit(node) + if isinstance(cast, List): + if_true_cast.extend(cast) + elif isinstance(cast, AstNode): + if_true_cast.append(cast) + + # If we have ts nodes in alternatives, then we're guaranteed + # at least an else at the end of the if-statement construct + # We generate the cast for the final else statement, and then + # reverse the rest of the if-elses that we have, so we can + # create the CAST correctly + final_else_cast = [] + if len(alternatives) > 0: + final_else = alternatives.pop() + alternatives.reverse() + final_else_body = get_children_by_types(final_else, "block")[0].children + for node in final_else_body: + cast = self.visit(node) + if isinstance(cast, List): + final_else_cast.extend(cast) + elif isinstance(cast, AstNode): + final_else_cast.append(cast) + + # We go through any additional if-else nodes that we may have, + # generating their ModelIf CAST and appending the tail of the + # overall if-else construct, starting with the else at the very end + # We do this tail appending so that when we finish generating CAST the + # resulting ModelIf CAST is in the correct order + alternatives_cast = None + for ts_node in alternatives: + assert ts_node.type == "elif_clause" + temp_cast = self.visit_if_statement(ts_node) + if alternatives_cast == None: + temp_cast.orelse = final_else_cast + else: + temp_cast.orelse = [alternatives_cast] + alternatives_cast = temp_cast + + if alternatives_cast == None: + if_false_cast = final_else_cast + else: + if_false_cast = [alternatives_cast] + + return ModelIf( + expr=if_condition, + body=if_true_cast, + orelse=if_false_cast, + source_refs=[self.node_helper.get_source_ref(node)] + ) def visit_assignment(self, node: Node) -> Assignment: left, _, right = node.children @@ -275,7 +357,6 @@ def visit_binary_op(self, node: Node) -> Operator: Binary Ops left OP right where left and right can either be operators or literals - """ ref = self.node_helper.get_source_ref(node) op = get_op(self.node_helper.get_operator(node.children[1])) diff --git a/skema/program_analysis/CAST/python/util.py b/skema/program_analysis/CAST/python/util.py index f315c44f2a4..ceb12c60a5e 100644 --- a/skema/program_analysis/CAST/python/util.py +++ b/skema/program_analysis/CAST/python/util.py @@ -26,6 +26,12 @@ def get_op(operator): '-': 'ast.Sub', '*': 'ast.Mult', '/': 'ast.Div', + '==' : 'ast.Eq', + '!=' : 'ast.NotEq', + '<' : 'ast.Lt', + '<=' : 'ast.LtE', + '>' : 'ast.Gt', + '>=' : 'ast.GtE', # ast.UAdd: 'ast.UAdd', # ast.USub: 'ast.USub', # ast.FloorDiv: 'ast.FloorDiv', @@ -38,12 +44,6 @@ def get_op(operator): # ast.BitXor: 'ast.BitXor', # ast.And: 'ast.And', # ast.Or: 'ast.Or', - # ast.Eq: 'ast.Eq', - # ast.NotEq: 'ast.NotEq', - # ast.Lt: 'ast.Lt', - # ast.LtE: 'ast.LtE', - # ast.Gt: 'ast.Gt', - # ast.GtE: 'ast.GtE', # ast.In: 'ast.In', # ast.NotIn: 'ast.NotIn', # ast.Not: 'ast.Not', diff --git a/skema/program_analysis/tests/test_conditional_cast.py b/skema/program_analysis/tests/test_conditional_cast.py new file mode 100644 index 00000000000..ad0ba999976 --- /dev/null +++ b/skema/program_analysis/tests/test_conditional_cast.py @@ -0,0 +1,201 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Var, + Name, + LiteralValue, + ModelIf, + Operator +) + +def cond1(): + return """ +x = 2 + +if x < 5: + x = x + 1 +else: + x = x - 3 + """ + +def cond2(): + return """ +x = 2 +y = 3 + +if x < 5: + x = 1 + y = 2 + x = x * y +else: + x = x - 3 + """ + +def cond3(): + return """ +x = 2 +y = 4 + +if x < 5: + x = x + y + y = 1 +elif x > 10: + y = x + 2 + x = 1 +elif x == 30: + x = 1 + y = 2 + z = x * y +else: + x = 0 + y = x - 2 + """ + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + +def test_cond1(): + exp_cast = generate_cast(cond1()) + + # Test basic conditional + asg_node = exp_cast.nodes[0].body[0] + cond_node = exp_cast.nodes[0].body[1] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '2' + + assert isinstance(cond_node, ModelIf) + cond_expr = cond_node.expr + cond_body = cond_node.body + cond_else = cond_node.orelse + + assert isinstance(cond_expr, Operator) + assert cond_expr.op == "ast.Lt" + assert isinstance(cond_expr.operands[0], Name) + assert isinstance(cond_expr.operands[1], LiteralValue) + + assert len(cond_body) == 1 + assert isinstance(cond_body[0], Assignment) + assert isinstance(cond_body[0].left, Var) + assert isinstance(cond_body[0].right, Operator) + assert cond_body[0].right.op == "ast.Add" + + assert len(cond_else) == 1 + assert isinstance(cond_else[0], Assignment) + assert isinstance(cond_else[0].left, Var) + assert isinstance(cond_else[0].right, Operator) + assert cond_else[0].right.op == "ast.Sub" + + +def test_cond2(): + exp_cast = generate_cast(cond2()) + + # Test multiple variable conditional + asg_node = exp_cast.nodes[0].body[0] + cond_node = exp_cast.nodes[0].body[2] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + assert asg_node.left.val.id == 0 + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '2' + + asg_node = exp_cast.nodes[0].body[1] + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "y" + assert asg_node.left.val.id == 1 + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '3' + + assert isinstance(cond_node, ModelIf) + cond_expr = cond_node.expr + cond_body = cond_node.body + cond_else = cond_node.orelse + + assert isinstance(cond_expr, Operator) + assert cond_expr.op == "ast.Lt" + assert isinstance(cond_expr.operands[0], Name) + assert cond_expr.operands[0].name == "x" + assert isinstance(cond_expr.operands[1], LiteralValue) + assert cond_expr.operands[1].value_type == "Integer" + assert cond_expr.operands[1].value == "5" + + assert len(cond_body) == 3 + assert isinstance(cond_body[0], Assignment) + assert isinstance(cond_body[0].left, Var) + assert cond_body[0].left.val.name == "x" + assert isinstance(cond_body[0].right, LiteralValue) + assert cond_body[0].right.value == "1" + + assert isinstance(cond_body[1], Assignment) + assert isinstance(cond_body[1].left, Var) + assert cond_body[1].left.val.name == "y" + assert isinstance(cond_body[1].right, LiteralValue) + assert cond_body[1].right.value == "2" + + assert isinstance(cond_body[2], Assignment) + assert isinstance(cond_body[2].left, Var) + assert isinstance(cond_body[2].right, Operator) + + assert cond_body[2].right.op == "ast.Mult" + + assert isinstance(cond_body[2].right.operands[0], Name) + assert cond_body[2].right.operands[0].name == "x" + assert cond_body[2].right.operands[0].id == 0 + assert isinstance(cond_body[2].right.operands[1], Name) + assert cond_body[2].right.operands[1].name == "y" + assert cond_body[2].right.operands[1].id == 1 + + assert len(cond_else) == 1 + assert isinstance(cond_else[0], Assignment) + assert isinstance(cond_else[0].left, Var) + assert isinstance(cond_else[0].right, Operator) + assert cond_else[0].right.op == "ast.Sub" + +def test_cond3(): + exp_cast = generate_cast(cond3()) + + # Test nested ifs + cond_node = exp_cast.nodes[0].body[2] + + assert isinstance(cond_node, ModelIf) + cond_body = cond_node.body + cond_else = cond_node.orelse + + assert len(cond_body) == 2 + assert len(cond_else) == 1 + assert isinstance(cond_else[0], ModelIf) + nested_if = cond_else[0] + cond_body = nested_if.body + cond_else = nested_if.orelse + + assert len(cond_body) == 2 + assert len(cond_else) == 1 + assert isinstance(cond_else[0], ModelIf) + nested_if = cond_else[0] + cond_body = nested_if.body + cond_else = nested_if.orelse + + assert len(cond_body) == 3 + assert len(cond_else) == 2 + assert isinstance(cond_else[0], Assignment) + assert isinstance(cond_else[1], Assignment) From 643f2fb352aa3afcc1cb22a4d093cef4171d8549 Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:00:32 -0800 Subject: [PATCH 05/22] Updates to LLM endpoints (#727) ## Summary of Changes This updates the LLM endpoints (linespan and code2amr) to support including imports and also returning the linespan for the given AMR. ### Related issues Resolves #717 Resolves #715 --- skema/rest/llm_proxy.py | 10 ++ skema/rest/tests/test_model_to_amr.py | 147 +++++++++++++++++++++++--- skema/rest/workflows.py | 22 +++- 3 files changed, 163 insertions(+), 16 deletions(-) diff --git a/skema/rest/llm_proxy.py b/skema/rest/llm_proxy.py index 35fea8f90fb..191f97e84c4 100644 --- a/skema/rest/llm_proxy.py +++ b/skema/rest/llm_proxy.py @@ -145,8 +145,18 @@ async def get_lines_of_model(zip_file: UploadFile = File()) -> List[Dynamics]: line_begin = 0 line_end = 0 + # if the line_begin of meta entry 2 (base 0) and meta entry 3 (base 0) are we add a slice from [meta2.line_begin, meta3.line_begin) + # to capture all the imports, return a Dynamics.block with 2 entries, both of which need to be concatenated to pass forward + file_line_begin = response_zip.json()['modules'][0]['metadata_collection'][2][0]['line_begin'] + + code_line_begin = response_zip.json()['modules'][0]['metadata_collection'][3][0]['line_begin'] - 1 + + if file_line_begin != code_line_begin: + block.append(f"L{file_line_begin}-L{code_line_begin}") + block.append(f"L{line_begin}-L{line_end}") + output = Dynamics(name=file, description=description, block=block) outputs.append(output) block = [] diff --git a/skema/rest/tests/test_model_to_amr.py b/skema/rest/tests/test_model_to_amr.py index 51c240818ed..00a88bcbd17 100644 --- a/skema/rest/tests/test_model_to_amr.py +++ b/skema/rest/tests/test_model_to_amr.py @@ -19,6 +19,10 @@ "https://artifacts.askem.lum.ai/askem/data/models/zip-archives/CHIME-SIR-model.zip" ) +SIDARTHE_URL = ( + "https://artifacts.askem.lum.ai/askem/data/models/zip-archives/SIDARTHE.zip" +) + def test_any_amr_chime_sir(): """ Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. @@ -36,16 +40,26 @@ def test_any_amr_chime_sir(): llm_mock_output = [dyn1, dyn2, dyn3, dyn4] line_begin = [] + import_begin = [] line_end = [] + import_end = [] files = [] blobs = [] amrs = [] + for linespan in llm_mock_output: - lines = linespan.block[0].split("-") + blocks = len(linespan.block) + lines = linespan.block[blocks-1].split("-") line_begin.append( max(int(lines[0][1:]) - 1, 0) ) # Normalizing the 1-index response from llm_proxy line_end.append(int(lines[1][1:])) + if blocks == 2: + lines = linespan.block[0].split("-") + import_begin.append( + max(int(lines[0][1:]) - 1, 0) + ) # Normalizing the 1-index response from llm_proxy + import_end.append(int(lines[1][1:])) # So we are required to do the same when slicing the source code using its output. with ZipFile(zip_bytes, "r") as zip: @@ -62,7 +76,11 @@ def test_any_amr_chime_sir(): if line_begin[i] == line_end[i]: print("failed linespan") else: - blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + if blocks == 2: + temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) + blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + else: + blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) try: time.sleep(0.5) code_snippet_response = asyncio.run( @@ -74,27 +92,128 @@ def test_any_amr_chime_sir(): ) ) if "model" in code_snippet_response: + code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" + code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" + code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}" amrs.append(code_snippet_response) else: print("snippets failure") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") except: - print("except hit") + print("Hit except to snippets failure") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") # we will return the amr with most states, in assumption it is the most "correct" # by default it returns the first entry - print(f"amrs: {amrs}\n") - amr = amrs[0] - print(f"initial amr: {amr}\n") - for temp_amr in amrs: - try: - temp_len = len(temp_amr["model"]["states"]) - amr_len = len(amr["model"]["states"]) - if temp_len > amr_len: - amr = temp_amr - except: - continue + print(f"{amrs}") + try: + amr = amrs[0] + for temp_amr in amrs: + try: + temp_len = len(temp_amr["model"]["states"]) + amr_len = len(amr["model"]["states"]) + if temp_len > amr_len: + amr = temp_amr + except: + continue + except: + amr = logging print(f"final amr: {amr}\n") # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. assert "model" in amr, f"'model' should be in AMR response, but got {amr}" +def test_any_amr_sidarthe(): + """ + Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. + This will test if just the core dynamics works, the whole script, and also rewritten scripts work. + """ + response = requests.get(SIDARTHE_URL) + zip_bytes = BytesIO(response.content) + + # NOTE: For CI we are unable to use the LLM assisted functions due to API keys + # So, we will instead mock the output for those functions instead + dyn1 = Dynamics(name="commented_Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L59"]) + dyn2 = Dynamics(name="Evaluation_Scenario_2.1.a.ii-Code_Version_A.py", description=None, block=["L1-L6","L7-L18"]) + llm_mock_output = [dyn1, dyn2] + + line_begin = [] + import_begin = [] + line_end = [] + import_end = [] + files = [] + blobs = [] + amrs = [] + + + for linespan in llm_mock_output: + blocks = len(linespan.block) + lines = linespan.block[blocks-1].split("-") + line_begin.append( + max(int(lines[0][1:]) - 1, 0) + ) # Normalizing the 1-index response from llm_proxy + line_end.append(int(lines[1][1:])) + if blocks == 2: + lines = linespan.block[0].split("-") + import_begin.append( + max(int(lines[0][1:]) - 1, 0) + ) # Normalizing the 1-index response from llm_proxy + import_end.append(int(lines[1][1:])) + + # So we are required to do the same when slicing the source code using its output. + with ZipFile(zip_bytes, "r") as zip: + for file in zip.namelist(): + file_obj = Path(file) + if file_obj.suffix in [".py"]: + files.append(file) + blobs.append(zip.open(file).read().decode("utf-8")) + + # The source code is a string, so to slice using the line spans, we must first convert it to a list. + # Then we can convert it back to a string using .join + logging = [] + for i in range(len(blobs)): + if line_begin[i] == line_end[i]: + print("failed linespan") + else: + if blocks == 2: + temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) + blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + else: + blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + try: + time.sleep(0.5) + code_snippet_response = asyncio.run( + code_snippets_to_pn_amr( + System( + files=[files[i]], + blobs=[blobs[i]], + ) + ) + ) + if "model" in code_snippet_response: + code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" + code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" + code_snippet_response["header"]["linespan"] = f"{llm_mock_output[i]}" + amrs.append(code_snippet_response) + else: + print("snippets failure") + logging.append(f"{files[i]} failed to parse an AMR from the dynamics") + except: + print("Hit except to snippets failure") + logging.append(f"{files[i]} failed to parse an AMR from the dynamics") + # we will return the amr with most states, in assumption it is the most "correct" + # by default it returns the first entry + print(f"{amrs}") + try: + amr = amrs[0] + for temp_amr in amrs: + try: + temp_len = len(temp_amr["model"]["states"]) + amr_len = len(amr["model"]["states"]) + if temp_len > amr_len: + amr = temp_amr + except: + continue + except: + amr = logging + print(f"final amr: {amr}\n") + # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. + assert "model" in amr, f"'model' should be in AMR response, but got {amr}" \ No newline at end of file diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index 7ef76c6eaf8..2600e753d30 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -238,16 +238,27 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): print(f"Time response linespan: {time.time()}") line_begin = [] + import_begin = [] line_end = [] + import_end = [] files = [] blobs = [] amrs = [] + + # There could now be multiple blocks that we need to handle and adjoin together for linespan in linespans: - lines = linespan.block[0].split("-") + blocks = len(linespan.block) + lines = linespan.block[blocks-1].split("-") line_begin.append( max(int(lines[0][1:]) - 1, 0) ) # Normalizing the 1-index response from llm_proxy line_end.append(int(lines[1][1:])) + if blocks == 2: + lines = linespan.block[0].split("-") + import_begin.append( + max(int(lines[0][1:]) - 1, 0) + ) # Normalizing the 1-index response from llm_proxy + import_end.append(int(lines[1][1:])) # So we are required to do the same when slicing the source code using its output. with ZipFile(BytesIO(zip_file.file.read()), "r") as zip: @@ -264,7 +275,11 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): if line_begin[i] == line_end[i]: print("failed linespan") else: - blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + if blocks == 2: + temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) + blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + else: + blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) try: time.sleep(0.5) print(f"Time call code-snippets: {time.time()}") @@ -276,6 +291,9 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): ) print(f"Time response code-snippets: {time.time()}") if "model" in code_snippet_response: + code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" + code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" + code_snippet_response["header"]["linespan"] = f"{linespans[i]}" amrs.append(code_snippet_response) else: print("snippets failure") From 2a4f166ac02fe35489b3123c3213933805e8e78a Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Tue, 12 Dec 2023 13:51:16 -0800 Subject: [PATCH 06/22] Fixed small bug for imports (#729) ## Summary of Changes Fixed a bug for when we add imports to code2amr ### Related issues Resolves ??? --- skema/rest/workflows.py | 1 + skema/skema-rs/skema/src/database.rs | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index 2600e753d30..0411814f18b 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -283,6 +283,7 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): try: time.sleep(0.5) print(f"Time call code-snippets: {time.time()}") + print(blobs[i]) code_snippet_response = await code_snippets_to_pn_amr( code2fn.System( files=[files[i]], diff --git a/skema/skema-rs/skema/src/database.rs b/skema/skema-rs/skema/src/database.rs index 8ea7cec9f90..075b0b72404 100644 --- a/skema/skema-rs/skema/src/database.rs +++ b/skema/skema-rs/skema/src/database.rs @@ -392,12 +392,25 @@ fn create_module(gromet: &ModuleCollection) -> Vec { fn create_graph_queries(gromet: &ModuleCollection, start: u32) -> Vec { let mut queries: Vec = vec![]; + let mut only_imports = true; // if a library module need to walk through gromet differently if gromet.modules[0].r#fn.bf.is_none() { queries.append(&mut create_function_net_lib(gromet, start)); } else { // if executable code - queries.append(&mut create_function_net(gromet, start)); + for bf in gromet.modules[0].r#fn.bf.as_ref().unwrap().iter() { + if bf.function_type != FunctionType::Imported + && bf.function_type != FunctionType::ImportedMethod + { + only_imports = false; + } + } + println!("{}", only_imports); + if only_imports { + queries.append(&mut create_function_net_lib(gromet, start)); + } else { + queries.append(&mut create_function_net(gromet, start)); + } } queries } From cfdf54f2121f25b6b3aba1a752a8f343bd76bd13 Mon Sep 17 00:00:00 2001 From: Vincent Raymond Date: Wed, 13 Dec 2023 13:15:15 -0500 Subject: [PATCH 07/22] [fortran] Further bug fixes and speed improvements (#722) ## Summary of Changes ### Fortran->CAST efficiency updates - Using itertools.accumulate to generate overlapping sums in NodeHelper class - Creating a single dummy source reference node instead of creating a new instance every time we come across a missing source reference. ### CAST->Gromet efficiency updates - Updates the port_id generation logic ### Fortran->CAST - Adds `get_children_except_types` function to node_helper.py, which is the inverse of `get_children_by_types` function. - Adds `generate_cast_body` function to TS2CAST which handles the pattern of visiting every node in a List. - Adds support for unary operators in expressions (a = +b) - Adds basic support for print statements - Adds support for function derived type member expressions (b = a%function()) - Refactors the visitor for the if conditional to increase accuracy and reduce code - Refactors return logic to properly differentiate subroutine and function ### Related issues Resolves #724 Resolves #721 Resolves #703 --- .../CAST/fortran/node_helper.py | 9 +- .../program_analysis/CAST/fortran/ts2cast.py | 280 +++++++++--------- skema/program_analysis/CAST/fortran/util.py | 6 +- .../CAST/fortran/variable_context.py | 36 ++- .../CAST2FN/ann_cast/to_gromet_pass.py | 18 +- skema/utils/script_functions.py | 1 + 6 files changed, 196 insertions(+), 154 deletions(-) diff --git a/skema/program_analysis/CAST/fortran/node_helper.py b/skema/program_analysis/CAST/fortran/node_helper.py index f3d83853dd7..51f614a3586 100644 --- a/skema/program_analysis/CAST/fortran/node_helper.py +++ b/skema/program_analysis/CAST/fortran/node_helper.py @@ -1,8 +1,10 @@ +import itertools from typing import List, Dict -from skema.program_analysis.CAST2FN.model.cast import SourceRef from tree_sitter import Node +from skema.program_analysis.CAST2FN.model.cast import SourceRef + CONTROL_CHARACTERS = [ ",", "=", @@ -41,7 +43,7 @@ def __init__(self, source: str, source_file_name: str): # get_identifier optimization variables self.source_lines = source.splitlines(keepends=True) self.line_lengths = [len(line) for line in self.source_lines] - self.line_length_sums = [sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))] + self.line_length_sums = list(itertools.accumulate(self.line_lengths))#[sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))] def get_source_ref(self, node: Node) -> SourceRef: """Given a node and file name, return a CAST SourceRef object.""" @@ -96,6 +98,9 @@ def get_children_by_types(node: Node, types: List): """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list""" return [child for child in node.children if child.type in types] +def get_children_except_types(node: Node, types: List): + """Takes in a node and a list of types as inputs and returns all children not matching those types. Otherwise, return an empty list""" + return [child for child in node.children if child.type not in types] def get_first_child_index(node, type: str): """Get the index of the first child of node with type type.""" diff --git a/skema/program_analysis/CAST/fortran/ts2cast.py b/skema/program_analysis/CAST/fortran/ts2cast.py index 542a7ece5be..02b325d54c9 100644 --- a/skema/program_analysis/CAST/fortran/ts2cast.py +++ b/skema/program_analysis/CAST/fortran/ts2cast.py @@ -10,6 +10,7 @@ from skema.program_analysis.CAST2FN.model.cast import ( Module, SourceRef, + ModelBreak, Assignment, LiteralValue, Var, @@ -33,6 +34,7 @@ NodeHelper, remove_comments, get_children_by_types, + get_children_except_types, get_first_child_by_type, get_control_children, get_non_control_children, @@ -44,6 +46,16 @@ from skema.program_analysis.CAST.fortran.preprocessor.preprocess import preprocess from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH +builtin_statements = set( + [ + "read_statement", + "write_statement", + "rewind_statement", + "open_statement", + "common_statement", + "print_statement" + ] +) class TS2CAST(object): def __init__(self, source_file_path: str): # Prepare source with preprocessor @@ -68,7 +80,7 @@ def __init__(self, source_file_path: str): # Start visiting self.out_cast = self.generate_cast() - #print(self.out_cast[0].to_json_str()) + print(self.out_cast[0].to_json_str()) def generate_cast(self) -> List[CAST]: '''Interface for generating CAST.''' @@ -126,7 +138,7 @@ def visit(self, node: Node): return self.visit_identifier(node) elif node.type == "name": return self.visit_name(node) - elif node.type in ["math_expression", "relational_expression"]: + elif node.type in ["unary_expression", "math_expression", "relational_expression"]: return self.visit_math_expression(node) elif node.type in [ "number_literal", @@ -137,11 +149,13 @@ def visit(self, node: Node): return self.visit_literal(node) elif node.type == "keyword_statement": return self.visit_keyword_statement(node) + elif node.type in builtin_statements: + return self.visit_fortran_builtin_statement(node) elif node.type == "extent_specifier": return self.visit_extent_specifier(node) - elif node.type == "do_loop_statement": + elif node.type in ["do_loop_statement"]: return self.visit_do_loop_statement(node) - elif node.type == "if_statement": + elif node.type in ["if_statement", "else_if_clause", "else_clause"]: return self.visit_if_statement(node) elif node.type == "logical_expression": return self.visit_logical_expression(node) @@ -214,15 +228,16 @@ def visit_function_def(self, node): # Top level statement node statement_node = get_children_by_types(node, ["subroutine_statement", "function_statement"])[0] + name_node = get_first_child_by_type(statement_node, "name") name = self.visit( name_node ) # Visit the name node to add it to the variable context # If this is a function, check for return type and return value - intrinsic_type = None - return_value = None if node.type == "function": + intrinsic_type = None + return_value = None signature_qualifiers = get_children_by_types( statement_node, ["intrinsic_type", "function_result"] ) @@ -235,20 +250,21 @@ def visit_function_def(self, node): elif qualifier.type == "function_result": return_value = self.visit( get_first_child_by_type(qualifier, "identifier") - ) # TODO: UPDATE NODES - self.variable_context.add_return_value(return_value.val.name) - - # #TODO: What happens if function doesn't return anything? - # If this is a function, and there is no explicit results variable, then we will assume the return value is the name of the function - if not return_value: - self.variable_context.add_return_value( - self.node_helper.get_identifier(name_node) - ) + ).val + self.variable_context.add_return_value(return_value.name) + + + # NOTE: In the case of a function specifically, if there is no explicit return value, the return value will be the name of the function + # TODO: Should this be a node instead + if not return_value: + self.variable_context.add_return_value( + self.node_helper.get_identifier(name_node) + ) + return_value = self.visit(name_node) - # If funciton has both, then we also need to update the type of the return value in the variable context - # It does not explicity have to be declared - if return_value and intrinsic_type: - self.variable_context.update_type(return_value.val.name, intrinsic_type) + # If funciton has both an explicit intrinsic type, then we also need to update the type of the return value in the variable context + if intrinsic_type: + self.variable_context.update_type(return_value.name, intrinsic_type) # Generating the function arguments by walking the parameters node func_args = [] @@ -260,7 +276,7 @@ def visit_function_def(self, node): self.node_helper.get_identifier(parameter) ) func_args.append(self.visit(parameter)) - + # The first child of function will be the function statement, the rest will be body nodes body = [] for body_node in node.children[1:]: @@ -301,14 +317,10 @@ def visit_function_call(self, node): # A subroutine and function won't neccessarily have an arguments node. # So we should be careful about trying to access it. + function_node = get_children_by_types(node, ["unary_expression", "subroutine", "identifier", "derived_type_member_expression"])[0] - if function_node.type == "derived_type_member_expression": - func = Attribute( - value=None, - attr=None - ) - return None + return self.visit_derived_type_member_expression(function_node) arguments_node = get_first_child_by_type(node, "argument_list") @@ -350,6 +362,7 @@ def visit_function_call(self, node): source_refs=[self.node_helper.get_source_ref(node)], ) + def visit_keyword_statement(self, node): # NOTE: RETURN is not the only Fortran keyword. GO TO and CONTINUE are also considered keywords. # TODO: Handle GO TO and CONTINUE @@ -357,6 +370,10 @@ def visit_keyword_statement(self, node): if node.type == "keyword_statement": if "continue" in identifier or "go to" in identifier: return self._visit_no_op(node) + if "exit" in identifier: + return ModelBreak( + source_refs = [self.node_helper.get_source_ref(node)] + ) # In Fortran the return statement doesn't return a value (there is the obsolete "alternative return") # We keep track of values that need to be returned in the variable context @@ -365,30 +382,52 @@ def visit_keyword_statement(self, node): ] # TODO: Make function for this if len(return_values) == 1: - # TODO: Fix this case value = self.variable_context.get_node(list(return_values)[0]) elif len(return_values) > 1: value = LiteralValue( value_type="Tuple", value=[ - Var( - val=self.variable_context.get_node(ret), - type=self.variable_context.get_type(ret), - default_value=None, - source_refs=None, - ) - for ret in return_values + self.variable_context.get_node(ret) for ret in return_values ], - source_code_data_type=None, # TODO: REFACTOR + source_code_data_type=None, source_refs=None, ) else: - value = LiteralValue(val=None, type=None, source_refs=None) + value = LiteralValue(value=None, value_type=None, source_refs=None) return ModelReturn( value=value, source_refs=[self.node_helper.get_source_ref(node)] ) + def visit_fortran_builtin_statement(self, node): + """Visitor for Fortran keywords that are not classified as keyword_statement by tree-sitter""" + # All of the node types that fall into this category end with _statment. + # So the function name will be the node type with _statement removed (write, read, open, ...) + func = self.get_gromet_function_node(node.type.replace("_statement", "")) + + + arguments = [] + + return Call( + func=func, + arguments=arguments, + source_language="Fortran", + source_language_version=None, + source_refs=[self.node_helper.get_source_ref(node)] + ) + + def visit_print_statement(self, node): + func = self.get_gromet_function_node("print") + + arguments = [] + + return Call( + func=func, + arguments=arguments, + source_language=None, + source_language_version=None + ) + def visit_use_statement(self, node): # (use) # (use) @@ -410,7 +449,7 @@ def visit_use_statement(self, node): alias=import_alias, all=import_all, symbol=None, - source_refs=None, + source_refs=[self.node_helper.get_source_ref(node)], ) else: imports = [] @@ -447,16 +486,14 @@ def visit_do_loop_statement(self, node) -> Loop: (...) ... (body) ... """ + - # First check for - # TODO: Add do until Loop support - while_statement_node = get_first_child_by_type(node, "while_statement") - if while_statement_node: + loop_control_node= get_first_child_by_type(node, "loop_contrel_expression") + if not loop_control_node: return self._visit_while(node) # If there is a loop control expression, the first body node will be the node after the loop_control_expression # It is valid Fortran to have a single itteration do loop as well. - # TODO: Add support for single itteration do-loop # NOTE: This code is for the creation of the main body. The do loop will still add some additional nodes at the end of this body. body = [] body_start_index = 1 + get_first_child_index(node, "loop_control_expression") @@ -580,77 +617,35 @@ def visit_if_statement(self, node): # (else_clause) # (end_if_statement) - if_condition = self.visit(get_first_child_by_type(node, "parenthesized_expression")) - - child_types = [child.type for child in node.children] - - try: - elseif_index = child_types.index("elseif_clause") - except ValueError: - elseif_index = -1 - - try: - else_index = child_types.index("else_clause") - except ValueError: - else_index = -1 - - if elseif_index != -1: - body_stop_index = elseif_index - else: - body_stop_index = else_index - - # Single line if conditions don't have a 'then' or 'end if' clause. - # So the starting index for the body can either be 2 or 3. - then_index = get_first_child_index(node, "then") - if then_index: - body_start_index = then_index+1 - else: - body_start_index = 2 - body_stop_index = len(node.children) - - prev = None - orelse = None - # If there are else_if statements, they need - if elseif_index != -1: - orelse = ModelIf() - prev = orelse - for condition in node.children[elseif_index:else_index]: - if condition.type == "comment": - continue - elseif_expr = self.visit(condition.children[2]) - elseif_body = [self.visit(child) for child in condition.children[4:]] + #TODO: Can you have a parenthesized expression as a body node + body_nodes = get_children_except_types(node, ["if", "elseif", "else", "then", "parenthesized_expression", "elseif_clause", "else_clause", "end_if_statement"]) + body = self.generate_cast_body(body_nodes) + + expr_node = get_first_child_by_type(node, "parenthesized_expression") + expr = None + if expr_node: + expr = self.visit(expr_node) + + elseif_nodes = get_children_by_types(node, ["elseif_clause"]) + elseif_cast = [self.visit(elseif_clause) for elseif_clause in elseif_nodes] + for i in range(len(elseif_cast)-1): + elseif_cast[i].orelse = [elseif_cast[i+1]] - prev.orelse = ModelIf(elseif_expr, elseif_body, []) - prev = prev.orelse - - if else_index != -1: - else_body = [ - self.visit(child) for child in node.children[else_index].children[1:] - ] - if prev: - prev.orelse = else_body - else: - orelse = else_body - - # TODO: This orelse logic has gotten a little complex, we might want to refactor this. - if isinstance(orelse, ModelIf): - orelse = orelse.orelse - if orelse: - if isinstance(orelse, ModelIf): - orelse = [orelse] + else_node = get_first_child_by_type(node, "else_clause") + else_cast = None + if else_node: + else_cast = self.visit(else_node) + + orelse = [] + if len(elseif_cast) > 0: + orelse = [elseif_cast[0]] + elif else_cast: + orelse = else_cast.body - body = [] - for child in node.children[body_start_index:body_stop_index]: - child_cast = self.visit(child) - if isinstance(child_cast, AstNode): - body.append(child_cast) - elif isinstance(child_cast, List): - body.extend(child_cast) - return ModelIf( - expr=self.visit(node.children[1]), + expr=expr, body=body, - orelse=orelse if orelse else [], + orelse=orelse ) def visit_logical_expression(self, node): @@ -671,7 +666,6 @@ def visit_logical_expression(self, node): is_or = "or" in operator.type top_if = ModelIf() - top_if_expr = self.visit(left) top_if.expr = top_if_expr @@ -777,6 +771,7 @@ def visit_identifier(self, node): default_value=default_value, source_refs=[self.node_helper.get_source_ref(node)], ) + def visit_math_expression(self, node): op = self.node_helper.get_identifier( @@ -785,6 +780,11 @@ def visit_math_expression(self, node): operands = [] for operand in get_non_control_children(node): operands.append(self.visit(operand)) + + # For operators, we will only need the name node since we are not allocating space + if operand.type == "identifier": + operands[-1] = operands[-1].val + return Operator( source_language="Fortran", @@ -822,8 +822,8 @@ def visit_variable_declaration(self, node) -> List: type_map = { "integer": "Integer", "real": "AbstractFloat", - "double precision": None, - "complex": None, + "double precision": "AbstractFloat", + "complex": "Tuple", # Complex is a Tuple (rational,irrational), "logical": "Boolean", "character": "String", } @@ -871,9 +871,9 @@ def visit_variable_declaration(self, node) -> List: ], ) ) - vars[-1].left.type = "dimension" + vars[-1].left.type = "List" self.variable_context.update_type( - vars[-1].left.val.name, "dimension" + vars[-1].left.val.name, "List" ) else: # If its a regular assignment, we can update the type normally @@ -892,8 +892,8 @@ def visit_variable_declaration(self, node) -> List: # Declaring a dimension variable using the x(1:5) format. It will look like a call expression in tree-sitter. # We treat it like an identifier by visiting its identifier node. Then the type gets overridden by "dimension" vars.append(self.visit(get_first_child_by_type(variable, "identifier"))) - vars[-1].type = "dimension" - self.variable_context.update_type(vars[-1].val.name, "dimension") + vars[-1].type = "List" + self.variable_context.update_type(vars[-1].val.name, "List") # By default, all variables are added to a function's list of return values # If the intent is actually in, then we need to remove them from the list @@ -964,7 +964,7 @@ def visit_derived_type(self, node: Node) -> RecordDef: # If we tell the variable context we are in a record definition, it will append the type name as a prefix to all defined variables. self.variable_context.enter_record_definition(record_name) - # TODO: Full support for this requires handling the contains statement generally + # Note: funcs = [] derived_type_procedures_node = get_first_child_by_type( node, "derived_type_procedures" @@ -1118,23 +1118,25 @@ def _visit_while(self, node) -> Loop: """ while_statement_node = get_first_child_by_type(node, "while_statement") - # The first body node will be the node after the while_statement - body = [] - body_start_index = 1 + get_first_child_index(node, "while_statement") - for body_node in node.children[body_start_index:]: - child_cast = self.visit(body_node) - if isinstance(child_cast, List): - body.extend(child_cast) - elif isinstance(child_cast, AstNode): - body.append(child_cast) - - # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly. - expr = self.visit( - get_first_child_by_type(while_statement_node, "parenthesized_expression") - ) + # Fortran has certain while(True) constructs that won't contain a while_statement node + if not while_statement_node: + body_start_index = 0 + expr = LiteralValue( + value_type="Boolean", + value="True", + ) + else: + body_start_index = 1 + get_first_child_index(node, "while_statement") + # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly. + expr = self.visit( + get_first_child_by_type(while_statement_node, "parenthesized_expression") + ) + # The first body node will be the node after the while_statement + body = self.generate_cast_body(node.children[body_start_index:]) + return Loop( - pre=[], # TODO: Should pre and post contain anything? + pre=[], expr=expr, body=body, post=[], @@ -1196,3 +1198,17 @@ def get_gromet_function_node(self, func_name: str) -> Name: return self.variable_context.get_node(func_name) return self.variable_context.add_variable(func_name, "function", None) + + def generate_cast_body(self, body_nodes: List): + body = [] + for node in body_nodes: + cast = self.visit(node) + if isinstance(cast, AstNode): + body.append(cast) + elif isinstance(cast, List): + body.extend(cast) + return body + +#TS2CAST("drotmg.f") +#import cProfile +#cProfile.run("TS2CAST('he_coef0_dres.F')", sort="tottime") \ No newline at end of file diff --git a/skema/program_analysis/CAST/fortran/util.py b/skema/program_analysis/CAST/fortran/util.py index 07c12ee809c..b4a9bc72d39 100644 --- a/skema/program_analysis/CAST/fortran/util.py +++ b/skema/program_analysis/CAST/fortran/util.py @@ -1,13 +1,15 @@ from typing import List from skema.program_analysis.CAST2FN.model.cast import AstNode, LiteralValue, SourceRef +DUMMY_SOURCE_REF = [SourceRef("", -1, -1, -1, -1)] +DUMMY_SOURCE_CODE_DATA_TYPE = ["Fortran", "Fotran95", "None"] def generate_dummy_source_refs(node: AstNode) -> AstNode: """Walks a tree of AstNodes replacing any null SourceRefs with a dummy value""" if isinstance(node, LiteralValue) and not node.source_code_data_type: - node.source_code_data_type = ["Fortran", "Fotran95", "None"] + node.source_code_data_type = DUMMY_SOURCE_CODE_DATA_TYPE if not node.source_refs: - node.source_refs = [SourceRef("", -1, -1, -1, -1)] + node.source_refs = DUMMY_SOURCE_REF for attribute_str in node.attribute_map: attribute = getattr(node, attribute_str) diff --git a/skema/program_analysis/CAST/fortran/variable_context.py b/skema/program_analysis/CAST/fortran/variable_context.py index eca184bee97..cfdd8a69b69 100644 --- a/skema/program_analysis/CAST/fortran/variable_context.py +++ b/skema/program_analysis/CAST/fortran/variable_context.py @@ -8,14 +8,17 @@ class VariableContext(object): def __init__(self): self.context = [{}] # Stack of context dictionaries self.context_return_values = [set()] # Stack of context return values + + # All symbols will use a seperate naming convention to prevent two scopes using the same symbol name + # The name will be a dot notation list of scopes i.e. scope1.scope2.symbol + self.all_symbols_scopes = [] self.all_symbols = {} - self.record_definitions = {} - + # The prefix is used to handle adding Record types to the variable context. # This gives each symbol a unqique name. For example "a" would become "type_name.a" # For nested type definitions (derived type in a module), multiple prefixes can be added. self.prefix = [] - + # Flag neccessary to declare if a function is internal or external self.internal = False @@ -36,7 +39,7 @@ def push_context(self): def pop_context(self): """Pop the current variable context off of the stack and remove any references to those symbols.""" - + # If the internal flag is set, then all new scopes will use the top-level context if self.internal: return None @@ -45,7 +48,10 @@ def pop_context(self): # Remove symbols from all_symbols variable for symbol in context: - self.all_symbols.pop(symbol) + if isinstance(self.all_symbols[symbol], List): + self.all_symbols[symbol].pop() + else: + self.all_symbols.pop(symbol) self.context_return_values.pop() @@ -68,8 +74,13 @@ def add_variable(self, symbol: str, type: str, source_refs: List) -> Name: } # Add reference to all_symbols - self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name] - + if full_symbol_name in self.all_symbols: + if isinstance(self.all_symbols[full_symbol_name], List): + self.all_symbols[full_symbol_name].append(self.context[-1][full_symbol_name]) + else: + self.all_symbols[full_symbol_name] = [self.all_symbols[full_symbol_name], self.context[-1][full_symbol_name]] + else: + self.all_symbols[full_symbol_name] = self.context[-1][full_symbol_name] return cast_name def is_variable(self, symbol: str) -> bool: @@ -77,16 +88,25 @@ def is_variable(self, symbol: str) -> bool: return symbol in self.all_symbols def get_node(self, symbol: str) -> Dict: + if isinstance(self.all_symbols[symbol], List): + return self.all_symbols[symbol][-1]["node"] + return self.all_symbols[symbol]["node"] def get_type(self, symbol: str) -> str: + if isinstance(self.all_symbols[symbol], List): + return self.all_symbols[symbol][-1]["type"] + return self.all_symbols[symbol]["type"] def update_type(self, symbol: str, type: str): """Update the type associated with a given symbol""" # Generate the full symbol name using the prefix full_symbol_name = ".".join(self.prefix + [symbol]) - self.all_symbols[full_symbol_name]["type"] = type + if isinstance(self.all_symbols[full_symbol_name], List): + self.all_symbols[full_symbol_name][-1]["type"] = type + else: + self.all_symbols[full_symbol_name]["type"] = type def add_return_value(self, symbol): self.context_return_values[-1].add(symbol) diff --git a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py index 4e81daf28d7..33e9ed71661 100644 --- a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py +++ b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py @@ -84,20 +84,19 @@ def insert_gromet_object(t: list, obj): If the table we're trying to insert into doesn't already exist, then we first create it, and then insert the value. """ + + if t == None: + t = [] # Logic for generating port ids if isinstance(obj, GrometPort): - if t == None: - obj.id = 1 - else: - current_box = obj.box - current_box_ports = [port for port in t if port.box == current_box] - obj.id = len(current_box_ports) + 1 + obj.id = 1 + for port in reversed(t): + if port.box == obj.box: + obj.id = port.id + 1 + break - if t == None: - t = [] t.append(obj) - return t @@ -3014,7 +3013,6 @@ def handle_function_def( # can clear the local variable environment var_environment["local"] = deepcopy(prev_local_env) - @_visit.register def visit_function_def( self, node: AnnCastFunctionDef, parent_gromet_fn, parent_cast_node diff --git a/skema/utils/script_functions.py b/skema/utils/script_functions.py index 2f58051cbf8..afbfaa57b4b 100644 --- a/skema/utils/script_functions.py +++ b/skema/utils/script_functions.py @@ -254,6 +254,7 @@ def ann_cast_pipeline( pdf_file_name = f"{f_name}-AnnCast.pdf" agraph.to_pdf(pdf_file_name) + print("\nCalling GrfnVarCreationPass-------------------") GrfnVarCreationPass(pipeline_state) From 3383088440c225558352e15acca9ebc406f9cab7 Mon Sep 17 00:00:00 2001 From: Deepsana Shahi Date: Thu, 14 Dec 2023 11:22:00 -0700 Subject: [PATCH 08/22] [Equations] Handles Sidharthe derivative (#730) ## Summary of Changes The parser handles Sidarthe derivative properly ### Related issues Resolves ??? --------- Co-authored-by: Deepsana Shahi --- .../mathml/src/parsers/first_order_ode.rs | 2 +- .../mathml/src/parsers/interpreted_mathml.rs | 39 +++++++++-- .../src/parsers/math_expression_tree.rs | 70 +++++++++++++++++++ 3 files changed, 105 insertions(+), 6 deletions(-) diff --git a/skema/skema-rs/mathml/src/parsers/first_order_ode.rs b/skema/skema-rs/mathml/src/parsers/first_order_ode.rs index 3756fe61b93..71facecb083 100644 --- a/skema/skema-rs/mathml/src/parsers/first_order_ode.rs +++ b/skema/skema-rs/mathml/src/parsers/first_order_ode.rs @@ -1108,7 +1108,7 @@ fn test_first_order_ode() { assert_eq!(lhs_var.to_string(), "S"); assert_eq!(func_of[0].to_string(), ""); - assert_eq!(with_respect_to.to_string(), ""); + assert_eq!(with_respect_to.to_string(), "t"); assert_eq!(rhs.to_string(), "(* (* (- β) I) (/ S N))"); } diff --git a/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs b/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs index eea4f13e634..7a6fd5a3dee 100644 --- a/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs +++ b/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs @@ -436,7 +436,7 @@ pub fn newtonian_derivative(input: Span) -> IResult<(Derivative, Ci)> { ); let (s, (func, order)) = delimited( - stag!("mover"), + alt((tag(""), tag(""))), pair( map( ci_unknown, @@ -452,7 +452,7 @@ pub fn newtonian_derivative(input: Span) -> IResult<(Derivative, Ci)> { ), n_dots, ), - etag!("mover"), + alt((tag(""), etag!("mover"))), )(input)?; let (s, mi_func_of) = alt((parenthesized_identifier, empty_parenthesis))(s)?; let mut ci_func_of: Vec = Vec::new(); @@ -460,9 +460,10 @@ pub fn newtonian_derivative(input: Span) -> IResult<(Derivative, Ci)> { let b = Ci::new(Some(Type::Real), Box::new(MathExpression::Mi(bvar)), None); ci_func_of.push(b.clone()); } + let func_mi = ci_func_of.get(0).unwrap().content.clone(); let new_with_respect_to: Box = Box::new(MathExpression::Ci(Ci { r#type: None, - content: Box::new(MathExpression::Mi(Mi("".to_string()))), + content: Box::new(MathExpression::Mi(Mi(func_mi.to_string()))), func_of: None, })); @@ -597,7 +598,7 @@ pub fn math_expression(input: Span) -> IResult { ws(alt(( map(div, MathExpression::Mo), alt((ws(absolute_with_msup), ws(paren_as_msup))), - sqrt, + //sqrt, map( grad_func, |( @@ -676,6 +677,34 @@ pub fn math_expression(input: Span) -> IResult { }) }, ), + map( + newtonian_derivative, + |( + Derivative { + order, + var_index, + bound_var, + }, + Ci { + r#type, + content, + func_of, + }, + )| { + MathExpression::Differential(Differential { + diff: Box::new(MathExpression::Mo(Operator::Derivative(Derivative { + order, + var_index, + bound_var, + }))), + func: Box::new(MathExpression::Ci(Ci { + r#type, + content, + func_of, + })), + }) + }, + ), map( ci_univariate_with_bounds, |Ci { @@ -771,7 +800,7 @@ pub fn math_expression(input: Span) -> IResult { }) }, ), - absolute, + alt((absolute, sqrt)), alt(( map(operator, MathExpression::Mo), map(gradient, MathExpression::Mo), diff --git a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs index 98577a59f2c..c745f4cc396 100644 --- a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs +++ b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs @@ -2172,3 +2172,73 @@ fn new_quadratic_equation() { ); assert_eq!(exp.to_latex(), "x=\\frac{(-b)-\\sqrt{b^{2}-(4*a*c)}}{2*a}"); } + +#[test] +fn test_dot_in_derivative() { + let input = " + + + S + ˙ + + +( + t + ) +"; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(D(1, t) S)"); +} + +#[test] +fn test_sidarthe_equation() { + let input = " + + + S + ˙ + + + ( + t + ) + = + + S + ( + t + ) + ( + α + I + ( + t + ) + + + β + D + ( + t + ) + + + γ + A + ( + t + ) + + + δ + R + ( + t + ) + ) +"; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!( + s_exp, + "(= (D(1, t) S) (* (- S) (+ (+ (+ (* α I) (* β D)) (* γ A)) (* δ R))))" + ); +} From 0b4935feacbdf30c2a23c81a9e5a1bc639a29751 Mon Sep 17 00:00:00 2001 From: Joseph Astier Date: Thu, 14 Dec 2023 13:30:53 -0700 Subject: [PATCH 09/22] Jastier/loops (#676) --- .../program_analysis/CAST/matlab/__init__.py | 3 - .../CAST/matlab/matlab_to_cast.py | 214 ++++++++++++++---- .../CAST/matlab/tests/__init__.py | 0 .../CAST/matlab/tests/test_assignment.py | 22 +- .../CAST/matlab/tests/test_file_ingest.py | 11 +- .../CAST/matlab/tests/test_loop.py | 122 +++++++++- .../CAST/matlab/tests/test_switch.py | 34 ++- .../CAST/matlab/tests/utils.py | 7 +- skema/program_analysis/CAST/matlab/tokens.py | 12 +- .../CAST/matlab/variable_context.py | 15 ++ 10 files changed, 376 insertions(+), 64 deletions(-) create mode 100644 skema/program_analysis/CAST/matlab/tests/__init__.py diff --git a/skema/program_analysis/CAST/matlab/__init__.py b/skema/program_analysis/CAST/matlab/__init__.py index 3c4351b21c3..e69de29bb2d 100644 --- a/skema/program_analysis/CAST/matlab/__init__.py +++ b/skema/program_analysis/CAST/matlab/__init__.py @@ -1,3 +0,0 @@ -__pdoc__ = { - 'tests': False -} diff --git a/skema/program_analysis/CAST/matlab/matlab_to_cast.py b/skema/program_analysis/CAST/matlab/matlab_to_cast.py index abb7408ff19..22cf8015679 100644 --- a/skema/program_analysis/CAST/matlab/matlab_to_cast.py +++ b/skema/program_analysis/CAST/matlab/matlab_to_cast.py @@ -110,11 +110,10 @@ def visit(self, node): ]:return self.visit_identifier(node) elif node.type == "if_statement": return self.visit_if_statement(node) -# elif node.type in [ -# "for_statement", -# "iterator", -# "while_statement" -# ]: return self.visit_loop(node) + elif node.type == "iterator": + return self.visit_iterator(node) + elif node.type == "for_statement": + return self.visit_for_statement(node) elif node.type in [ "cell", "matrix" @@ -139,6 +138,8 @@ def visit(self, node): ]: return self.visit_operator(node) elif node.type == "string": return self.visit_string(node) + elif node.type == "range": + return self.visit_range(node) elif node.type == "switch_statement": return self.visit_switch_statement(node) else: @@ -210,7 +211,6 @@ def visit_identifier(self, node): val = self.visit_name(node), type = self.variable_context.get_type(identifier) if self.variable_context.is_variable(identifier) else "Unknown", - default_value = "LiteralValue", source_refs = [self.node_helper.get_source_ref(node)], ) @@ -248,12 +248,135 @@ def get_conditional(conditional_node): return first - # General loop translator for all MATLAB loop types - # def visit_loop(self, node) -> Loop: - # """ Translate Tree-sitter for_loop node into CAST Loop node """ - # return Loop ( - # source_refs = [self.node_helper.get_source_ref(node)] - # ) + # CAST has no Iterator node, so we return a partially + # completed Loop object + # MATLAB iterators are either matrices or ranges. + def visit_iterator(self, node) -> Loop: + + itr_var = self.visit(get_first_child_by_type(node, "identifier")) + source_ref = self.node_helper.get_source_ref(node) + + # process matrix iterator + matrix_node = get_first_child_by_type(node, "matrix") + if matrix_node is not None: + row_node = get_first_child_by_type(matrix_node, "row") + if row_node is not None: + mat = [self.visit(child) for child in + get_keyword_children(row_node)] + mat_idx = 0 + mat_len = len(mat) + + + return Loop( + pre = [ + Assignment( + left = "_mat", + right = mat, + source_refs = [source_ref] + ), + Assignment( + left = "_mat_len", + right = mat_len, + source_refs = [source_ref] + ), + Assignment( + left = "_mat_idx", + right = mat_idx, + source_refs = [source_ref] + ), + Assignment( + left = itr_var, + right = mat[mat_idx], + source_refs = [source_ref] + ) + ], + expr = self.get_operator( + op = "<", + operands = ["_mat_idx", "_mat_len"], + source_refs = [source_ref] + ), + body = [ + Assignment( + left = "_mat_idx", + right = self.get_operator( + op = "+", + operands = ["_mat_idx", 1], + source_refs = [source_ref] + ), + source_refs = [source_ref] + ), + Assignment( + left = itr_var, + right = "_mat[_mat_idx]", + source_refs = [source_ref] + ) + ], + post = [] + ) + + + + # process range iterator + range_node = get_first_child_by_type(node, "range") + if range_node is not None: + numbers = [self.visit(child) for child in + get_children_by_types(range_node, ["number"])] + start = numbers[0] + step = 1 + stop = 0 + if len(numbers) == 2: + stop = numbers[1] + + elif len(numbers) == 3: + step = numbers[1] + stop = numbers[2] + + range_name_node = self.variable_context.get_gromet_function_node("range") + iter_name_node = self.variable_context.get_gromet_function_node("iter") + next_name_node = self.variable_context.get_gromet_function_node("next") + generated_iter_name_node = self.variable_context.generate_iterator() + stop_condition_name_node = self.variable_context.generate_stop_condition() + + return Loop( + pre = [ + Assignment( + left = itr_var, + right = start, + source_refs = [source_ref] + ) + ], + expr = self.get_operator( + op = "<=", + operands = [itr_var, stop], + source_refs = [source_ref] + ), + body = [ + Assignment( + left = itr_var, + right = self.get_operator( + op = "+", + operands = [itr_var, step], + source_refs = [source_ref] + ), + source_refs = [source_ref] + ) + ], + post = [] + ) + + + def visit_range(self, node): + return None + + def visit_for_statement(self, node) -> Loop: + """ Translate Tree-sitter for loop node into CAST Loop node """ + + loop = self.visit(get_first_child_by_type(node, "iterator")) + loop.source_refs=[self.node_helper.get_source_ref(node)] + loop.body = self.get_block(node) + loop.body + + return loop + def visit_matrix(self, node): """ Translate the Tree-sitter cell node into a List """ @@ -328,17 +451,14 @@ def visit_number(self, node) -> LiteralValue: ) def visit_operator(self, node): - """return an Operator based on the Tree-sitter node """ + """return an operator based on the Tree-sitter node """ # The operator will be the first control character op = self.node_helper.get_identifier( get_control_children(node)[0] ) # the operands will be the keyword children operands=[self.visit(child) for child in get_keyword_children(node)] - return Operator( - source_language="matlab", - interpreter=INTERPRETER, - version=MATLAB_VERSION, + return self.get_operator( op = op, operands = operands, source_refs=[self.node_helper.get_source_ref(node)], @@ -363,20 +483,9 @@ def visit_switch_statement(self, node): "string", "unary_operator" ] - - def get_operator(op, operands, source_refs): - """ return an Operator representing the case test """ - return Operator( - source_language = "matlab", - interpreter = INTERPRETER, - version = MATLAB_VERSION, - op = op, - operands = operands, - source_refs = source_refs - ) - def get_case_expression(case_node, identifier): - """ return an Operator representing the case test """ + def get_case_expression(case_node, switch_var): + """ return an operator representing the case test """ source_refs=[self.node_helper.get_source_ref(case_node)] cell_node = get_first_child_by_type(case_node, "cell") # multiple case arguments @@ -387,27 +496,41 @@ def get_case_expression(case_node, identifier): source_code_data_type=["matlab", MATLAB_VERSION, "unknown"], source_refs=[self.node_helper.get_source_ref(cell_node)] ) - return get_operator("in", [identifier, operand], source_refs) + return self.get_operator( + op = "in", + operands = [switch_var, operand], + source_refs = source_refs + ) # single case argument operand = [self.visit(node) for node in get_children_by_types(case_node, case_node_types)][0] - return get_operator("==", [identifier, operand], source_refs) + return self.get_operator( + op = "==", + operands = [switch_var, operand], + source_refs = source_refs + ) - def get_model_if(case_node, identifier): + def get_model_if(case_node, switch_var): """ return conditional logic representing the case """ return ModelIf( - expr = get_case_expression(case_node, identifier), + expr = get_case_expression(case_node, switch_var), body = self.get_block(case_node), orelse = [], source_refs=[self.node_helper.get_source_ref(case_node)] ) - # switch statement identifier - identifier = self.visit(get_first_child_by_type(node, "identifier")) - + # switch variable is usually an identifier + switch_var = get_first_child_by_type(node, "identifier") + if switch_var is not None: + switch_var = self.visit(switch_var) + + # however it can be a function call + else: + switch_var = self.visit(get_first_child_by_type(node, "function_call")) + # n case clauses as 'if then' nodes case_nodes = get_children_by_types(node, ["case_clause"]) - model_ifs = [get_model_if(node, identifier) for node in case_nodes] + model_ifs = [get_model_if(node, switch_var) for node in case_nodes] for i, model_if in enumerate(model_ifs[1:]): model_ifs[i].orelse = [model_if] @@ -426,6 +549,21 @@ def get_block(self, node) -> List[AstNode]: return [self.visit(child) for child in get_keyword_children(block)] + def get_operator(self, op, operands, source_refs): + """ return an operator representing the arguments """ + return Operator( + source_language = "matlab", + interpreter = INTERPRETER, + version = MATLAB_VERSION, + op = op, + operands = operands, + source_refs = source_refs + ) + + def get_gromet_function_node(self, func_name: str) -> Name: + if self.variable_context.is_variable(func_name): + return self.variable_context.get_node(func_name) + # skip control nodes and other junk def _visit_passthrough(self, node): if len(node.children) == 0: diff --git a/skema/program_analysis/CAST/matlab/tests/__init__.py b/skema/program_analysis/CAST/matlab/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/skema/program_analysis/CAST/matlab/tests/test_assignment.py b/skema/program_analysis/CAST/matlab/tests/test_assignment.py index 90393c927bd..633e4dce64d 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_assignment.py +++ b/skema/program_analysis/CAST/matlab/tests/test_assignment.py @@ -7,9 +7,8 @@ def test_boolean(): """ Test assignment of literal boolean types. """ # we translate these MATLAB keywords into capitalized strings for Python - nodes = cast("x = true; y = false") - check(nodes[0], Assignment(left = "x", right = "True")) - check(nodes[1], Assignment(left = "y", right = "False")) + check(cast("x = true")[0], Assignment(left = "x", right = "True")) + check(cast("y = false")[0], Assignment(left = "y", right = "False")) def test_number_zero_integer(): """ Test assignment of integer and real numbers.""" @@ -42,9 +41,20 @@ def test_identifier(): def test_operator(): """ Test assignment of operator""" check( - cast("Vtot = V1PF+V1AZ;")[0], + cast("x = x + 1")[0], Assignment( - left = "Vtot", - right = Operator(op = "+",operands = ["V1PF", "V1AZ"]) + left = "x", + right = Operator(op = "+",operands = ["x", 1]) ) ) + +def test_matrix(): + """ Test assignment of matrix""" + check( + cast("x = [1 cat 'dog' ]")[0], + Assignment( + left = "x", + right = [1, 'cat', "'dog'"] + ) + ) + diff --git a/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py b/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py index 4d32d563ff6..0d988111795 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py +++ b/skema/program_analysis/CAST/matlab/tests/test_file_ingest.py @@ -1,11 +1,18 @@ +import os.path from skema.program_analysis.CAST.matlab.matlab_to_cast import MatlabToCast from skema.program_analysis.CAST.matlab.tests.utils import (check, cast) from skema.program_analysis.CAST2FN.model.cast import Assignment def test_file_ingest(): """ Test the ability of the CAST translator to read from a file""" - filename = "skema/program_analysis/CAST/matlab/tests/data/matlab.m" - cast = MatlabToCast(source_path = filename).out_cast + + filepath = "skema/program_analysis/CAST/matlab/tests/data/matlab.m" + if not os.path.exists(filepath): + filepath = "data/matlab.m" + + + + cast = MatlabToCast(source_path = filepath).out_cast module = cast.nodes[0] nodes = module.body check(nodes[0], Assignment(left = "y", right = "b")) diff --git a/skema/program_analysis/CAST/matlab/tests/test_loop.py b/skema/program_analysis/CAST/matlab/tests/test_loop.py index 354ce0582e8..dcd409691e8 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_loop.py +++ b/skema/program_analysis/CAST/matlab/tests/test_loop.py @@ -1,13 +1,123 @@ from skema.program_analysis.CAST.matlab.tests.utils import (check, cast) -from skema.program_analysis.CAST2FN.model.cast import Loop +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Call, + Loop, + Operator +) -# Test the for loop and others -def no_test_for_loop(): +# Test the for loop incrementing by 1 +def test_implicit_step(): """ Test the MATLAB for loop syntax elements""" source = """ - for n = 1:10 - x = do_something(n) + for n = 0:10 + disp(n) end """ nodes = cast(source) - check(nodes[0], Loop()) + check(nodes[0], + Loop( + pre = [Assignment(left = "n", right = 0)], + expr = Operator(op = "<=", operands = ["n", 10]), + body = [ + Call( + func = "disp", + arguments = ["n"] + ), + Assignment( + left = "n", + right = Operator( + op = "+", + operands = ["n", 1] + ) + ) + ], + post = [] + ) + ) + +# Test the for loop incrementing by n +def test_explicit_step(): + """ Test the MATLAB for loop syntax elements""" + source = """ + for n = 0:2:10 + disp(n) + end + """ + nodes = cast(source) + check(nodes[0], + Loop( + pre = [Assignment(left = "n", right = 0)], + expr = Operator(op = "<=", operands = ["n", 10]), + body = [ + Call( + func = "disp", + arguments = ["n"] + ), + Assignment( + left = "n", + right = Operator( + op = "+", + operands = ["n", 2] + ) + ) + ], + post = [] + ) + ) + + + + +# Test the for loop using matrix steps +def test_matrix(): + """ Test the MATLAB for loop syntax elements""" + source = """ + for k = [10 3 5 6] + disp(k) + end + """ + nodes = cast(source) + check(nodes[0], + Loop( + pre = [ + Assignment( + left = "_mat", + right = [10, 3, 5, 6] + ), + Assignment( + left = "_mat_len", + right = 4 + ), + Assignment( + left = "_mat_idx", + right = 0 + ), + Assignment( + left = "k", + right = 10 + ) + ], + expr = Operator(op = "<", operands = ["_mat_idx", "_mat_len"]), + body = [ + Call( + func = "disp", + arguments = ["k"] + ), + Assignment( + left = "_mat_idx", + right = Operator( + op = "+", + operands = ["_mat_idx", 1] + ) + ), + Assignment( + left = "k", + right = "_mat[_mat_idx]" + ) + ], + post = [] + + ) + ) + diff --git a/skema/program_analysis/CAST/matlab/tests/test_switch.py b/skema/program_analysis/CAST/matlab/tests/test_switch.py index 07582d91f39..e01454cad95 100644 --- a/skema/program_analysis/CAST/matlab/tests/test_switch.py +++ b/skema/program_analysis/CAST/matlab/tests/test_switch.py @@ -1,11 +1,12 @@ from skema.program_analysis.CAST.matlab.tests.utils import (check, cast) from skema.program_analysis.CAST2FN.model.cast import ( Assignment, + Call, ModelIf, Operator ) -def test_case_clause_1_argument(): +def test_1_argument(): """ Test CAST from single argument case clause.""" source = """ switch s @@ -37,7 +38,7 @@ def test_case_clause_1_argument(): ) ) -def test_case_clause_n_arguments(): +def test_n_arguments(): """ Test CAST from multipe argument case clause.""" source = """ @@ -60,3 +61,32 @@ def test_case_clause_n_arguments(): orelse = [Assignment(left="n", right = 0)] ) ) + +def test_call_argument(): + """ Test CAST using the value of a function call """ + + source = """ + switch fd(i,j) + case 0 + x = 5 + end + + """ + # switch statement translated into conditional + check( + cast(source)[0], + ModelIf( + expr = Operator( + op = "==", + operands = [ + Call ( + func = "fd", + arguments = ["i","j"] + ), + 0 + ] + ), + body = [Assignment(left="x", right = 5)], + orelse = [] + ) + ) diff --git a/skema/program_analysis/CAST/matlab/tests/utils.py b/skema/program_analysis/CAST/matlab/tests/utils.py index 8345715cee8..4fc9eba06f3 100644 --- a/skema/program_analysis/CAST/matlab/tests/utils.py +++ b/skema/program_analysis/CAST/matlab/tests/utils.py @@ -37,6 +37,11 @@ def check(result, expected = None): check(result.expr, expected.expr) check(result.body, expected.body) check(result.orelse, expected.orelse) + elif isinstance(result, Loop): + check(result.pre, expected.pre) + check(result.expr, expected.expr) + check(result.body, expected.body) + check(result.post, expected.post) elif isinstance(result, LiteralValue): check(result.value, expected) elif isinstance(result, Var): @@ -48,7 +53,7 @@ def check(result, expected = None): # every CAST node has a source_refs element if isinstance(result, AstNode): - assert not result.source_refs == None + assert result.source_refs is not None # we curently produce a CAST object with a single Module in the nodes list. def cast(source): diff --git a/skema/program_analysis/CAST/matlab/tokens.py b/skema/program_analysis/CAST/matlab/tokens.py index 7dea817f1cb..25b060f09e1 100644 --- a/skema/program_analysis/CAST/matlab/tokens.py +++ b/skema/program_analysis/CAST/matlab/tokens.py @@ -26,6 +26,7 @@ 'function_arguments', 'function_call', 'function_definition', + 'function_output', 'identifier', 'if', 'if_statement', @@ -46,20 +47,19 @@ 'switch_statement', 'unary_operator', - # keywords to be supported + # keywords currently being added 'break_statement', 'continue_statement', - 'field_expression', 'for', 'for_statement', - 'function_output', 'iterator', + 'range', + + # keywords to be supported + 'field_expression', 'lambda', 'line_continuation', 'multioutput_variable', - 'range', - 'while', - 'while_statement' ] """ anything not a keyword """ diff --git a/skema/program_analysis/CAST/matlab/variable_context.py b/skema/program_analysis/CAST/matlab/variable_context.py index 4bc8486b9f2..3d6db267e6a 100644 --- a/skema/program_analysis/CAST/matlab/variable_context.py +++ b/skema/program_analysis/CAST/matlab/variable_context.py @@ -72,3 +72,18 @@ def get_node(self, symbol: str) -> Dict: def get_type(self, symbol: str) -> str: return self.all_symbols[symbol]["type"] + + def get_gromet_function_node(self, func_name: str) -> Name: + if self.is_variable(func_name): + return self.get_node(func_name) + + def generate_iterator(self): + symbol = f"generated_iter_{self.iterator_id}" + self.iterator_id += 1 + return self.add_variable(symbol, "iterator", None) + + def generate_stop_condition(self): + symbol = f"sc_{self.stop_condition_id}" + self.stop_condition_id += 1 + return self.add_variable(symbol, "boolean", None) + From 398bfdf641b6a191d3da4c8e4685a3b8787025f0 Mon Sep 17 00:00:00 2001 From: Liang Zhang <68933075+ualiangzhang@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:15:37 -0700 Subject: [PATCH 10/22] [ISA] Add the ISA service and reorganize the code implementation (#735) ## Summary of Changes - Modified `ISA` based on `MathExpressionTree` instead of `MathExpression` - Added the `ISA` service - Optimize the runtime of the code --- skema/isa/isa_service.py | 19 +- skema/isa/lib.py | 6 +- skema/rest/api.py | 7 + skema/rest/tests/test_isa.py | 78 + skema/skema-rs/mathml/src/expression.rs | 2264 ++++--------------- skema/skema-rs/skema/src/services/mathml.rs | 9 +- 6 files changed, 528 insertions(+), 1855 deletions(-) create mode 100644 skema/rest/tests/test_isa.py diff --git a/skema/isa/isa_service.py b/skema/isa/isa_service.py index 011bac98f96..76a8dbc3c31 100644 --- a/skema/isa/isa_service.py +++ b/skema/isa/isa_service.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- -from fastapi import FastAPI, File +from fastapi import FastAPI, APIRouter from skema.isa.lib import align_mathml_eqs from pydantic import BaseModel +import requests -# Create a web app using FastAPI +from skema.rest.proxies import SKEMA_RS_ADDESS -app = FastAPI() +router = APIRouter() # Model for ISA_Result @@ -15,12 +16,12 @@ class ISA_Result(BaseModel): union_graph: str = None -@app.get("/ping", summary="Ping endpoint to test health of service") -def ping(): - return "The ISA service is running." +@router.get("/healthcheck", summary="Status of ISA service") +async def healthcheck() -> int: + return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code -@app.put("/align-eqns", summary="Align two MathML equations") +@router.put("/align-eqns", summary="Align two MathML equations") async def align_eqns( file1: str, file2: str, mention_json1: str = "", mention_json2: str = "" ) -> ISA_Result: @@ -41,3 +42,7 @@ async def align_eqns( ir.matching_ratio = matching_ratio ir.union_graph = union_graph.to_string() return ir + + +app = FastAPI() +app.include_router(router) diff --git a/skema/isa/lib.py b/skema/isa/lib.py index 12cce55c017..195e0c97532 100644 --- a/skema/isa/lib.py +++ b/skema/isa/lib.py @@ -2,13 +2,14 @@ """ All the functions required by performing incremental structure alignment (ISA) Author: Liang Zhang (liangzh@arizona.edu) -Updated date: August 24, 2023 +Updated date: December 18, 2023 """ import json import warnings from typing import List, Any, Union, Dict from numpy import ndarray from pydot import Dot +from skema.rest.proxies import SKEMA_RS_ADDESS warnings.filterwarnings("ignore") import requests @@ -173,8 +174,9 @@ def generate_graph(file: str = "", render: bool = False) -> pydot.Dot: content = f.read() digraph = requests.put( - "http://localhost:8080/mathml/math-exp-graph", data=content.encode("utf-8") + f"{SKEMA_RS_ADDESS}/mathml/math-exp-graph", data=content.encode("utf-8") ) + if render: src = Source(digraph.text) src.render("doctest-output/mathml_exp_tree", view=True) diff --git a/skema/rest/api.py b/skema/rest/api.py index 03fe324ae4b..09fd461eb9a 100644 --- a/skema/rest/api.py +++ b/skema/rest/api.py @@ -13,6 +13,7 @@ metal_proxy, llm_proxy, ) +from skema.isa import isa_service from skema.img2mml import eqn2mml from skema.skema_py import server as code2fn from skema.gromet.execution_engine import server as execution_engine @@ -139,6 +140,12 @@ tags=["metal"] ) +app.include_router( + isa_service.router, + prefix="/isa", + tags=["isa"] +) + @app.get("/version", tags=["core"], summary="API version") async def version() -> str: diff --git a/skema/rest/tests/test_isa.py b/skema/rest/tests/test_isa.py new file mode 100644 index 00000000000..7aa36605e42 --- /dev/null +++ b/skema/rest/tests/test_isa.py @@ -0,0 +1,78 @@ +import json + +from fastapi.testclient import TestClient +from skema.isa.isa_service import app +import pytest + +client = TestClient(app) + + +@pytest.mark.ci_only +def test_align_eqns(): + """Test case for /align-eqns endpoint.""" + + halfar_dome_eqn = """ + + + + H + + + + t + + + = + + + ( + Γ + + H + + n + + + 2 + + + | + + H + + | + + n + + 1 + + + + H + ) + + """ + mention_json1_content = "" + mention_json2_content = "" + data = { + "file1": halfar_dome_eqn, + "file2": halfar_dome_eqn, + "mention_json1": mention_json1_content, + "mention_json2": mention_json2_content, + } + + endpoint = "/align-eqns" + response = client.put(endpoint, params=data) + expected = 'digraph G {\n0 [color=blue, label="Div(Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H))"];\n1 [color=blue, label="D(1, t)(H)"];\n2 [color=blue, label="Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H)"];\n3 [color=blue, label="Γ"];\n4 [color=blue, label="H^(n+2)"];\n5 [color=blue, label="H"];\n6 [color=blue, label="n+2"];\n7 [color=blue, label="n"];\n8 [color=blue, label="2"];\n9 [color=blue, label="Abs(Grad(H))^(n-1)"];\n10 [color=blue, label="Abs(Grad(H))"];\n11 [color=blue, label="Grad(H)"];\n12 [color=blue, label="n-1"];\n13 [color=blue, label="1"];\n1 -> 0 [color=blue, label="="];\n2 -> 0 [color=blue, label="Div"];\n3 -> 2 [color=blue, label="*"];\n4 -> 2 [color=blue, label="*"];\n5 -> 4 [color=blue, label="^"];\n6 -> 4 [color=blue, label="^"];\n7 -> 6 [color=blue, label="+"];\n8 -> 6 [color=blue, label="+"];\n9 -> 2 [color=blue, label="*"];\n10 -> 9 [color=blue, label="^"];\n11 -> 10 [color=blue, label="Abs"];\n5 -> 11 [color=blue, label="Grad"];\n12 -> 9 [color=blue, label="^"];\n7 -> 12 [color=blue, label="+"];\n13 -> 12 [color=blue, label="-"];\n11 -> 2 [color=blue, label="*"];\n}\n' + + # check status code + assert ( + response.status_code == 200 + ), f"Request was unsuccessful (status code was {response.status_code} instead of 200)" + # check response of matching_ratio + assert ( + json.loads(response.text)["matching_ratio"] == 1.0 + ), f"Response should be 1.0, but instead received {response.text}" + # check response of union_graph + assert ( + json.loads(response.text)["union_graph"] == expected + ), f"Response should be {expected}, but instead received {response.text}" diff --git a/skema/skema-rs/mathml/src/expression.rs b/skema/skema-rs/mathml/src/expression.rs index 410aa0308db..7a438e0f279 100644 --- a/skema/skema-rs/mathml/src/expression.rs +++ b/skema/skema-rs/mathml/src/expression.rs @@ -1,18 +1,12 @@ -use crate::{ - ast::{ - operator::Operator, - Math, MathExpression, - MathExpression::{Mfrac, Mn, Mo, Mover, Msqrt, Msubsup, Msup}, - Mi, Mrow, - }, - petri_net::recognizers::recognize_leibniz_differential_operator, -}; +use crate::ast::{operator::Operator, MathExpression, Mi}; +use crate::parsers::math_expression_tree::MathExpressionTree; use petgraph::{graph::NodeIndex, Graph}; use std::{clone::Clone, collections::VecDeque}; /// Struct for representing mathematical expressions in order to align with source code. pub type MathExpressionGraph<'a> = Graph; +use petgraph::dot::Dot; use std::string::ToString; #[derive(Debug, PartialEq, Eq, Clone)] @@ -23,13 +17,6 @@ pub enum Atom { } /// Intermediate data structure to support the generation of graphs of mathematical expressions -#[derive(Debug, Default, PartialEq, Clone)] -pub struct Expression { - pub ops: Vec, - pub args: Vec, - pub name: String, -} - #[derive(Debug, PartialEq, Clone)] pub enum Expr { Atom(Atom), @@ -40,177 +27,325 @@ pub enum Expr { }, } -/// Check if the fraction is a derivative expressed in Leibniz notation. If yes, mutate it to -/// remove the 'd' prefixes. -pub fn is_derivative( - numerator: &mut Box, - denominator: &mut Box, -) -> bool { - if recognize_leibniz_differential_operator(numerator, denominator).is_ok() { - if let MathExpression::Mrow(Mrow(x)) = &mut **numerator { - x.remove(0); - } - - if let MathExpression::Mrow(Mrow(x)) = &mut **denominator { - x.remove(0); - } - return true; - } - false -} - -/// Identify if there is an implicit multiplication operator, and if so, add an -/// explicit multiplication operator. -fn insert_explicit_multiplication_operator(pre: &mut Expression) { - if pre.args.len() >= pre.ops.len() { - pre.ops.push(Operator::Multiply); +fn is_unary_operator(op: &Operator) -> bool { + match op { + Operator::Sqrt + | Operator::Factorial + | Operator::Exp + | Operator::Grad + | Operator::Div + | Operator::Abs + | Operator::Derivative(_) + | Operator::Sin + | Operator::Cos + | Operator::Tan + | Operator::Sec + | Operator::Csc + | Operator::Cot + | Operator::Arcsin + | Operator::Arccos + | Operator::Arctan + | Operator::Arcsec + | Operator::Arccsc + | Operator::Arccot + | Operator::Mean => true, + _ => false, } } -impl MathExpression { - /// Convert a MathExpression struct to a Expression struct. - pub fn to_expr(self, pre: &mut Expression) { - match self { - MathExpression::Mi(Mi(x)) => { - // Process unary minus operation. - if !pre.args.is_empty() { - // Check the last arg - let args_last_idx = pre.args.len() - 1; - if let Expr::Atom(Atom::Operator(Operator::Subtract)) = &pre.args[args_last_idx] - { - let neg_identifier = format!("-{x}"); - pre.args[args_last_idx] = Expr::Atom(Atom::Identifier(neg_identifier)); - return; - } - } - // deal with the invisible multiply operator - if pre.args.len() >= pre.ops.len() { - pre.ops.push(Operator::Multiply); - } - pre.args - .push(Expr::Atom(Atom::Identifier(x.replace(' ', "")))); +/// Processes a MathExpression under the type of MathExpressionTree::Atom and appends +/// the corresponding LaTeX representation to the provided String. +fn process_atom_expression(expr: &MathExpression, expression: &mut Expr) { + match expr { + // If it's a Ci variant, recursively process its content + MathExpression::Ci(x) => { + process_atom_expression(&x.content, expression); + } + MathExpression::Mi(Mi(id)) => { + if let Expr::Expression { ops, args, name } = expression { + args.push(Expr::Atom(Atom::Identifier(id.replace(' ', "")))); } - Mn(x) => { - insert_explicit_multiplication_operator(pre); - // Remove redundant whitespace - pre.args.push(Expr::Atom(Atom::Number(x.replace(' ', "")))); + } + MathExpression::Mn(number) => { + if let Expr::Expression { ops, args, name } = expression { + args.push(Expr::Atom(Atom::Number(number.replace(' ', "")))); } - Mo(x) => { - // Insert a temporary placeholder identifier to deal with unary minus operation. - // The placeholder will be removed later. - if x == Operator::Subtract && pre.ops.len() > pre.args.len() { - pre.ops.push(x); - pre.args - .push(Expr::Atom(Atom::Identifier("place_holder".to_string()))); - } else { - pre.ops.push(x); - } + } + MathExpression::Msqrt(x) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Sqrt); + process_atom_expression(x, &mut new_expr); } - MathExpression::Mrow(Mrow(xs)) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - pre_exp.ops.push(Operator::Other("".to_string())); - for x in xs { - x.to_expr(&mut pre_exp); - } - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); } - Msubsup(xs1, xs2, xs3) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - pre_exp.ops.push(Operator::Other("".to_string())); - pre_exp.ops.push(Operator::Other("_".to_string())); - xs1.to_expr(&mut pre_exp); - pre_exp.ops.push(Operator::Other("^".to_string())); - xs2.to_expr(&mut pre_exp); - xs3.to_expr(&mut pre_exp); - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + } + MathExpression::Mfrac(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); } - Msqrt(xs) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - pre_exp.ops.push(Operator::Sqrt); - xs.to_expr(&mut pre_exp); - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Divide); + process_atom_expression(x2, &mut new_expr); } - Mfrac(mut xs1, mut xs2) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - if is_derivative(&mut xs1, &mut xs2) { - pre_exp.ops.push(Operator::Other("derivative".to_string())); - } else { - pre_exp.ops.push(Operator::Other("".to_string())); - } - xs1.to_expr(&mut pre_exp); - pre_exp.ops.push(Operator::Divide); - xs2.to_expr(&mut pre_exp); - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); } - Msup(xs1, xs2) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - pre_exp.ops.push(Operator::Other("".to_string())); - xs1.to_expr(&mut pre_exp); - pre_exp.ops.push(Operator::Other("^".to_string())); - xs2.to_expr(&mut pre_exp); - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + } + MathExpression::Msup(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); } - Mover(xs1, xs2) => { - insert_explicit_multiplication_operator(pre); - let mut pre_exp = Expression::default(); - pre_exp.ops.push(Operator::Other("".to_string())); - xs1.to_expr(&mut pre_exp); - xs2.to_expr(&mut pre_exp); - pre_exp.ops.remove(0); - pre.args.push(Expr::Expression { - ops: pre_exp.ops, - args: pre_exp.args, - name: "".to_string(), - }); + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("^".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Msub(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("_".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Msubsup(x1, x2, x3) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); } - _ => { - panic!("Unhandled type!"); + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("_".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("^".to_string())); + process_atom_expression(x3, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Munder(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("under".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Mover(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + process_atom_expression(x1, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("over".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Mtext(x) => { + if let Expr::Expression { ops, args, name } = expression { + args.push(Expr::Atom(Atom::Identifier(x.replace(' ', "")))); } } + MathExpression::Mspace(x) => { + if let Expr::Expression { ops, args, name } = expression { + args.push(Expr::Atom(Atom::Identifier(x.to_string()))); + } + } + MathExpression::AbsoluteSup(x1, x2) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("|.|".to_string())); + process_atom_expression(x1, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("_".to_string())); + process_atom_expression(x2, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + MathExpression::Mrow(vec_me) => { + for me in vec_me.0.iter() { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if let Expr::Expression { ops, args, name } = &mut new_expr { + process_atom_expression(me, &mut new_expr); + } + if let Expr::Expression { ops, args, name } = expression { + args.push(new_expr.clone()); + } + } + } + t => panic!("Unhandled MathExpression: {:?}", t), } +} +impl MathExpressionTree { + /// Convert a MathExpressionTree struct to a Expression struct. + pub fn to_expr(self, expr: &mut Expr) -> &mut Expr { + match self { + MathExpressionTree::Atom(a) => { + process_atom_expression(&a, expr); + } + MathExpressionTree::Cons(head, rest) => { + let mut new_expr = Expr::Expression { + ops: Vec::::new(), + args: Vec::::new(), + name: String::new(), + }; + if is_unary_operator(&head) || (head == Operator::Subtract && rest.len() == 1) { + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(head); + rest[0].clone().to_expr(&mut new_expr); + } + } else { + if let Expr::Expression { ops, args, name } = &mut new_expr { + ops.push(Operator::Other("".to_string())); + for (index, r) in rest.iter().enumerate() { + if index < rest.len() - 1 { + ops.push(head.clone()); + } + } + } + if let Expr::Expression { ops, args, name } = &mut new_expr { + for r in &rest { + r.clone().to_expr(&mut new_expr); + } + } + } + if let Expr::Expression { ops, args, name } = expr { + args.push(new_expr.clone()); + } + } + } + expr + } pub fn to_graph(self) -> MathExpressionGraph<'static> { - let mut pre_exp = Expression { - ops: Vec::::new(), + let mut expr = self.clone(); + let mut pre_exp = Expr::Expression { + ops: vec![Operator::Other("root".to_string())], args: Vec::::new(), name: "root".to_string(), }; - pre_exp.ops.push(Operator::Other("root".to_string())); - self.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - // if need to convert to canonical form, please uncomment the following - // pre_exp.distribute_expr(); - // pre_exp.group_expr(); - // pre_exp.collapse_expr(); - pre_exp.set_name(); - pre_exp.to_graph() + expr.to_expr(&mut pre_exp); + + if let Expr::Expression { ops, args, name } = &mut pre_exp { + for mut arg in args { + if let Expr::Expression { .. } = arg { + arg.group_expr(); + } + } + } + if let Expr::Expression { ops, args, name } = &mut pre_exp { + for mut arg in args { + if let Expr::Expression { .. } = arg { + arg.collapse_expr(); + } + } + } + /// if need to convert to canonical form, please uncomment the following + // if let Expr::Expression {ops, args, name} = &mut pre_exp { + // for mut arg in args { + // if let Expr::Expression { .. } = arg { + // arg.distribute_expr(); + // } + // } + // } + // if let Expr::Expression {ops, args, name} = &mut pre_exp { + // for mut arg in args { + // if let Expr::Expression { .. } = arg { + // arg.group_expr(); + // } + // } + // } + // if let Expr::Expression {ops, args, name} = &mut pre_exp { + // for mut arg in args { + // if let Expr::Expression { .. } = arg { + // arg.collapse_expr(); + // } + // } + // } + if let Expr::Expression { ops, args, name } = &mut pre_exp { + for mut arg in args { + if let Expr::Expression { .. } = arg { + arg.set_name(); + } + } + } + let mut g = MathExpressionGraph::new(); + if let Expr::Expression { ops, args, name } = &mut pre_exp { + for mut arg in args { + if let Expr::Expression { .. } = arg { + arg.to_graph(&mut g); + } + } + } + g } } @@ -519,7 +654,7 @@ impl Expr { Atom::Operator(_) => {} }, Expr::Expression { ops, .. } => { - let mut string; + let mut string = "".to_string(); if ops[0] != Operator::Other("".to_string()) { string = ops[0].to_string(); string.push('('); @@ -1008,51 +1143,6 @@ pub fn need_to_distribute(ops: Vec) -> bool { false } -impl Expression { - pub fn group_expr(&mut self) { - for arg in &mut self.args { - if let Expr::Expression { .. } = arg { - arg.group_expr(); - } - } - } - - pub fn collapse_expr(&mut self) { - for arg in &mut self.args { - if let Expr::Expression { .. } = arg { - arg.collapse_expr(); - } - } - } - - #[allow(dead_code)] // used in tests I believe - fn distribute_expr(&mut self) { - for arg in &mut self.args { - if let Expr::Expression { .. } = arg { - arg.distribute_expr(); - } - } - } - - pub fn set_name(&mut self) { - for arg in &mut self.args { - if let Expr::Expression { .. } = arg { - arg.set_name(); - } - } - } - - pub fn to_graph(&mut self) -> MathExpressionGraph { - let mut g = MathExpressionGraph::new(); - for arg in &mut self.args { - if let Expr::Expression { .. } = arg { - arg.to_graph(&mut g); - } - } - g - } -} - /// Remove redundant parentheses. pub fn remove_redundant_parens(string: &mut String) -> &mut String { while contains_redundant_parens(string) { @@ -1084,1638 +1174,134 @@ pub fn get_node_idx(graph: &mut MathExpressionGraph, name: &mut String) -> NodeI graph.add_node(name.to_string()) } -/// Remove redundant mrow next to specific MathML elements. This function will likely be removed -/// once the img2mml pipeline is fixed. -pub fn remove_redundant_mrow(mml: String, key_word: String) -> String { - let mut content = mml; - let key_words_left = "".to_string() + &*key_word.clone(); - let mut key_word_right = key_word.clone(); - key_word_right.insert(1, '/'); - let key_words_right = key_word_right.clone() + ""; - let locs: Vec<_> = content - .match_indices(&key_words_left) - .map(|(i, _)| i) - .collect(); - for loc in locs.iter().rev() { - if content[loc + 1..].contains(&key_words_right) { - let l = content[*loc..].find(&key_word_right).map(|i| i + *loc); - if let Some(x) = l { - if content.len() > (x + key_words_right.len()) - && content[x..x + key_words_right.len()] == key_words_right - { - content.replace_range(x..x + key_words_right.len(), key_word_right.as_str()); - content.replace_range(*loc..*loc + key_words_left.len(), key_word.as_str()); - } - } - } - } - content -} - -/// Remove redundant mrows in mathml because some mathml elements don't need mrow to wrap. This -/// function will likely be removed -/// once the img2mml pipeline is fixed. -pub fn remove_redundant_mrows(mathml_content: String) -> String { - let mut content = mathml_content; - content = content.replace("", "("); - content = content.replace("", ")"); - let f = |b: &[u8]| -> Vec { - let v = (0..) - .zip(b) - .scan(vec![], |a, (b, c)| { - Some(match c { - 40 => { - a.push(b); - None - } - 41 => Some((a.pop()?, b)), - _ => None, - }) - }) - .flatten() - .collect::>(); - for k in &v { - if k.0 == 0 && k.1 == b.len() - 1 { - return b[1..b.len() - 1].to_vec(); - } - for l in &v { - if l.0 == k.0 + 1 && l.1 == k.1 - 1 { - return [&b[..k.0], &b[l.0..k.1], &b[k.1 + 1..]].concat(); - } - } - } - b.to_vec() - }; - let g = |mut b: Vec| { - while f(&b) != b { - b = f(&b) - } - b - }; - content = std::str::from_utf8(&g(content.bytes().collect())) - .unwrap() - .to_string(); - content = content.replace('(', ""); - content = content.replace(')', ""); - content = remove_redundant_mrow(content, "".to_string()); - content = remove_redundant_mrow(content, "".to_string()); - content = remove_redundant_mrow(content, "".to_string()); - content = remove_redundant_mrow(content, "".to_string()); - content -} - -/// Preprocess the content prior to parsing. -pub fn preprocess_content(content_str: String) -> String { - let mut pre_string = content_str; - pre_string = pre_string.replace(' ', ""); - pre_string = pre_string.replace('\n', ""); - pre_string = pre_string.replace('\t', ""); - pre_string = pre_string.replace("(t)", ""); - pre_string = pre_string.replace(",", ""); - pre_string = pre_string.replace("(", ""); - pre_string = pre_string.replace(")", ""); - - // Unicode to Symbol - let unicode_locs: Vec<_> = pre_string.match_indices("&#").map(|(i, _)| i).collect(); - for ul in unicode_locs.iter().rev() { - let loc = pre_string[*ul..].find('<').map(|i| i + ul); - match loc { - None => {} - Some(_x) => {} - } - } - pre_string = html_escape::decode_html_entities(&pre_string).to_string(); - pre_string = pre_string.replace( - &html_escape::decode_html_entities("−").to_string(), - "-", - ); - pre_string = remove_redundant_mrows(pre_string); - pre_string -} - -/// Wrap mathml vectors by mrow as a single expression to process -pub fn wrap_math(math: Math) -> MathExpression { - let mut math_vec = vec![]; - for con in math.content { - math_vec.push(con); - } - - MathExpression::Mrow(Mrow(math_vec)) -} - -#[test] -fn test_to_expr() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - - if let Expr::Expression { ops, args, .. } = &pre_exp.args[0] { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } -} - -#[test] -fn test_to_expr2() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - Mn("4".to_string()), - MathExpression::Mi(Mi("c".to_string())), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - - math_expression.to_expr(&mut pre_exp); - pre_exp.ops.push(Operator::Other("root".to_string())); - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - match &args[2] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Number("4".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - } - } -} - -#[test] -fn test_to_expr3() { - let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])))); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Sqrt); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } - } - } - } -} - -#[test] -fn test_to_expr4() { - let math_expression = Mfrac( - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mi(Mi("c".to_string()))), - ); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Divide); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } - } - match &args[1] { - Expr::Atom(_x) => { - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - Expr::Expression { .. } => {} - } - } - } -} - -#[test] -fn test_to_expr5() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } - } - } -} - -#[test] -fn test_to_expr6() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("e".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("f".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("g".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("h".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("h".to_string()))); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(ops[3], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("e".to_string()))); - } - } - match &args[2] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("f".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("g".to_string()))); - } - } - } - } -} - -#[test] -fn test_to_expr7() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(name, "(a+b*c)"); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(name, "b*c"); - } - } - } - } -} - -#[test] -fn test_to_expr8() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("e".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("f".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("g".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("h".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("h".to_string()))); - assert_eq!(name, "(a+b*c*d/e-f*g-h)"); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(ops[3], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("d".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("e".to_string()))); - assert_eq!(name, "b*c*d/e"); - } - } - match &args[2] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("f".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("g".to_string()))); - assert_eq!(name, "f*g"); - } - } - } - } -} - -#[test] -fn test_to_expr9() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(name, "(a+b*(c-d))"); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(name, "b*(c-d)"); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, name } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - assert_eq!(name, "(c-d)"); - } - } - } - } - } - } -} - -#[test] -fn test_to_expr10() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr11() { - let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - ])))); - - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr12() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("e".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("f".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("g".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("h".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr13() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr14() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr15() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Subtract), - Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])))), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.set_name(); - let _g = pre_exp.to_graph(); -} - -#[test] -fn test_to_expr16() { - let math_expression = Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - ])))); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr17() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("s".to_string())), - Mo(Operator::Equals), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr18() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("s".to_string())), - Mo(Operator::Equals), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Subtract), - Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - ])))), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr19() { - let input = "tests/sir.xml"; - let contents = std::fs::read_to_string(input) - .unwrap_or_else(|_| panic!("{}", "Unable to read file {input}!")); - let mut math = contents - .parse::() - .unwrap_or_else(|_| panic!("{}", "Unable to parse file {input}!")); - math.normalize(); - let _g = &mut math.content[0].clone().to_graph(); -} - -#[test] -fn test_to_expr20() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("s".to_string())), - Mo(Operator::Equals), - Mfrac( - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - MathExpression::Mi(Mi("d".to_string())), - Msqrt(Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ])))), - ]))), - ), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr21() { - let math_expression = Msup( - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mi(Mi("c".to_string()))), - ); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Other("^".to_string())); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } - } - match &args[1] { - Expr::Atom(_x) => { - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - Expr::Expression { .. } => {} - } - } - } -} - -#[test] -fn test_to_expr22() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - Msup( - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ]))), - ), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr23() { - let math_expression = MathExpression::Mrow(Mrow(vec![Msubsup( - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("d".to_string())), - ]))), - Box::from(MathExpression::Mi(Mi("c".to_string()))), - )])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr24() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("c".to_string())), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr25() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr26() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr27() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr28() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])); - let _g = math_expression.to_graph(); -} - -#[test] -fn test_to_expr29() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - Msup( - Box::from(MathExpression::Mrow(Mrow(vec![ - Mo(Operator::Subtract), - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ]))), - Box::from(MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ]))), - ), - ])); - let _g = math_expression.to_graph(); -} - -#[cfg(test)] -fn get_preprocessed_normalized_math_from_file(filename: &str) -> Math { - let mut contents = std::fs::read_to_string(filename) - .unwrap_or_else(|_| panic!("{}", "Unable to read file {input}!")); - contents = preprocess_content(contents); - let math = &mut contents - .parse::() - .unwrap_or_else(|_| panic!("{}", "Unable to parse file {input}!")); - math.normalize(); - math.clone() -} -#[test] -fn test_to_expr30() { - let math = get_preprocessed_normalized_math_from_file("tests/seir_eq1.xml"); - let mut math_vec = vec![]; - for con in math.content { - math_vec.push(con); - } - let new_math = MathExpression::Mrow(Mrow(math_vec)); - let _g = new_math.to_graph(); -} - #[test] -fn test_to_expr32() { - let math = get_preprocessed_normalized_math_from_file("tests/seirdv_eq7.xml"); - let new_math = wrap_math(math); - let _g = new_math.to_graph(); -} - -#[test] -fn test_to_expr33() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("c".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } - } - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } - } - } +fn test_plus_to_graph() { + let input = " + + + a + + + b + + + "; + let exp = input.parse::().unwrap(); + let g = exp.to_graph(); + let dot_representation = Dot::new(&g); + assert_eq!( + dot_representation + .to_string() + .replace("\n", "") + .replace(" ", ""), + "digraph{0[label=\"a+b\"]1[label=\"a\"]2[label=\"b\"]1->0[label=\"+\"]2->0[label=\"+\"]}" + ) } #[test] -fn test_to_expr34() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Divide); - - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - } - } - - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - } - } +fn test_equation_halfar_dome_8_1_to_graph() { + let input = " + + + + + H + + + + t + + + = + + + ( + Γ + + H + + n + + + 2 + + + | + + H + + | + + n + + 1 + + + + H + ) + + "; + + let exp = input.parse::().unwrap(); + let g = exp.to_graph(); + let dot_representation = Dot::new(&g); + assert_eq!(dot_representation.to_string() + .replace("\n", "") + .replace(" ", ""), + "digraph{0[label=\"Div(Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H))\"]1[label=\"D(1,t)(H)\"]2[label=\"Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H)\"]3[label=\"Γ\"]4[label=\"H^(n+2)\"]5[label=\"H\"]6[label=\"n+2\"]7[label=\"n\"]8[label=\"2\"]9[label=\"Abs(Grad(H))^(n-1)\"]10[label=\"Abs(Grad(H))\"]11[label=\"Grad(H)\"]12[label=\"n-1\"]13[label=\"1\"]1->0[label=\"=\"]2->0[label=\"Div\"]3->2[label=\"*\"]4->2[label=\"*\"]5->4[label=\"^\"]6->4[label=\"^\"]7->6[label=\"+\"]8->6[label=\"+\"]9->2[label=\"*\"]10->9[label=\"^\"]11->10[label=\"Abs\"]5->11[label=\"Grad\"]12->9[label=\"^\"]7->12[label=\"+\"]13->12[label=\"-\"]11->2[label=\"*\"]}"); } #[test] -fn test_to_expr35() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("c".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } -} - -#[test] -fn test_to_expr36() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("b".to_string())), - ])), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("c".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } -} - -#[test] -fn test_to_expr37() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])), - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ])), - Mo(Operator::Add), - MathExpression::Mi(Mi("e".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Add); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(ops[3], Operator::Subtract); - assert_eq!(ops[4], Operator::Add); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("d".to_string()))); - assert_eq!(args[4], Expr::Atom(Atom::Identifier("e".to_string()))); - } - } -} - -#[test] -fn test_to_expr38() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("c".to_string())), - ])), - Mo(Operator::Add), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(ops[3], Operator::Subtract); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } -} - -#[test] -fn test_to_expr39() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("c".to_string())), - ])), - Mo(Operator::Multiply), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Divide); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(ops[3], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[3], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } -} - -#[test] -fn test_to_expr40() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("c".to_string())), - ])), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("e".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(ops[3], Operator::Subtract); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("e".to_string()))); - } - } - match &args[2] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("e".to_string()))); - } - } - } - } - } - } -} - -#[test] -fn test_to_expr41() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("b".to_string())), - ])), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("c".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("d".to_string())), - ])), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(ops[3], Operator::Subtract); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } - match &args[1] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - match &args[2] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("c".to_string()))); - } - } - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("b".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - } - } - } - } -} - -#[test] -fn test_to_expr42() { - let math_expression = MathExpression::Mrow(Mrow(vec![ - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("a".to_string())), - Mo(Operator::Subtract), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("b".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("c".to_string())), - ])), - ])), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("f".to_string())), - Mo(Operator::Add), - MathExpression::Mi(Mi("g".to_string())), - ])), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("d".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("e".to_string())), - ])), - Mo(Operator::Divide), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("h".to_string())), - Mo(Operator::Subtract), - MathExpression::Mi(Mi("i".to_string())), - ])), - Mo(Operator::Multiply), - MathExpression::Mrow(Mrow(vec![ - MathExpression::Mi(Mi("j".to_string())), - Mo(Operator::Divide), - MathExpression::Mi(Mi("k".to_string())), - ])), - Mo(Operator::Add), - MathExpression::Mi(Mi("l".to_string())), - ])); - let mut pre_exp = Expression { - ops: Vec::::new(), - args: Vec::::new(), - name: "root".to_string(), - }; - pre_exp.ops.push(Operator::Other("root".to_string())); - math_expression.to_expr(&mut pre_exp); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.distribute_expr(); - pre_exp.group_expr(); - pre_exp.collapse_expr(); - pre_exp.set_name(); - - match &pre_exp.args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Subtract); - assert_eq!(ops[3], Operator::Add); - assert_eq!(ops[4], Operator::Subtract); - assert_eq!(ops[5], Operator::Add); - assert_eq!(ops[6], Operator::Add); - assert_eq!(args[6], Expr::Atom(Atom::Identifier("l".to_string()))); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(ops[3], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("a".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("d".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("j".to_string()))); - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(ops[3], Operator::Subtract); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!( - args[0], - Expr::Atom(Atom::Identifier("f".to_string())) - ); - assert_eq!( - args[1], - Expr::Atom(Atom::Identifier("h".to_string())) - ); - assert_eq!( - args[2], - Expr::Atom(Atom::Identifier("k".to_string())) - ); - } - } - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!( - args[0], - Expr::Atom(Atom::Identifier("g".to_string())) - ); - assert_eq!( - args[1], - Expr::Atom(Atom::Identifier("i".to_string())) - ); - assert_eq!( - args[2], - Expr::Atom(Atom::Identifier("k".to_string())) - ); - } - } - } - } - } - } - match &args[5] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!(ops[3], Operator::Divide); - assert_eq!(args[0], Expr::Atom(Atom::Identifier("c".to_string()))); - assert_eq!(args[1], Expr::Atom(Atom::Identifier("e".to_string()))); - assert_eq!(args[2], Expr::Atom(Atom::Identifier("j".to_string()))); - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Subtract); - assert_eq!(ops[2], Operator::Add); - assert_eq!(ops[3], Operator::Subtract); - match &args[0] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!( - args[0], - Expr::Atom(Atom::Identifier("f".to_string())) - ); - assert_eq!( - args[1], - Expr::Atom(Atom::Identifier("h".to_string())) - ); - assert_eq!( - args[2], - Expr::Atom(Atom::Identifier("k".to_string())) - ); - } - } - match &args[3] { - Expr::Atom(_) => {} - Expr::Expression { ops, args, .. } => { - assert_eq!(ops[0], Operator::Other("".to_string())); - assert_eq!(ops[1], Operator::Multiply); - assert_eq!(ops[2], Operator::Multiply); - assert_eq!( - args[0], - Expr::Atom(Atom::Identifier("g".to_string())) - ); - assert_eq!( - args[1], - Expr::Atom(Atom::Identifier("i".to_string())) - ); - assert_eq!( - args[2], - Expr::Atom(Atom::Identifier("k".to_string())) - ); - } - } - } - } - } - } - } - } +fn test_equation_sidarthe_1_to_graph() { + let input = " + + + + S + ˙ + + + ( + t + ) + = + + S + ( + t + ) + ( + α + I + ( + t + ) + + + β + D + ( + t + ) + + + γ + A + ( + t + ) + + + δ + R + ( + t + ) + ) + + "; + + let exp = input.parse::().unwrap(); + let g = exp.to_graph(); + let dot_representation = Dot::new(&g); + assert_eq!(dot_representation.to_string() + .replace("\n", "") + .replace(" ", ""), + "digraph{0[label=\"-(S)*(α*I+β*D+γ*A+δ*R)\"]1[label=\"D(1,t)(S)\"]2[label=\"-(S)\"]3[label=\"S\"]4[label=\"α*I+β*D+γ*A+δ*R\"]5[label=\"α*I\"]6[label=\"α\"]7[label=\"I\"]8[label=\"β*D\"]9[label=\"β\"]10[label=\"D\"]11[label=\"γ*A\"]12[label=\"γ\"]13[label=\"A\"]14[label=\"δ*R\"]15[label=\"δ\"]16[label=\"R\"]1->0[label=\"=\"]2->0[label=\"*\"]3->2[label=\"-\"]4->0[label=\"*\"]5->4[label=\"+\"]6->5[label=\"*\"]7->5[label=\"*\"]8->4[label=\"+\"]9->8[label=\"*\"]10->8[label=\"*\"]11->4[label=\"+\"]12->11[label=\"*\"]13->11[label=\"*\"]14->4[label=\"+\"]15->14[label=\"*\"]16->14[label=\"*\"]}"); } diff --git a/skema/skema-rs/skema/src/services/mathml.rs b/skema/skema-rs/skema/src/services/mathml.rs index 468bcfd2482..c3bf7ff693c 100644 --- a/skema/skema-rs/skema/src/services/mathml.rs +++ b/skema/skema-rs/skema/src/services/mathml.rs @@ -11,8 +11,6 @@ use mathml::parsers::math_expression_tree::{ use mathml::{ acset::{AMRmathml, PetriNet, RegNet}, - ast::Math, - expression::{preprocess_content, wrap_math}, parsers::first_order_ode::{first_order_ode, FirstOrderODE}, }; use petgraph::dot::{Config, Dot}; @@ -57,11 +55,8 @@ pub async fn get_ast_graph(payload: String) -> String { #[put("/mathml/math-exp-graph")] pub async fn get_math_exp_graph(payload: String) -> String { let mut contents = payload; - contents = preprocess_content(contents); - let mut math = contents.parse::().unwrap(); - math.normalize(); - let new_math = wrap_math(math); - let g = new_math.to_graph(); + let exp = contents.parse::().unwrap(); + let g = exp.to_graph(); let dot_representation = Dot::new(&g); dot_representation.to_string() } From 3e38c4810dec121f4d8296d5bbe5eea021eb4c1c Mon Sep 17 00:00:00 2001 From: titomeister Date: Thu, 21 Dec 2023 21:54:02 -0700 Subject: [PATCH 11/22] Skema Docs CAST Development Notes (#736) ## Summary of Changes This PR makes a small modification to the development notes that live in the skema docs website. It adds a new small section detailing notes on how to use Var and Name constructs. More things related to CAST development will be added to this section when the need arises. Resolves #733 --- docs/dev/cast_frontend.md | 9 +++++++++ mkdocs.yml | 1 + 2 files changed, 10 insertions(+) create mode 100644 docs/dev/cast_frontend.md diff --git a/docs/dev/cast_frontend.md b/docs/dev/cast_frontend.md new file mode 100644 index 00000000000..b49e01d4364 --- /dev/null +++ b/docs/dev/cast_frontend.md @@ -0,0 +1,9 @@ +## CAST FrontEnd Generation Notes +### Using Var vs Name nodes +Currently in the CAST generation we have a convention on when to use Var and Name nodes. +The GroMEt generation depends on these being conistent, otherwise there will be errors in the generation. +In the future this convention might change, or be eliminated altogether, but for now this is the current set of rules. + +- If the variable in question is being stored into (i.e. as the result of an assignment), then we use Var. Even if it's a variable that has already been defined. +- If the variable in question is being read from (i.e. being used in an expression), then we use Name. +- Whenever we're creating a function call Call() node, the name of the function is specified using the Name node. diff --git a/mkdocs.yml b/mkdocs.yml index 84d88cd8272..d338cf679bf 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -45,6 +45,7 @@ nav: - Generating code2fn model coverage reports: "dev/generating_code2fn_model_coverage.md" - Using code ingestion frontends: "dev/using_code_ingestion_frontends.md" - Using tree-sitter preprocessor: "dev/using_tree_sitter_preprocessor.md" + - CAST Front-end generation: "dev/cast_frontend.md" - Coverage: - Code2fn coverage reports: "coverage/code2fn_coverage/report.html" - TA1 Integration Dashboard: "https://integration-dashboard.terarium.ai/TA1" From 97f7e508361107f2c3511d0e0a51c69b06c6d538 Mon Sep 17 00:00:00 2001 From: titomeister Date: Thu, 21 Dec 2023 22:14:31 -0700 Subject: [PATCH 12/22] Gromet Wiring Detector PR (#732) This PR adds a new script `skema/program_analysis/gromet_wire_diagnosis.py` that can be used to do some simple analysis and error detecting in the wires of GroMEt FNs. It currently checks the ports of all types of wires, and detects whether the ports are out of bounds (in either negative or positive indices) within their respective port tables. It also attempts to find the most relevant SourceCodeReference metadata that is associated with the wires and displays the line number information contained within it. ## Summary of Changes - Adds `skema/program_analysis/gromet_wire_diagnosis.py` script - Modifies `skema/program_analysis/JSON2GroMEt/json2gromet.py` script so that it can ingest newer GroMEt JSON that uses the updated Gromet metadata fields. - Fixes a small issue with the incorrect SourceCodeReference metadata type being used in `skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py` - Adds a test script `skema/program_analysis/tests/test_wiring_diagnosis.py` that can test the consistency of the individual wire checker utility without needing a GroMEt JSON. ### Potential Next steps - Need to determine more things that can be easily analyzed in a GroMEt JSON - Come up with a more robust way of determining what line numbers go with the wires. - Determine what to do with the Metadata fields that don't have an "is_metadatum" field attached to them. (NOTE: A solution to this has been currently proposed.) Resolves #697 --------- Co-authored-by: Vincent Raymond --- skema/gromet/fn/gromet_fn_module.py | 4 +- .../CAST2FN/ann_cast/to_gromet_pass.py | 1 - .../JSON2GroMEt/json2gromet.py | 12 +- .../program_analysis/gromet_wire_diagnosis.py | 208 ++++++++++++++++++ .../tests/test_wiring_diagnosis.py | 28 +++ 5 files changed, 245 insertions(+), 8 deletions(-) create mode 100644 skema/program_analysis/gromet_wire_diagnosis.py create mode 100644 skema/program_analysis/tests/test_wiring_diagnosis.py diff --git a/skema/gromet/fn/gromet_fn_module.py b/skema/gromet/fn/gromet_fn_module.py index ec1a94f11e3..2d36d86e703 100644 --- a/skema/gromet/fn/gromet_fn_module.py +++ b/skema/gromet/fn/gromet_fn_module.py @@ -191,7 +191,7 @@ def fn_array(self, fn_array): def metadata_collection(self): """Gets the metadata_collection of this GrometFNModule. # noqa: E501 - Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GroMEt object. # noqa: E501 + Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GrometFNModule object. # noqa: E501 :return: The metadata_collection of this GrometFNModule. # noqa: E501 :rtype: list[list[Metadata]] @@ -202,7 +202,7 @@ def metadata_collection(self): def metadata_collection(self, metadata_collection): """Sets the metadata_collection of this GrometFNModule. - Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GroMEt object. # noqa: E501 + Table (array) of lists (arrays) of metadata, where each list in the Table-array represents the collection of metadata associated with a GrometFNModule object. # noqa: E501 :param metadata_collection: The metadata_collection of this GrometFNModule. # noqa: E501 :type: list[list[Metadata]] diff --git a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py index 33e9ed71661..f3dc35ea75d 100644 --- a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py +++ b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py @@ -3239,7 +3239,6 @@ def visit_literal_value( ) code_data_metadata = SourceCodeDataType( - gromet_type="source_code_data_type", provenance=generate_provenance(), source_language=ref[0], source_language_version=ref[1], diff --git a/skema/program_analysis/JSON2GroMEt/json2gromet.py b/skema/program_analysis/JSON2GroMEt/json2gromet.py index 2c174e72120..59f8df020c2 100644 --- a/skema/program_analysis/JSON2GroMEt/json2gromet.py +++ b/skema/program_analysis/JSON2GroMEt/json2gromet.py @@ -32,8 +32,10 @@ def json_to_gromet(path: str) -> GrometFNModuleCollection: sys.modules["skema.gromet.metadata"], inspect.isclass ): instance = metadata_object() - if "metadata_type" in instance.attribute_map: - gromet_metadata_map[instance.metadata_type] = metadata_object + if "is_metadatum" in instance.attribute_map and instance.is_metadatum: + gromet_metadata_map[metadata_name] = metadata_object + else: + gromet_fn_map[metadata_name] = metadata_object def get_obj_type(obj: Dict) -> Any: """Given a dictionary representing a Gromet object (i.e. BoxFunction), return an instance of that object. @@ -42,10 +44,10 @@ def get_obj_type(obj: Dict) -> Any: # First check if we already have a mapping to a data-class memeber. All Gromet FN and most Gromet Metadata classes will fall into this category. # There are a few Gromet Metadata fields such as Provenance that do not have a "metadata_type" field - if "gromet_type" in obj: + if "gromet_type" in obj and ("is_metadatum" not in obj or obj["is_metadatum"] != True): return gromet_fn_map[obj["gromet_type"]]() - elif "metadata_type" in obj: - return gromet_metadata_map[obj["metadata_type"]]() + elif obj["is_metadatum"]: + return gromet_metadata_map[obj["gromet_type"]]() # If there is not a mapping to an object, we will check the fields to see if they match an existing class in the data-model. # For example: (id, box, metadata) would map to GrometPort diff --git a/skema/program_analysis/gromet_wire_diagnosis.py b/skema/program_analysis/gromet_wire_diagnosis.py new file mode 100644 index 00000000000..b4ae814551e --- /dev/null +++ b/skema/program_analysis/gromet_wire_diagnosis.py @@ -0,0 +1,208 @@ +import argparse +from skema.program_analysis.JSON2GroMEt import json2gromet +from skema.gromet.metadata import SourceCodeReference + +# Ways to expand +# Check loop, condition FN indices +# Check bf call FN indices +# Boxes associated with ports + +def disp_wire(wire): + return f"src:{wire.src}<-->tgt:{wire.tgt}" + +def get_length(gromet_item): + # For any gromet object we can generically retrieve the length, since they all exist + # in lists + return len(gromet_item) if gromet_item != None else 0 + +def check_wire(gromet_wire, src_port_count, tgt_port_count, wire_type = "", metadata=None): + # The current wiring checks are + # Checking if the ports on both ends of the wire are below or over the bounds + error_detected = False + if gromet_wire.src < 0: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has negative src port.") + if gromet_wire.src == 0: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has zero src port.") + if gromet_wire.src > src_port_count: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has a src port that goes over the boundary of {src_port_count} src ports.") + + if gromet_wire.tgt < 0: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has negative tgt port.") + if gromet_wire.tgt == 0: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has zero tgt port.") + if gromet_wire.tgt > tgt_port_count: + error_detected = True + print(f"Gromet Wire {wire_type} {disp_wire(gromet_wire)} has a tgt port that goes over the boundary of {tgt_port_count} tgt ports.") + + + if error_detected: + if metadata == None: + print("No line number information exists for this particular wire!") + else: + print(f"Wire is associated with source code lines start:{metadata.line_begin} end:{metadata.line_end}") + print() + + return error_detected + +def find_metadata_idx(gromet_fn): + """ + Attempts to find a metadata associated with this fn + If it finds something, return it, otherwise return None + """ + if gromet_fn.b != None: + for b in gromet_fn.b: + if b.metadata != None: + return b.metadata + + if gromet_fn.bf != None: + for bf in gromet_fn.bf: + if bf.metadata != None: + return bf.metadata + + return None + +def analyze_fn_wiring(gromet_fn, metadata_collection): + # Acquire information for all the ports, if they exist + pif_length = get_length(gromet_fn.pif) + pof_length = get_length(gromet_fn.pof) + opi_length = get_length(gromet_fn.opi) + opo_length = get_length(gromet_fn.opo) + pil_length = get_length(gromet_fn.pil) + pol_length = get_length(gromet_fn.pol) + pic_length = get_length(gromet_fn.pic) + poc_length = get_length(gromet_fn.poc) + + # Find a SourceCodeReference metadata that we can extract line number information for + # so we can display some line number information about potential errors in the wiring + # NOTE: Can we make this extraction more accurate? + metadata_idx = find_metadata_idx(gromet_fn) + metadata = None + if metadata_idx != None: + for md in metadata_collection[metadata_idx - 1]: + if isinstance(md, SourceCodeReference): + metadata = md + + wopio_length = get_length(gromet_fn.wopio) + if wopio_length > 0: + for wire in gromet_fn.wff: + check_wire(wire, opo_length, opi_length, "wff", metadata) + + ######################## loop (bl) wiring + + wlopi_length = get_length(gromet_fn.wlopi) + if wlopi_length > 0: + for wire in gromet_fn.wlopi: + check_wire(wire, pil_length, opi_length, "wlopi", metadata) + + wll_length = get_length(gromet_fn.wll) + if wll_length > 0: + for wire in gromet_fn.wll: + check_wire(wire, pil_length, pol_length, "wll", metadata) + + wlf_length = get_length(gromet_fn.wlf) + if wlf_length > 0: + for wire in gromet_fn.wlf: + check_wire(wire, pif_length, pol_length, "wlf", metadata) + + wlc_length = get_length(gromet_fn.wlc) + if wlc_length > 0: + for wire in gromet_fn.wlc: + check_wire(wire, pic_length, pol_length, "wlc", metadata) + + wlopo_length = get_length(gromet_fn.wlopo) + if wlopo_length > 0: + for wire in gromet_fn.wlopo: + check_wire(wire, opo_length, pol_length, "wlopo", metadata) + + ######################## function (bf) wiring + wfopi_length = get_length(gromet_fn.wfopi) + if wfopi_length > 0: + for wire in gromet_fn.wfopi: + check_wire(wire, pif_length, opi_length, "wfopi", metadata) + + wfl_length = get_length(gromet_fn.wfl) + if wfl_length > 0: + for wire in gromet_fn.wfl: + check_wire(wire, pil_length, pof_length, "wfl", metadata) + + wff_length = get_length(gromet_fn.wff) + if wff_length > 0: + for wire in gromet_fn.wff: + check_wire(wire, pif_length, pof_length, "wff", metadata) + + wfc_length = get_length(gromet_fn.wfc) + if wfc_length > 0: + for wire in gromet_fn.wfc: + check_wire(wire, pic_length, pof_length, "wfc", metadata) + + wfopo_length = get_length(gromet_fn.wfopo) + if wfopo_length > 0: + for wire in gromet_fn.wfopo: + check_wire(wire, opo_length, pof_length, "wfopo", metadata) + + ######################## condition (bc) wiring + wcopi_length = get_length(gromet_fn.wcopi) + if wcopi_length > 0: + for wire in gromet_fn.wcopi: + check_wire(wire, pic_length, opi_length, "wcopi", metadata) + + wcl_length = get_length(gromet_fn.wcl) + if wcl_length > 0: + for wire in gromet_fn.wcl: + check_wire(wire, pil_length, poc_length, "wcl", metadata) + + wcf_length = get_length(gromet_fn.wcf) + if wcf_length > 0: + for wire in gromet_fn.wcf: + check_wire(wire, pif_length, poc_length, "wcf", metadata) + + wcc_length = get_length(gromet_fn.wcc) + if wcc_length > 0: + for wire in gromet_fn.wcc: + check_wire(wire, pic_length, poc_length, "wcc", metadata) + + wcopo_length = get_length(gromet_fn.wcopo) + if wcopo_length > 0: + for wire in gromet_fn.wcopo: + check_wire(wire, opo_length, poc_length, "wcopo", metadata) + + +def wiring_analyzer(gromet_obj): + # TODO: Multifiles + + for module in gromet_obj.modules: + # first_module = gromet_obj.modules[0] + metadata = [] + # Analyze base FN + print(f"Analyzing {module.name}") + analyze_fn_wiring(module.fn, module.metadata_collection) + + # Analyze the rest of the FN_array + for fn in module.fn_array: + analyze_fn_wiring(fn, module.metadata_collection) + +def get_args(): + parser = argparse.ArgumentParser( + "Attempts to analyize GroMEt JSON for issues" + ) + parser.add_argument( + "gromet_file_path", + help="input GroMEt JSON file" + ) + + options = parser.parse_args() + return options + +if __name__ == "__main__": + args = get_args() + gromet_obj = json2gromet.json_to_gromet(args.gromet_file_path) + + wiring_analyzer(gromet_obj) + + + diff --git a/skema/program_analysis/tests/test_wiring_diagnosis.py b/skema/program_analysis/tests/test_wiring_diagnosis.py new file mode 100644 index 00000000000..0f23252ed06 --- /dev/null +++ b/skema/program_analysis/tests/test_wiring_diagnosis.py @@ -0,0 +1,28 @@ +from skema.program_analysis.gromet_wire_diagnosis import check_wire +from skema.gromet.fn import GrometWire + + +def test_correct_wire(): + correct_wire = GrometWire(src=1, tgt=1) + result = check_wire(correct_wire, 1, 1, "wff") + assert not result + + correct_wire = GrometWire(src=3, tgt=4) + result = check_wire(correct_wire, 4, 5, "wlc") + assert not result + + correct_wire = GrometWire(src=2, tgt=1) + result = check_wire(correct_wire, 2, 1, "wff") + +def test_wrong_wire(): + wrong_wire = GrometWire(src=0, tgt=-1) + result = check_wire(wrong_wire, 1, 1, "wff") + assert result + + wrong_wire = GrometWire(src=20, tgt=2) + result = check_wire(wrong_wire, 19, 2, "wff") + assert result + + wrong_wire = GrometWire(src=-1, tgt=2) + result = check_wire(wrong_wire, 1, 1, "wlc") + assert result From 594e12ac4f6fcc453c589ae73f2941b9bdf48f90 Mon Sep 17 00:00:00 2001 From: Gus Hahn-Powell Date: Fri, 22 Dec 2023 17:01:29 -0700 Subject: [PATCH 13/22] [REST] Refactor for improved consistency (#741) ## Summary of Changes - Explicitly return a status code of 200 for calls to /version. - Support head calls to /version. - Changed ISA PUT request to a POST - Renamed file1 and file2 to `mml1` and `mml2` (in the REST API, these will **always** be "file" contents (i.e., MML). - Added ISA test case data to `isa/data.py` for ease of use between ISA tests and API docs. - Refactor status codes to use constants for better consistency. - Temporarily disable ISA router (deployment debugging) --- skema/img2mml/eqn2mml.py | 10 +++--- skema/isa/data.py | 44 ++++++++++++++++++++++++ skema/isa/isa_service.py | 48 ++++++++++++++++++++------ skema/isa/lib.py | 10 +++--- skema/rest/api.py | 30 ++++++++++++++--- skema/rest/tests/test_isa.py | 52 ++++------------------------- skema/skema_py/server.py | 13 +++++--- skema/skema_py/tests/test_server.py | 4 +-- 8 files changed, 135 insertions(+), 76 deletions(-) create mode 100644 skema/isa/data.py diff --git a/skema/img2mml/eqn2mml.py b/skema/img2mml/eqn2mml.py index 9c7e3260d0b..e523f4521e5 100644 --- a/skema/img2mml/eqn2mml.py +++ b/skema/img2mml/eqn2mml.py @@ -7,7 +7,7 @@ from typing import Text from typing_extensions import Annotated -from fastapi import APIRouter, FastAPI, Response, Request, Query, UploadFile +from fastapi import APIRouter, FastAPI, status, Response, Request, Query, UploadFile from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS from skema.img2mml.api import ( get_mathml_from_bytes, @@ -86,23 +86,23 @@ def process_latex_equation(eqn: Text) -> Response: "/img2mml/healthcheck", summary="Check health of eqn2mml service", response_model=int, - status_code=200, + status_code=status.HTTP_200_OK, ) def img2mml_healthcheck() -> int: - return 200 + return status.HTTP_200_OK @router.get( "/latex2mml/healthcheck", summary="Check health of mathjax service", response_model=int, - status_code=200, + status_code=status.HTTP_200_OK, ) def latex2mml_healthcheck() -> int: try: return int(requests.get(f"{SKEMA_MATHJAX_ADDRESS}/healthcheck").status_code) except: - return 500 + return status.HTTP_500_INTERNAL_SERVER_ERROR @router.post("/image/mml", summary="Get MathML representation of an equation image") diff --git a/skema/isa/data.py b/skema/isa/data.py new file mode 100644 index 00000000000..8e5ecbc1334 --- /dev/null +++ b/skema/isa/data.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +mml = """ + + + + H + + + + t + + + = + + + ( + Γ + + H + + n + + + 2 + + + | + + H + + | + + n + + 1 + + + + H + ) + + """ + +expected = 'digraph G {\n0 [color=blue, label="Div(Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H))"];\n1 [color=blue, label="D(1, t)(H)"];\n2 [color=blue, label="Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H)"];\n3 [color=blue, label="Γ"];\n4 [color=blue, label="H^(n+2)"];\n5 [color=blue, label="H"];\n6 [color=blue, label="n+2"];\n7 [color=blue, label="n"];\n8 [color=blue, label="2"];\n9 [color=blue, label="Abs(Grad(H))^(n-1)"];\n10 [color=blue, label="Abs(Grad(H))"];\n11 [color=blue, label="Grad(H)"];\n12 [color=blue, label="n-1"];\n13 [color=blue, label="1"];\n1 -> 0 [color=blue, label="="];\n2 -> 0 [color=blue, label="Div"];\n3 -> 2 [color=blue, label="*"];\n4 -> 2 [color=blue, label="*"];\n5 -> 4 [color=blue, label="^"];\n6 -> 4 [color=blue, label="^"];\n7 -> 6 [color=blue, label="+"];\n8 -> 6 [color=blue, label="+"];\n9 -> 2 [color=blue, label="*"];\n10 -> 9 [color=blue, label="^"];\n11 -> 10 [color=blue, label="Abs"];\n5 -> 11 [color=blue, label="Grad"];\n12 -> 9 [color=blue, label="^"];\n7 -> 12 [color=blue, label="+"];\n13 -> 12 [color=blue, label="-"];\n11 -> 2 [color=blue, label="*"];\n}\n' \ No newline at end of file diff --git a/skema/isa/isa_service.py b/skema/isa/isa_service.py index 76a8dbc3c31..a1a5d6d912e 100644 --- a/skema/isa/isa_service.py +++ b/skema/isa/isa_service.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from fastapi import FastAPI, APIRouter +from fastapi import FastAPI, APIRouter, status from skema.isa.lib import align_mathml_eqs +import skema.isa.data as isa_data from pydantic import BaseModel import requests @@ -16,17 +17,38 @@ class ISA_Result(BaseModel): union_graph: str = None -@router.get("/healthcheck", summary="Status of ISA service") +@router.get( + "/healthcheck", + summary="Status of ISA service", + response_model=int, + status_code=status.HTTP_200_OK +) async def healthcheck() -> int: return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code -@router.put("/align-eqns", summary="Align two MathML equations") +@router.post( + "/align-eqns", + summary="Align two MathML equations" +) async def align_eqns( - file1: str, file2: str, mention_json1: str = "", mention_json2: str = "" + mml1: str, mml2: str, mention_json1: str = "", mention_json2: str = "" ) -> ISA_Result: - """ + f""" Endpoint for align two MathML equations. + + ### Python example + + ``` + import requests + + request = {{ + "mml1": {isa_data.mml}, + "mml2": {isa_data.mml} + }} + + response=client.post("/isa/align-eqns", json=request) + res = response.json() """ ( matching_ratio, @@ -37,12 +59,16 @@ async def align_eqns( aligned_indices2, union_graph, perfectly_matched_indices1, - ) = align_mathml_eqs(file1, file2, mention_json1, mention_json2) - ir = ISA_Result() - ir.matching_ratio = matching_ratio - ir.union_graph = union_graph.to_string() - return ir + ) = align_mathml_eqs(mml1, mml2, mention_json1, mention_json2) + return ISA_Result( + matching_ratio = matching_ratio, + union_graph = union_graph.to_string() + ) app = FastAPI() -app.include_router(router) +app.include_router( + router, + prefix="/isa", + tags=["isa"], +) diff --git a/skema/isa/lib.py b/skema/isa/lib.py index 195e0c97532..a6590b8dc92 100644 --- a/skema/isa/lib.py +++ b/skema/isa/lib.py @@ -673,8 +673,8 @@ def check_square_array(arr: np.ndarray) -> List[int]: def align_mathml_eqs( - file1: str = "", - file2: str = "", + mml1: str = "", + mml2: str = "", mention_json1: str = "", mention_json2: str = "", mode: int = 2, @@ -687,7 +687,7 @@ def align_mathml_eqs( [1] Fishkind, D. E., Adali, S., Patsolic, H. G., Meng, L., Singh, D., Lyzinski, V., & Priebe, C. E. (2019). Seeded graph matching. Pattern recognition, 87, 203-215. - Input: the paths of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2; + Input: mml1 & mml2: the file path or contents of the two equation MathMLs; mention_json1: the mention file of paper 1; mention_json1: the mention file of paper 2; mode 0: without considering any priors; mode 1: having a heuristic prior with the similarity of node labels; mode 2: using the variable definitions Output: @@ -700,8 +700,8 @@ def align_mathml_eqs( union_graph: the visualization of the alignment result perfectly_matched_indices1: strictly matched node indices in Graph 1 """ - graph1 = generate_graph(file1) - graph2 = generate_graph(file2) + graph1 = generate_graph(mml1) + graph2 = generate_graph(mml2) amatrix1, node_labels1 = generate_amatrix(graph1) amatrix2, node_labels2 = generate_amatrix(graph2) diff --git a/skema/rest/api.py b/skema/rest/api.py index 09fd461eb9a..69ba06ab95c 100644 --- a/skema/rest/api.py +++ b/skema/rest/api.py @@ -63,15 +63,27 @@ }, { "name": "morae", - "description": "", + "description": "Operations to MORAE.", "externalDocs": { "description": "Issues", "url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AMORAE", }, }, + { + "name": "isa", + "description": "Operations to ISA", + "externalDocs": { + "description": "Issues", + "url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AISA", + }, + }, { "name": "text reading", "description": "Unified proxy and integration code for MIT and SKEMA TR pipelines", + "externalDocs": { + "description": "Issues", + "url": "https://github.com/ml4ai/skema/issues?q=is%3Aopen+is%3Aissue+label%3AText%20Reading", + }, }, { "name": "metal", @@ -146,8 +158,18 @@ tags=["isa"] ) - -@app.get("/version", tags=["core"], summary="API version") +@app.head( + "/version", + tags=["core"], + summary="API version", + status_code=status.HTTP_200_OK +) +@app.get( + "/version", + tags=["core"], + summary="API version", + status_code=status.HTTP_200_OK +) async def version() -> str: return PlainTextResponse(VERSION) @@ -172,7 +194,7 @@ async def healthcheck(response: Response) -> schema.HealthStatus: morae_status = await morae_proxy.healthcheck() mathjax_status = eqn2mml.latex2mml_healthcheck() eqn2mml_status = eqn2mml.img2mml_healthcheck() - code2fn_status = code2fn.ping() + code2fn_status = code2fn.healthcheck() text_reading_status = integrated_text_reading_proxy.healthcheck() metal_status = metal_proxy.healthcheck() # check if any services failing and alter response status code accordingly diff --git a/skema/rest/tests/test_isa.py b/skema/rest/tests/test_isa.py index 7aa36605e42..96ace972810 100644 --- a/skema/rest/tests/test_isa.py +++ b/skema/rest/tests/test_isa.py @@ -2,6 +2,7 @@ from fastapi.testclient import TestClient from skema.isa.isa_service import app +import skema.isa.data as isa_data import pytest client = TestClient(app) @@ -11,58 +12,19 @@ def test_align_eqns(): """Test case for /align-eqns endpoint.""" - halfar_dome_eqn = """ - - - - H - - - - t - - - = - - - ( - Γ - - H - - n - + - 2 - - - | - - H - - | - - n - - 1 - - - - H - ) - - """ + halfar_dome_eqn = isa_data.mml mention_json1_content = "" mention_json2_content = "" data = { - "file1": halfar_dome_eqn, - "file2": halfar_dome_eqn, + "mml1": halfar_dome_eqn, + "mml2": halfar_dome_eqn, "mention_json1": mention_json1_content, "mention_json2": mention_json2_content, } - endpoint = "/align-eqns" - response = client.put(endpoint, params=data) - expected = 'digraph G {\n0 [color=blue, label="Div(Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H))"];\n1 [color=blue, label="D(1, t)(H)"];\n2 [color=blue, label="Γ*(H^(n+2))*(Abs(Grad(H))^(n-1))*Grad(H)"];\n3 [color=blue, label="Γ"];\n4 [color=blue, label="H^(n+2)"];\n5 [color=blue, label="H"];\n6 [color=blue, label="n+2"];\n7 [color=blue, label="n"];\n8 [color=blue, label="2"];\n9 [color=blue, label="Abs(Grad(H))^(n-1)"];\n10 [color=blue, label="Abs(Grad(H))"];\n11 [color=blue, label="Grad(H)"];\n12 [color=blue, label="n-1"];\n13 [color=blue, label="1"];\n1 -> 0 [color=blue, label="="];\n2 -> 0 [color=blue, label="Div"];\n3 -> 2 [color=blue, label="*"];\n4 -> 2 [color=blue, label="*"];\n5 -> 4 [color=blue, label="^"];\n6 -> 4 [color=blue, label="^"];\n7 -> 6 [color=blue, label="+"];\n8 -> 6 [color=blue, label="+"];\n9 -> 2 [color=blue, label="*"];\n10 -> 9 [color=blue, label="^"];\n11 -> 10 [color=blue, label="Abs"];\n5 -> 11 [color=blue, label="Grad"];\n12 -> 9 [color=blue, label="^"];\n7 -> 12 [color=blue, label="+"];\n13 -> 12 [color=blue, label="-"];\n11 -> 2 [color=blue, label="*"];\n}\n' + endpoint = "/isa/align-eqns" + response = client.post(endpoint, params=data) + expected = isa_data.expected # check status code assert ( diff --git a/skema/skema_py/server.py b/skema/skema_py/server.py index 2d70196417c..77ebdb1e755 100644 --- a/skema/skema_py/server.py +++ b/skema/skema_py/server.py @@ -8,7 +8,7 @@ from io import BytesIO from zipfile import ZipFile from urllib.request import urlopen -from fastapi import APIRouter, FastAPI, Body, File, UploadFile +from fastapi import APIRouter, FastAPI, status, Body, File, UploadFile from fastapi.responses import JSONResponse from pydantic import BaseModel, Field @@ -226,9 +226,14 @@ async def system_to_gromet(system: System): router = APIRouter() -@router.get("/ping", summary="Ping endpoint to test health of service") -def ping() -> int: - return 200 +@router.get( + "/healthcheck", + summary="Ping endpoint to test health of service", + status_code=status.HTTP_200_OK, + response_model=int +) +def healthcheck() -> int: + return status.HTTP_200_OK @router.get( diff --git a/skema/skema_py/tests/test_server.py b/skema/skema_py/tests/test_server.py index d0602baf90f..2d62b8b179f 100644 --- a/skema/skema_py/tests/test_server.py +++ b/skema/skema_py/tests/test_server.py @@ -11,9 +11,9 @@ client = TestClient(app) -def test_ping(): +def test_healthcheck(): """Test case for /code2fn/ping endpoint.""" - response = client.get("/code2fn/ping") + response = client.get("/code2fn/healthcheck") assert response.status_code == 200 From 563e3de5213ba654eacc3ee11ad317cce54cdf21 Mon Sep 17 00:00:00 2001 From: Vincent Raymond Date: Tue, 9 Jan 2024 11:02:18 -0500 Subject: [PATCH 14/22] [fortran] CISM coverage updates (#739) ## Summary of Changes ## Tree-Sitter Parsers - Adds a new `--ci` flag to build_parsers.py so that the .so is only copied to site packages when running on CI. When running locally this can cause missing import issues. - Updates Github workflow to add --ci flag to build_parsers call. ## Preprocessor - Updates arguments to GCC invocation to explicitly specify the source language. This prevents GCC from mixing up C style comments and Fortran concatenation (//) ## TS2CAST - Updates generate_cast_body to add a no_op to output List if it is empty. - Updates Loop visitor to use generate_cast_body to prevent Null values from appearing in CAST. ## CAST->GROMET - Adds error handling to find_func_in_module function to prevent it from crashing on Fortran source code. ### Related issues Resolves #700 Resolves #719 --- .github/workflows/tests-and-docs.yml | 2 +- .../CAST/fortran/preprocessor/preprocess.py | 18 +- .../program_analysis/CAST/fortran/ts2cast.py | 221 +++++++++--------- .../CAST/pythonAST/modules_list.py | 7 +- .../tree_sitter_parsers/build_parsers.py | 5 +- 5 files changed, 139 insertions(+), 114 deletions(-) diff --git a/.github/workflows/tests-and-docs.yml b/.github/workflows/tests-and-docs.yml index ba08f5970a3..a338b47ddd8 100644 --- a/.github/workflows/tests-and-docs.yml +++ b/.github/workflows/tests-and-docs.yml @@ -86,7 +86,7 @@ jobs: # Install tree-sitter parser (for Python component unit tests) - name: Install tree-sitter parsers working-directory: . - run: python skema/program_analysis/tree_sitter_parsers/build_parsers.py --all + run: python skema/program_analysis/tree_sitter_parsers/build_parsers.py --ci --all # docs (API) diff --git a/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py b/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py index 5bd6942017f..2e6a608a4a3 100644 --- a/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py +++ b/skema/program_analysis/CAST/fortran/preprocessor/preprocess.py @@ -34,7 +34,7 @@ def preprocess( """ # NOTE: The order of preprocessing steps does matter. We have to run the GCC preprocessor before correcting the continuation lines or there could be issues - # TODO: Create single location for generating include base path + # TODO: Create single location for generating include base path source = source_path.read_text() # Get paths for intermediate products @@ -67,7 +67,7 @@ def preprocess( # Step 2: Correct include directives to remove system references source = fix_include_directives(source) - + # Step 3: Process with gcc c-preprocessor include_base_directory = Path(source_path.parent, f"include_{source_path.stem}") if not include_base_directory.exists(): @@ -75,13 +75,13 @@ def preprocess( source = run_c_preprocessor(source, include_base_directory) if out_gcc: gcc_path.write_text(source) - + # Step 4: Prepare for tree-sitter # This step removes any additional preprocessor directives added or not removed by GCC source = "\n".join( ["!" + line if line.startswith("#") else line for line in source.splitlines()] ) - + # Step 5: Check for unsupported idioms if out_unsupported: unsupported_path.write_text( @@ -173,7 +173,7 @@ def fix_include_directives(source: str) -> str: def run_c_preprocessor(source: str, include_base_path: Path) -> str: """Run the gcc c-preprocessor. Its run from the context of the include_base_path, so that it can find all included files""" result = run( - ["gcc", "-cpp", "-E", "-"], + ["gcc", "-cpp", "-E", "-x", "f95", "-"], input=source, text=True, capture_output=True, @@ -183,8 +183,14 @@ def run_c_preprocessor(source: str, include_base_path: Path) -> str: return result.stdout +def convert_assigned(source: str) -> str: + """Convered ASSIGNED GO TO to COMPUTED GO TO""" + pass + + def convert_to_free_form(source: str) -> str: """If fixed-form Fortran source, convert to free-form""" + def validate_parse_tree(source: str) -> bool: """Parse source with tree-sitter and check if an error is returned.""" language = Language(INSTALLED_LANGUAGES_FILEPATH, "fortran") @@ -204,7 +210,7 @@ def validate_parse_tree(source: str) -> bool: ) if validate_parse_tree(free_source): return free_source - + return source diff --git a/skema/program_analysis/CAST/fortran/ts2cast.py b/skema/program_analysis/CAST/fortran/ts2cast.py index 02b325d54c9..8f26983ed32 100644 --- a/skema/program_analysis/CAST/fortran/ts2cast.py +++ b/skema/program_analysis/CAST/fortran/ts2cast.py @@ -44,7 +44,9 @@ from skema.program_analysis.CAST.fortran.util import generate_dummy_source_refs from skema.program_analysis.CAST.fortran.preprocessor.preprocess import preprocess -from skema.program_analysis.tree_sitter_parsers.build_parsers import INSTALLED_LANGUAGES_FILEPATH +from skema.program_analysis.tree_sitter_parsers.build_parsers import ( + INSTALLED_LANGUAGES_FILEPATH, +) builtin_statements = set( [ @@ -52,43 +54,41 @@ "write_statement", "rewind_statement", "open_statement", - "common_statement", - "print_statement" + "print_statement", ] ) + + class TS2CAST(object): def __init__(self, source_file_path: str): # Prepare source with preprocessor self.path = Path(source_file_path) self.source_file_name = self.path.name self.source = preprocess(self.path) - + # Run tree-sitter on preprocessor output to generate parse tree parser = Parser() - parser.set_language( - Language( - INSTALLED_LANGUAGES_FILEPATH, - "fortran" - ) - ) + parser.set_language(Language(INSTALLED_LANGUAGES_FILEPATH, "fortran")) self.tree = parser.parse(bytes(self.source, "utf8")) self.root_node = remove_comments(self.tree.root_node) - + # Walking data self.variable_context = VariableContext() self.node_helper = NodeHelper(self.source, self.source_file_name) # Start visiting self.out_cast = self.generate_cast() - print(self.out_cast[0].to_json_str()) - + # print(self.out_cast[0].to_json_str()) + def generate_cast(self) -> List[CAST]: - '''Interface for generating CAST.''' + """Interface for generating CAST.""" modules = self.run(self.root_node) - return [CAST([generate_dummy_source_refs(module)], "Fortran") for module in modules] - + return [ + CAST([generate_dummy_source_refs(module)], "Fortran") for module in modules + ] + def run(self, root) -> List[Module]: - '''Top level visitor function. Will return between 1-3 Module objects.''' + """Top level visitor function. Will return between 1-3 Module objects.""" # A program can have between 1-3 modules # 1. A module body # 2. A program body @@ -110,17 +110,18 @@ def run(self, root) -> List[Module]: body.extend(child_cast) elif isinstance(child_cast, AstNode): body.append(child_cast) - modules.append(Module( - name=None, - body=body, - source_refs=[self.node_helper.get_source_ref(root)] - )) - + modules.append( + Module( + name=None, + body=body, + source_refs=[self.node_helper.get_source_ref(root)], + ) + ) return modules def visit(self, node: Node): - if node.type in ["program", "module"] : + if node.type in ["program", "module"]: return self.visit_module(node) elif node.type == "internal_procedures": return self.visit_internal_procedures(node) @@ -138,7 +139,11 @@ def visit(self, node: Node): return self.visit_identifier(node) elif node.type == "name": return self.visit_name(node) - elif node.type in ["unary_expression", "math_expression", "relational_expression"]: + elif node.type in [ + "unary_expression", + "math_expression", + "relational_expression", + ]: return self.visit_math_expression(node) elif node.type in [ "number_literal", @@ -167,9 +172,9 @@ def visit(self, node: Node): return self._visit_passthrough(node) def visit_module(self, node: Node) -> Module: - '''Visitor for program and module statement. Returns a Module object''' + """Visitor for program and module statement. Returns a Module object""" self.variable_context.push_context() - + program_body = [] for child in node.children[1:-1]: # Ignore the start and end program statement child_cast = self.visit(child) @@ -177,17 +182,17 @@ def visit_module(self, node: Node) -> Module: program_body.extend(child_cast) elif isinstance(child_cast, AstNode): program_body.append(child_cast) - + self.variable_context.pop_context() - + return Module( - name=None, #TODO: Fill out name field + name=None, # TODO: Fill out name field body=program_body, - source_refs = [self.node_helper.get_source_ref(node)] + source_refs=[self.node_helper.get_source_ref(node)], ) def visit_internal_procedures(self, node: Node) -> List[FunctionDef]: - '''Visitor for internal procedures. Returns list of FunctionDef''' + """Visitor for internal procedures. Returns list of FunctionDef""" internal_procedures = get_children_by_types(node, ["function", "subroutine"]) return [self.visit(procedure) for procedure in internal_procedures] @@ -227,8 +232,11 @@ def visit_function_def(self, node): self.variable_context.push_context() # Top level statement node - statement_node = get_children_by_types(node, ["subroutine_statement", "function_statement"])[0] - + + statement_node = get_children_by_types( + node, ["subroutine_statement", "function_statement"] + )[0] + name_node = get_first_child_by_type(statement_node, "name") name = self.visit( name_node @@ -253,7 +261,6 @@ def visit_function_def(self, node): ).val self.variable_context.add_return_value(return_value.name) - # NOTE: In the case of a function specifically, if there is no explicit return value, the return value will be the name of the function # TODO: Should this be a node instead if not return_value: @@ -262,7 +269,7 @@ def visit_function_def(self, node): ) return_value = self.visit(name_node) - # If funciton has both an explicit intrinsic type, then we also need to update the type of the return value in the variable context + # If funciton has both an explicit intrinsic type, then we also need to update the type of the return value in the variable context if intrinsic_type: self.variable_context.update_type(return_value.name, intrinsic_type) @@ -276,7 +283,7 @@ def visit_function_def(self, node): self.node_helper.get_identifier(parameter) ) func_args.append(self.visit(parameter)) - + # The first child of function will be the function statement, the rest will be body nodes body = [] for body_node in node.children[1:]: @@ -317,13 +324,20 @@ def visit_function_call(self, node): # A subroutine and function won't neccessarily have an arguments node. # So we should be careful about trying to access it. - - function_node = get_children_by_types(node, ["unary_expression", "subroutine", "identifier", "derived_type_member_expression"])[0] + function_node = get_children_by_types( + node, + [ + "unary_expression", + "subroutine", + "identifier", + "derived_type_member_expression", + ], + )[0] if function_node.type == "derived_type_member_expression": return self.visit_derived_type_member_expression(function_node) - + arguments_node = get_first_child_by_type(node, "argument_list") - + # If this is a unary expression (+foo()) the identifier will be nested. # TODO: If this is a non '+' unary expression, how do we add it to the CAST? if function_node.type == "unary_expression": @@ -362,7 +376,6 @@ def visit_function_call(self, node): source_refs=[self.node_helper.get_source_ref(node)], ) - def visit_keyword_statement(self, node): # NOTE: RETURN is not the only Fortran keyword. GO TO and CONTINUE are also considered keywords. # TODO: Handle GO TO and CONTINUE @@ -371,10 +384,8 @@ def visit_keyword_statement(self, node): if "continue" in identifier or "go to" in identifier: return self._visit_no_op(node) if "exit" in identifier: - return ModelBreak( - source_refs = [self.node_helper.get_source_ref(node)] - ) - + return ModelBreak(source_refs=[self.node_helper.get_source_ref(node)]) + # In Fortran the return statement doesn't return a value (there is the obsolete "alternative return") # We keep track of values that need to be returned in the variable context return_values = self.variable_context.context_return_values[ @@ -386,9 +397,7 @@ def visit_keyword_statement(self, node): elif len(return_values) > 1: value = LiteralValue( value_type="Tuple", - value=[ - self.variable_context.get_node(ret) for ret in return_values - ], + value=[self.variable_context.get_node(ret) for ret in return_values], source_code_data_type=None, source_refs=None, ) @@ -404,7 +413,6 @@ def visit_fortran_builtin_statement(self, node): # All of the node types that fall into this category end with _statment. # So the function name will be the node type with _statement removed (write, read, open, ...) func = self.get_gromet_function_node(node.type.replace("_statement", "")) - arguments = [] @@ -413,7 +421,7 @@ def visit_fortran_builtin_statement(self, node): arguments=arguments, source_language="Fortran", source_language_version=None, - source_refs=[self.node_helper.get_source_ref(node)] + source_refs=[self.node_helper.get_source_ref(node)], ) def visit_print_statement(self, node): @@ -425,7 +433,7 @@ def visit_print_statement(self, node): func=func, arguments=arguments, source_language=None, - source_language_version=None + source_language_version=None, ) def visit_use_statement(self, node): @@ -486,23 +494,16 @@ def visit_do_loop_statement(self, node) -> Loop: (...) ... (body) ... """ - - loop_control_node= get_first_child_by_type(node, "loop_contrel_expression") + loop_control_node = get_first_child_by_type(node, "loop_contrel_expression") if not loop_control_node: return self._visit_while(node) # If there is a loop control expression, the first body node will be the node after the loop_control_expression # It is valid Fortran to have a single itteration do loop as well. # NOTE: This code is for the creation of the main body. The do loop will still add some additional nodes at the end of this body. - body = [] body_start_index = 1 + get_first_child_index(node, "loop_control_expression") - for body_node in node.children[body_start_index:]: - child_cast = self.visit(body_node) - if isinstance(child_cast, List): - body.extend(child_cast) - elif isinstance(child_cast, AstNode): - body.append(child_cast) + body = self.generate_cast_body(node.children[body_start_index:]) # For the init and expression fields, we first need to determine if we are in a regular "do" or a "do while" loop # PRE: @@ -617,10 +618,22 @@ def visit_if_statement(self, node): # (else_clause) # (end_if_statement) - #TODO: Can you have a parenthesized expression as a body node - body_nodes = get_children_except_types(node, ["if", "elseif", "else", "then", "parenthesized_expression", "elseif_clause", "else_clause", "end_if_statement"]) + # TODO: Can you have a parenthesized expression as a body node + body_nodes = get_children_except_types( + node, + [ + "if", + "elseif", + "else", + "then", + "parenthesized_expression", + "elseif_clause", + "else_clause", + "end_if_statement", + ], + ) body = self.generate_cast_body(body_nodes) - + expr_node = get_first_child_by_type(node, "parenthesized_expression") expr = None if expr_node: @@ -628,47 +641,43 @@ def visit_if_statement(self, node): elseif_nodes = get_children_by_types(node, ["elseif_clause"]) elseif_cast = [self.visit(elseif_clause) for elseif_clause in elseif_nodes] - for i in range(len(elseif_cast)-1): - elseif_cast[i].orelse = [elseif_cast[i+1]] - + for i in range(len(elseif_cast) - 1): + elseif_cast[i].orelse = [elseif_cast[i + 1]] + else_node = get_first_child_by_type(node, "else_clause") else_cast = None if else_node: else_cast = self.visit(else_node) - + orelse = [] if len(elseif_cast) > 0: orelse = [elseif_cast[0]] elif else_cast: orelse = else_cast.body - return ModelIf( - expr=expr, - body=body, - orelse=orelse - ) + return ModelIf(expr=expr, body=body, orelse=orelse) def visit_logical_expression(self, node): """Visitior for logical expression (i.e. true and false) which is used in compound conditional""" # If this is a .not. operator, we need to pass it on to the math_expression visitor if len(node.children) < 3: return self.visit_math_expression(node) - + literal_value_false = LiteralValue("Boolean", False) literal_value_true = LiteralValue("Boolean", True) - + # AND: Right side goes in body if, left side in condition - # OR: Right side goes in body else, left side in condition + # OR: Right side goes in body else, left side in condition left, operator, right = node.children - + # First we need to check if this is logical and or a logical or # The tehcnical types for these are \.or\. and \.and\. so to simplify things we can use the in keyword - is_or = "or" in operator.type - + is_or = "or" in operator.type + top_if = ModelIf() top_if_expr = self.visit(left) top_if.expr = top_if_expr - + bottom_if_expr = self.visit(right) if is_or: top_if.orelse = [bottom_if_expr] @@ -771,7 +780,6 @@ def visit_identifier(self, node): default_value=default_value, source_refs=[self.node_helper.get_source_ref(node)], ) - def visit_math_expression(self, node): op = self.node_helper.get_identifier( @@ -780,12 +788,11 @@ def visit_math_expression(self, node): operands = [] for operand in get_non_control_children(node): operands.append(self.visit(operand)) - + # For operators, we will only need the name node since we are not allocating space if operand.type == "identifier": operands[-1] = operands[-1].val - return Operator( source_language="Fortran", interpreter=None, @@ -823,12 +830,14 @@ def visit_variable_declaration(self, node) -> List: "integer": "Integer", "real": "AbstractFloat", "double precision": "AbstractFloat", - "complex": "Tuple", # Complex is a Tuple (rational,irrational), + "complex": "Tuple", # Complex is a Tuple (rational,irrational), "logical": "Boolean", "character": "String", } # NOTE: Identifiers are case sensitive, so we always need to make sure we are comparing to the lower() version - variable_type = type_map[self.node_helper.get_identifier(intrinsic_type_node).lower()] + variable_type = type_map[ + self.node_helper.get_identifier(intrinsic_type_node).lower() + ] elif derived_type_node: variable_type = self.node_helper.get_identifier( get_first_child_by_type(derived_type_node, "type_name", recurse=True), @@ -866,15 +875,11 @@ def visit_variable_declaration(self, node) -> List: ) ), right=self.visit(variable.children[2]), - source_refs=[ - self.node_helper.get_source_ref(variable) - ], + source_refs=[self.node_helper.get_source_ref(variable)], ) ) vars[-1].left.type = "List" - self.variable_context.update_type( - vars[-1].left.val.name, "List" - ) + self.variable_context.update_type(vars[-1].left.val.name, "List") else: # If its a regular assignment, we can update the type normally vars.append(self.visit(variable)) @@ -964,7 +969,7 @@ def visit_derived_type(self, node: Node) -> RecordDef: # If we tell the variable context we are in a record definition, it will append the type name as a prefix to all defined variables. self.variable_context.enter_record_definition(record_name) - # Note: + # Note: funcs = [] derived_type_procedures_node = get_first_child_by_type( node, "derived_type_procedures" @@ -1034,15 +1039,17 @@ def visit_derived_type_member_expression(self, node) -> Attribute: else: # We shouldn't be accessing get_node directly, since it may not exist in the case of an import. # Instead, we should visit the identifier node which will add it to the variable context automatically if it doesn't exist. - value = self.visit(get_first_child_by_type(node, "identifier", recurse=True)) + value = self.visit( + get_first_child_by_type(node, "identifier", recurse=True) + ) # NOTE: Attribue should be a Name node, NOT a string or Var node - #attr = self.node_helper.get_identifier( + # attr = self.node_helper.get_identifier( # get_first_child_by_type(node, "type_member", recurse=True) - #) - #print(self.node_helper.get_identifier(get_first_child_by_type(node, "type_member", recurse=True))) + # ) + # print(self.node_helper.get_identifier(get_first_child_by_type(node, "type_member", recurse=True))) attr = self.visit_name(get_first_child_by_type(node, "type_member")) - + return Attribute( value=value, attr=attr, @@ -1129,12 +1136,14 @@ def _visit_while(self, node) -> Loop: body_start_index = 1 + get_first_child_index(node, "while_statement") # We don't have explicit handling for parenthesized_expression, but the passthrough handler will make sure that we visit the expression correctly. expr = self.visit( - get_first_child_by_type(while_statement_node, "parenthesized_expression") + get_first_child_by_type( + while_statement_node, "parenthesized_expression" + ) ) # The first body node will be the node after the while_statement body = self.generate_cast_body(node.children[body_start_index:]) - + return Loop( pre=[], expr=expr, @@ -1188,9 +1197,9 @@ def _visit_no_op(self, node): func=self.get_gromet_function_node("no_op"), source_language=None, source_language_version=None, - arguments=[] + arguments=[], ) - + def get_gromet_function_node(self, func_name: str) -> Name: # Idealy, we would be able to create a dummy node and just call the name visitor. # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here. @@ -1207,8 +1216,10 @@ def generate_cast_body(self, body_nodes: List): body.append(cast) elif isinstance(cast, List): body.extend(cast) - return body -#TS2CAST("drotmg.f") -#import cProfile -#cProfile.run("TS2CAST('he_coef0_dres.F')", sort="tottime") \ No newline at end of file + # Gromet doesn't support empty bodies, so we should create a no_op instead + if len(body) == 0: + body.append(self._visit_no_op(None)) + + # TODO: How to add more support for source references + return body diff --git a/skema/program_analysis/CAST/pythonAST/modules_list.py b/skema/program_analysis/CAST/pythonAST/modules_list.py index a26cbfcdaf2..04f504a98c9 100644 --- a/skema/program_analysis/CAST/pythonAST/modules_list.py +++ b/skema/program_analysis/CAST/pythonAST/modules_list.py @@ -321,7 +321,12 @@ def find_func_in_module(module_name, func_name): import sys sys.path.append(os.getcwd()) - module_import = importlib.import_module(module_name) + # TODO: Support find_func_in_module for Fortran source code as well + try: + module_import = importlib.import_module(module_name) + except: + return False + funcs = list(dir(module_import)) return func_name in funcs diff --git a/skema/program_analysis/tree_sitter_parsers/build_parsers.py b/skema/program_analysis/tree_sitter_parsers/build_parsers.py index 2c73cbabd95..cf87c8d9bcf 100644 --- a/skema/program_analysis/tree_sitter_parsers/build_parsers.py +++ b/skema/program_analysis/tree_sitter_parsers/build_parsers.py @@ -74,6 +74,7 @@ def copy_to_site_packages(): flag = f"--{language}" help_text = f"Include {language} language" parser.add_argument(flag, action="store_true", help=help_text) + parser.add_argument("--ci", action="store_true", help="Copy to site packages if running on ci") args = parser.parse_args() if args.all: @@ -82,4 +83,6 @@ def copy_to_site_packages(): selected_languages = [language for language, value in vars(args).items() if value] build_parsers(selected_languages) - copy_to_site_packages() + + if args.ci: + copy_to_site_packages() From f2b27d41f2660834d3fc514d6b79189e73b2c818 Mon Sep 17 00:00:00 2001 From: titomeister Date: Fri, 12 Jan 2024 11:32:00 -0700 Subject: [PATCH 15/22] Python tree-sitter to CAST porting: Loops (#745) This PR introduces support for generating CAST for Loops (for/while) using tree-sitter, as part of the ongoing effort to port over the Python AST to CAST generation to using tree-sitter. ### Python Tree Sitter - Added support for generating CAST Loop nodes using Python tree sitter. - Specifically, we added support for Python's For and While loop CAST generation. - Added some support for tree-sitter "patterns: "list_pattern", "tuple_pattern", "list_pattern". These are used primarily in the For loop syntax as the item we're iterating over. - Also added support for generating Iterator function calls, as used by the Python For Loops. - Updated the CAST to AGraph visualizer to better visualize tuples. ### Testing - Added some small unit tests for loops (for/while) to determine consistency. - Added a small unit test for detecting missing identifiers in the first line of a Python program. ### Other Fixes - Fixes an issue with missing identifier names when they appeared on the first line of the Python program. This was done by adding some additional handling that is specific to the first line of the program. Resolves #498 Resolves #740 --- .../CAST/python/node_helper.py | 37 ++- skema/program_analysis/CAST/python/ts2cast.py | 225 +++++++++++++- .../visitors/cast_to_agraph_visitor.py | 5 +- .../tests/test_expression_cast.py | 3 + skema/program_analysis/tests/test_for_cast.py | 288 ++++++++++++++++++ .../program_analysis/tests/test_identifier.py | 39 +++ .../program_analysis/tests/test_while_cast.py | 146 +++++++++ 7 files changed, 734 insertions(+), 9 deletions(-) create mode 100644 skema/program_analysis/tests/test_for_cast.py create mode 100644 skema/program_analysis/tests/test_identifier.py create mode 100644 skema/program_analysis/tests/test_while_cast.py diff --git a/skema/program_analysis/CAST/python/node_helper.py b/skema/program_analysis/CAST/python/node_helper.py index 0c4b4304cd9..5e66f0fb567 100644 --- a/skema/program_analysis/CAST/python/node_helper.py +++ b/skema/program_analysis/CAST/python/node_helper.py @@ -1,3 +1,4 @@ +import itertools from typing import List, Dict from skema.program_analysis.CAST2FN.model.cast import SourceRef @@ -24,6 +25,34 @@ "not" ] +# Whatever constructs we see in the left +# part of the for loop construct +# for LEFT in RIGHT: +FOR_LOOP_LEFT_TYPES = [ + "identifier", + "tuple_pattern", + "pattern_list", + "list_pattern" +] + +# Whatever constructs we see in the right +# part of the for loop construct +# for LEFT in RIGHT: +FOR_LOOP_RIGHT_TYPES = [ + "call", + "identifier", + "list", + "tuple" +] + +# Whatever constructs we see in the conditional +# part of the while loop +WHILE_COND_TYPES = [ + "boolean_operator", + "call", + "comparison_operator" +] + class NodeHelper(): def __init__(self, source: str, source_file_name: str): self.source = source @@ -32,14 +61,16 @@ def __init__(self, source: str, source_file_name: str): # get_identifier optimization variables self.source_lines = source.splitlines(keepends=True) self.line_lengths = [len(line) for line in self.source_lines] - self.line_length_sums = [sum(self.line_lengths[:i+1]) for i in range(len(self.source_lines))] - + self.line_length_sums = [0] + list(itertools.accumulate(self.line_lengths)) + def get_identifier(self, node: Node) -> str: """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point""" start_line, start_column = node.start_point end_line, end_column = node.end_point - start_index = self.line_length_sums[start_line-1] + start_column + # Edge case for when an identifier is on the very first line of the code + # We can't index into the line_length_sums + start_index = self.line_length_sums[start_line] + start_column if start_line == end_line: end_index = start_index + (end_column-start_column) else: diff --git a/skema/program_analysis/CAST/python/ts2cast.py b/skema/program_analysis/CAST/python/ts2cast.py index b498166b217..cf51f6b6e55 100644 --- a/skema/program_analysis/CAST/python/ts2cast.py +++ b/skema/program_analysis/CAST/python/ts2cast.py @@ -25,7 +25,8 @@ ModelIf, RecordDef, Attribute, - ScalarType + ScalarType, + StructureType ) from skema.program_analysis.CAST.python.node_helper import ( @@ -35,7 +36,10 @@ get_first_child_index, get_last_child_index, get_control_children, - get_non_control_children + get_non_control_children, + FOR_LOOP_LEFT_TYPES, + FOR_LOOP_RIGHT_TYPES, + WHILE_COND_TYPES ) from skema.program_analysis.CAST.python.util import ( generate_dummy_source_refs, @@ -71,6 +75,9 @@ def __init__(self, source_file_path: str, from_file = True): ) ) + # Additional variables used in generation + self.var_count = 0 + # Tree walking structures self.variable_context = VariableContext() self.node_helper = NodeHelper(self.source, self.source_file_name) @@ -82,6 +89,7 @@ def __init__(self, source_file_path: str, from_file = True): def generate_cast(self) -> List[CAST]: '''Interface for generating CAST.''' module = self.run(self.tree.root_node) + module.name = self.source_file_name return CAST([generate_dummy_source_refs(module)], "Python") def run(self, root) -> List[Module]: @@ -115,12 +123,18 @@ def visit(self, node: Node): return self.visit_assignment(node) elif node.type == "identifier": return self.visit_identifier(node) - elif node.type =="unary_operator": + elif node.type == "unary_operator": return self.visit_unary_op(node) - elif node.type =="binary_operator": + elif node.type == "binary_operator": return self.visit_binary_op(node) - elif node.type in ["integer"]: + elif node.type in ["integer", "list"]: return self.visit_literal(node) + elif node.type in ["list_pattern", "pattern_list", "tuple_pattern"]: + return self.visit_pattern(node) + elif node.type == "while_statement": + return self.visit_while(node) + elif node.type == "for_statement": + return self.visit_for(node) else: return self._visit_passthrough(node) @@ -224,6 +238,21 @@ def visit_call(self, node: Node) -> Call: elif isinstance(cast, AstNode): func_args.append(cast) + if func_name.val.name == "range": + start_step_value = LiteralValue( + ScalarType.INTEGER, + value="1", + source_code_data_type=["Python", PYTHON_VERSION, str(type(1))], + source_refs=[ref] + ) + # Add a step value + if len(func_args) == 2: + func_args.append(start_step_value) + # Add a start and step value + elif len(func_args) == 1: + func_args.insert(0, start_step_value) + func_args.append(start_step_value) + # Function calls only want the 'Name' part of the 'Var' that the visit returns return Call( func=func_name.val, @@ -371,6 +400,17 @@ def visit_binary_op(self, node: Node) -> Operator: source_refs=[ref] ) + def visit_pattern(self, node: Node): + pattern_cast = [] + for elem in node.children: + cast = self.visit(elem) + if isinstance(cast, List): + pattern_cast.extend(cast) + elif isinstance(cast, AstNode): + pattern_cast.append(cast) + + return LiteralValue(value_type=StructureType.TUPLE, value=pattern_cast) + def visit_identifier(self, node: Node) -> Var: identifier = self.node_helper.get_identifier(node) @@ -417,6 +457,173 @@ def visit_literal(self, node: Node) -> Any: source_code_data_type=["Python", PYTHON_VERSION, str(type(True))], source_refs=[literal_source_ref] ) + elif literal_type == "list": + list_items = [] + for elem in node.children: + cast = self.visit(elem) + if isinstance(cast, List): + list_items.extend(cast) + elif isinstance(cast, AstNode): + list_items.append(cast) + + return LiteralValue( + value_type=StructureType.LIST, + value = list_items, + source_code_data_type=["Python", PYTHON_VERSION, str(type([0]))], + source_refs=[literal_source_ref] + ) + elif literal_type == "tuple": + tuple_items = [] + for elem in node.children: + cast = self.visit(cast) + if isinstance(cast, List): + tuple_items.extend(cast) + elif isinstance(cast, AstNode): + tuple_items.append(cast) + + return LiteralValue( + value_type=StructureType.LIST, + value = tuple_items, + source_code_data_type=["Python", PYTHON_VERSION, str(type((0)))], + source_refs=[literal_source_ref] + ) + + + + def visit_while(self, node: Node) -> Loop: + ref = self.node_helper.get_source_ref(node) + + # Push a variable context since a loop + # can create variables that only it can see + self.variable_context.push_context() + + loop_cond_node = get_children_by_types(node, WHILE_COND_TYPES)[0] + loop_body_node = get_children_by_types(node, "block")[0].children + + loop_cond = self.visit(loop_cond_node) + + loop_body = [] + for node in loop_body_node: + cast = self.visit(node) + if isinstance(cast, List): + loop_body.extend(cast) + elif isinstance(cast, AstNode): + loop_body.append(cast) + + self.variable_context.pop_context() + + return Loop( + pre=[], + expr=loop_cond, + body=loop_body, + post=[], + source_refs = ref + ) + + def visit_for(self, node: Node) -> Loop: + ref = self.node_helper.get_source_ref(node) + + # Pre: left, right + loop_cond_left = get_children_by_types(node, FOR_LOOP_LEFT_TYPES)[0] + loop_cond_right = get_children_by_types(node, FOR_LOOP_RIGHT_TYPES)[-1] + + # Construct pre and expr value using left and right as needed + # need calls to "_Iterator" + + self.variable_context.push_context() + iterator_name = self.variable_context.generate_iterator() + stop_cond_name = self.variable_context.generate_stop_condition() + iter_func = self.get_gromet_function_node("iter") + next_func = self.get_gromet_function_node("next") + + loop_cond_left_cast = self.visit(loop_cond_left) + loop_cond_right_cast = self.visit(loop_cond_right) + + loop_pre = [] + loop_pre.append( + Assignment( + left = Var(iterator_name, "Iterator"), + right = Call( + iter_func, + arguments=[loop_cond_right_cast] + ) + ) + ) + + loop_pre.append( + Assignment( + left=LiteralValue( + "Tuple", + [ + loop_cond_left_cast, + Var(iterator_name, "Iterator"), + Var(stop_cond_name, "Boolean"), + ], + source_code_data_type = ["Python",PYTHON_VERSION,"Tuple"], + source_refs=ref + ), + right=Call( + next_func, + arguments=[Var(iterator_name, "Iterator")], + ), + ) + + ) + + loop_expr = Operator( + source_language="Python", + interpreter="Python", + version=PYTHON_VERSION, + op="ast.Eq", + operands=[ + stop_cond_name, + LiteralValue( + ScalarType.BOOLEAN, + False, + ["Python", PYTHON_VERSION, "boolean"], + source_refs=ref, + ) + ], + source_refs=ref + ) + + loop_body_node = get_children_by_types(node, "block")[0].children + loop_body = [] + for node in loop_body_node: + cast = self.visit(node) + if isinstance(cast, List): + loop_body.extend(cast) + elif isinstance(cast, AstNode): + loop_body.append(cast) + + # Insert an additional call to 'next' at the end of the loop body, + # to facilitate looping in GroMEt + loop_body.append( + Assignment( + left=LiteralValue( + "Tuple", + [ + loop_cond_left_cast, + Var(iterator_name, "Iterator"), + Var(stop_cond_name, "Boolean"), + ], + ), + right=Call( + next_func, + arguments=[Var(iterator_name, "Iterator")], + ), + ) + ) + + self.variable_context.pop_context() + return Loop( + pre=loop_pre, + expr=loop_expr, + body=loop_body, + post=[], + source_refs = ref + ) + def visit_name(self, node): # First, we will check if this name is already defined, and if it is return the name node generated previously @@ -436,6 +643,14 @@ def _visit_passthrough(self, node): child_cast = self.visit(child) if child_cast: return child_cast + + def get_gromet_function_node(self, func_name: str) -> Name: + # Idealy, we would be able to create a dummy node and just call the name visitor. + # However, tree-sitter does not allow you to create or modify nodes, so we have to recreate the logic here. + if self.variable_context.is_variable(func_name): + return self.variable_context.get_node(func_name) + + return self.variable_context.add_variable(func_name, "function", None) def get_name_node(node): # Given a CAST node, if it's type Var, then we extract the name node out of it diff --git a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py index 1c4fe73d0ed..5bffaaa8ef4 100644 --- a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py +++ b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py @@ -592,7 +592,10 @@ def _(self, node: LiteralValue): return node_uid elif node.value_type == StructureType.TUPLE: node_uid = uuid.uuid4() - self.G.add_node(node_uid, label=f"Tuple (...)") + self.G.add_node(node_uid, label=f"Tuple") + tuple_elems = self.visit_list(node.value) + for elem_uid in tuple_elems: + self.G.add_edge(node_uid, elem_uid) return node_uid elif node.value_type == None: node_uid = uuid.uuid4() diff --git a/skema/program_analysis/tests/test_expression_cast.py b/skema/program_analysis/tests/test_expression_cast.py index 5d1d1e72b1a..c0f7459caa9 100644 --- a/skema/program_analysis/tests/test_expression_cast.py +++ b/skema/program_analysis/tests/test_expression_cast.py @@ -66,3 +66,6 @@ def test_exp1(): assert isinstance(asg_node.right, LiteralValue) assert asg_node.right.value_type == "Integer" assert asg_node.right.value == '3' + +if __name__ == "__main__": + cast = generate_cast(exp0()) \ No newline at end of file diff --git a/skema/program_analysis/tests/test_for_cast.py b/skema/program_analysis/tests/test_for_cast.py new file mode 100644 index 00000000000..3693b8bbe75 --- /dev/null +++ b/skema/program_analysis/tests/test_for_cast.py @@ -0,0 +1,288 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Var, + Call, + Name, + LiteralValue, + ModelIf, + Loop, + Operator +) + +def for1(): + return """ +x = 7 +for i in range(10): + x = x + i + """ + +def for2(): + return """ +x = 1 +for a,b in range(10): + x = x + a + b + """ + +def for3(): + return """ +x = 1 +L = [1,2,3] + +for i in L: + x = x + i + """ + + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + +def test_for1(): + cast = generate_cast(for1()) + + asg_node = cast.nodes[0].body[0] + loop_node = cast.nodes[0].body[1] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '7' + + assert isinstance(loop_node, Loop) + assert len(loop_node.pre) == 2 + + # Loop Pre + loop_pre = loop_node.pre + assert isinstance(loop_pre[0], Assignment) + assert isinstance(loop_pre[0].left, Var) + assert loop_pre[0].left.val.name == "generated_iter_0" + + assert isinstance(loop_pre[0].right, Call) + assert loop_pre[0].right.func.name == "iter" + iter_args = loop_pre[0].right.arguments + + assert len(iter_args) == 1 + assert isinstance(iter_args[0], Call) + assert iter_args[0].func.name == "range" + assert len(iter_args[0].arguments) == 3 + + assert isinstance(iter_args[0].arguments[0], LiteralValue) + assert iter_args[0].arguments[0].value == "1" + assert isinstance(iter_args[0].arguments[1], LiteralValue) + assert iter_args[0].arguments[1].value == "10" + assert isinstance(iter_args[0].arguments[2], LiteralValue) + assert iter_args[0].arguments[2].value == "1" + + assert isinstance(loop_pre[1], Assignment) + assert isinstance(loop_pre[1].left, LiteralValue) + assert loop_pre[1].left.value_type == "Tuple" + + assert isinstance(loop_pre[1].left.value[0], Var) + assert loop_pre[1].left.value[0].val.name == "i" + assert isinstance(loop_pre[1].left.value[1], Var) + assert loop_pre[1].left.value[1].val.name == "generated_iter_0" + assert isinstance(loop_pre[1].left.value[2], Var) + assert loop_pre[1].left.value[2].val.name == "sc_0" + + assert isinstance(loop_pre[1].right, Call) + assert loop_pre[1].right.func.name == "next" + assert len(loop_pre[1].right.arguments) == 1 + assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0" + + # Loop Test + loop_test = loop_node.expr + assert isinstance(loop_test, Operator) + assert loop_test.op == "ast.Eq" + assert isinstance(loop_test.operands[0], Name) + assert loop_test.operands[0].name == "sc_0" + + assert isinstance(loop_test.operands[1], LiteralValue) + assert loop_test.operands[1].value_type == "Boolean" + + # Loop Body + loop_body = loop_node.body + next_call = loop_body[-1] + assert isinstance(next_call, Assignment) + assert isinstance(next_call.right, Call) + assert next_call.right.func.name == "next" + assert next_call.right.arguments[0].val.name == "generated_iter_0" + + +def test_for2(): + cast = generate_cast(for2()) + + asg_node = cast.nodes[0].body[0] + loop_node = cast.nodes[0].body[1] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '1' + + assert isinstance(loop_node, Loop) + assert len(loop_node.pre) == 2 + + # Loop Pre + loop_pre = loop_node.pre + assert isinstance(loop_pre[0], Assignment) + assert isinstance(loop_pre[0].left, Var) + assert loop_pre[0].left.val.name == "generated_iter_0" + + assert isinstance(loop_pre[0].right, Call) + assert loop_pre[0].right.func.name == "iter" + iter_args = loop_pre[0].right.arguments + + assert len(iter_args) == 1 + assert isinstance(iter_args[0], Call) + assert iter_args[0].func.name == "range" + assert len(iter_args[0].arguments) == 3 + + assert isinstance(iter_args[0].arguments[0], LiteralValue) + assert iter_args[0].arguments[0].value == "1" + assert isinstance(iter_args[0].arguments[1], LiteralValue) + assert iter_args[0].arguments[1].value == "10" + assert isinstance(iter_args[0].arguments[2], LiteralValue) + assert iter_args[0].arguments[2].value == "1" + + assert isinstance(loop_pre[1], Assignment) + assert isinstance(loop_pre[1].left, LiteralValue) + assert loop_pre[1].left.value_type == "Tuple" + + assert isinstance(loop_pre[1].left.value[0], LiteralValue) + assert loop_pre[1].left.value[0].value_type == "Tuple" + + assert isinstance(loop_pre[1].left.value[0].value[0], Var) + assert loop_pre[1].left.value[0].value[0].val.name == "a" + assert isinstance(loop_pre[1].left.value[0].value[1], Var) + assert loop_pre[1].left.value[0].value[1].val.name == "b" + + assert isinstance(loop_pre[1].left.value[1], Var) + assert loop_pre[1].left.value[1].val.name == "generated_iter_0" + assert isinstance(loop_pre[1].left.value[2], Var) + assert loop_pre[1].left.value[2].val.name == "sc_0" + + assert isinstance(loop_pre[1].right, Call) + assert loop_pre[1].right.func.name == "next" + assert len(loop_pre[1].right.arguments) == 1 + assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0" + + # Loop Test + loop_test = loop_node.expr + assert isinstance(loop_test, Operator) + assert loop_test.op == "ast.Eq" + assert isinstance(loop_test.operands[0], Name) + assert loop_test.operands[0].name == "sc_0" + + assert isinstance(loop_test.operands[1], LiteralValue) + assert loop_test.operands[1].value_type == "Boolean" + + # Loop Body + loop_body = loop_node.body + body_asg = loop_body[0] + assert isinstance(body_asg, Assignment) + + assert isinstance(body_asg.right, Operator) + assert isinstance(body_asg.right.operands[0], Operator) + assert isinstance(body_asg.right.operands[0].operands[0], Name) + assert body_asg.right.operands[0].operands[0].name == "x" + + assert isinstance(body_asg.right.operands[0].operands[1], Name) + assert body_asg.right.operands[0].operands[1].name == "a" + + assert isinstance(body_asg.right.operands[1], Name) + assert body_asg.right.operands[1].name == "b" + + next_call = loop_body[-1] + assert isinstance(next_call, Assignment) + assert isinstance(next_call.right, Call) + assert next_call.right.func.name == "next" + assert next_call.right.arguments[0].val.name == "generated_iter_0" + + +def test_for3(): + cast = generate_cast(for3()) + + asg_node = cast.nodes[0].body[0] + list_node = cast.nodes[0].body[1] + loop_node = cast.nodes[0].body[2] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '1' + + assert isinstance(loop_node, Loop) + assert len(loop_node.pre) == 2 + + assert isinstance(list_node, Assignment) + assert isinstance(list_node.left, Var) + assert list_node.left.val.name == "L" + + assert isinstance(list_node.right, LiteralValue) + assert list_node.right.value_type == "List" + + # Loop Pre + loop_pre = loop_node.pre + assert isinstance(loop_pre[0], Assignment) + assert isinstance(loop_pre[0].left, Var) + assert loop_pre[0].left.val.name == "generated_iter_0" + + assert isinstance(loop_pre[0].right, Call) + assert loop_pre[0].right.func.name == "iter" + iter_args = loop_pre[0].right.arguments + + assert len(iter_args) == 1 + assert isinstance(iter_args[0], Var) + assert iter_args[0].val.name == "L" + + assert isinstance(loop_pre[1], Assignment) + assert isinstance(loop_pre[1].left, LiteralValue) + assert loop_pre[1].left.value_type == "Tuple" + + assert isinstance(loop_pre[1].left.value[0], Var) + assert loop_pre[1].left.value[0].val.name == "i" + assert isinstance(loop_pre[1].left.value[1], Var) + assert loop_pre[1].left.value[1].val.name == "generated_iter_0" + assert isinstance(loop_pre[1].left.value[2], Var) + assert loop_pre[1].left.value[2].val.name == "sc_0" + + assert isinstance(loop_pre[1].right, Call) + assert loop_pre[1].right.func.name == "next" + assert len(loop_pre[1].right.arguments) == 1 + assert loop_pre[1].right.arguments[0].val.name == "generated_iter_0" + + # Loop Test + loop_test = loop_node.expr + assert isinstance(loop_test, Operator) + assert loop_test.op == "ast.Eq" + assert isinstance(loop_test.operands[0], Name) + assert loop_test.operands[0].name == "sc_0" + + assert isinstance(loop_test.operands[1], LiteralValue) + assert loop_test.operands[1].value_type == "Boolean" + + # Loop Body + loop_body = loop_node.body + next_call = loop_body[-1] + + assert isinstance(next_call, Assignment) + assert isinstance(next_call.right, Call) + assert next_call.right.func.name == "next" + assert next_call.right.arguments[0].val.name == "generated_iter_0" diff --git a/skema/program_analysis/tests/test_identifier.py b/skema/program_analysis/tests/test_identifier.py new file mode 100644 index 00000000000..46e976172b9 --- /dev/null +++ b/skema/program_analysis/tests/test_identifier.py @@ -0,0 +1,39 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Var, + Call, + Name, + LiteralValue, + ModelIf, + Loop, + Operator +) + +def identifier1(): + return """x = 2""" + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + + +# Tests to make sure that identifiers are correctly being generated +def test_identifier1(): + cast = generate_cast(identifier1()) + + asg_node = cast.nodes[0].body[0] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '2' + diff --git a/skema/program_analysis/tests/test_while_cast.py b/skema/program_analysis/tests/test_while_cast.py new file mode 100644 index 00000000000..5d1f2613175 --- /dev/null +++ b/skema/program_analysis/tests/test_while_cast.py @@ -0,0 +1,146 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.CAST.python.ts2cast import TS2CAST +from skema.program_analysis.CAST2FN.model.cast import ( + Assignment, + Var, + Call, + Name, + LiteralValue, + ModelIf, + Loop, + Operator +) + +def while1(): + return """ +x = 2 +while x < 5: + x = x + 1 + """ + +def while2(): + return """ +x = 2 +y = 3 + +while x < 5: + x = x + 1 + x = x + y + """ + +def generate_cast(test_file_string): + # use Python to CAST + out_cast = TS2CAST(test_file_string, from_file=False).out_cast + + return out_cast + +def test_while1(): + cast = generate_cast(while1()) + + asg_node = cast.nodes[0].body[0] + loop_node = cast.nodes[0].body[1] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '2' + + assert isinstance(loop_node, Loop) + assert len(loop_node.pre) == 0 + + # Loop Test + loop_test = loop_node.expr + assert isinstance(loop_test, Operator) + assert loop_test.op == "ast.Lt" + assert isinstance(loop_test.operands[0], Name) + assert loop_test.operands[0].name == "x" + + assert isinstance(loop_test.operands[1], LiteralValue) + assert loop_test.operands[1].value_type == "Integer" + assert loop_test.operands[1].value == "5" + + # Loop Body + loop_body = loop_node.body + asg = loop_body[0] + assert isinstance(asg, Assignment) + assert isinstance(asg.left, Var) + assert asg.left.val.name == "x" + + assert isinstance(asg.right, Operator) + assert asg.right.op == "ast.Add" + assert isinstance(asg.right.operands[0], Name) + assert isinstance(asg.right.operands[1], LiteralValue) + assert asg.right.operands[1].value == "1" + +def test_while2(): + cast = generate_cast(while2()) + + asg_node = cast.nodes[0].body[0] + asg_node_2 = cast.nodes[0].body[1] + loop_node = cast.nodes[0].body[2] + + assert isinstance(asg_node, Assignment) + assert isinstance(asg_node.left, Var) + assert isinstance(asg_node.left.val, Name) + assert asg_node.left.val.name == "x" + + assert isinstance(asg_node.right, LiteralValue) + assert asg_node.right.value_type == "Integer" + assert asg_node.right.value == '2' + + assert isinstance(asg_node_2, Assignment) + assert isinstance(asg_node_2.left, Var) + assert isinstance(asg_node_2.left.val, Name) + assert asg_node_2.left.val.name == "y" + + assert isinstance(asg_node_2.right, LiteralValue) + assert asg_node_2.right.value_type == "Integer" + assert asg_node_2.right.value == '3' + + assert isinstance(loop_node, Loop) + assert len(loop_node.pre) == 0 + + # Loop Test + loop_test = loop_node.expr + assert isinstance(loop_test, Operator) + assert loop_test.op == "ast.Lt" + assert isinstance(loop_test.operands[0], Name) + assert loop_test.operands[0].name == "x" + + assert isinstance(loop_test.operands[1], LiteralValue) + assert loop_test.operands[1].value_type == "Integer" + assert loop_test.operands[1].value == "5" + + # Loop Body + loop_body = loop_node.body + asg = loop_body[0] + assert isinstance(asg, Assignment) + assert isinstance(asg.left, Var) + assert asg.left.val.name == "x" + + assert isinstance(asg.right, Operator) + assert asg.right.op == "ast.Add" + assert isinstance(asg.right.operands[0], Name) + assert asg.right.operands[0].name == "x" + + assert isinstance(asg.right.operands[1], LiteralValue) + assert asg.right.operands[1].value == "1" + + asg = loop_body[1] + assert isinstance(asg, Assignment) + assert isinstance(asg.left, Var) + assert asg.left.val.name == "x" + + assert isinstance(asg.right, Operator) + assert asg.right.op == "ast.Add" + assert isinstance(asg.right.operands[0], Name) + assert asg.right.operands[0].name == "x" + + assert isinstance(asg.right.operands[1], Name) + assert asg.right.operands[1].name == "y" + From 7c2932cf0e545516c299d0df1fa6bd9e339432e4 Mon Sep 17 00:00:00 2001 From: Deepsana Shahi Date: Fri, 12 Jan 2024 13:59:54 -0700 Subject: [PATCH 16/22] Parsers to handle additional mathematical expressions (#751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary of Changes Parser can parse: - `/` operator in equation - change in a variable, for example Δx - Handles Summation operation (Munderover only) - Explicit handling of dot product. E.g. f⋅u - Explicit handling of cross product. E.g. f×u - Handles Msub function of components. E.g. Q_i ( t_{i-1}, s_{i-1} ) identifies ( t_{i-1}, s_{i-1} ) as identifiers. Additionally handles e.g. ... in the equation E.g. Q_i(s_{i-1}, T_{i-1}, ... ) - Handles first order partial derivative. E.g ∂_{t} S - Handles hat operation. E.g. r \hat{x} - Handles Gradient of multiple components. e.g.∇(f*(g+h)) - Handles gradient subscript e.g. ∇_{h} as operator - Handlles vector identity notation E.g. (v ⋅ ∇) u. ### Related issues Resolves ??? --------- Co-authored-by: Deepsana Shahi Co-authored-by: Justin --- skema/skema-rs/mathml/src/ast.rs | 24 + skema/skema-rs/mathml/src/ast/operator.rs | 42 +- .../mathml/src/parsers/generic_mathml.rs | 18 + .../mathml/src/parsers/interpreted_mathml.rs | 370 ++++++++++-- .../src/parsers/math_expression_tree.rs | 531 ++++++++++++++++-- 5 files changed, 904 insertions(+), 81 deletions(-) diff --git a/skema/skema-rs/mathml/src/ast.rs b/skema/skema-rs/mathml/src/ast.rs index ee058f7e231..e9718967625 100644 --- a/skema/skema-rs/mathml/src/ast.rs +++ b/skema/skema-rs/mathml/src/ast.rs @@ -4,6 +4,7 @@ use std::fmt; pub mod operator; use operator::Operator; +//use crate::ast::MathExpression::SummationOp; #[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] pub struct Mi(pub String); @@ -40,6 +41,19 @@ pub struct Differential { pub func: Box, } +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +pub struct SummationMath { + pub op: Box, + pub func: Box, +} + +/// Hat operation +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +pub struct HatComp { + pub op: Box, + pub comp: Box, +} + /// The MathExpression enum is not faithful to the corresponding element type in MathML 3 /// (https://www.w3.org/TR/MathML3/appendixa.html#parsing_MathExpression) #[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Hash, Default, new)] @@ -66,8 +80,10 @@ pub enum MathExpression { //GroupTuple(Vec), Ci(Ci), Differential(Differential), + SummationMath(SummationMath), AbsoluteSup(Box, Box), Absolute(Box, Box), + HatComp(HatComp), //Differential(Box, Box), #[default] None, @@ -110,6 +126,14 @@ impl fmt::Display for MathExpression { write!(f, "{superscript:?}") } MathExpression::Mtext(text) => write!(f, "{}", text), + MathExpression::SummationMath(SummationMath { op, func }) => { + write!(f, "{op}")?; + write!(f, "{func}") + } + MathExpression::HatComp(HatComp { op, comp }) => { + write!(f, "{op}")?; + write!(f, "{comp}") + } expression => write!(f, "{expression:?}"), } } diff --git a/skema/skema-rs/mathml/src/ast/operator.rs b/skema/skema-rs/mathml/src/ast/operator.rs index f1e8dcb8eba..3e406261073 100644 --- a/skema/skema-rs/mathml/src/ast/operator.rs +++ b/skema/skema-rs/mathml/src/ast/operator.rs @@ -1,4 +1,5 @@ use crate::ast::Ci; +use crate::ast::MathExpression; use derive_new::new; use std::fmt; @@ -9,6 +10,7 @@ pub struct Derivative { pub var_index: u8, pub bound_var: Ci, } + /// Partial derivative operator #[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] pub struct PartialDerivative { @@ -17,6 +19,26 @@ pub struct PartialDerivative { pub bound_var: Ci, } +/// Summation operator with under and over components +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +pub struct SumUnderOver { + pub op: Box, + pub under: Box, + pub over: Box, +} + +/// Hat operation +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +pub struct HatOp { + pub comp: Box, +} + +/// Handles grad operations with subscript. E.g. ∇_{x} +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +pub struct GradSub { + pub sub: Box, +} + #[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] pub enum Operator { Add, @@ -33,6 +55,7 @@ pub enum Operator { Power, Comma, Grad, + GradSub(GradSub), Dot, Period, Div, @@ -52,6 +75,11 @@ pub enum Operator { Arccsc, Arccot, Mean, + Sum, + SumUnderOver(SumUnderOver), + Cross, + Hat, + HatOp(HatOp), // Catchall for operators we haven't explicitly defined as enum variants yet. Other(String), } @@ -68,7 +96,7 @@ impl fmt::Display for Operator { Operator::Lparen => write!(f, "("), Operator::Rparen => write!(f, ")"), Operator::Compose => write!(f, "."), - Operator::Comma => write!(f, ""), + Operator::Comma => write!(f, ","), Operator::Factorial => write!(f, "!"), Operator::Derivative(Derivative { order, @@ -101,10 +129,20 @@ impl fmt::Display for Operator { Operator::Arccot => write!(f, "Arccot"), Operator::Mean => write!(f, "Mean"), Operator::Grad => write!(f, "Grad"), - Operator::Dot => write!(f, "Dot"), + Operator::GradSub(GradSub {sub}) =>{ + write!(f, "Grad_{sub})") + } + Operator::Dot => write!(f, "⋅"), Operator::Period => write!(f, ""), Operator::Div => write!(f, "Div"), Operator::Abs => write!(f, "Abs"), + Operator::Sum => write!(f, "∑"), + Operator::SumUnderOver(SumUnderOver { op, under, over }) => { + write!(f, "{op}_{{{under}}}^{{{over}}}") + } + Operator::Cross => write!(f, "×"), + Operator::Hat => write!(f, "Hat"), + Operator::HatOp(HatOp { comp }) => write!(f, "Hat({comp})"), } } } diff --git a/skema/skema-rs/mathml/src/parsers/generic_mathml.rs b/skema/skema-rs/mathml/src/parsers/generic_mathml.rs index d9becd0b8b3..0330ae2cc48 100644 --- a/skema/skema-rs/mathml/src/parsers/generic_mathml.rs +++ b/skema/skema-rs/mathml/src/parsers/generic_mathml.rs @@ -198,6 +198,11 @@ pub fn multiply(input: Span) -> IResult { Ok((s, op)) } +pub fn divide(input: Span) -> IResult { + let (s, op) = value(Operator::Divide, alt((ws(tag("∕")), ws(tag("∕")))))(input)?; + Ok((s, op)) +} + pub fn equals(input: Span) -> IResult { let (s, op) = value(Operator::Equals, ws(tag("=")))(input)?; Ok((s, op)) @@ -228,6 +233,11 @@ pub fn mean(input: Span) -> IResult { Ok((s, op)) } +pub fn hat(input: Span) -> IResult { + let (s, op) = value(Operator::Hat, alt((ws(tag("^")), ws(tag("^")))))(input)?; + Ok((s, op)) +} + pub fn grad(input: Span) -> IResult { let (s, op) = value(Operator::Grad, alt((ws(tag("∇")), ws(tag("∇")))))(input)?; Ok((s, op)) @@ -237,6 +247,11 @@ pub fn dot(input: Span) -> IResult { Ok((s, op)) } +pub fn cross(input: Span) -> IResult { + let (s, op) = value(Operator::Cross, alt((ws(tag("×")), ws(tag("×")))))(input)?; + Ok((s, op)) +} + fn operator_other(input: Span) -> IResult { let (s, consumed) = ws(recognize(not_line_ending))(input)?; let op = Operator::Other(consumed.to_string()); @@ -251,9 +266,12 @@ pub fn operator(input: Span) -> IResult { lparen, rparen, mean, + hat, multiply, + divide, grad, dot, + cross, period, operator_other, ))(input)?; diff --git a/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs b/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs index 7a6fd5a3dee..1d642da5a19 100644 --- a/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs +++ b/skema/skema-rs/mathml/src/parsers/interpreted_mathml.rs @@ -6,16 +6,17 @@ use crate::{ ast::{ - operator::{Derivative, Operator, PartialDerivative}, - Ci, Differential, Math, MathExpression, Mi, Mrow, Type, + operator::{Derivative, GradSub, HatOp, Operator, PartialDerivative, SumUnderOver}, + Ci, Differential, HatComp, Math, MathExpression, Mi, Mrow, SummationMath, Type, }, parsers::generic_mathml::{ - add, attribute, dot, elem_many0, equals, etag, grad, lparen, mean, mi, mn, msub, msubsup, - mtext, multiply, rparen, stag, subtract, tag_parser, ws, xml_declaration, IResult, - ParseError, Span, + add, attribute, cross, divide, dot, elem_many0, equals, etag, grad, hat, lparen, mean, mi, + mn, msub, msubsup, mtext, multiply, rparen, stag, subtract, tag_parser, ws, + xml_declaration, IResult, ParseError, Span, }, }; +use nom::sequence::terminated; use nom::{ branch::alt, bytes::complete::tag, @@ -29,7 +30,9 @@ use nom::{ pub fn operator(input: Span) -> IResult { let (s, op) = ws(delimited( stag!("mo"), - alt((add, subtract, multiply, equals, lparen, rparen, mean, dot)), + alt(( + add, subtract, multiply, divide, equals, lparen, rparen, mean, dot, cross, hat, + )), etag!("mo"), ))(input)?; Ok((s, op)) @@ -45,6 +48,27 @@ fn parenthesized_identifier(input: Span) -> IResult> { Ok((s, bound_vars)) } +/// Parses function of identifiers +/// Example: Q_i ( t_{i-1}, s_{i-1} ) identifies ( t_{i-1}, s_{i-1} ) as identifiers. +fn parenthesized_msub_identifier(input: Span) -> IResult> { + let mo_lparen = delimited(stag!("mo"), lparen, etag!("mo")); + let mo_rparen = delimited(stag!("mo"), rparen, etag!("mo")); + let mo_mrow_lparen = delimited(tag(""), lparen, tag("")); + let mo_mrow_rparen = delimited(tag(""), rparen, tag("")); + let mo_comma = delimited(stag!("mo"), ws(tag(",")), etag!("mo")); + let (s, bound_vars) = delimited( + alt((mo_lparen, mo_mrow_lparen)), + separated_list1(mo_comma, alt((msub, multiple_dots))), + alt((mo_mrow_rparen, mo_rparen)), + )(input)?; + let mut mi_func_of: Vec = Vec::new(); + for bvar in bound_vars { + let b = Mi(bvar.to_string()); + mi_func_of.push(b.clone()); + } + Ok((s, mi_func_of)) +} + /// Parse empty univariate function. /// Example: S fn empty_parenthesis(input: Span) -> IResult> { @@ -55,7 +79,10 @@ fn empty_parenthesis(input: Span) -> IResult> { /// Parse content identifiers corresponding to univariate functions. /// Example: S(t) pub fn ci_univariate_with_bounds(input: Span) -> IResult { - let (s, (Mi(x), bound_vars)) = tuple((mi, parenthesized_identifier))(input)?; + let (s, (Mi(x), bound_vars)) = tuple(( + mi, + alt((parenthesized_identifier, parenthesized_msub_identifier)), + ))(input)?; let mut ci_func_of: Vec = Vec::new(); for bvar in bound_vars { let b = Ci::new(Some(Type::Real), Box::new(MathExpression::Mi(bvar)), None); @@ -88,8 +115,14 @@ pub fn ci_univariate_without_bounds(input: Span) -> IResult { /// Parse identifiers corresponding to univariate functions for ordinary derivatives /// such that it can identify content identifiers with and without parenthesis identifiers. pub fn ci_univariate_func(input: Span) -> IResult { - let (s, (x, bound_vars)) = - tuple((mi, alt((parenthesized_identifier, empty_parenthesis))))(input)?; + let (s, (x, bound_vars)) = tuple(( + mi, + alt(( + parenthesized_identifier, + empty_parenthesis, + parenthesized_msub_identifier, + )), + ))(input)?; let mut ci_func_of: Vec = Vec::new(); for bvar in bound_vars { let b = Ci::new(Some(Type::Real), Box::new(MathExpression::Mi(bvar)), None); @@ -114,8 +147,14 @@ pub fn ci_subscript(input: Span) -> IResult { /// Parse contest identifier for Msub corresponding to univariate functions for ordinary /// derivatives pub fn ci_subscript_func(input: Span) -> IResult { - let (s, (x, bound_vars)) = - tuple((msub, alt((parenthesized_identifier, empty_parenthesis))))(input)?; + let (s, (x, bound_vars)) = tuple(( + msub, + alt(( + parenthesized_msub_identifier, + parenthesized_identifier, + empty_parenthesis, + )), + ))(input)?; let mut ci_func_of: Vec = Vec::new(); for bvar in bound_vars { let b = Ci::new(Some(Type::Real), Box::new(MathExpression::Mi(bvar)), None); @@ -139,10 +178,34 @@ pub fn superscript(input: Span) -> IResult { /// Parse Mover pub fn over_term(input: Span) -> IResult { let (s, over) = ws(map( - tag_parser!("mover", pair(math_expression, math_expression)), + alt(( + tag_parser!("mover", pair(math_expression, math_expression)), + delimited( + tag(""), + pair(math_expression, math_expression), + tag(""), + ), + )), |(x, y)| MathExpression::Mover(Box::new(x), Box::new(y)), ))(input)?; - Ok((s, over)) + if let MathExpression::Mover(ref x, ref y) = over.clone() { + if MathExpression::Mo(Operator::Hat) == **y { + let new_op = Operator::HatOp(HatOp::new(x.clone())); + return Ok((s, MathExpression::Mo(new_op))); + } else { + return Ok((s, over)); + } + } + Err(nom::Err::Error(ParseError::new( + "Unable to obtain Mover term".to_string(), + input, + ))) +} + +/// Parse Hat operator with components. Example: r \hat{x} +pub fn hat_operator(input: Span) -> IResult<(MathExpression, MathExpression)> { + let (s, (comp, op)) = pair(mi, over_term)(input)?; + Ok((s, (op, MathExpression::Mi(comp)))) } /// Parse the identifier 'd' @@ -347,6 +410,76 @@ pub fn first_order_derivative_leibniz_notation(input: Span) -> IResult<(Derivati ))) } +/// Parse first order partial derivative. Example: ∂_{t} S +pub fn first_order_partial_derivative_partial_func( + input: Span, +) -> IResult<(PartialDerivative, Ci)> { + let (s, _) = tuple((stag!("msub"), partial))(input)?; + let (s, with_respect_to) = ws(terminated(mi, etag!("msub")))(s)?; + let (s, func) = ws(alt(( + ci_univariate_func, + map( + ci_unknown, + |Ci { + content, func_of, .. + }| { + Ci { + r#type: Some(Type::Function), + content, + func_of, + } + }, + ), + ci_subscript_func, + )))(s)?; + if let Some(ref ci_vec) = func.func_of { + for (indx, bvar) in ci_vec.iter().enumerate() { + if Some(bvar.content.clone()) + == Some(Box::new(MathExpression::Mi(with_respect_to.clone()))) + { + return Ok(( + s, + ( + PartialDerivative::new( + 1, + (indx + 1) as u8, + Ci::new( + Some(Type::Real), + Box::new(MathExpression::Mi(with_respect_to)), + None, + ), + ), + func, + ), + )); + } else if Some(bvar.content.clone()) + == Some(Box::new(MathExpression::Mi(Mi("".to_string())))) + { + return Ok(( + s, + ( + PartialDerivative::new( + 1, + 1, + Ci::new( + Some(Type::Real), + Box::new(MathExpression::Mi(with_respect_to)), + None, + ), + ), + func, + ), + )); + } + } + } + + Err(nom::Err::Error(ParseError::new( + "Unable to match function_of with with_respect_to in ∂_{t} S".to_string(), + input, + ))) +} + /// Parse a first-order partial ordinary derivative written in Leibniz notation. pub fn first_order_partial_derivative_leibniz_notation( input: Span, @@ -539,12 +672,25 @@ pub fn gradient(input: Span) -> IResult { Ok((s, op)) } +/// Gradient sub E.g. ∇_{x} +pub fn gradient_subscript(input: Span) -> IResult { + let (s, _) = tuple((stag!("msub"), gradient))(input)?; + let (s, mi) = ws(terminated(mi, etag!("msub")))(s)?; + let grad_sub = Operator::GradSub(GradSub::new(Box::new(MathExpression::Mi(mi.clone())))); + Ok((s, grad_sub)) +} + pub fn grad_func(input: Span) -> IResult<(Operator, Ci)> { let (s, (op, id)) = ws(pair(gradient, mi))(input)?; let ci = Ci::new(Some(Type::Real), Box::new(MathExpression::Mi(id)), None); Ok((s, (op, ci))) } +pub fn functions_of_grad(input: Span) -> IResult<(Operator, Mrow)> { + let (s, (op, id)) = ws(pair(gradient, map(ws(many0(math_expression)), Mrow)))(input)?; + Ok((s, (op, id))) +} + ///Absolute with Msup value pub fn absolute_with_msup(input: Span) -> IResult { let (s, sup) = ws(map( @@ -592,13 +738,157 @@ pub fn sqrt(input: Span) -> IResult { )) } +/// Parser for change in a variable : +/// Example: Δx +pub fn change_in_variable(input: Span) -> IResult { + let (s, elements) = ws(preceded( + alt((tag("Δ"), tag("Δ"))), + math_expression, + ))(input)?; + let temp_sum = format!("Δ{}", elements); + let change_in_var = Ci::new( + Some(Type::Real), + Box::new(MathExpression::Mi(Mi(temp_sum.to_string()))), + None, + ); + Ok((s, change_in_var)) +} + +/// Parser handles vector identity notation. +/// E.g. (v ⋅ ∇) u +pub fn gradient_with_closed_paren(input: Span) -> IResult> { + let (s, (_lp, (mi, (op, _gg)))) = ws(pair( + tag("("), + pair( + mi, + pair( + ws(delimited(stag!("mo"), dot, etag!("mo"))), + terminated(gradient, tag(")")), + ), + ), + ))(input)?; + let mut expression: Vec = Vec::new(); + expression.push(MathExpression::Mi(mi)); + expression.push(MathExpression::Mo(op)); + let ci = Ci::new( + Some(Type::Real), + Box::new(MathExpression::Mi(Mi("Grad".to_string()))), + None, + ); + expression.push(MathExpression::Ci(ci.clone())); + Ok((s, expression)) +} + +/// Parser handles e.g. `...` in the equation +/// E.g. Q_i(s_{i-1}, T_{i-1}, ... ) +pub fn multiple_dots(input: Span) -> IResult { + let (s, x) = ws(delimited(tag(""), tag("…"), tag("")))(input)?; + let ci = Ci::new( + Some(Type::List), + Box::new(MathExpression::Mi(Mi(x.to_string()))), + None, + ); + Ok((s, MathExpression::Ci(ci))) +} + +/// Handles summation as operator +pub fn munderover_summation(input: Span) -> IResult<(SumUnderOver, Mrow)> { + let (s, (under, over)) = ws(delimited( + alt(( + tag(""), + tag(""), + )), + pair( + ws(delimited( + tag(""), + many0(math_expression), + tag(""), + )), + many0(math_expression), + ), + tag(""), + ))(input)?; + let (s, comps) = many0(math_expression)(s)?; + let sum_operator = Operator::Sum; + let under_comp = MathExpression::Mrow(Mrow(under)); + let over_comp = MathExpression::Mrow(Mrow(over)); + let other_comps = Mrow::new(comps); + let operator = SumUnderOver::new( + Box::new(MathExpression::Mo(sum_operator)), + Box::new(under_comp), + Box::new(over_comp), + ); + Ok((s, (operator, other_comps))) +} + /// Parser for math expressions. This varies from the one in the generic_mathml module, since it /// assumes that expressions such as S(t) are actually univariate functions. pub fn math_expression(input: Span) -> IResult { ws(alt(( - map(div, MathExpression::Mo), - alt((ws(absolute_with_msup), ws(paren_as_msup))), - //sqrt, + alt(( + map(gradient_with_closed_paren, |row| { + MathExpression::Mrow(Mrow(row)) + }), + map(gradient_subscript, MathExpression::Mo), + )), + alt(( + map(div, MathExpression::Mo), + map(hat_operator, |(op, row)| { + MathExpression::HatComp(HatComp { + op: Box::new(op), + comp: Box::new(row), + }) + }), + )), + map( + first_order_partial_derivative_partial_func, + |( + PartialDerivative { + order, + var_index, + bound_var, + }, + Ci { + r#type, + content, + func_of, + }, + )| { + MathExpression::Differential(Differential { + diff: Box::new(MathExpression::Mo(Operator::PartialDerivative( + PartialDerivative { + order, + var_index, + bound_var, + }, + ))), + func: Box::new(MathExpression::Ci(Ci { + r#type, + content, + func_of, + })), + }) + }, + ), + alt(( + ws(absolute_with_msup), + ws(paren_as_msup), + map(change_in_variable, MathExpression::Ci), + //sqrt, + map( + munderover_summation, + |(SumUnderOver { op, under, over }, comp)| { + MathExpression::SummationMath(SummationMath { + op: Box::new(MathExpression::Mo(Operator::SumUnderOver(SumUnderOver { + op, + under, + over, + }))), + func: Box::new(MathExpression::Mrow(comp)), + }) + }, + ), + )), map( grad_func, |( @@ -717,6 +1007,18 @@ pub fn math_expression(input: Span) -> IResult { }) }, ), + map( + ci_subscript_func, + |Ci { + content, func_of, .. + }| { + MathExpression::Ci(Ci { + r#type: Some(Type::Real), + content, + func_of, + }) + }, + ), map( first_order_with_func_in_parenthesis, |( @@ -759,6 +1061,12 @@ pub fn math_expression(input: Span) -> IResult { }) }, ), + map(functions_of_grad, |(op, comp)| { + MathExpression::Differential(Differential { + diff: Box::new(MathExpression::Mo(op)), + func: Box::new(MathExpression::Mrow(comp)), + }) + }), map(ci_univariate_without_bounds, MathExpression::Ci), map(ci_subscript, MathExpression::Ci), map( @@ -780,32 +1088,18 @@ pub fn math_expression(input: Span) -> IResult { func_of: None, }) }), - map( - grad_func, - |( - op, - Ci { - r#type, - content, - func_of, - }, - )| { - MathExpression::Differential(Differential { - diff: Box::new(MathExpression::Mo(op)), - func: Box::new(MathExpression::Ci(Ci { - r#type, - content, - func_of, - })), - }) - }, - ), - alt((absolute, sqrt)), alt(( + absolute, + sqrt, map(operator, MathExpression::Mo), map(gradient, MathExpression::Mo), + mn, + msub, + superscript, + mfrac, + mtext, + over_term, )), - alt((mn, msub, superscript, mfrac, mtext, over_term)), map(mrow, MathExpression::Mrow), msubsup, )))(input) diff --git a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs index c745f4cc396..718c900d45d 100644 --- a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs +++ b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs @@ -3,7 +3,7 @@ use crate::{ ast::{ - operator::{Derivative, Operator, PartialDerivative}, + operator::{Derivative, GradSub, HatOp, Operator, PartialDerivative, SumUnderOver}, Math, MathExpression, Mi, Mrow, }, parsers::interpreted_mathml::interpreted_math, @@ -246,30 +246,30 @@ fn unicode_to_latex(input: &str) -> String { } fn is_unary_operator(op: &Operator) -> bool { - match op { + matches!( + op, Operator::Sqrt - | Operator::Factorial - | Operator::Exp - | Operator::Power - | Operator::Grad - | Operator::Div - | Operator::Abs - | Operator::Derivative(_) - | Operator::Sin - | Operator::Cos - | Operator::Tan - | Operator::Sec - | Operator::Csc - | Operator::Cot - | Operator::Arcsin - | Operator::Arccos - | Operator::Arctan - | Operator::Arcsec - | Operator::Arccsc - | Operator::Arccot - | Operator::Mean => true, - _ => false, - } + | Operator::Factorial + | Operator::Exp + | Operator::Power + | Operator::Grad + | Operator::Div + | Operator::Abs + | Operator::Derivative(_) + | Operator::Sin + | Operator::Cos + | Operator::Tan + | Operator::Sec + | Operator::Csc + | Operator::Cot + | Operator::Arcsin + | Operator::Arccos + | Operator::Arctan + | Operator::Arcsec + | Operator::Arccsc + | Operator::Arccot + | Operator::Mean + ) } // Process parentheses in an expression and update the LaTeX string. @@ -417,6 +417,11 @@ impl MathExpressionTree { Operator::Divide => content_mathml.push_str(""), Operator::Power => content_mathml.push_str(""), Operator::Exp => content_mathml.push_str(""), + Operator::Abs => content_mathml.push_str(""), + Operator::Grad => content_mathml.push_str(""), + Operator::Div => content_mathml.push_str(""), + Operator::Cos => content_mathml.push_str(""), + Operator::Sin => content_mathml.push_str(""), Operator::Derivative(Derivative { order, var_index, @@ -700,7 +705,7 @@ impl MathExpression { tokens.push(MathExpression::Mo(Operator::Lparen)); for element in elements { if let MathExpression::Ci(x) = element { - // Handles cos and sin as operators + // Handles cos, sin, tan as operators if x.content == Box::new(MathExpression::Mi(Mi("cos".to_string()))) { tokens.push(MathExpression::Mo(Operator::Cos)); if let Some(vec) = x.func_of.clone() { @@ -710,6 +715,18 @@ impl MathExpression { } } else if x.content == Box::new(MathExpression::Mi(Mi("sin".to_string()))) { tokens.push(MathExpression::Mo(Operator::Sin)); + if let Some(vec) = x.func_of.clone() { + for v in vec { + tokens.push(MathExpression::Ci(v)); + } + } + } else if x.content == Box::new(MathExpression::Mi(Mi("tan".to_string()))) { + tokens.push(MathExpression::Mo(Operator::Tan)); + if let Some(vec) = x.func_of.clone() { + for v in vec { + tokens.push(MathExpression::Ci(v)); + } + } } else { element.flatten(tokens); } @@ -740,6 +757,7 @@ impl MathExpression { denominator.flatten(tokens); tokens.push(MathExpression::Mo(Operator::Rparen)); } + /// Insert implicit `exponential` and `power` operators MathExpression::Msup(base, superscript) => { if let MathExpression::Ci(x) = &**base { if x.content == Box::new(MathExpression::Mi(Mi("e".to_string()))) { @@ -757,7 +775,6 @@ impl MathExpression { tokens.push(MathExpression::Mo(Operator::Rparen)); } } else { - //tokens.push(MathExpression::Mo(Operator::Lparen)); base.flatten(tokens); tokens.push(MathExpression::Mo(Operator::Power)); tokens.push(MathExpression::Mo(Operator::Lparen)); @@ -801,6 +818,19 @@ impl MathExpression { tokens.push(MathExpression::Mo(Operator::Rparen)); } } + MathExpression::SummationMath(x) => { + tokens.push(MathExpression::Mo(Operator::Lparen)); + x.op.flatten(tokens); + x.func.flatten(tokens); + tokens.push(MathExpression::Mo(Operator::Rparen)); + } + MathExpression::HatComp(x) => { + //tokens.push(MathExpression::Mo(Operator::Lparen)); + x.op.flatten(tokens); + tokens.push(MathExpression::Mo(Operator::Lparen)); + x.comp.flatten(tokens); + tokens.push(MathExpression::Mo(Operator::Rparen)); + } t => tokens.push(t.clone()), } } @@ -1001,13 +1031,17 @@ fn prefix_binding_power(op: &Operator) -> ((), u8) { Operator::Sin => ((), 21), Operator::Tan => ((), 21), Operator::Mean => ((), 25), - Operator::Dot => ((), 25), + Operator::Hat => ((), 25), + //Operator::Cross => ((), 25), Operator::Grad => ((), 25), + Operator::GradSub(GradSub { .. }) => ((), 25), Operator::Derivative(Derivative { .. }) => ((), 25), Operator::PartialDerivative(PartialDerivative { .. }) => ((), 25), Operator::Div => ((), 25), Operator::Abs => ((), 25), Operator::Sqrt => ((), 25), + Operator::SumUnderOver(SumUnderOver { .. }) => ((), 25), + Operator::HatOp(HatOp { .. }) => ((), 25), _ => panic!("Bad operator: {:?}", op), } } @@ -1016,6 +1050,7 @@ fn prefix_binding_power(op: &Operator) -> ((), u8) { fn postfix_binding_power(op: &Operator) -> Option<(u8, ())> { let res = match op { Operator::Factorial => (11, ()), + //Operator::HatOp(HatOp { .. }) => (11, ()), _ => return None, }; Some(res) @@ -1031,6 +1066,9 @@ fn infix_binding_power(op: &Operator) -> Option<(u8, u8)> { Operator::Divide => (9, 10), Operator::Compose => (14, 13), Operator::Power => (16, 15), + Operator::Dot => (18, 17), + Operator::Cross => (18, 17), + //Operator::Comma => (18, 17), Operator::Other(op) => panic!("Unhandled operator: {}!", op), _ => return None, }; @@ -1446,6 +1484,8 @@ fn test_trig_cos() { "; let exp = input.parse::().unwrap(); + let cmml = exp.to_cmml(); + assert_eq!(cmml, "x"); let s_exp = exp.to_string(); assert_eq!(s_exp, "(Cos x)"); } @@ -1589,6 +1629,8 @@ fn test_grad() { "; let exp = input.parse::().unwrap(); + let cmml = exp.to_cmml(); + assert_eq!(cmml, "H"); let s_exp = exp.to_string(); assert_eq!(s_exp, "(Grad H)"); } @@ -1717,6 +1759,8 @@ fn test_divergence() { "; let exp = input.parse::().unwrap(); + let cmml = exp.to_cmml(); + assert_eq!(cmml, "H"); let s_exp = exp.to_string(); assert_eq!(s_exp, "(Div H)"); } @@ -2125,12 +2169,12 @@ fn test_equation_with_mtext() { #[test] fn new_msqrt_test_function() { let input = " - - 4 - a - c - -"; + + 4 + a + c + + "; let exp = input.parse::().unwrap(); let s_exp = exp.to_string(); assert_eq!(s_exp, "(√ (* (* 4 a) c))"); @@ -2176,16 +2220,16 @@ fn new_quadratic_equation() { #[test] fn test_dot_in_derivative() { let input = " - + - S - ˙ + S + ˙ - -( - t - ) -"; + + ( + t + ) + "; let exp = input.parse::().unwrap(); let s_exp = exp.to_string(); assert_eq!(s_exp, "(D(1, t) S)"); @@ -2242,3 +2286,408 @@ fn test_sidarthe_equation() { "(= (D(1, t) S) (* (- S) (+ (+ (+ (* α I) (* β D)) (* γ A)) (* δ R))))" ); } + +#[test] +fn test_heating_rate() { + let input = " + + Q + i + + = + + ( + + T + i + + + + T + + i + + 1 + + + ) + + + + ( + + C + p + + Δ + t + ) + + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(= Q_{i} (/ (- T_{i} T_{i-1}) (* C_{p} Δt)))"); +} + +#[test] +fn test_sum_munderover() { + let input = " + + + + l + = + k + + K + + S + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(∑_{l=k}^{K} S)"); +} + +#[test] +fn test_hydrostatic() { + let input = " + + Φ + k + + = + + Φ + s + + + + R + + + + l + = + k + + K + + + H + + k + l + + + + T + + v + l + + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + //println!("s_exp={:?}", s_exp); + assert_eq!( + s_exp, + "(= Φ_{k} (+ Φ_{s} (* R (∑_{l=k}^{K} (* H_{kl} T_{vl})))))" + ); +} + +#[test] +fn test_temperature_evolution() { + let input = " + + + Δ + + s + i + + + + Δ + t + + + + + C + p + + = + + + ( + + s + i + + + + s + + i + + 1 + + + ) + + + Δ + t + + + + + C + p + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!( + s_exp, + "(= (/ (/ Δs_{i} Δt) C_{p}) (/ (/ (- s_{i} s_{i-1}) Δt) C_{p}))" + ); +} + +#[test] +fn test_cross_product() { + let input = " + f + × + u + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(× f u)"); +} +#[test] +fn test_dot_product() { + let input = " + f + + u + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(⋅ f u)"); +} + +#[test] +fn test_partial_with_msub_t() { + let input = " + + + t + + S + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(PD(1, t) S)"); +} + +#[test] +fn test_dry_static_energy() { + let input = " + + s + i + + = + + s + + i + + 1 + + + + + ( + Δ + t + ) + + Q + i + + + ( + + s + + i + + 1 + + + , + + T + + i + + 1 + + + , + + Φ + + i + + 1 + + + , + + q + + i + + 1 + + + , + + ) + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(= s_{i} (+ s_{i-1} (* Δt Q_{i})))"); +} + +#[test] +fn test_hat_operator() { + let input = " + ζ + + + z + ^ + + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(Hat(z) ζ)"); +} + +#[test] +fn test_vector_invariant_form() { + let input = " + + + t + + u + + + ( + ζ + + + z + ^ + + + + + f + ) + × + u + = + + + + [ + g + ( + h + + + b + ) + + + + 1 + 2 + + u + + u + ] + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(= (+ (PD(1, t) u) (× (+ (Hat(z) ζ) f) u)) (- (Grad (+ (* g (+ h b)) (* (/ 1 2) (⋅ u u))))))"); +} + +#[test] +fn test_mi_dot_gradient() { + let input = " + ( + v + + + ) + u + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(* (⋅ v Grad) u)"); +} + +#[test] +fn test_momentum_conservation() { + let input = " + + + t + + u + = + + ( + v + + + ) + u + + f + × + u + + + + h + + ( + p + + + g + η + ) + + + + τ + + + + F + + u + + + "; + let exp = input.parse::().unwrap(); + let s_exp = exp.to_string(); + assert_eq!(s_exp, "(= (PD(1, t) u) (+ (- (- (- (* (- (⋅ v Grad)) u) (× f u)) (Grad_h) (+ p (* g η)))) (Div τ)) F_{u}))"); +} From bf65aaa9478ace101a2c70bb0802fb0729c05edf Mon Sep 17 00:00:00 2001 From: Joseph Astier Date: Fri, 12 Jan 2024 16:16:30 -0700 Subject: [PATCH 17/22] CAST to GroMEt bugfix (#734) ## Summary of Changes - Adds a validation test to detect GroMEt issues - Added a Character scalar type ### Related issues Resolves #731 --------- Co-authored-by: Joseph Astier Co-authored-by: titomeister --- .../CAST/matlab/matlab_to_cast.py | 30 +++++++++++-------- .../CAST/matlab/tests/utils.py | 17 +++++++++++ .../visitors/cast_to_agraph_visitor.py | 4 +++ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/skema/program_analysis/CAST/matlab/matlab_to_cast.py b/skema/program_analysis/CAST/matlab/matlab_to_cast.py index 22cf8015679..da171083daf 100644 --- a/skema/program_analysis/CAST/matlab/matlab_to_cast.py +++ b/skema/program_analysis/CAST/matlab/matlab_to_cast.py @@ -156,15 +156,16 @@ def visit_assignment(self, node): def visit_boolean(self, node): """ Translate Tree-sitter boolean node """ + value_type = "Boolean" for child in node.children: # set the first letter to upper case for python value = child.type value = value[0].upper() + value[1:].lower() # store as string, use Python Boolean capitalization. return LiteralValue( - value_type="Boolean", + value_type=value_type, value = value, - source_code_data_type=["matlab", MATLAB_VERSION, "boolean"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(node)], ) @@ -394,10 +395,11 @@ def get_values(element, ret)-> List: if len(values) > 0: value = values[0] + value_type="List", return LiteralValue( - value_type="List", + value_type=value_type, value = value, - source_code_data_type=["matlab", MATLAB_VERSION, "matrix"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(node)], ) @@ -437,16 +439,18 @@ def visit_number(self, node) -> LiteralValue: literal_value = self.node_helper.get_identifier(node) # Check if this is a real value, or an Integer if "e" in literal_value.lower() or "." in literal_value: + value_type = "AbstractFloat" return LiteralValue( - value_type="AbstractFloat", + value_type=value_type, value=float(literal_value), - source_code_data_type=["matlab", MATLAB_VERSION, "real"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(node)] ) + value_type = "Integer" return LiteralValue( - value_type="Integer", + value_type=value_type, value=int(literal_value), - source_code_data_type=["matlab", MATLAB_VERSION, "integer"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(node)] ) @@ -465,10 +469,11 @@ def visit_operator(self, node): ) def visit_string(self, node): + value_type = "Character" return LiteralValue( - value_type="Character", + value_type=value_type, value=self.node_helper.get_identifier(node), - source_code_data_type=["matlab", MATLAB_VERSION, "character"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(node)] ) @@ -490,10 +495,11 @@ def get_case_expression(case_node, switch_var): cell_node = get_first_child_by_type(case_node, "cell") # multiple case arguments if (cell_node): + value_type="List", operand = LiteralValue( - value_type="List", + value_type=value_type, value = self.visit(cell_node), - source_code_data_type=["matlab", MATLAB_VERSION, "unknown"], + source_code_data_type=["matlab", MATLAB_VERSION, value_type], source_refs=[self.node_helper.get_source_ref(cell_node)] ) return self.get_operator( diff --git a/skema/program_analysis/CAST/matlab/tests/utils.py b/skema/program_analysis/CAST/matlab/tests/utils.py index 4fc9eba06f3..9375b95df56 100644 --- a/skema/program_analysis/CAST/matlab/tests/utils.py +++ b/skema/program_analysis/CAST/matlab/tests/utils.py @@ -13,6 +13,11 @@ Name, Var ) +from skema.program_analysis.CAST2FN.visitors.cast_to_agraph_visitor import ( + CASTToAGraphVisitor, +) +from skema.program_analysis.CAST2FN.cast import CAST + def check(result, expected = None): """ Test for match with the same datatypes. """ @@ -60,9 +65,21 @@ def cast(source): """ Return the MatlabToCast output """ # there should only be one CAST object in the cast output list cast = MatlabToCast(source = source).out_cast + # the cast should be parsable + # assert validate(cast) == True # there should be one module in the CAST object assert len(cast.nodes) == 1 module = cast.nodes[0] assert isinstance(module, Module) # return the module body node list return module.body + +def validate(cast): + """ Test that the cast can be parsed """ + try: + foo = CASTToAGraphVisitor(cast) + foo.to_pdf("/dev/null") + return True + except: + return False + diff --git a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py index 5bffaaa8ef4..e32ebabf32a 100644 --- a/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py +++ b/skema/program_analysis/CAST2FN/visitors/cast_to_agraph_visitor.py @@ -582,6 +582,10 @@ def _(self, node: LiteralValue): node_uid = uuid.uuid4() self.G.add_node(node_uid, label=f"Boolean: {str(node.value)}") return node_uid + elif node.value_type == ScalarType.CHARACTER: + node_uid = uuid.uuid4() + self.G.add_node(node_uid, label=f"Character: {str(node.value)}") + return node_uid elif node.value_type == ScalarType.ABSTRACTFLOAT: node_uid = uuid.uuid4() self.G.add_node(node_uid, label=f"abstractFloat: {node.value}") From ea7e2b699f726db798dc4b6ae04aa1c2286b8f38 Mon Sep 17 00:00:00 2001 From: Gus Hahn-Powell Date: Tue, 16 Jan 2024 09:31:55 -0700 Subject: [PATCH 18/22] [REST] Performance improvements (async and client reuse) (#749) ## Summary of Changes This PR introduces performance-related improvements. - Avoids proxied calls to our remote deployment for the `llm_proxy` service (see #746 ) - Moves from `requests` to `httpx` for asynchronous calls to other services (see #747 ) - Reuses clients for improved performance via dependency injection (see #748 ) ### Related issues - Resolves #746 - Resolves #747 - Resolves #748 --------- Co-authored-by: Justin --- pyproject.toml | 3 +- skema/isa/isa_service.py | 12 ++-- .../comment_extractor/comment_extractor.py | 2 +- .../comment_extractor/server.py | 2 +- .../tests/test_comment_server.py | 6 +- skema/rest/api.py | 12 ++-- skema/rest/config.py | 9 +++ skema/rest/integrated_text_reading_proxy.py | 9 +-- skema/rest/llm_proxy.py | 40 ++++++----- skema/rest/morae_proxy.py | 34 +++++---- skema/rest/tests/test_eqn_to_latex.py | 11 +-- skema/rest/tests/test_model_to_amr.py | 60 +++++++++------- skema/rest/utils.py | 10 +++ skema/rest/workflows.py | 70 +++++++++++-------- skema/skema_py/server.py | 5 +- 15 files changed, 170 insertions(+), 115 deletions(-) create mode 100644 skema/rest/config.py diff --git a/pyproject.toml b/pyproject.toml index af3f401680b..20028ab88d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies=[ "typing_extensions", # see https://github.com/pydantic/pydantic/issues/5821#issuecomment-1559196859 "fastapi~=0.100.0", "starlette", + "httpx", "pydantic>=2.0.0", "uvicorn", "python-multipart", @@ -42,7 +43,7 @@ dynamic = ["readme"] # Pygraphviz is often tricky to install, so we reserve it for the dev extras # list. # - six: Required by auto-generated Swagger models -dev = ["pytest", "pytest-cov", "pytest-xdist", "httpx", "black", "mypy", "coverage", "pygraphviz", "six"] +dev = ["pytest", "pytest-cov", "pytest-xdist", "pytest-asyncio", "black", "mypy", "coverage", "pygraphviz", "six"] demo = ["notebook"] diff --git a/skema/isa/isa_service.py b/skema/isa/isa_service.py index a1a5d6d912e..bc0d2b4de16 100644 --- a/skema/isa/isa_service.py +++ b/skema/isa/isa_service.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- -from fastapi import FastAPI, APIRouter, status +from fastapi import Depends, FastAPI, APIRouter, status from skema.isa.lib import align_mathml_eqs import skema.isa.data as isa_data +from skema.rest import utils from pydantic import BaseModel -import requests +import httpx from skema.rest.proxies import SKEMA_RS_ADDESS @@ -23,8 +24,9 @@ class ISA_Result(BaseModel): response_model=int, status_code=status.HTTP_200_OK ) -async def healthcheck() -> int: - return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code +async def healthcheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> int: + res = await client.get(f"{SKEMA_RS_ADDESS}/ping") + return res.status_code @router.post( @@ -47,7 +49,7 @@ async def align_eqns( "mml2": {isa_data.mml} }} - response=client.post("/isa/align-eqns", json=request) + response=requests.post("/isa/align-eqns", json=request) res = response.json() """ ( diff --git a/skema/program_analysis/comment_extractor/comment_extractor.py b/skema/program_analysis/comment_extractor/comment_extractor.py index f43023ddec7..9a8bb4e1ac3 100644 --- a/skema/program_analysis/comment_extractor/comment_extractor.py +++ b/skema/program_analysis/comment_extractor/comment_extractor.py @@ -245,7 +245,7 @@ def extract_comments_multi( request: MultiFileCommentRequest, ) -> MultiFileCommentResponse: """Wrapper for processing multiple source files at a time.""" - return MultiFileCommentResponse.parse_obj( + return MultiFileCommentResponse(** { "files": { file_name: extract_comments_single(file_request) diff --git a/skema/program_analysis/comment_extractor/server.py b/skema/program_analysis/comment_extractor/server.py index f570ff67314..3b8ff908e3f 100644 --- a/skema/program_analysis/comment_extractor/server.py +++ b/skema/program_analysis/comment_extractor/server.py @@ -83,7 +83,7 @@ async def comments_extract_zip( } return comment_service.extract_comments_multi( - MultiFileCommentRequest.parse_obj(request) + MultiFileCommentRequest(**request) ) app = FastAPI() diff --git a/skema/program_analysis/comment_extractor/tests/test_comment_server.py b/skema/program_analysis/comment_extractor/tests/test_comment_server.py index 9b5691cd219..bc9457677b6 100644 --- a/skema/program_analysis/comment_extractor/tests/test_comment_server.py +++ b/skema/program_analysis/comment_extractor/tests/test_comment_server.py @@ -14,7 +14,7 @@ def test_comments_get_supported_languages(): response = client.get("/comment_service/comments-get-supported-languages") assert response.status_code == 200 - languages = comment_service.SupportedLanguageResponse.parse_obj(response.json()) + languages = comment_service.SupportedLanguageResponse(**response.json()) assert isinstance(languages, comment_service.SupportedLanguageResponse) assert len(languages.languages) > 0 @@ -37,7 +37,7 @@ def test_comments_extract(): response = client.post("/comment_service/comments-extract", json=request) assert response.status_code == 200 - comments = comment_service.SingleFileCommentResponse.parse_obj(response.json()) + comments = comment_service.SingleFileCommentResponse(**response.json()) assert isinstance(comments, comment_service.SingleFileCommentResponse) @@ -72,5 +72,5 @@ def test_comments_extract_zip(): ) assert response.status_code == 200 - comments = comment_service.MultiFileCommentResponse.parse_obj(response.json()) + comments = comment_service.MultiFileCommentResponse(**response.json()) assert isinstance(comments, comment_service.MultiFileCommentResponse) \ No newline at end of file diff --git a/skema/rest/api.py b/skema/rest/api.py index 69ba06ab95c..862fa58b577 100644 --- a/skema/rest/api.py +++ b/skema/rest/api.py @@ -1,10 +1,11 @@ import os from typing import Dict -from fastapi import FastAPI, Response, status +from fastapi import Depends, FastAPI, Response, status from fastapi.responses import PlainTextResponse from skema.rest import ( + config, schema, workflows, proxies, @@ -12,12 +13,14 @@ morae_proxy, metal_proxy, llm_proxy, + utils ) from skema.isa import isa_service from skema.img2mml import eqn2mml from skema.skema_py import server as code2fn from skema.gromet.execution_engine import server as execution_engine from skema.program_analysis.comment_extractor import server as comment_service +import httpx VERSION: str = os.environ.get("APP_VERSION", "????") @@ -170,7 +173,7 @@ summary="API version", status_code=status.HTTP_200_OK ) -async def version() -> str: +def version() -> str: return PlainTextResponse(VERSION) @@ -190,8 +193,8 @@ async def version() -> str: }, }, ) -async def healthcheck(response: Response) -> schema.HealthStatus: - morae_status = await morae_proxy.healthcheck() +async def healthcheck(response: Response, client: httpx.AsyncClient = Depends(utils.get_client)) -> schema.HealthStatus: + morae_status = await morae_proxy.healthcheck(client) mathjax_status = eqn2mml.latex2mml_healthcheck() eqn2mml_status = eqn2mml.img2mml_healthcheck() code2fn_status = code2fn.healthcheck() @@ -230,6 +233,7 @@ async def environment_variables() -> Dict: "SKEMA_GRAPH_DB_HOST": proxies.SKEMA_GRAPH_DB_HOST, "SKEMA_GRAPH_DB_PORT": proxies.SKEMA_GRAPH_DB_PORT, "SKEMA_RS_ADDRESS": proxies.SKEMA_RS_ADDESS, + "SKEMA_RS_DEFAULT_TIMEOUT": config.SKEMA_RS_DEFAULT_TIMEOUT, "SKEMA_MATHJAX_PROTOCOL": proxies.SKEMA_MATHJAX_PROTOCOL, "SKEMA_MATHJAX_HOST": proxies.SKEMA_MATHJAX_HOST, diff --git a/skema/rest/config.py b/skema/rest/config.py new file mode 100644 index 00000000000..92038fbb368 --- /dev/null +++ b/skema/rest/config.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- +""" +ENV-based config +""" + +import os + + +SKEMA_RS_DEFAULT_TIMEOUT = float(os.environ.get("SKEMA_RS_DEFAULT_TIMEOUT", "60.0")) \ No newline at end of file diff --git a/skema/rest/integrated_text_reading_proxy.py b/skema/rest/integrated_text_reading_proxy.py index d167455d057..6cb64420242 100644 --- a/skema/rest/integrated_text_reading_proxy.py +++ b/skema/rest/integrated_text_reading_proxy.py @@ -11,9 +11,10 @@ import pandas as pd import requests +import httpx from askem_extractions.data_model import AttributeCollection from askem_extractions.importers import import_arizona -from fastapi import APIRouter, FastAPI, UploadFile, Response, status +from fastapi import APIRouter, Depends, FastAPI, UploadFile, Response, status from skema.rest.proxies import SKEMA_TR_ADDRESS, MIT_TR_ADDRESS, OPENAI_KEY, COSMOS_ADDRESS from skema.rest.schema import ( @@ -22,7 +23,7 @@ TextReadingDocumentResults, TextReadingError, MiraGroundingInputs, MiraGroundingOutputItem, TextReadingEvaluationResults, ) -from skema.rest.utils import compute_text_reading_evaluation +from skema.rest import utils router = APIRouter() @@ -676,7 +677,7 @@ def quantitative_eval() -> TextReadingEvaluationResults: # Read the SKEMA extractions extractions = AttributeCollection.from_json(Path(__file__).parents[0] / "data" / "extractions_sidarthe_skema.json") - return compute_text_reading_evaluation(gt_data, extractions) + return utils.compute_text_reading_evaluation(gt_data, extractions) @router.post("/eval", response_model=TextReadingEvaluationResults, status_code=200) @@ -716,7 +717,7 @@ def quantitative_eval(extractions_file: UploadFile, extractions = AttributeCollection( attributes=list(it.chain.from_iterable(c.attributes for c in collections))) - return compute_text_reading_evaluation(gt_data, extractions, json_contents) + return utils.compute_text_reading_evaluation(gt_data, extractions, json_contents) app = FastAPI() diff --git a/skema/rest/llm_proxy.py b/skema/rest/llm_proxy.py index 191f97e84c4..a7454e1d453 100644 --- a/skema/rest/llm_proxy.py +++ b/skema/rest/llm_proxy.py @@ -11,10 +11,10 @@ from fastapi import APIRouter, FastAPI, File, UploadFile from io import BytesIO from zipfile import ZipFile -import requests from pathlib import Path from pydantic import BaseModel, Field from typing import List, Optional +from skema.skema_py import server as code2fn from skema.rest.proxies import SKEMA_OPENAI_KEY import time @@ -121,13 +121,14 @@ async def get_lines_of_model(zip_file: UploadFile = File()) -> List[Dynamics]: function_name = parsed_output['model_function'] - # Get the FN from it - url = "https://api.askem.lum.ai/code2fn/fn-given-filepaths" - time.sleep(0.5) - response_zip = requests.post(url, json=single_snippet_payload) - + # FIXME: we should rewrite things to avoid this need + #time.sleep(0.5) + system = code2fn.System(**single_snippet_payload) + print(f"System:\t{system}") + response_zip = await code2fn.fn_given_filepaths(system) + #print(f"response_zip:\t{response_zip}") # get metadata entry for function - for entry in response_zip.json()['modules'][0]['fn_array']: + for entry in response_zip['modules'][0]['fn_array']: try: if entry['b'][0]['name'][0:len(function_name)] == function_name: metadata_idx = entry['b'][0]['metadata'] @@ -135,26 +136,29 @@ async def get_lines_of_model(zip_file: UploadFile = File()) -> List[Dynamics]: continue # get line span using metadata - for (i,metadata) in enumerate(response_zip.json()['modules'][0]['metadata_collection']): + for (i,metadata) in enumerate(response_zip['modules'][0]['metadata_collection']): if i == (metadata_idx - 1): line_begin = metadata[0]['line_begin'] line_end = metadata[0]['line_end'] - except: + # if the line_begin of meta entry 2 (base 0) and meta entry 3 (base 0) are we add a slice from [meta2.line_begin, meta3.line_begin) + # to capture all the imports, return a Dynamics.block with 2 entries, both of which need to be concatenated to pass forward + file_line_begin = response_zip['modules'][0]['metadata_collection'][2][0]['line_begin'] + + code_line_begin = response_zip['modules'][0]['metadata_collection'][3][0]['line_begin'] - 1 + + if (file_line_begin != code_line_begin) and (code_line_begin > file_line_begin): + block.append(f"L{file_line_begin}-L{code_line_begin}") + + block.append(f"L{line_begin}-L{line_end}") + except Exception as e: print("Failed to parse dynamics") + print(f"e:\t{e}") description = "Failed to parse dynamics" line_begin = 0 line_end = 0 + block.append(f"L{line_begin}-L{line_end}") - # if the line_begin of meta entry 2 (base 0) and meta entry 3 (base 0) are we add a slice from [meta2.line_begin, meta3.line_begin) - # to capture all the imports, return a Dynamics.block with 2 entries, both of which need to be concatenated to pass forward - file_line_begin = response_zip.json()['modules'][0]['metadata_collection'][2][0]['line_begin'] - - code_line_begin = response_zip.json()['modules'][0]['metadata_collection'][3][0]['line_begin'] - 1 - - if file_line_begin != code_line_begin: - block.append(f"L{file_line_begin}-L{code_line_begin}") - block.append(f"L{line_begin}-L{line_end}") output = Dynamics(name=file, description=description, block=block) diff --git a/skema/rest/morae_proxy.py b/skema/rest/morae_proxy.py index ce1bc1e2651..c88d40cb9e9 100644 --- a/skema/rest/morae_proxy.py +++ b/skema/rest/morae_proxy.py @@ -6,8 +6,10 @@ from typing import Any, Dict, List, Text from skema.rest.proxies import SKEMA_RS_ADDESS -from fastapi import APIRouter -import requests +from fastapi import APIRouter, Depends +from skema.rest import utils +# TODO: replace use of requests with httpx +import httpx router = APIRouter() @@ -15,26 +17,30 @@ # FIXME: make GrometFunctionModuleCollection a pydantic model via code gen @router.post("/model", summary="Pushes gromet (function network) to the graph database", include_in_schema=False) -async def post_model(gromet: Dict[Text, Any]): - return requests.post(f"{SKEMA_RS_ADDESS}/models", json=gromet).json() +async def post_model(gromet: Dict[Text, Any], client: httpx.AsyncClient = Depends(utils.get_client)): + res = await client.post(f"{SKEMA_RS_ADDESS}/models", json=gromet) + return res.json() @router.get("/models", summary="Gets function network IDs from the graph database") -async def get_models() -> List[int]: - request = requests.get(f"{SKEMA_RS_ADDESS}/models") - print(f"request: {request}") - return request.json() +async def get_models(client: httpx.AsyncClient = Depends(utils.get_client)) -> List[int]: + res = await client.get(f"{SKEMA_RS_ADDESS}/models") + print(f"request: {res}") + return res.json() @router.get("/ping", summary="Status of MORAE service") -async def healthcheck() -> int: - return requests.get(f"{SKEMA_RS_ADDESS}/ping").status_code +async def healthcheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> int: + res = await client.get(f"{SKEMA_RS_ADDESS}/ping") + return res.status_code @router.get("/version", summary="Status of MORAE service") -async def versioncheck() -> str: - return requests.get(f"{SKEMA_RS_ADDESS}/version").text +async def versioncheck(client: httpx.AsyncClient = Depends(utils.get_client)) -> str: + res = await client.get(f"{SKEMA_RS_ADDESS}/version") + return res.text @router.post("/mathml/decapodes", summary="Gets Decapodes from a list of MathML strings") -async def get_decapodes(mathml: List[str]) -> Dict[Text, Any]: - return requests.put(f"{SKEMA_RS_ADDESS}/mathml/decapodes", json=mathml).json() \ No newline at end of file +async def get_decapodes(mathml: List[str], client: httpx.AsyncClient = Depends(utils.get_client)) -> Dict[Text, Any]: + res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/decapodes", json=mathml) + return res.json() \ No newline at end of file diff --git a/skema/rest/tests/test_eqn_to_latex.py b/skema/rest/tests/test_eqn_to_latex.py index b0f5de39a5e..8ed6339b89d 100644 --- a/skema/rest/tests/test_eqn_to_latex.py +++ b/skema/rest/tests/test_eqn_to_latex.py @@ -1,14 +1,15 @@ from pathlib import Path -from fastapi.testclient import TestClient +from httpx import AsyncClient from skema.rest.workflows import app import pytest import json -client = TestClient(app) + @pytest.mark.ci_only -def test_post_image_to_latex(): +@pytest.mark.asyncio +async def test_post_image_to_latex(): """Test case for /images/equations-to-latex endpoint.""" cwd = Path(__file__).parents[0] @@ -18,7 +19,9 @@ def test_post_image_to_latex(): } endpoint = "/images/equations-to-latex" - response = client.post(endpoint, files=files) + # see https://fastapi.tiangolo.com/advanced/async-tests/#async-tests + async with AsyncClient(app=app, base_url="http://eqn-to-latex-test") as ac: + response = await ac.post(endpoint, files=files) expected = "\\frac{d H}{dt}=\\nabla \\cdot {(\\Gamma*H^{n+2}*\\left|\\nabla{H}\\right|^{n-1}*\\nabla{H})}" # check for route's existence assert ( diff --git a/skema/rest/tests/test_model_to_amr.py b/skema/rest/tests/test_model_to_amr.py index 00a88bcbd17..b59d23a4140 100644 --- a/skema/rest/tests/test_model_to_amr.py +++ b/skema/rest/tests/test_model_to_amr.py @@ -12,8 +12,10 @@ ) from skema.rest.llm_proxy import Dynamics from skema.rest.proxies import SKEMA_RS_ADDESS -from skema.skema_py.server import System -import time +from skema.skema_py import server as code2fn +import json +import httpx +import pytest CHIME_SIR_URL = ( "https://artifacts.askem.lum.ai/askem/data/models/zip-archives/CHIME-SIR-model.zip" @@ -23,7 +25,8 @@ "https://artifacts.askem.lum.ai/askem/data/models/zip-archives/SIDARTHE.zip" ) -def test_any_amr_chime_sir(): +@pytest.mark.asyncio +async def test_any_amr_chime_sir(): """ Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. This will test if just the core dynamics works, the whole script, and also rewritten scripts work. @@ -82,15 +85,16 @@ def test_any_amr_chime_sir(): else: blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) try: - time.sleep(0.5) - code_snippet_response = asyncio.run( - code_snippets_to_pn_amr( - System( - files=[files[i]], - blobs=[blobs[i]], - ) - ) - ) + async with httpx.AsyncClient() as client: + code_snippet_response = await code_snippets_to_pn_amr( + system=code2fn.System( + files=[files[i]], + blobs=[blobs[i]], + ), + client=client + ) + # code_snippet_response = json.loads(code_snippet_response.body) + # print(f"code_snippet_response for test_any_amr_chime_sir: {code_snippet_response}") if "model" in code_snippet_response: code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" @@ -99,8 +103,9 @@ def test_any_amr_chime_sir(): else: print("snippets failure") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") - except: + except Exception as e: print("Hit except to snippets failure") + print(f"Exception for test_any_amr_chime_sir:\t{e}") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") # we will return the amr with most states, in assumption it is the most "correct" # by default it returns the first entry @@ -115,13 +120,15 @@ def test_any_amr_chime_sir(): amr = temp_amr except: continue - except: + except Exception as e: + print(f"Exception for test_any_amr_chime_sir:\t{e}") amr = logging print(f"final amr: {amr}\n") # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. assert "model" in amr, f"'model' should be in AMR response, but got {amr}" -def test_any_amr_sidarthe(): +@pytest.mark.asyncio +async def test_any_amr_sidarthe(): """ Unit test for checking that Chime-SIR model produces any AMR. This test zip contains 4 versions of CHIME SIR. This will test if just the core dynamics works, the whole script, and also rewritten scripts work. @@ -179,15 +186,14 @@ def test_any_amr_sidarthe(): else: blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) try: - time.sleep(0.5) - code_snippet_response = asyncio.run( - code_snippets_to_pn_amr( - System( - files=[files[i]], - blobs=[blobs[i]], - ) - ) - ) + async with httpx.AsyncClient() as client: + code_snippet_response = await code_snippets_to_pn_amr( + system=code2fn.System( + files=[files[i]], + blobs=[blobs[i]], + ), + client=client + ) if "model" in code_snippet_response: code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" code_snippet_response["header"]["description"] = f"This model came from code file: {files[i]}" @@ -196,8 +202,9 @@ def test_any_amr_sidarthe(): else: print("snippets failure") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") - except: + except Exception as e: print("Hit except to snippets failure") + print(f"Exception for test_any_amr_sidarthe:\t{e}") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") # we will return the amr with most states, in assumption it is the most "correct" # by default it returns the first entry @@ -212,7 +219,8 @@ def test_any_amr_sidarthe(): amr = temp_amr except: continue - except: + except Exception as e: + print(f"Exception for final amr of test_any_amr_sidarthe:\t{e}") amr = logging print(f"final amr: {amr}\n") # For this test, we are just checking that AMR was generated without crashing. We are not checking for accuracy. diff --git a/skema/rest/utils.py b/skema/rest/utils.py index 51a4d151bc4..319a6265f23 100644 --- a/skema/rest/utils.py +++ b/skema/rest/utils.py @@ -1,13 +1,23 @@ import itertools as it +import httpx from collections import defaultdict from typing import Any, Dict from askem_extractions.data_model import AttributeCollection, AttributeType, AnchoredEntity from bs4 import BeautifulSoup, Comment +from skema.rest import config from skema.rest.schema import TextReadingEvaluationResults, AMRLinkingEvaluationResults +# see https://stackoverflow.com/a/74401249 +async def get_client(): + # create a new client for each request + async with httpx.AsyncClient(timeout=config.SKEMA_RS_DEFAULT_TIMEOUT, follow_redirects=True) as client: + # yield the client to the endpoint function + yield client + # close the client when the request is done + def fn_preprocessor(function_network: Dict[str, Any]): fn_data = function_network.copy() diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index 0411814f18b..922dbac0594 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -3,20 +3,22 @@ End-to-end skema workflows """ import copy -import requests import time from zipfile import ZipFile from io import BytesIO from typing import List from pathlib import Path +import httpx +import json +import requests -from fastapi import APIRouter, File, UploadFile, FastAPI +from fastapi import APIRouter, Depends, File, UploadFile, FastAPI from starlette.responses import JSONResponse from skema.img2mml import eqn2mml from skema.img2mml.eqn2mml import image2mathml_db from skema.img2mml.api import get_mathml_from_bytes -from skema.rest import schema, utils, llm_proxy +from skema.rest import config, schema, utils, llm_proxy from skema.rest.proxies import SKEMA_RS_ADDESS from skema.skema_py import server as code2fn @@ -27,7 +29,7 @@ @router.post( "/images/base64/equations-to-amr", summary="Equations (base64 images) → MML → AMR" ) -async def equations_to_amr(data: schema.EquationImagesToAMR): +async def equations_to_amr(data: schema.EquationImagesToAMR, client: httpx.AsyncClient = Depends(utils.get_client)): """ Converts images of equations to AMR. @@ -57,7 +59,7 @@ async def equations_to_amr(data: schema.EquationImagesToAMR): ] payload = {"mathml": mml, "model": data.model} # FIXME: why is this a PUT? - res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) + res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) if res.status_code != 200: return JSONResponse( status_code=400, @@ -71,7 +73,7 @@ async def equations_to_amr(data: schema.EquationImagesToAMR): # equation images -> mml -> latex @router.post("/images/equations-to-latex", summary="Equations (images) → MML → LaTeX") -async def equations_to_latex(data: UploadFile): +async def equations_to_latex(data: UploadFile, client: httpx.AsyncClient = Depends(utils.get_client)): """ Converts images of equations to LaTeX. @@ -96,8 +98,9 @@ async def equations_to_latex(data: UploadFile): # pass image bytes to get_mathml_from_bytes function mml_res = get_mathml_from_bytes(image_bytes, image2mathml_db) proxy_url = f"{SKEMA_RS_ADDESS}/mathml/latex" + print(f"MMML:\t{mml_res}") print(f"Proxying request to {proxy_url}") - response = requests.post(proxy_url, data=mml_res) + response = await client.post(proxy_url, data=mml_res) # Check the response if response.status_code == 200: # The request was successful @@ -111,7 +114,7 @@ async def equations_to_latex(data: UploadFile): # tex equations -> pmml -> amr @router.post("/latex/equations-to-amr", summary="Equations (LaTeX) → pMML → AMR") -async def equations_to_amr(data: schema.EquationLatexToAMR): +async def equations_to_amr(data: schema.EquationLatexToAMR, client: httpx.AsyncClient = Depends(utils.get_client)): """ Converts equations (in LaTeX) to AMR. @@ -131,7 +134,7 @@ async def equations_to_amr(data: schema.EquationLatexToAMR): utils.clean_mml(eqn2mml.get_mathml_from_latex(tex)) for tex in data.equations ] payload = {"mathml": mml, "model": data.model} - res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) + res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) if res.status_code != 200: return JSONResponse( status_code=400, @@ -145,9 +148,9 @@ async def equations_to_amr(data: schema.EquationLatexToAMR): # pmml -> amr @router.post("/pmml/equations-to-amr", summary="Equations pMML → AMR") -async def equations_to_amr(data: schema.MmlToAMR): +async def equations_to_amr(data: schema.MmlToAMR, client: httpx.AsyncClient = Depends(utils.get_client)): payload = {"mathml": data.equations, "model": data.model} - res = requests.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) + res = await client.put(f"{SKEMA_RS_ADDESS}/mathml/amr", json=payload) if res.status_code != 200: return JSONResponse( status_code=400, @@ -159,17 +162,20 @@ async def equations_to_amr(data: schema.MmlToAMR): return res.json() -# code snippets -> fn -> petrinet amr +# code snippets -> fn -> petrinet amr @router.post("/code/snippets-to-pn-amr", summary="Code snippets → PetriNet AMR") -async def code_snippets_to_pn_amr(system: code2fn.System): +async def code_snippets_to_pn_amr(system: code2fn.System, client: httpx.AsyncClient = Depends(utils.get_client)): gromet = await code2fn.fn_given_filepaths(system) - gromet, logs = utils.fn_preprocessor(gromet) - res = requests.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet) + gromet, _ = utils.fn_preprocessor(gromet) + # print(f"gromet:{gromet}") + # print(f"client.follow_redirects:\t{client.follow_redirects}") + # print(f"client.timeout:\t{client.timeout}") + res = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet) if res.status_code != 200: return JSONResponse( status_code=400, content={ - "error": f"MORAE PUT /models/PN failed to process payload", + "error": f"MORAE PUT /models/PN failed to process payload ({res.text})", "payload": gromet, }, ) @@ -199,10 +205,10 @@ async def code_snippets_to_rn_amr(system: code2fn.System): @router.post( "/code/codebase-to-pn-amr", summary="Code repo (zip archive) → PetriNet AMR" ) -async def repo_to_pn_amr(zip_file: UploadFile = File()): +async def repo_to_pn_amr(zip_file: UploadFile = File(), client: httpx.AsyncClient = Depends(utils.get_client)): gromet = await code2fn.fn_given_filepaths_zip(zip_file) - gromet, logs = utils.fn_preprocessor(gromet) - res = requests.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet) + gromet, _ = utils.fn_preprocessor(gromet) + res = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet) if res.status_code != 200: return JSONResponse( status_code=400, @@ -219,7 +225,7 @@ async def repo_to_pn_amr(zip_file: UploadFile = File()): "/code/llm-assisted-codebase-to-pn-amr", summary="Code repo (zip archive) → PetriNet AMR", ) -async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): +async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File(), client: httpx.AsyncClient = Depends(utils.get_client)): """Codebase->AMR workflow using an llm to extract the dynamics line span. ### Python example ``` @@ -271,25 +277,26 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): # The source code is a string, so to slice using the line spans, we must first convert it to a list. # Then we can convert it back to a string using .join logging = [] + import_counter = 0 for i in range(len(blobs)): if line_begin[i] == line_end[i]: print("failed linespan") else: - if blocks == 2: - temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[i]:import_end[i]]) + if len(linespans[i].block) == 2: + temp = "".join(blobs[i].splitlines(keepends=True)[import_begin[import_counter]:import_end[import_counter]]) blobs[i] = temp + "\n" + "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) + import_counter += 1 else: blobs[i] = "".join(blobs[i].splitlines(keepends=True)[line_begin[i]:line_end[i]]) try: - time.sleep(0.5) print(f"Time call code-snippets: {time.time()}") - print(blobs[i]) - code_snippet_response = await code_snippets_to_pn_amr( - code2fn.System( - files=[files[i]], - blobs=[blobs[i]], - ) - ) + gromet = await code2fn.fn_given_filepaths(code2fn.System( + files=[files[i]], + blobs=[blobs[i]], + )) + gromet, _ = utils.fn_preprocessor(gromet) + code_snippet_response = await client.put(f"{SKEMA_RS_ADDESS}/models/PN", json=gromet) + code_snippet_response = code_snippet_response.json() print(f"Time response code-snippets: {time.time()}") if "model" in code_snippet_response: code_snippet_response["header"]["name"] = "LLM-assisted code to amr model" @@ -299,8 +306,9 @@ async def llm_assisted_codebase_to_pn_amr(zip_file: UploadFile = File()): else: print("snippets failure") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") - except: + except Exception as e: print("Hit except to snippets failure") + print(f"Exception:\t{e}") logging.append(f"{files[i]} failed to parse an AMR from the dynamics") # we will return the amr with most states, in assumption it is the most "correct" # by default it returns the first entry diff --git a/skema/skema_py/server.py b/skema/skema_py/server.py index 77ebdb1e755..c7f2b4fc661 100644 --- a/skema/skema_py/server.py +++ b/skema/skema_py/server.py @@ -7,7 +7,6 @@ from typing import List, Dict, Optional from io import BytesIO from zipfile import ZipFile -from urllib.request import urlopen from fastapi import APIRouter, FastAPI, status, Body, File, UploadFile from fastapi.responses import JSONResponse from pydantic import BaseModel, Field @@ -131,11 +130,11 @@ async def system_to_enriched_system(system: System) -> System: comments = {"files": {}} for file_path, result in zip(file_paths, results): comments["files"][str(file_path)] = result - system.comments = MultiFileCommentResponse.parse_obj(comments) + system.comments = MultiFileCommentResponse(**comments) return system - +# returns an abbreviated Dict representing a GrometFNModuleCollection async def system_to_gromet(system: System): """Convert a System to Gromet JSON""" From e017c99e8b8a66f869f75dba6b2432619ce10214 Mon Sep 17 00:00:00 2001 From: Enrique Noriega Date: Tue, 16 Jan 2024 14:18:03 -0700 Subject: [PATCH 19/22] Fixed #755 (#756) ## Summary of Changes Updated the TR grading code to stop double counting automated extractions which lead to an incorrect amount of true positives. ## Expanded explanation There is a many to one relationship between SKEMA extractions and manual annotations. Multiple extractions can match an annotation (be correct). To compute P, R and F1, all the extractions associated to a manual annotation should be counted as a single true positive. This change addresses this issue. ### Related issues Resolves 755 --- .../test_integrated_text_reading_proxy.py | 6 ++-- skema/rest/utils.py | 36 ++++++++++++------- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/skema/rest/tests/test_integrated_text_reading_proxy.py b/skema/rest/tests/test_integrated_text_reading_proxy.py index 46c0bff8822..d4e726e2440 100644 --- a/skema/rest/tests/test_integrated_text_reading_proxy.py +++ b/skema/rest/tests/test_integrated_text_reading_proxy.py @@ -114,9 +114,9 @@ def test_extraction_evaluation(): results = response.json() assert results['num_manual_annotations'] == 220, "There should be 220 gt manual annotations" - assert results['precision'] == approx(0.7230769230768118), "Precision drastically different from the expected value" - assert results['recall'] == approx(0.21363636363636362), "Recall drastically different from the expected value" - assert results['f1'] == approx(0.32982456136828636), "F1 drastically different from the expected value" + assert results['precision'] == approx(0.5230769230768426), "Precision drastically different from the expected value" + assert results['recall'] == approx(0.154545454545454542), "Recall drastically different from the expected value" + assert results['f1'] == approx(0.23859649119285095), "F1 drastically different from the expected value" def test_healthcheck(): diff --git a/skema/rest/utils.py b/skema/rest/utils.py index 319a6265f23..765add351a6 100644 --- a/skema/rest/utils.py +++ b/skema/rest/utils.py @@ -18,6 +18,7 @@ async def get_client(): yield client # close the client when the request is done + def fn_preprocessor(function_network: Dict[str, Any]): fn_data = function_network.copy() @@ -180,23 +181,32 @@ def compute_text_reading_evaluation(gt_data: list, attributes: AttributeCollecti page = a["page"] annotations_by_page[page].append(a) + def annotation_key(a: Dict): + return a['page'], tuple(a['start_xy']), a['text'] + # Count the matches tp, tn, fp, fn = 0, 0, 0, 0 + matched_annotations = set() for e in extractions: + matched = False for m in e.mentions: - if m.extraction_source is not None: - te = m.extraction_source - if te.page is not None: - e_page = te.page - page_annotations = annotations_by_page[e_page] - matched = False - for a in page_annotations: - if extraction_matches_annotation(m, a, json_contents): - matched = True - tp += 1 - break - if not matched: - fp += 1 + if not matched: + if m.extraction_source is not None: + te = m.extraction_source + if te.page is not None: + e_page = te.page + page_annotations = annotations_by_page[e_page] + + for a in page_annotations: + key = annotation_key(a) + if key not in matched_annotations: + if extraction_matches_annotation(m, a, json_contents): + matched_annotations.add(key) + matched = True + tp += 1 + break + if not matched: + fp += 1 recall = tp / len(gt_data) precision = tp / (tp + fp + 0.00000000001) From 2af61215ddd7d4257345ef258b7470b012597485 Mon Sep 17 00:00:00 2001 From: Justin Lieffers <76677555+Free-Quarks@users.noreply.github.com> Date: Tue, 16 Jan 2024 15:25:28 -0700 Subject: [PATCH 20/22] Code2MET Rust Endpoint (#757) ## Summary of Changes This PR adds a new Rust side endpoint that converts gromets to a vector of MET. This will be of use as we rollout more ISA based workflows. This also added a commented out endpoint stub in workflows.py for taking in code-snippets and converting them into a vector of MET's as well, to further lay some foundation for ISA workflows. ### Related issues Resolves ??? --------- Co-authored-by: Justin --- skema/rest/workflows.py | 28 ++++++++++++++ skema/skema-rs/mathml/src/ast.rs | 18 ++++----- skema/skema-rs/mathml/src/ast/operator.rs | 13 ++++--- .../src/parsers/math_expression_tree.rs | 4 +- skema/skema-rs/skema/src/bin/skema_service.rs | 2 + skema/skema-rs/skema/src/services/gromet.rs | 38 +++++++++++++++++++ 6 files changed, 86 insertions(+), 17 deletions(-) diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index 922dbac0594..c95c8d43cb3 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -345,6 +345,34 @@ async def repo_to_rn_amr(zip_file: UploadFile = File()): ) return res.json() """ +""" +# code snippets -> fn -> Vec -> ???? +@router.post("/isa/code-align", summary="ISA aided inference") +async def code_snippets_to_isa_align(system: code2fn.System, client: httpx.AsyncClient = Depends(utils.get_client)): + gromet = await code2fn.fn_given_filepaths(system) + gromet, _ = utils.fn_preprocessor(gromet) + # print(f"gromet:{gromet}") + # print(f"client.follow_redirects:\t{client.follow_redirects}") + # print(f"client.timeout:\t{client.timeout}") + res = await client.put(f"{SKEMA_RS_ADDESS}/models/MET", json=gromet) + # res is a vector of MET's from the code (assuming it could extract correctly) + if res.status_code != 200: + return JSONResponse( + status_code=400, + content={ + "error": f"MORAE PUT /models/PN failed to process payload ({res.text})", + "payload": gromet, + }, + ) + + # Liang, if you want to put your ISA portion here? + # ISA: + # + # + # + # + return res.json() +""" app = FastAPI() app.include_router(router) \ No newline at end of file diff --git a/skema/skema-rs/mathml/src/ast.rs b/skema/skema-rs/mathml/src/ast.rs index e9718967625..99a8479c85f 100644 --- a/skema/skema-rs/mathml/src/ast.rs +++ b/skema/skema-rs/mathml/src/ast.rs @@ -2,17 +2,17 @@ use derive_new::new; use std::fmt; pub mod operator; - +use serde::{Deserialize, Serialize}; use operator::Operator; //use crate::ast::MathExpression::SummationOp; -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct Mi(pub String); -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct Mrow(pub Vec); -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub enum Type { Integer, Rational, @@ -28,27 +28,27 @@ pub enum Type { Matrix, } -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct Ci { pub r#type: Option, pub content: Box, pub func_of: Option>, } -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct Differential { pub diff: Box, pub func: Box, } -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct SummationMath { pub op: Box, pub func: Box, } /// Hat operation -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct HatComp { pub op: Box, pub comp: Box, @@ -56,7 +56,7 @@ pub struct HatComp { /// The MathExpression enum is not faithful to the corresponding element type in MathML 3 /// (https://www.w3.org/TR/MathML3/appendixa.html#parsing_MathExpression) -#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Hash, Default, new)] +#[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Hash, Default, new, Deserialize, Serialize)] pub enum MathExpression { Mi(Mi), Mo(Operator), diff --git a/skema/skema-rs/mathml/src/ast/operator.rs b/skema/skema-rs/mathml/src/ast/operator.rs index 3e406261073..04a542ce87e 100644 --- a/skema/skema-rs/mathml/src/ast/operator.rs +++ b/skema/skema-rs/mathml/src/ast/operator.rs @@ -2,9 +2,10 @@ use crate::ast::Ci; use crate::ast::MathExpression; use derive_new::new; use std::fmt; +use serde::{Deserialize, Serialize}; /// Derivative operator, in line with Spivak notation: http://ceres-solver.org/spivak_notation.html -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct Derivative { pub order: u8, pub var_index: u8, @@ -12,7 +13,7 @@ pub struct Derivative { } /// Partial derivative operator -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct PartialDerivative { pub order: u8, pub var_index: u8, @@ -20,7 +21,7 @@ pub struct PartialDerivative { } /// Summation operator with under and over components -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct SumUnderOver { pub op: Box, pub under: Box, @@ -28,18 +29,18 @@ pub struct SumUnderOver { } /// Hat operation -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct HatOp { pub comp: Box, } /// Handles grad operations with subscript. E.g. ∇_{x} -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub struct GradSub { pub sub: Box, } -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub enum Operator { Add, Multiply, diff --git a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs index 718c900d45d..58f24732573 100644 --- a/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs +++ b/skema/skema-rs/mathml/src/parsers/math_expression_tree.rs @@ -11,7 +11,7 @@ use crate::{ use derive_new::new; use nom::error::Error; use regex::Regex; - +use serde::{Deserialize, Serialize}; use std::{fmt, str::FromStr}; #[cfg(test)] @@ -19,7 +19,7 @@ use crate::parsers::first_order_ode::{first_order_ode, FirstOrderODE}; ///New whitespace handler before parsing /// An S-expression like structure to represent mathematical expressions. -#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new)] +#[derive(Debug, Ord, PartialOrd, PartialEq, Eq, Clone, Hash, new, Deserialize, Serialize)] pub enum MathExpressionTree { Atom(MathExpression), Cons(Operator, Vec), diff --git a/skema/skema-rs/skema/src/bin/skema_service.rs b/skema/skema-rs/skema/src/bin/skema_service.rs index f3a7e03e68b..6a218d0575f 100644 --- a/skema/skema-rs/skema/src/bin/skema_service.rs +++ b/skema/skema-rs/skema/src/bin/skema_service.rs @@ -58,6 +58,7 @@ async fn main() -> std::io::Result<()> { gromet::get_model_RN, gromet::model2PN, gromet::model2RN, + gromet::model2MET, ping, version ), @@ -141,6 +142,7 @@ async fn main() -> std::io::Result<()> { .service(gromet::get_model_RN) .service(gromet::model2PN) .service(gromet::model2RN) + .service(gromet::model2MET) .service(ping) .service(version) .service(SwaggerUi::new("/docs/{_:.*}").url("/api-doc/openapi.json", openapi.clone())) diff --git a/skema/skema-rs/skema/src/services/gromet.rs b/skema/skema-rs/skema/src/services/gromet.rs index 79479c00a31..747f2cdf717 100644 --- a/skema/skema-rs/skema/src/services/gromet.rs +++ b/skema/skema-rs/skema/src/services/gromet.rs @@ -7,6 +7,8 @@ use actix_web::web::ServiceConfig; use actix_web::{delete, get, post, put, web, HttpResponse}; use mathml::acset::{PetriNet, RegNet}; +use mathml::ast::MathExpression; +use mathml::parsers::math_expression_tree::MathExpressionTree; use neo4rs; use neo4rs::{query, Error, Node}; use std::collections::HashMap; @@ -327,3 +329,39 @@ pub async fn model2RN( model_to_RN(payload.into_inner(), config1).await.unwrap(), )) } + +/// This returns a MET vector from a gromet. +#[allow(non_snake_case)] +#[utoipa::path( + request_body = ModuleCollection, + responses( + ( + status = 200, description = "Successfully retrieved MET" + ) + ) +)] +#[put("/models/MET")] +pub async fn model2MET( + payload: web::Json, + config: web::Data, +) -> HttpResponse { + let config1 = Config { + db_host: config.db_host.clone(), + db_port: config.db_port, + db_protocol: config.db_protocol.clone(), + }; + let module_id = push_model_to_db(payload.into_inner(), config1.clone()).await; // pushes model to db and gets id + let ref_module_id1 = module_id.as_ref(); + let ref_module_id2 = module_id.as_ref(); + let mathml_ast = module_id2mathml_MET_ast(*ref_module_id1.unwrap(), config1.clone()).await; // turns model into mathml ast equations + let _del_response = delete_module(*ref_module_id2.unwrap(), config1.clone()).await; // deletes model from db + let mut mets = Vec::::new(); + for equation in mathml_ast.iter() { + let mut equal_args = Vec::::new(); + equal_args.push(MathExpressionTree::Atom(MathExpression::Ci(equation.lhs_var.clone()))); + equal_args.push(equation.rhs.clone()); + let met = MathExpressionTree::Cons(mathml::ast::operator::Operator::Equals, equal_args.clone()); + mets.push(met.clone()); + } + HttpResponse::Ok().json(web::Json(mets)) +} From b089c47e4db95df8e81ba88065123933f79f6a49 Mon Sep 17 00:00:00 2001 From: titomeister Date: Thu, 18 Jan 2024 15:56:03 -0700 Subject: [PATCH 21/22] Gromet While Loop Wiring Bug PR (#753) This PR primarily fixes a crash in the Gromet generation when there's a While loop condition consists only of a function call. It also fixes another minor wiring bug related to LiteralValues ### Summary of Changes - Fixed Gromet generation crash for While loop condition of a function call. - Fixes a missing wiring bug when a function returns a plain LiteralValue. - Adds some small unit tests to maintain consistency. Resolves #728 --- .../CAST2FN/ann_cast/to_gromet_pass.py | 8 +- .../tests/test_literal_returns.py | 141 ++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 skema/program_analysis/tests/test_literal_returns.py diff --git a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py index f3dc35ea75d..b9247be87fb 100644 --- a/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py +++ b/skema/program_analysis/CAST2FN/ann_cast/to_gromet_pass.py @@ -2520,6 +2520,7 @@ def visit_call( from_assignment = False from_call = False from_operator = False + from_loop = False func_name, qual_func_name = get_func_name(node) if isinstance(parent_cast_node, AnnCastAssignment): @@ -2528,6 +2529,8 @@ def visit_call( from_call = True elif isinstance(parent_cast_node, AnnCastOperator): from_operator = True + elif isinstance(parent_cast_node, AnnCastLoop): + from_loop = True if isinstance(node.func, AnnCastAttribute): self.visit(node.func, parent_gromet_fn, parent_cast_node) @@ -2732,7 +2735,7 @@ def visit_call( ) # if isinstance(arg.right) - if from_call or from_operator or from_assignment: + if from_call or from_operator or from_assignment or from_loop: # Operator and calls need a pof appended here because they dont # do it themselves # At some point we would like the call handler to always append a POF @@ -2912,6 +2915,9 @@ def wire_return_node(self, node, gromet_fn): if isinstance(node, AnnCastLiteralValue): if is_tuple(node): self.pack_return_tuple(node, gromet_fn) + else: + gromet_fn.opo = insert_gromet_object(gromet_fn.opo, GrometPort(box=len(gromet_fn.b))) + gromet_fn.wfopo = insert_gromet_object(gromet_fn.wfopo, GrometWire(src=len(gromet_fn.opo),tgt=len(gromet_fn.pof))) return elif isinstance(node, AnnCastVar): var_name = node.val.name diff --git a/skema/program_analysis/tests/test_literal_returns.py b/skema/program_analysis/tests/test_literal_returns.py new file mode 100644 index 00000000000..f59bb0d4a8e --- /dev/null +++ b/skema/program_analysis/tests/test_literal_returns.py @@ -0,0 +1,141 @@ +# import json NOTE: json and Path aren't used right now, +# from pathlib import Path but will be used in the future +from skema.program_analysis.multi_file_ingester import process_file_system +from skema.gromet.fn import ( + GrometFNModuleCollection, + FunctionType, + TypedValue, +) +import ast + +from skema.program_analysis.CAST.pythonAST import py_ast_to_cast +from skema.program_analysis.CAST2FN.model.cast import SourceRef +from skema.program_analysis.CAST2FN import cast +from skema.program_analysis.CAST2FN.cast import CAST +from skema.program_analysis.run_ann_cast_pipeline import ann_cast_pipeline + + +def return1(): + return """ +def return_true(): + return True + """ + +def return2(): + return """ +def return_true(): + return True + +while (return_true()): + print("Test") + """ + + +def generate_gromet(test_file_string): + # use ast.Parse to get Python AST + contents = ast.parse(test_file_string) + + # use Python to CAST + line_count = len(test_file_string.split("\n")) + convert = py_ast_to_cast.PyASTToCAST("temp") + C = convert.visit(contents, {}, {}) + C.source_refs = [SourceRef("temp", None, None, 1, line_count)] + out_cast = cast.CAST([C], "python") + + # use AnnCastPipeline to create GroMEt + gromet = ann_cast_pipeline(out_cast, gromet=True, to_file=False, from_obj=True) + + return gromet + +def test_return1(): + gromet = generate_gromet(return1()) + + base_fn = gromet.fn + + assert len(base_fn.b) == 1 + + func_fn = gromet.fn_array[0] + assert len(func_fn.b) == 1 + + assert len(func_fn.opo) == 1 + assert func_fn.opo[0].box == 1 + + assert len(func_fn.bf) == 1 + assert func_fn.bf[0].function_type == FunctionType.LITERAL + assert func_fn.bf[0].value.value_type == "Boolean" + assert func_fn.bf[0].value.value == "True" + + assert len(func_fn.pof) == 1 + assert func_fn.pof[0].box == 1 + + assert len(func_fn.wfopo) == 1 + assert func_fn.wfopo[0].src == 1 and func_fn.wfopo[0].tgt == 1 + + +def test_return2(): + exp_gromet = generate_gromet(return2()) + + base_fn = exp_gromet.fn + assert len(base_fn.bl) == 1 + assert base_fn.bl[0].condition == 2 + assert base_fn.bl[0].body == 3 + + func_fn = exp_gromet.fn_array[0] + assert len(func_fn.b) == 1 + + assert len(func_fn.opo) == 1 + assert func_fn.opo[0].box == 1 + + assert len(func_fn.bf) == 1 + assert func_fn.bf[0].function_type == FunctionType.LITERAL + assert func_fn.bf[0].value.value_type == "Boolean" + assert func_fn.bf[0].value.value == "True" + + assert len(func_fn.pof) == 1 + assert func_fn.pof[0].box == 1 + + assert len(func_fn.wfopo) == 1 + assert func_fn.wfopo[0].src == 1 and func_fn.wfopo[0].tgt == 1 + + predicate_fn = exp_gromet.fn_array[1] + assert len(predicate_fn.b) == 1 + assert len(predicate_fn.opo) == 1 + assert predicate_fn.opo[0].box == 1 + + assert len(predicate_fn.bf) == 1 + assert predicate_fn.bf[0].body == 1 + + assert len(predicate_fn.pof) == 1 + assert predicate_fn.pof[0].box == 1 + + assert len(predicate_fn.wfopo) == 1 + assert predicate_fn.wfopo[0].src == 1 + assert predicate_fn.wfopo[0].tgt == 1 + + loop_fn = exp_gromet.fn_array[2] + assert len(loop_fn.bf) == 1 + assert loop_fn.bf[0].body == 4 + + loop_body_fn = exp_gromet.fn_array[3] + assert len(loop_body_fn.opo) == 1 + assert loop_body_fn.opo[0].box == 1 + + assert len(loop_body_fn.bf) == 2 + assert loop_body_fn.bf[1].function_type == FunctionType.LITERAL + assert loop_body_fn.bf[1].value.value_type == "List" + + assert len(loop_body_fn.pif) == 1 + assert loop_body_fn.pif[0].box == 1 + + assert len(loop_body_fn.pof) == 2 + assert loop_body_fn.pof[0].box == 1 + assert loop_body_fn.pof[1].box == 2 + + assert len(loop_body_fn.wff) == 1 + assert loop_body_fn.wff[0].src == 1 + assert loop_body_fn.wff[0].tgt == 2 + + assert len(loop_body_fn.wfopo) == 1 + assert loop_body_fn.wfopo[0].src == 1 + assert loop_body_fn.wfopo[0].tgt == 1 + \ No newline at end of file From 4605924a447f23f65fc036ca9e5ce14683905b54 Mon Sep 17 00:00:00 2001 From: Liang Zhang <68933075+ualiangzhang@users.noreply.github.com> Date: Tue, 23 Jan 2024 20:13:48 -0700 Subject: [PATCH 22/22] [Equations] Add the base64 support to equations_to_latex (#759) ## Summary of Changes Add the base64 support to the `equations_to_latex` endpoint --- skema/rest/tests/test_eqn_to_latex.py | 34 +++++++++++++++++-- skema/rest/workflows.py | 48 +++++++++++++++++++++++++-- 2 files changed, 76 insertions(+), 6 deletions(-) diff --git a/skema/rest/tests/test_eqn_to_latex.py b/skema/rest/tests/test_eqn_to_latex.py index 8ed6339b89d..89cfa81e7b9 100644 --- a/skema/rest/tests/test_eqn_to_latex.py +++ b/skema/rest/tests/test_eqn_to_latex.py @@ -1,3 +1,4 @@ +import base64 from pathlib import Path from httpx import AsyncClient from skema.rest.workflows import app @@ -5,8 +6,6 @@ import json - - @pytest.mark.ci_only @pytest.mark.asyncio async def test_post_image_to_latex(): @@ -21,7 +20,7 @@ async def test_post_image_to_latex(): endpoint = "/images/equations-to-latex" # see https://fastapi.tiangolo.com/advanced/async-tests/#async-tests async with AsyncClient(app=app, base_url="http://eqn-to-latex-test") as ac: - response = await ac.post(endpoint, files=files) + response = await ac.post(endpoint, files=files) expected = "\\frac{d H}{dt}=\\nabla \\cdot {(\\Gamma*H^{n+2}*\\left|\\nabla{H}\\right|^{n-1}*\\nabla{H})}" # check for route's existence assert ( @@ -35,3 +34,32 @@ async def test_post_image_to_latex(): assert ( json.loads(response.text) == expected ), f"Response should be {expected}, but instead received {response.text}" + + +@pytest.mark.ci_only +@pytest.mark.asyncio +async def test_post_image_to_latex_base64(): + """Test case for /images/base64/equations-to-latex endpoint.""" + cwd = Path(__file__).parents[0] + image_path = cwd / "data" / "img2latex" / "halfar.png" + with Path(image_path).open("rb") as infile: + img_bytes = infile.read() + img_b64 = base64.b64encode(img_bytes).decode("utf-8") + + endpoint = "/images/base64/equations-to-latex" + # see https://fastapi.tiangolo.com/advanced/async-tests/#async-tests + async with AsyncClient(app=app, base_url="http://eqn-to-latex-base64-test") as ac: + response = await ac.post(endpoint, data=img_b64) + expected = "\\frac{d H}{dt}=\\nabla \\cdot {(\\Gamma*H^{n+2}*\\left|\\nabla{H}\\right|^{n-1}*\\nabla{H})}" + # check for route's existence + assert ( + any(route.path == endpoint for route in app.routes) == True + ), "{endpoint} does not exist for app" + # check status code + assert ( + response.status_code == 200 + ), f"Request was unsuccessful (status code was {response.status_code} instead of 200)" + # check response + assert ( + json.loads(response.text) == expected + ), f"Response should be {expected}, but instead received {response.text}" \ No newline at end of file diff --git a/skema/rest/workflows.py b/skema/rest/workflows.py index c95c8d43cb3..5f0e53745db 100644 --- a/skema/rest/workflows.py +++ b/skema/rest/workflows.py @@ -12,11 +12,11 @@ import json import requests -from fastapi import APIRouter, Depends, File, UploadFile, FastAPI +from fastapi import APIRouter, Depends, File, UploadFile, FastAPI, Request from starlette.responses import JSONResponse from skema.img2mml import eqn2mml -from skema.img2mml.eqn2mml import image2mathml_db +from skema.img2mml.eqn2mml import image2mathml_db, b64_image_to_mml from skema.img2mml.api import get_mathml_from_bytes from skema.rest import config, schema, utils, llm_proxy from skema.rest.proxies import SKEMA_RS_ADDESS @@ -112,6 +112,48 @@ async def equations_to_latex(data: UploadFile, client: httpx.AsyncClient = Depen return f"Error: {response.status_code} {response.text}" +# equation images -> base64 -> mml -> latex +@router.post("/images/base64/equations-to-latex", summary="Equations (images) → MML → LaTeX") +async def equations_to_latex(request: Request, client: httpx.AsyncClient = Depends(utils.get_client)): + """ + Converts images of equations to LaTeX. + + ### Python example + + Endpoint for generating LaTeX from an input image. + + ``` + from pathlib import Path + import base64 + import requests + + url = "http://127.0.0.1:8000/workflows/images/base64/equations-to-latex" + with Path("test.png").open("rb") as infile: + img_bytes = infile.read() + img_b64 = base64.b64encode(img_bytes).decode("utf-8") + r = requests.post(url, data=img_b64) + print(r.text) + ``` + """ + # Read image data + img_b64 = await request.body() + mml_res = b64_image_to_mml(img_b64) + + # pass image bytes to get_mathml_from_bytes function + proxy_url = f"{SKEMA_RS_ADDESS}/mathml/latex" + print(f"MML:\t{mml_res}") + print(f"Proxying request to {proxy_url}") + response = await client.post(proxy_url, data=mml_res) + # Check the response + if response.status_code == 200: + # The request was successful + return response.text + else: + # The request failed + print(f"Error: {response.status_code}") + print(response.text) + return f"Error: {response.status_code} {response.text}" + # tex equations -> pmml -> amr @router.post("/latex/equations-to-amr", summary="Equations (LaTeX) → pMML → AMR") async def equations_to_amr(data: schema.EquationLatexToAMR, client: httpx.AsyncClient = Depends(utils.get_client)): @@ -162,7 +204,7 @@ async def equations_to_amr(data: schema.MmlToAMR, client: httpx.AsyncClient = De return res.json() -# code snippets -> fn -> petrinet amr +# code snippets -> fn -> petrinet amr @router.post("/code/snippets-to-pn-amr", summary="Code snippets → PetriNet AMR") async def code_snippets_to_pn_amr(system: code2fn.System, client: httpx.AsyncClient = Depends(utils.get_client)): gromet = await code2fn.fn_given_filepaths(system)