Skip to content

Commit

Permalink
revert "speedup sampling"
Browse files Browse the repository at this point in the history
  • Loading branch information
liuanji committed Nov 7, 2024
1 parent 41994f7 commit 5cc6f53
Showing 1 changed file with 43 additions and 153 deletions.
196 changes: 43 additions & 153 deletions src/pyjuice/queries/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from pyjuice.nodes import CircuitNodes
from pyjuice.model import TensorCircuit
from pyjuice.utils.kernel_launcher import FastJITFunction


@njit
Expand All @@ -23,68 +22,19 @@ def _assign_cids_ind_target(ind_target, element_pointers, ind_b, num_samples):
element_pointers[bid] = ind_t + 1


@triton.jit
def _assign_nids_ind_target_kernel(ind_target_ptr, ind_ch_count_ptr, node_pointers_ptr, ind_b_ptr,
num_samples, num_nodes, BLOCK_SIZE: tl.constexpr, NUM_BLKS: tl.constexpr):
bid = tl.program_id(0) # The batch ID for this node block

target_val_sid = tl.load(node_pointers_ptr + bid)

offsets = tl.arange(0, BLOCK_SIZE)
offset_first = 0

for i in range(NUM_BLKS):
mask = (offsets < num_nodes)

inds_b = tl.load(ind_b_ptr + offsets, mask = mask)
mask_b = (inds_b == bid)

count_c = tl.load(ind_ch_count_ptr + offsets, mask = mask & mask_b, other = 0)

cumcount_c = tl.cumsum(count_c, axis = 0) - count_c + target_val_sid

tl.store(ind_target_ptr + offsets, cumcount_c * num_samples + bid, mask = mask & mask_b)

last_onehot = ((offsets + 1) == tl.max((offsets + 1) * mask_b.to(tl.int64))).to(tl.int64)
target_val_sid = tl.max(cumcount_c) + tl.sum(count_c * last_onehot)

offsets += BLOCK_SIZE


def _assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples):
"""
A GPU implementation of the following:
@njit
def _assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples):
for nid in range(ind_target.shape[0]):
bid = ind_b[nid]
ind_t = node_pointers[bid]
ind_target[i] = ind_t * num_samples + bid
node_pointers[bid] = ind_t + ind_ch_count[nid]
"""

num_nodes = ind_b.size(0)

BLOCK_SIZE = min(512, triton.next_power_of_2(num_nodes))
NUM_BLKS = triton.cdiv(num_nodes, BLOCK_SIZE)

grid = (num_samples,)

_assign_nids_ind_target_kernel[grid](
ind_target,
ind_ch_count,
node_pointers,
ind_b,
num_samples,
num_nodes,
BLOCK_SIZE = BLOCK_SIZE,
NUM_BLKS = NUM_BLKS
)
@njit
def _assign_nids_ind_target(ind_target, ind_target_sid, node_pointers, ind_ch_count, ind_b, num_samples):
nid = 0
for i in range(ind_target.shape[0]):
if nid < ind_target_sid.shape[0] - 1 and i >= ind_target_sid[nid+1]:
nid += 1
bid = ind_b[nid]
ind_t = node_pointers[bid]
ind_target[i] = ind_t * num_samples + bid
node_pointers[bid] = ind_t + 1


