Skip to content

Commit

Permalink
Make rulenode2expr work for StateFixedShapedHoles
Browse files Browse the repository at this point in the history
  • Loading branch information
Whebon authored and ReubenJ committed Apr 18, 2024
1 parent 435c838 commit 9a3a412
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)
return retval
end

if node.ind ruleset
union!(retval, [node.ind])
if get_rule(node) ruleset
union!(retval, [get_rule(node)])
end

if isempty(node.children)
Expand Down Expand Up @@ -96,55 +96,55 @@ Extract the derivation sequence from a path (sequence of child indices) and an [
If the path is deeper than the deepest node, it returns what it has.
"""
function get_rulesequence(node::RuleNode, path::Vector{Int})
if node.ind == 0 # sign for empty node
if get_rule(node) == 0 # sign for empty node
return Vector{Int}()
elseif isempty(node.children) # no children, nowhere to follow the path; still return the index
return [node.ind]
return [get_rule(node)]
elseif isempty(path)
return [node.ind]
return [get_rule(node)]
elseif isassigned(path, 2)
# at least two items are left in the path
# need to access the child with get because it can happen that the child is not yet built
return append!([node.ind], get_rulesequence(get(node.children, path[begin], RuleNode(0)), path[2:end]))
return append!([get_rule(node)], get_rulesequence(get(node.children, path[begin], RuleNode(0)), path[2:end]))
else
# if only one item left in the path
# need to access the child with get because it can happen that the child is not yet built
return append!([node.ind], get_rulesequence(get(node.children, path[begin], RuleNode(0)), Vector{Int}()))
return append!([get_rule(node)], get_rulesequence(get(node.children, path[begin], RuleNode(0)), Vector{Int}()))
end
end

get_rulesequence(::Hole, ::Vector{Int}) = Vector{Int}()

"""
rulesonleft(expr::RuleNode, path::Vector{Int})::Set{Int}
rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int}
Finds all rules that are used in the left subtree defined by the path.
"""
function rulesonleft(expr::RuleNode, path::Vector{Int})::Set{Int}
if isempty(expr.children)
function rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int}
if isempty(node.children)
# if the encountered node is terminal or non-expanded non-terminal, return node id
Set{Int}(expr.ind)
Set{Int}(get_rule(node))
elseif isempty(path)
# if path is empty, collect the entire subtree
ruleset = Set{Int}(expr.ind)
for ch in expr.children
ruleset = Set{Int}(get_rule(node))
for ch in node.children
union!(ruleset, rulesonleft(ch, Vector{Int}()))
end
return ruleset
elseif length(path) == 1
# if there is only one element left in the path, collect all children except the one indicated in the path
ruleset = Set{Int}(expr.ind)
ruleset = Set{Int}(get_rule(node))
for i in 1:path[begin]-1
union!(ruleset, rulesonleft(expr.children[i], Vector{Int}()))
union!(ruleset, rulesonleft(node.children[i], Vector{Int}()))
end
return ruleset
else
# collect all subtrees up to the child indexed in the path
ruleset = Set{Int}(expr.ind)
ruleset = Set{Int}(get_rule(node))
for i in 1:path[begin]-1
union!(ruleset, rulesonleft(expr.children[i], Vector{Int}()))
union!(ruleset, rulesonleft(node.children[i], Vector{Int}()))
end
union!(ruleset, rulesonleft(expr.children[path[begin]], path[2:end]))
union!(ruleset, rulesonleft(node.children[path[begin]], path[2:end]))
return ruleset
end
end
Expand Down Expand Up @@ -175,46 +175,47 @@ rulesonleft(::Hole, ::Vector{Int}) = Set{Int}()


"""
rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar)
rulenode2expr(rulenode::AbstractRuleNode, grammar::AbstractGrammar)
Converts a [`RuleNode`](@ref) into a Julia expression corresponding to the rule definitions in the grammar.
Converts an [`AbstractRuleNode`](@ref) into a Julia expression corresponding to the rule definitions in the grammar.
The returned expression can be evaluated with Julia semantics using `eval()`.
"""
function rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar)
root = (rulenode._val !== nothing) ?
rulenode._val : deepcopy(grammar.rules[rulenode.ind])
if !grammar.isterminal[rulenode.ind] # not terminal
function rulenode2expr(rulenode::AbstractRuleNode, grammar::AbstractGrammar)
if !isfilled(rulenode)
return _get_hole_type(rulenode, grammar)
end
root = hasdynamicvalue(rulenode) ? rulenode._val : deepcopy(grammar.rules[get_rule(rulenode)])
if !grammar.isterminal[get_rule(rulenode)] # not terminal
root,_ = _rulenode2expr(root, rulenode, grammar)
end
return root
end


function _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar)
# Find the index of the first element that is true
index = findfirst(==(true), rulenode.domain)
function _get_hole_type(hole::Hole, grammar::AbstractGrammar)
#TODO: convert the children of UniformHoles to subexpressions
@assert !isfilled(hole) "Hole $(hole) is convertable to an expression. There is no need to represent it using a symbol."
index = findfirst(rulenode.domain)
return isnothing(index) ? :Nothing : grammar.types[index]
end
rulenode2expr(rulenode::Hole, grammar::AbstractGrammar) = _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar)

function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::AbstractGrammar, j=0)
for (k,arg) in enumerate(expr.args)
if isa(arg, Expr)
expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j)
elseif haskey(grammar.bytype, arg)
child = rulenode.children[j+=1]
if isa(child, Hole)
expr.args[k] = _rulenode2expr(child, grammar)
continue
end
expr.args[k] = (child._val !== nothing) ?
child._val : deepcopy(grammar.rules[child.ind])
if !isterminal(grammar, child)
expr.args[k],_ = _rulenode2expr(expr.args[k], child, grammar, 0)

function _rulenode2expr(expr::Expr, rulenode::AbstractRuleNode, grammar::AbstractGrammar, j=0)
if isfilled(rulenode)
for (k,arg) in enumerate(expr.args)
if isa(arg, Expr)
expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j)
elseif haskey(grammar.bytype, arg)
child = rulenode.children[j+=1]
if !isfilled(rulenode)
return _get_hole_type(rulenode, grammar)
end
expr.args[k] = hasdynamicvalue(rulenode) ? child._val : deepcopy(grammar.rules[get_rule(child)])
if !isterminal(grammar, child)
expr.args[k],_ = _rulenode2expr(expr.args[k], child, grammar, 0)
end
end
end
return expr, j
end
return expr, j
end


Expand All @@ -225,9 +226,8 @@ function _rulenode2expr(typ::Symbol, rulenode::RuleNode, grammar::AbstractGramma
if isa(child, Hole)
return retval, j
end
retval = (child._val !== nothing) ?
child._val : deepcopy(grammar.rules[child.ind])
if !grammar.isterminal[child.ind]
retval = hasdynamicvalue(rulenode) ? child._val : deepcopy(grammar.rules[get_rule(child)])
if !grammar.isterminal[get_rule(child)]
retval,_ = _rulenode2expr(retval, child, grammar, 0)
end
end
Expand All @@ -239,7 +239,7 @@ end
Calculates the log probability associated with a rulenode in a probabilistic grammar.
"""
function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar)
return log_probability(grammar, node.ind) + sum((rulenode_log_probability(c, grammar) for c node.children), init=1)
return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c node.children), init=1)
end

