Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic using solver (4 PRs) #64

Merged
merged 2 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nodelocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ Returns a NodeLoc pointing to the root node.
root_node_loc(root::RuleNode) = NodeLoc(root, 0)

"""
get(root::RuleNode, loc::NodeLoc)
get(root::AbstractRuleNode, loc::NodeLoc)
Obtain the node pointed to by loc.
"""
function Base.get(root::RuleNode, loc::NodeLoc)
function Base.get(root::AbstractRuleNode, loc::NodeLoc)
parent, i = loc.parent, loc.i
if loc.i > 0
return parent.children[i]
Expand Down
126 changes: 67 additions & 59 deletions src/rulenode_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ function rulesoftype(node::RuleNode, ruleset::Set{Int}, ignoreNode::RuleNode)
return retval
end

if node.ind ∈ ruleset
union!(retval, [node.ind])
if get_rule(node) ∈ ruleset
union!(retval, [get_rule(node)])
end

if isempty(node.children)
Expand Down Expand Up @@ -96,55 +96,55 @@ Extract the derivation sequence from a path (sequence of child indices) and an [
If the path is deeper than the deepest node, it returns what it has.
"""
function get_rulesequence(node::RuleNode, path::Vector{Int})
if node.ind == 0 # sign for empty node
if get_rule(node) == 0 # sign for empty node
return Vector{Int}()
elseif isempty(node.children) # no children, nowhere to follow the path; still return the index
return [node.ind]
return [get_rule(node)]
elseif isempty(path)
return [node.ind]
return [get_rule(node)]
elseif isassigned(path, 2)
# at least two items are left in the path
# need to access the child with get because it can happen that the child is not yet built
return append!([node.ind], get_rulesequence(get(node.children, path[begin], RuleNode(0)), path[2:end]))
return append!([get_rule(node)], get_rulesequence(get(node.children, path[begin], RuleNode(0)), path[2:end]))
else
# if only one item left in the path
# need to access the child with get because it can happen that the child is not yet built
return append!([node.ind], get_rulesequence(get(node.children, path[begin], RuleNode(0)), Vector{Int}()))
return append!([get_rule(node)], get_rulesequence(get(node.children, path[begin], RuleNode(0)), Vector{Int}()))
end
end

get_rulesequence(::Hole, ::Vector{Int}) = Vector{Int}()

"""
rulesonleft(expr::RuleNode, path::Vector{Int})::Set{Int}
rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int}

Finds all rules that are used in the left subtree defined by the path.
"""
function rulesonleft(expr::RuleNode, path::Vector{Int})::Set{Int}
if isempty(expr.children)
function rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int}
if isempty(node.children)
# if the encountered node is terminal or non-expanded non-terminal, return node id
Set{Int}(expr.ind)
Set{Int}(get_rule(node))
elseif isempty(path)
# if path is empty, collect the entire subtree
ruleset = Set{Int}(expr.ind)
for ch in expr.children
ruleset = Set{Int}(get_rule(node))
for ch in node.children
union!(ruleset, rulesonleft(ch, Vector{Int}()))
end
return ruleset
elseif length(path) == 1
# if there is only one element left in the path, collect all children except the one indicated in the path
ruleset = Set{Int}(expr.ind)
ruleset = Set{Int}(get_rule(node))
for i in 1:path[begin]-1
union!(ruleset, rulesonleft(expr.children[i], Vector{Int}()))
union!(ruleset, rulesonleft(node.children[i], Vector{Int}()))
end
return ruleset
else
# collect all subtrees up to the child indexed in the path
ruleset = Set{Int}(expr.ind)
ruleset = Set{Int}(get_rule(node))
for i in 1:path[begin]-1
union!(ruleset, rulesonleft(expr.children[i], Vector{Int}()))
union!(ruleset, rulesonleft(node.children[i], Vector{Int}()))
end
union!(ruleset, rulesonleft(expr.children[path[begin]], path[2:end]))
union!(ruleset, rulesonleft(node.children[path[begin]], path[2:end]))
return ruleset
end
end
Expand Down Expand Up @@ -175,46 +175,47 @@ rulesonleft(::Hole, ::Vector{Int}) = Set{Int}()


"""
rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar)
rulenode2expr(rulenode::AbstractRuleNode, grammar::AbstractGrammar)

