Skip to content

Commit

Permalink
Add an option to tie input parameters in HCLTs
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Jan 8, 2025
1 parent fafab08 commit 1d778ef
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
23 changes: 20 additions & 3 deletions src/pyjuice/structures/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def BayesianTreeToHiddenRegionGraph(tree: nx.Graph,
InputDist: Type[Distribution],
dist_params: dict,
num_root_ns: int = 1,
block_size: Optional[int] = None) -> CircuitNodes:
block_size: Optional[int] = None,
tie_input_params: bool = False) -> CircuitNodes:
"""
Given a Tree Bayesian Network tree T1 (i.e. at most one parents),
Expand Down Expand Up @@ -48,6 +49,22 @@ def children(n: int):
for n in clt.nodes:
assert len(list(clt.predecessors(n))) <= 1

# For input parameter tying
template_ni = None

def get_input_ns(v):
nonlocal template_ni
if tie_input_params:
if template_ni is None:
ni = inputs(v, num_node_blocks = num_node_blocks, dist = InputDist(**dist_params))
template_ni = ni
else:
ni = template_ni.duplicate(v, tie_params = True)
else:
ni = inputs(v, num_node_blocks = num_node_blocks, dist = InputDist(**dist_params))

return ni

# Compile the region graph for the circuit equivalent to T2
node_seq = list(nx.dfs_postorder_nodes(tree, root))
var2rnode = dict()
Expand All @@ -57,7 +74,7 @@ def children(n: int):

if len(chs) == 0:
# Input Region
r = inputs(v, num_node_blocks = num_node_blocks, dist = InputDist(**dist_params))
r = get_input_ns(v)
var2rnode[v] = r
else:
# Inner Region
Expand All @@ -66,7 +83,7 @@ def children(n: int):
ch_regions = [var2rnode[c] for c in chs]

# Add x_v to children(z_v)
leaf_r = inputs(v, num_node_blocks = num_node_blocks, dist = InputDist(**dist_params))
leaf_r = get_input_ns(v)
ch_regions.append(leaf_r)

rp = multiply(*ch_regions)
Expand Down
8 changes: 6 additions & 2 deletions src/pyjuice/structures/hclt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def HCLT(x: torch.Tensor, num_latents: int,
block_size: Optional[int] = None,
input_dist: Optional[Distribution] = None,
input_node_type: Type[Distribution] = Categorical,
input_node_params: dict = {"num_cats": 256}):
input_node_params: dict = {"num_cats": 256},
tie_input_params: bool = False):
"""
Construct Hidden Chow-Liu Trees (https://arxiv.org/pdf/2106.02264.pdf).
Expand All @@ -99,6 +100,9 @@ def HCLT(x: torch.Tensor, num_latents: int,
:param input_dist: input distribution
:type input_dist: Distribution
:param tie_input_params: whether to tie the input parameters
:type tie_input_params: bool
"""

if input_dist is not None:
Expand All @@ -111,7 +115,7 @@ def HCLT(x: torch.Tensor, num_latents: int,
root_r = BayesianTreeToHiddenRegionGraph(
T, root, num_latents, input_node_type,
input_node_params, num_root_ns = num_root_ns,
block_size = block_size
block_size = block_size, tie_input_params = tie_input_params
)

return root_r

0 comments on commit 1d778ef

Please sign in to comment.