rulenode_log_probability(::Hole, ::AbstractGrammar) = 1
Expand Down Expand Up @@ -270,7 +270,7 @@ iscomplete(grammar::AbstractGrammar, ::Hole) = false
Gives the return type or nonterminal symbol in the production rule used by `node`.
"""
return_type(grammar::AbstractGrammar, node::RuleNode)::Symbol = grammar.types[node.ind]
return_type(grammar::AbstractGrammar, node::RuleNode)::Symbol = grammar.types[get_rule(node)]


"""
Expand All @@ -286,15 +286,15 @@ return_type(grammar::AbstractGrammar, hole::UniformHole)::Symbol = grammar.types
Returns the list of child types (nonterminal symbols) in the production rule used by `node`.
"""
child_types(grammar::AbstractGrammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[node.ind]
child_types(grammar::AbstractGrammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[get_rule(node)]


"""
isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool
isterminal(grammar::AbstractGrammar, node::AbstractRuleNode)::Bool
Returns true if the production rule used by `node` is terminal, i.e., does not contain any nonterminal symbols.
"""
isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool = grammar.isterminal[node.ind]
isterminal(grammar::AbstractGrammar, node::AbstractRuleNode)::Bool = grammar.isterminal[get_rule(node)]


"""
Expand All @@ -310,9 +310,9 @@ nchildren(grammar::AbstractGrammar, node::RuleNode)::Int = length(child_types(gr
Return true if the rule used by `node` represents a variable in a program (essentially, an input to the program)
"""
isvariable(grammar::AbstractGrammar, node::RuleNode)::Bool = (
grammar.isterminal[node.ind] &&
grammar.rules[node.ind] isa Symbol &&
!_is_defined_in_modules(grammar.rules[node.ind], [Main, Base])
grammar.isterminal[get_rule(node)] &&
grammar.rules[get_rule(node)] isa Symbol &&
!_is_defined_in_modules(grammar.rules[get_rule(node)], [Main, Base])
)
"""
isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module)::Bool
Expand All @@ -322,9 +322,9 @@ Return true if the rule used by `node` represents a variable.
Taking into account the symbols defined in the given module(s).
"""
isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module...)::Bool = (
grammar.isterminal[node.ind] &&
grammar.rules[node.ind] isa Symbol &&
!_is_defined_in_modules(grammar.rules[node.ind], [mod..., Main, Base])
grammar.isterminal[get_rule(node)] &&
grammar.rules[get_rule(node)] isa Symbol &&
!_is_defined_in_modules(grammar.rules[get_rule(node)], [mod..., Main, Base])
)

"""
Expand Down

0 comments on commit 9a3a412

Please sign in to comment.