Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Add separate init, expose gather/scatter for WholeMemoryTensor and update example #81

Merged
merged 8 commits into from
Oct 17, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ void scatter_integer_int64_temp_func(const void* input,

REGISTER_DISPATCH_TWO_TYPES(ScatterFuncIntegerInt64,
scatter_integer_int64_temp_func,
HALF_FLOAT_DOUBLE,
HALF_FLOAT_DOUBLE)
ALLSINT,
ALLSINT)

wholememory_error_code_t scatter_integer_int64_func(const void* input,
wholememory_matrix_description_t input_desc,
Expand Down
10 changes: 8 additions & 2 deletions python/pylibwholegraph/examples/node_classfication.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
wgth.add_common_sampler_options(parser)
wgth.add_node_classfication_options(parser)
wgth.add_dataloader_options(parser)
parser.add_option(
"--fp16_embedding", action="store_true", dest="fp16_mbedding", default=False, help="Whether to use fp16 embedding"
)


(options, args) = parser.parse_args()

Expand Down Expand Up @@ -188,13 +192,15 @@ def main_func():
else wgth.create_wholememory_optimizer("adam", {})
)

embedding_dtype = torch.float32 if not options.fp16_mbedding else torch.float16

if wm_optimizer is None:
node_feat_wm_embedding = wgth.create_embedding_from_filelist(
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
os.path.join(options.root_dir, "node_feat.bin"),
torch.float,
embedding_dtype,
options.feat_dim,
optimizer=wm_optimizer,
cache_policy=cache_policy,
Expand All @@ -204,7 +210,7 @@ def main_func():
feature_comm,
embedding_wholememory_type,
embedding_wholememory_location,
torch.float,
embedding_dtype,
[graph_structure.node_count, options.feat_dim],
optimizer=wm_optimizer,
cache_policy=cache_policy,
Expand Down
86 changes: 86 additions & 0 deletions python/pylibwholegraph/examples/ogbn_papers100m_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import argparse
import os
import numpy as np
from scipy.sparse import coo_matrix
import pickle
from ogb.nodeproppred import NodePropPredDataset


def save_array(np_array, save_path, array_file_name):
array_full_path = os.path.join(save_path, array_file_name)
with open(array_full_path, 'wb') as f:
np_array.tofile(f)


def convert_papers100m_dataset(args):
ogb_root = args.ogb_root_dir
dataset = NodePropPredDataset(name='ogbn-papers100M', root=ogb_root)
graph, label = dataset[0]
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
train_label = label[train_idx]
valid_label = label[valid_idx]
test_label = label[test_idx]
data_and_label = {
"train_idx": train_idx,
"valid_idx": valid_idx,
"test_idx": test_idx,
"train_label": train_label,
"valid_label": valid_label,
"test_label": test_label,
}
num_nodes = graph["num_nodes"]
edge_index = graph["edge_index"]
node_feat = graph["node_feat"].astype(np.dtype(args.node_feat_format))
if not os.path.exists(args.convert_dir):
print(f"creating directory {args.convert_dir}...")
os.makedirs(args.convert_dir)
print("saving idx and labels...")
with open(
os.path.join(args.convert_dir, 'ogbn_papers100M_data_and_label.pkl'), "wb"
) as f:
pickle.dump(data_and_label, f)
print("saving node feature...")
with open(
os.path.join(args.convert_dir, 'node_feat.bin'), "wb"
) as f:
node_feat.tofile(f)

print("converting graph to csr...")
assert len(edge_index.shape) == 2
assert edge_index.shape[0] == 2
coo_src_ids = edge_index[0, :].astype(np.int32)
coo_dst_ids = edge_index[1, :].astype(np.int32)
if args.add_reverse_edges:
arg_graph_src = np.concatenate([coo_src_ids, coo_dst_ids])
arg_graph_dst = np.concatenate([coo_dst_ids, coo_src_ids])
else:
arg_graph_src = coo_src_ids
arg_graph_dst = coo_dst_ids
values = np.arange(len(arg_graph_src), dtype='int64')
coo_graph = coo_matrix((values, (arg_graph_src, arg_graph_dst)), shape=(num_nodes, num_nodes))
csr_graph = coo_graph.tocsr()
csr_row_ptr = csr_graph.indptr.astype(dtype='int64')
csr_col_ind = csr_graph.indices.astype(dtype='int32')
print("saving csr graph...")
save_array(csr_row_ptr, args.convert_dir, 'homograph_csr_row_ptr')
save_array(csr_col_ind, args.convert_dir, 'homograph_csr_col_idx')


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ogb_root_dir', type=str, default='dataset',
help='root dir of containing ogb datasets')
parser.add_argument('--convert_dir', type=str, default='dataset_papers100m_converted',
help='output dir containing converted datasets')
parser.add_argument('--node_feat_format', type=str, default='float32',
choices=['float32', 'float16'],
help='save format of node feature')
parser.add_argument('--add_reverse_edges', type=bool, default=True,
help='whether to add reverse edges')
args = parser.parse_args()
convert_papers100m_dataset(args)
2 changes: 1 addition & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)
from .embedding import WholeMemoryEmbeddingModule