@triton.jit
# @FastJITFunction
def sample_sum_layer_kernel(nids, cids, pids, node_mars, element_mars, params, node_samples, element_samples,
ind_target, ind_n, ind_b, seed, block_size: tl.constexpr, batch_size: tl.constexpr,
num_edges: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr, BLOCK_S: tl.constexpr,
Expand Down Expand Up @@ -193,83 +143,20 @@ def sample_sum_layer(layer, nids, cids, pids, node_mars, element_mars, params, n
return None


@triton.jit
# @FastJITFunction
def push_non_neg_ones_to_front_kernel(matrix_ptr, counts_ptr, row_count, col_count,
BLOCK_SIZE: tl.constexpr, NUM_BLKS: tl.constexpr):
off_col = tl.program_id(0)

offs_row = tl.arange(0, BLOCK_SIZE)

# Target row id
target_row_id = -1
target_row_id = target_row_id.to(tl.int64)

for i in range(NUM_BLKS):
mask_row = (offs_row < row_count)

value = tl.load(matrix_ptr + offs_row * col_count + off_col, mask = mask_row, other = -1)

mask_val = (value != -1)

offs_target = tl.cumsum(mask_val.to(tl.int64), axis = 0) + target_row_id

tl.store(matrix_ptr + offs_target * col_count + off_col, value, mask = mask_row & mask_val)

offs_row += BLOCK_SIZE
target_row_id += tl.sum(mask_val.to(tl.int64))

tl.store(counts_ptr + off_col, target_row_id + 1)

target_row_id = (target_row_id + 1).to(tl.int32)

while target_row_id < row_count:
offs_row = tl.arange(0, BLOCK_SIZE) + target_row_id
mask_row = (offs_row < row_count)

tl.store(matrix_ptr + offs_row * col_count + off_col, -1, mask = mask_row)

target_row_id += BLOCK_SIZE


def push_non_neg_ones_to_front(matrix):
"""
An efficient implementation of the following:

def push_non_neg_ones_to_front(matrix):
result = torch.full_like(matrix, -1)
result = torch.full_like(matrix, -1)

s_mask = (matrix != -1)
d_mask = torch.sum(s_mask, dim = 0, keepdims = True) > torch.arange(matrix.size(0)).to(matrix.device)[:,None]
s_mask = (matrix != -1)
d_mask = torch.sum(s_mask, dim = 0, keepdims = True) > torch.arange(matrix.size(0)).to(matrix.device)[:,None]

result[d_mask] = matrix[s_mask]
matrix[:] = result[:]
result[d_mask] = matrix[s_mask]
matrix[:] = result[:]

return s_mask.long().sum(dim = 0)
"""
row_count, col_count = matrix.size()

counts = torch.zeros([col_count], dtype = torch.long, device = matrix.device)

BLOCK_SIZE = min(1024, triton.next_power_of_2(row_count))
NUM_BLKS = triton.cdiv(row_count, BLOCK_SIZE)

grid = lambda meta: (col_count,)

push_non_neg_ones_to_front_kernel[grid](
matrix,
counts,
row_count,
col_count,
BLOCK_SIZE = BLOCK_SIZE,
NUM_BLKS = NUM_BLKS
)

return counts
return s_mask.long().sum(dim = 0)


@triton.jit
# @FastJITFunction
def count_prod_nch_kernel(nids, cids, element_samples, ind_ch_count, ind_nids, ind_nid_offs, ind_mask, ind_n, ind_b, partition_id,
block_size: tl.constexpr, num_samples: tl.constexpr, num_nblocks: tl.constexpr,
batch_size: tl.constexpr, num_edges: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_C: tl.constexpr,
Expand Down Expand Up @@ -353,8 +240,7 @@ def count_prod_nch(layer, nids, cids, element_samples, ind_ch_count, ind_nids, i


@triton.jit
# @FastJITFunction
def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_n, ind_b,
def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b,
ind_nids, ind_nid_offs, ind_mask, partition_id, block_size: tl.constexpr,
num_samples: tl.constexpr, num_nblocks: tl.constexpr, batch_size: tl.constexpr, num_edges: tl.constexpr,
BLOCK_S: tl.constexpr, BLOCK_C: tl.constexpr, C_NUM_BLKS: tl.constexpr):
Expand Down Expand Up @@ -383,36 +269,35 @@ def sample_prod_layer_kernel(nids, cids, node_samples, element_samples, ind_targ
mask_child = offs_child < num_edges

# Main loop over blocks of child nodes
target_id_base = tl.load(ind_target + offs_sample, mask = mask_sample, other = 0)
target_sid = tl.load(ind_target_sid + offs_sample, mask = mask_sample, other = 0)
for i in range(C_NUM_BLKS):

c_ids = tl.load(cids + local_nids[:,None] * num_edges + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:]), other = 0)

target_id = target_id_base[:,None] + offs_child[None,:] * num_samples
target_id = tl.load(ind_target + target_sid[:,None] + offs_child[None,:], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)), other = 0)

tl.store(node_samples + target_id, c_ids + local_nid_offs[:,None], mask = (mask_sample[:,None] & mask_child[None,:] & (c_ids > 0)))

offs_child += BLOCK_C
mask_child = offs_child < num_edges


def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target,
def sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid,
ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, block_size, partition_id):

num_samples = ind_n.size(0)
num_nblocks = nids.size(0)
num_edges = cids.size(1)
batch_size = node_samples.size(1)

