Skip to content

Commit

Permalink
remove VE
Browse files Browse the repository at this point in the history
  • Loading branch information
naseweisssss committed Nov 18, 2024
1 parent 34a17b2 commit 0921f2e
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 321 deletions.
175 changes: 1 addition & 174 deletions src/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,177 +283,4 @@ function is_conditionally_independent(
end

return true
end

using LinearAlgebra

# Add these structs and methods before the variable_elimination function
struct Factor
variables::Vector{Symbol}
distribution::Distribution
parents::Vector{Symbol}
end

"""
Create a factor from a node in the Bayesian network.
"""
function create_factor(bn::BayesianNetwork, node::Symbol)
node_id = bn.names_to_ids[node]
if !bn.is_stochastic[node_id]
error("Cannot create factor for deterministic node")
end

dist_idx = findfirst(id -> id == node_id, bn.stochastic_ids)
dist = bn.distributions[dist_idx]
parent_ids = inneighbors(bn.graph, node_id)
parents = Symbol[bn.names[pid] for pid in parent_ids]

return Factor([node], dist, parents)
end

"""
Multiply two factors.
"""
function multiply_factors(f1::Factor, f2::Factor)
new_vars = unique(vcat(f1.variables, f2.variables))
new_parents = unique(vcat(f1.parents, f2.parents))

if f1.distribution isa Normal && f2.distribution isa Normal
μ = mean(f1.distribution) + mean(f2.distribution)
σ = sqrt(var(f1.distribution) + var(f2.distribution))
new_dist = Normal(μ, σ)
elseif f1.distribution isa Categorical && f2.distribution isa Categorical
p = f1.distribution.p .* f2.distribution.p
p = p ./ sum(p)
new_dist = Categorical(p)
else
new_dist = Normal(0, 1)
end

return Factor(new_vars, new_dist, new_parents)
end

"""
Marginalize (sum/integrate) out a variable from a factor.
"""
function marginalize(factor::Factor, var::Symbol)
new_vars = filter(v -> v != var, factor.variables)
new_parents = filter(v -> v != var, factor.parents)

if factor.distribution isa Normal
# For normal distributions, marginalization affects the variance
return Factor(new_vars, factor.distribution, new_parents)
elseif factor.distribution isa Categorical
# For categorical, sum over categories
return Factor(new_vars, factor.distribution, new_parents)
end

return Factor(new_vars, factor.distribution, new_parents)
end

"""
variable_elimination(bn::BayesianNetwork, query::Symbol, evidence::Dict{Symbol,Any})
Perform variable elimination to compute P(query | evidence).
"""
function variable_elimination(
bn::BayesianNetwork{Symbol,Int,Any}, query::Symbol, evidence::Dict{Symbol,Float64}
)
println("\nStarting Variable Elimination")
println("Query variable: ", query)
println("Evidence: ", evidence)

# Step 1: Create initial factors
factors = Dict{Symbol,Factor}()
for node in bn.names
if bn.is_stochastic[bn.names_to_ids[node]]
println("Creating factor for: ", node)
factors[node] = create_factor(bn, node)
end
end

# Step 2: Incorporate evidence
for (var, val) in evidence
println("Incorporating evidence: ", var, " = ", val)
node_id = bn.names_to_ids[var]
if bn.is_stochastic[node_id]
dist_idx = findfirst(id -> id == node_id, bn.stochastic_ids)
if bn.distributions[dist_idx] isa Normal
factors[var] = Factor([var], Normal(val, 0.1), Symbol[])
elseif bn.distributions[dist_idx] isa Categorical
p = zeros(length(bn.distributions[dist_idx].p))
p[Int(val)] = 1.0
factors[var] = Factor([var], Categorical(p), Symbol[])
end
end
end

# Step 3: Determine elimination ordering
eliminate_vars = Symbol[]
for node in bn.names
if node != query && !haskey(evidence, node)
push!(eliminate_vars, node)
end
end
println("Variables to eliminate: ", eliminate_vars)

# Step 4: Variable elimination
for var in eliminate_vars
println("\nEliminating variable: ", var)

# Find factors containing this variable
relevant_factors = Factor[]
relevant_keys = Symbol[]
for (k, f) in factors
if var in f.variables || var in f.parents
push!(relevant_factors, f)
push!(relevant_keys, k)
end
end

if !isempty(relevant_factors)
# Multiply factors
combined_factor = reduce(multiply_factors, relevant_factors)

# Marginalize out the variable
new_factor = marginalize(combined_factor, var)

# Update factors
for k in relevant_keys
delete!(factors, k)
end

# Only add the new factor if it has variables
if !isempty(new_factor.variables)
factors[new_factor.variables[1]] = new_factor
end
end
end

# Step 5: Multiply remaining factors
final_factors = collect(values(factors))
if isempty(final_factors)
# If no factors remain, return a default probability
return 1.0
end

result_factor = reduce(multiply_factors, final_factors)

# Return normalized probability
if result_factor.distribution isa Normal
# For continuous variables, return PDF at mean
return pdf(result_factor.distribution, mean(result_factor.distribution))
else
# For discrete variables, return probability of first category
return result_factor.distribution.p[1]
end
end

# Add a more general method that converts to the specific type
function variable_elimination(
bn::BayesianNetwork{Symbol,Int,Any}, query::Symbol, evidence::Dict{Symbol,<:Any}
)
# Convert evidence to Dict{Symbol,Float64}
evidence_float = Dict{Symbol,Float64}(k => Float64(v) for (k, v) in evidence)
return variable_elimination(bn, query, evidence_float)
end
end
147 changes: 0 additions & 147 deletions test/experimental/ProbabilisticGraphicalModels/bayesnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,151 +280,4 @@ using JuliaBUGS.ProbabilisticGraphicalModels:
@test_throws KeyError is_conditionally_independent(bn, :A, :B, [:NonExistent])
end
end

@testset "Variable Elimination Tests" begin
println("\nTesting Variable Elimination")

@testset "Simple Chain Network (Z → X → Y)" begin
# Create a simple chain network: Z → X → Y
bn = BayesianNetwork{Symbol}()

# Add vertices with specific distributions
println("Adding vertices...")
add_stochastic_vertex!(bn, :Z, Categorical([0.7, 0.3]), false) # P(Z)
add_stochastic_vertex!(bn, :X, Normal(0, 1), false) # P(X|Z)
add_stochastic_vertex!(bn, :Y, Normal(1, 2), false) # P(Y|X)

# Add edges
println("Adding edges...")
add_edge!(bn, :Z, :X)
add_edge!(bn, :X, :Y)

# Test case 1: P(X | Y=1.5)
println("\nTest case 1: P(X | Y=1.5)")
evidence1 = Dict(:Y => 1.5)
query1 = :X
result1 = variable_elimination(bn, query1, evidence1)
@test result1 isa Number
@test result1 >= 0
println("P(X | Y=1.5) = ", result1)

# Test case 2: P(X | Z=1)
println("\nTest case 2: P(X | Z=1)")
evidence2 = Dict(:Z => 1)
query2 = :X
result2 = variable_elimination(bn, query2, evidence2)
@test result2 isa Number
@test result2 >= 0
println("P(X | Z=1) = ", result2)

# Test case 3: P(Y | Z=1)
println("\nTest case 3: P(Y | Z=1)")
evidence3 = Dict(:Z => 1)
query3 = :Y
result3 = variable_elimination(bn, query3, evidence3)
@test result3 isa Number
@test result3 >= 0
println("P(Y | Z=1) = ", result3)
end
end

@testset "Variable Elimination Tests" begin
println("\nTesting Variable Elimination")

@testset "Simple Chain Network (Z → X → Y)" begin
# Create a simple chain network: Z → X → Y
bn = BayesianNetwork{Symbol}()

# Add vertices with specific distributions
println("Adding vertices...")
add_stochastic_vertex!(bn, :Z, Categorical([0.7, 0.3]), false) # P(Z)
add_stochastic_vertex!(bn, :X, Normal(0, 1), false) # P(X|Z)
add_stochastic_vertex!(bn, :Y, Normal(1, 2), false) # P(Y|X)

# Add edges
println("Adding edges...")
add_edge!(bn, :Z, :X)
add_edge!(bn, :X, :Y)

# Test case 1: P(X | Y=1.5)
println("\nTest case 1: P(X | Y=1.5)")
evidence1 = Dict(:Y => 1.5)
query1 = :X
result1 = variable_elimination(bn, query1, evidence1)
@test result1 isa Number
@test result1 >= 0
println("P(X | Y=1.5) = ", result1)

# Test case 2: P(X | Z=1)
println("\nTest case 2: P(X | Z=1)")
evidence2 = Dict(:Z => 1)
query2 = :X
result2 = variable_elimination(bn, query2, evidence2)
@test result2 isa Number
@test result2 >= 0
println("P(X | Z=1) = ", result2)

# Test case 3: P(Y | Z=1)
println("\nTest case 3: P(Y | Z=1)")
evidence3 = Dict(:Z => 1)
query3 = :Y
result3 = variable_elimination(bn, query3, evidence3)
@test result3 isa Number
@test result3 >= 0
println("P(Y | Z=1) = ", result3)
end

@testset "Mixed Network (Discrete and Continuous)" begin
# Create a more complex network with both discrete and continuous variables
bn = BayesianNetwork{Symbol}()

# Add vertices
println("\nAdding vertices for mixed network...")
add_stochastic_vertex!(bn, :A, Categorical([0.4, 0.6]), false) # Discrete
add_stochastic_vertex!(bn, :B, Normal(0, 1), false) # Continuous
add_stochastic_vertex!(bn, :C, Categorical([0.3, 0.7]), false) # Discrete
add_stochastic_vertex!(bn, :D, Normal(1, 2), false) # Continuous

# Add edges: A → B → D ← C
println("Adding edges...")
add_edge!(bn, :A, :B)
add_edge!(bn, :B, :D)
add_edge!(bn, :C, :D)

# Test case 1: P(B | D=1.0)
println("\nTest case 1: P(B | D=1.0)")
evidence1 = Dict(:D => 1.0)
query1 = :B
result1 = variable_elimination(bn, query1, evidence1)
@test result1 isa Number
@test result1 >= 0
println("P(B | D=1.0) = ", result1)

# Test case 2: P(D | A=1, C=1)
println("\nTest case 2: P(D | A=1, C=1)")
evidence2 = Dict(:A => 1, :C => 1)
query2 = :D
result2 = variable_elimination(bn, query2, evidence2)
@test result2 isa Number
@test result2 >= 0
println("P(D | A=1, C=1) = ", result2)
end

@testset "Special Cases" begin
bn = BayesianNetwork{Symbol}()

# Single node case
add_stochastic_vertex!(bn, :X, Normal(0, 1), false)
result = variable_elimination(bn, :X, Dict{Symbol,Any}())
@test result isa Number
@test result >= 0

# No evidence case
add_stochastic_vertex!(bn, :Y, Normal(1, 2), false)
add_edge!(bn, :X, :Y)
result = variable_elimination(bn, :Y, Dict{Symbol,Any}())
@test result isa Number
@test result >= 0
end
end
end

0 comments on commit 0921f2e

Please sign in to comment.