Converts a [`RuleNode`](@ref) into a Julia expression corresponding to the rule definitions in the grammar.
Converts an [`AbstractRuleNode`](@ref) into a Julia expression corresponding to the rule definitions in the grammar.
The returned expression can be evaluated with Julia semantics using `eval()`.
"""
function rulenode2expr(rulenode::RuleNode, grammar::AbstractGrammar)
root = (rulenode._val !== nothing) ?
rulenode._val : deepcopy(grammar.rules[rulenode.ind])
if !grammar.isterminal[rulenode.ind] # not terminal
function rulenode2expr(rulenode::AbstractRuleNode, grammar::AbstractGrammar)
if !isfilled(rulenode)
return _get_hole_type(rulenode, grammar)
end
root = hasdynamicvalue(rulenode) ? rulenode._val : deepcopy(grammar.rules[get_rule(rulenode)])
if !grammar.isterminal[get_rule(rulenode)] # not terminal
root,_ = _rulenode2expr(root, rulenode, grammar)
end
return root
end


function _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar)
# Find the index of the first element that is true
index = findfirst(==(true), rulenode.domain)
function _get_hole_type(hole::Hole, grammar::AbstractGrammar)
#TODO: convert the children of UniformHoles to subexpressions
@assert !isfilled(hole) "Hole $(hole) is convertable to an expression. There is no need to represent it using a symbol."
index = findfirst(rulenode.domain)
return isnothing(index) ? :Nothing : grammar.types[index]
end
rulenode2expr(rulenode::Hole, grammar::AbstractGrammar) = _rulenode2expr(rulenode::Hole, grammar::AbstractGrammar)

function _rulenode2expr(expr::Expr, rulenode::RuleNode, grammar::AbstractGrammar, j=0)
for (k,arg) in enumerate(expr.args)
if isa(arg, Expr)
expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j)
elseif haskey(grammar.bytype, arg)
child = rulenode.children[j+=1]
if isa(child, Hole)
expr.args[k] = _rulenode2expr(child, grammar)
continue
end
expr.args[k] = (child._val !== nothing) ?
child._val : deepcopy(grammar.rules[child.ind])
if !isterminal(grammar, child)
expr.args[k],_ = _rulenode2expr(expr.args[k], child, grammar, 0)

function _rulenode2expr(expr::Expr, rulenode::AbstractRuleNode, grammar::AbstractGrammar, j=0)
if isfilled(rulenode)
for (k,arg) in enumerate(expr.args)
if isa(arg, Expr)
expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j)
elseif haskey(grammar.bytype, arg)
child = rulenode.children[j+=1]
if !isfilled(rulenode)
return _get_hole_type(rulenode, grammar)
end
expr.args[k] = hasdynamicvalue(rulenode) ? child._val : deepcopy(grammar.rules[get_rule(child)])
if !isterminal(grammar, child)
expr.args[k],_ = _rulenode2expr(expr.args[k], child, grammar, 0)
end
end
end
return expr, j
end
return expr, j
end


Expand All @@ -225,9 +226,8 @@ function _rulenode2expr(typ::Symbol, rulenode::RuleNode, grammar::AbstractGramma
if isa(child, Hole)
return retval, j
end
retval = (child._val !== nothing) ?
child._val : deepcopy(grammar.rules[child.ind])
if !grammar.isterminal[child.ind]
retval = hasdynamicvalue(rulenode) ? child._val : deepcopy(grammar.rules[get_rule(child)])
if !grammar.isterminal[get_rule(child)]
retval,_ = _rulenode2expr(retval, child, grammar, 0)
end
end
Expand All @@ -239,7 +239,7 @@ end
Calculates the log probability associated with a rulenode in a probabilistic grammar.
"""
function rulenode_log_probability(node::RuleNode, grammar::AbstractGrammar)
return log_probability(grammar, node.ind) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1)
return log_probability(grammar, get_rule(node)) + sum((rulenode_log_probability(c, grammar) for c ∈ node.children), init=1)
end

rulenode_log_probability(::Hole, ::AbstractGrammar) = 1
Expand Down Expand Up @@ -270,23 +270,31 @@ iscomplete(grammar::AbstractGrammar, ::Hole) = false

Gives the return type or nonterminal symbol in the production rule used by `node`.
"""
return_type(grammar::AbstractGrammar, node::RuleNode)::Symbol = grammar.types[node.ind]
return_type(grammar::AbstractGrammar, node::RuleNode)::Symbol = grammar.types[get_rule(node)]


"""
return_type(grammar::AbstractGrammar, hole::UniformHole)

Gives the return type or nonterminal symbol in the production rule used by `hole`.
"""
return_type(grammar::AbstractGrammar, hole::UniformHole)::Symbol = grammar.types[findfirst(hole.domain)]


"""
child_types(grammar::AbstractGrammar, node::RuleNode)

Returns the list of child types (nonterminal symbols) in the production rule used by `node`.
"""
child_types(grammar::AbstractGrammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[node.ind]
child_types(grammar::AbstractGrammar, node::RuleNode)::Vector{Symbol} = grammar.childtypes[get_rule(node)]


"""
isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool
isterminal(grammar::AbstractGrammar, node::AbstractRuleNode)::Bool

Returns true if the production rule used by `node` is terminal, i.e., does not contain any nonterminal symbols.
"""
isterminal(grammar::AbstractGrammar, node::RuleNode)::Bool = grammar.isterminal[node.ind]
isterminal(grammar::AbstractGrammar, node::AbstractRuleNode)::Bool = grammar.isterminal[get_rule(node)]


"""
Expand All @@ -302,9 +310,9 @@ nchildren(grammar::AbstractGrammar, node::RuleNode)::Int = length(child_types(gr
Return true if the rule used by `node` represents a variable in a program (essentially, an input to the program)
"""
isvariable(grammar::AbstractGrammar, node::RuleNode)::Bool = (
grammar.isterminal[node.ind] &&
grammar.rules[node.ind] isa Symbol &&
!_is_defined_in_modules(grammar.rules[node.ind], [Main, Base])
grammar.isterminal[get_rule(node)] &&
grammar.rules[get_rule(node)] isa Symbol &&
!_is_defined_in_modules(grammar.rules[get_rule(node)], [Main, Base])
)
"""
isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module)::Bool
Expand All @@ -314,9 +322,9 @@ Return true if the rule used by `node` represents a variable.
Taking into account the symbols defined in the given module(s).
"""
isvariable(grammar::AbstractGrammar, node::RuleNode, mod::Module...)::Bool = (
grammar.isterminal[node.ind] &&
grammar.rules[node.ind] isa Symbol &&
!_is_defined_in_modules(grammar.rules[node.ind], [mod..., Main, Base])
grammar.isterminal[get_rule(node)] &&
grammar.rules[get_rule(node)] isa Symbol &&
!_is_defined_in_modules(grammar.rules[get_rule(node)], [mod..., Main, Base])
)

"""
Expand Down
Loading