from .initialize import init_torch_env, init_torch_env_and_create_wm_comm, finalize
from .initialize import init, init_torch_env, init_torch_env_and_create_wm_comm, finalize

from .tensor import (
WholeMemoryTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def add_common_model_options(parser: OptionParser):
default="cugraph",
help="framework type, valid values are: dgl, pyg, wg, cugraph",
)
parser.add_option("--heads", type="int", dest="heads", default=1, help="num heads")
parser.add_option("--heads", type="int", dest="heads", default=4, help="num heads")
parser.add_option(
"-d", "--dropout", type="float", dest="dropout", default=0.5, help="dropout"
)
Expand All @@ -126,9 +126,8 @@ def add_common_sampler_options(parser: OptionParser):
parser.add_option(
"-s",
"--inferencesample",
type="int",
dest="inferencesample",
default=30,
default="30",
help="inference sample count, -1 is all",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@ def __init__(
self.reset_parameters()

def reset_parameters(self):
self.lin.reset_parameters()
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_normal_(self.lin.weight, gain=gain)
torch.nn.init.xavier_normal_(
self.att.view(2, self.heads, self.out_channels), gain=gain
self.att.view(2, self.heads, self.out_channels)[0, :, :], gain=gain
)
torch.nn.init.xavier_normal_(
self.att.view(2, self.heads, self.out_channels)[1, :, :], gain=gain
)
torch.nn.init.zeros_(self.bias)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def __init__(
self.reset_parameters()

def reset_parameters(self):
gain = torch.nn.init.calculate_gain("relu")
torch.nn.init.xavier_uniform_(self.lin.weight, gain=gain)
if self.project:
self.pre_lin.reset_parameters()
self.lin.reset_parameters()
torch.nn.init.xavier_uniform_(self.pre_lin.weight, gain=gain)
torch.nn.init.xavier_uniform_(self.lin.weight, gain=gain)

def forward(
self,
Expand Down
37 changes: 16 additions & 21 deletions python/pylibwholegraph/pylibwholegraph/torch/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,26 +150,19 @@ def create_sub_graph(
return edge_index
elif framework_name == "dgl":
if add_self_loop:
self_loop_ids = torch.arange(
0,
target_gid_1.numel(),
dtype=edge_data[0].dtype,
device=target_gid.device,
)
block = dgl.create_block(
csr_row_ptr, csr_col_ind = add_csr_self_loop(csr_row_ptr, csr_col_ind)
block = dgl.create_block(
(
'csc',
(
torch.cat([edge_data[0], self_loop_ids]),
torch.cat([edge_data[1], self_loop_ids]),
csr_row_ptr,
csr_col_ind,
torch.empty(0, dtype=torch.int),
),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
else:
block = dgl.create_block(
(edge_data[0], edge_data[1]),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
),
num_src_nodes=target_gid.size(0),
num_dst_nodes=target_gid_1.size(0),
)
return block
elif framework_name == "cugraph":
if add_self_loop:
Expand Down Expand Up @@ -224,19 +217,21 @@ def __init__(
self.gather_fn = WholeMemoryEmbeddingModule(self.node_embedding)
self.dropout = options.dropout
self.max_neighbors = parse_max_neighbors(options.layernum, options.neighbors)
self.max_inference_neighbors = parse_max_neighbors(options.layernum, options.inferencesample)

def forward(self, ids):
global framework_name
max_neighbors = self.max_neighbors if self.training else self.max_inference_neighbors
ids = ids.to(self.graph_structure.csr_col_ind.dtype).cuda()
(
target_gids,
edge_indice,
csr_row_ptrs,
csr_col_inds,
) = self.graph_structure.multilayer_sample_without_replacement(
ids, self.max_neighbors
ids, max_neighbors
)
x_feat = self.gather_fn(target_gids[0])
x_feat = self.gather_fn(target_gids[0], force_dtype=torch.float32)
for i in range(self.num_layer):
x_target_feat = x_feat[: target_gids[i + 1].numel()]
sub_graph = create_sub_graph(
Expand All @@ -245,7 +240,7 @@ def forward(self, ids):
edge_indice[i],
csr_row_ptrs[i],
csr_col_inds[i],
self.max_neighbors[i],
max_neighbors[self.num_layer - 1 - i],
self.add_self_loop,
)
x_feat = layer_forward(
Expand Down
5 changes: 5 additions & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from .comm import set_world_info, get_global_communicator, get_local_node_communicator


def init(world_rank: int, world_size: int, local_rank: int, local_size: int):
wmb.init(0)
set_world_info(world_rank, world_size, local_rank, local_size)


def init_torch_env(world_rank: int, world_size: int, local_rank: int, local_size: int):
r"""Init WholeGraph environment for PyTorch.
:param world_rank: world rank of current process
Expand Down
38 changes: 38 additions & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .comm import WholeMemoryCommunicator
from typing import Union, List
from .dlpack_utils import torch_import_from_dlpack
from .wholegraph_env import wrap_torch_tensor, get_wholegraph_env_fns, get_stream


WholeMemoryMemoryType = wmb.WholeMemoryMemoryType
Expand Down Expand Up @@ -57,6 +58,43 @@ def get_comm(self):
self.wmb_tensor.get_wholememory_handle().get_communicator()
)

def gather(self,
indice: torch.Tensor,
*,
force_dtype: Union[torch.dtype, None] = None):
assert indice.dim() == 1
embedding_dim = self.shape[1]
embedding_count = indice.shape[0]
current_cuda_device = "cuda:%d" % (torch.cuda.current_device(),)
output_dtype = (
force_dtype if force_dtype is not None else self.embedding_tensor.dtype
)
output_tensor = torch.empty(
[embedding_count, embedding_dim],
device=current_cuda_device,
dtype=output_dtype,
requires_grad=False,
)
wmb.wholememory_gather_op(self.wmb_tensor,
wrap_torch_tensor(indice),
wrap_torch_tensor(output_tensor),
get_wholegraph_env_fns(),
get_stream())
return output_tensor

def scatter(self,
input_tensor: torch.Tensor,
indice: torch.Tensor):
assert indice.dim() == 1
assert input_tensor.dim() == 2
assert indice.shape[0] == input_tensor.shape[0]
assert input_tensor.shape[1] == self.shape[1]
wmb.wholememory_scatter_op(wrap_torch_tensor(input_tensor),
wrap_torch_tensor(indice),
self.wmb_tensor,
get_wholegraph_env_fns(),
get_stream())

def get_sub_tensor(self, starts, ends):
"""
Get sub tensor of WholeMemory Tensor
Expand Down
Loading