Skip to content

Commit

Permalink
Merge branch 'master' into setDaemon
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov authored Oct 16, 2024
2 parents 4e131f2 + d92c98d commit 0cbc7e9
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 11 deletions.
21 changes: 21 additions & 0 deletions python/dgl/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,27 @@ def load_partition(part_config, part_id, load_feats=True, use_graphbolt=False):
"part-{}".format(part_id) in part_metadata
), "part-{} does not exist".format(part_id)
part_files = part_metadata["part-{}".format(part_id)]

exist_dgl_graph = exist_graphbolt_graph = False
if os.path.exists(os.path.join(config_path, f"part{part_id}", "graph.dgl")):
use_graphbolt = False
exist_dgl_graph = True
if os.path.exists(
os.path.join(
config_path, f"part{part_id}", "fused_csc_sampling_graph.pt"
)
):
use_graphbolt = True
exist_graphbolt_graph = True

# Check if both DGL graph and GraphBolt graph exist or not exist. Make sure only one exists.
if not exist_dgl_graph and not exist_graphbolt_graph:
raise ValueError("The graph object doesn't exist.")
if exist_dgl_graph and exist_graphbolt_graph:
raise ValueError(
"Both DGL graph and GraphBolt graph exist. Please remove one."
)

if use_graphbolt:
part_graph_field = "part_graph_graphbolt"
else:
Expand Down
2 changes: 2 additions & 0 deletions tests/distributed/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def test_dgl_partition_to_graphbolt_homo(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl"))
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
Expand Down Expand Up @@ -1067,6 +1068,7 @@ def test_dgl_partition_to_graphbolt_hetero(
orig_g = dgl.load_graphs(
os.path.join(test_dir, f"part{part_id}/graph.dgl")
)[0][0]
os.remove(os.path.join(test_dir, f"part{part_id}/graph.dgl"))
new_g = load_partition(
part_config, part_id, load_feats=False, use_graphbolt=True
)[0]
Expand Down
28 changes: 19 additions & 9 deletions tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,22 @@ def to_on_disk_numpy(test_dir, name, t):
return path


def _skip_condition_cached_feature():
return (F._default_context_str != "gpu") or (
torch.cuda.get_device_capability()[0] < 7
)


def _reason_to_skip_cached_feature():
if F._default_context_str != "gpu":
return "GPUCachedFeature tests are available only when testing the GPU backend."

return "GPUCachedFeature requires a Volta or later generation NVIDIA GPU."


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -116,9 +128,8 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b):


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -155,9 +166,8 @@ def test_gpu_cached_feature_read_async(dtype, pin_memory):


@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
_skip_condition_cached_feature(),
reason=_reason_to_skip_cached_feature(),
)
@unittest.skipIf(
not torch.ops.graphbolt.detect_io_uring(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@unittest.skipIf(
F._default_context_str != "gpu"
or torch.cuda.get_device_capability()[0] < 7,
reason="GPUCachedFeature tests are available only on GPU."
reason="GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def test_hetero_cached_feature(cached_feature_type):
or torch.cuda.get_device_capability()[0] < 7
):
pytest.skip(
"GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
"GPUCachedFeature tests are available only when testing the GPU backend."
if F._default_context_str != "gpu"
else "GPUCachedFeature requires a Volta or later generation NVIDIA GPU."
)
device = F.ctx() if cached_feature_type == gb.gpu_cached_feature else None
pin_memory = cached_feature_type == gb.gpu_cached_feature
Expand Down

0 comments on commit 0cbc7e9

Please sign in to comment.