BLOCK_C = min(128, triton.next_power_of_2(num_edges))
BLOCK_S = min(128 // BLOCK_C, triton.next_power_of_2(num_samples))
BLOCK_C = min(1024, triton.next_power_of_2(num_edges))
BLOCK_S = min(1024 // BLOCK_C, triton.next_power_of_2(num_samples))

C_NUM_BLKS = triton.cdiv(num_edges, BLOCK_C)

grid = (triton.cdiv(num_samples, BLOCK_S),)

sample_prod_layer_kernel[grid](
nids, cids, node_samples, element_samples, ind_target, ind_n, ind_b,
nids, cids, node_samples, element_samples, ind_target, ind_target_sid, ind_n, ind_b,
ind_nids, ind_nid_offs, ind_mask, partition_id, block_size, num_samples,
num_nblocks, batch_size, num_edges, BLOCK_S, BLOCK_C, C_NUM_BLKS
)
Expand Down Expand Up @@ -461,6 +346,7 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo

# Iterate over sum layers in the current layer group
for layer in layer_group:

# Gather the indices to be processed
lsid, leid = layer._layer_nid_range
ind_n, ind_b = torch.where((node_samples >= lsid) & (node_samples < leid))
Expand All @@ -479,9 +365,6 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo
nids = layer.partitioned_nids[partition_id]
cids = layer.partitioned_cids[partition_id]
pids = layer.partitioned_pids[partition_id]

if ind_n.size(0) == 0:
import pdb; pdb.set_trace()

sample_sum_layer(layer, nids, cids, pids, pc.node_mars, pc.element_mars, pc.params,
node_samples, element_samples, ind_target, ind_n, ind_b,
Expand Down Expand Up @@ -515,32 +398,39 @@ def sample(pc: TensorCircuit, num_samples: Optional[int] = None, conditional: bo
ind_nid_offs, ind_mask, ind_n, ind_b, layer.block_size, partition_id)

# Pre-compute the target indices in `node_samples`
ind_target = torch.zeros_like(ind_n)
_assign_nids_ind_target(ind_target, ind_ch_count, node_pointers, ind_b, num_samples)
ind_target_sid = np.zeros([ind_n.size(0)], dtype = np.int64)
ind_target_sid[1:] = ind_ch_count[:-1].cumsum(dim = 0).detach().cpu().numpy()
ind_target = np.zeros([ind_ch_count.sum()], dtype = np.int64)
_assign_nids_ind_target(ind_target, ind_target_sid,
node_pointers.detach().cpu().numpy(),
ind_ch_count.detach().cpu().numpy(),
ind_b.detach().cpu().numpy(), num_samples)
ind_target_sid = torch.from_numpy(ind_target_sid).to(pc.device)
ind_target = torch.from_numpy(ind_target).to(pc.device)

# Store child indices
for partition_id in range(layer.num_fw_partitions):
nids = layer.partitioned_nids[partition_id]
cids = layer.partitioned_cids[partition_id]

sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target,
sample_prod_layer(layer, nids, cids, node_samples, element_samples, ind_target, ind_target_sid,
ind_n, ind_b, ind_nids, ind_nid_offs, ind_mask, layer.block_size, partition_id)

if _sample_input_ns:
# Create tensor for the samples
data_dtype = pc.input_layer_group[0].get_data_dtype()
samples = torch.zeros([pc.num_vars, num_samples], dtype = data_dtype, device = pc.device)
# Create tensor for the samples
data_dtype = pc.input_layer_group[0].get_data_dtype()
samples = torch.zeros([pc.num_vars, num_samples], dtype = data_dtype, device = pc.device)

pc._init_buffer(name = "node_flows", shape = (pc.num_nodes, num_samples), set_value = 0.0)
ind_n, ind_b = torch.where(node_samples != -1)
ind_node = node_samples[ind_n, ind_b]
pc.node_flows[ind_node, ind_b] = 1.0
pc._init_buffer(name = "node_flows", shape = (pc.num_nodes, num_samples), set_value = 0.0)
ind_n, ind_b = torch.where(node_samples != -1)
ind_node = node_samples[ind_n, ind_b]
pc.node_flows[ind_node, ind_b] = 1.0

if _sample_input_ns:
for layer in pc.input_layer_group:
seed = random.randint(0, 2**31)
layer.sample(samples, pc.node_flows, seed = seed)

return samples.permute(1, 0).contiguous()
else:
# In this case, we do not explicitly sample input nodes
return node_samples
return node_samples

0 comments on commit 5cc6f53

Please sign in to comment.