Skip to content

Commit

Permalink
Fix TPU testing and collect all tests (#11098)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
5 people committed Jul 27, 2022
1 parent 95f5f17 commit fff62f0
Show file tree
Hide file tree
Showing 23 changed files with 213 additions and 203 deletions.
9 changes: 9 additions & 0 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ jobs:
timeoutInMinutes: "35"
condition: eq(variables['continue'], '1')

- bash: bash run_standalone_tasks.sh
workingDirectory: tests/tests_pytorch
env:
PL_USE_MOCKED_MNIST: "1"
PL_RUN_CUDA_TESTS: "1"
displayName: 'Testing: PyTorch standalone tasks'
timeoutInMinutes: "10"
condition: eq(variables['continue'], '1')

- bash: |
python -m coverage report
python -m coverage xml
Expand Down
10 changes: 3 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ references:
job_name=$(jsonnet -J ml-testing-accelerators/ dockers/tpu-tests/tpu_test_cases.jsonnet | kubectl create -f -) && \
job_name=${job_name#job.batch/}
job_name=${job_name% created}
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}')
echo "GKE pod name: $pod_name"
echo "Waiting on kubernetes job: $job_name"
i=0 && \
# N checks spaced 30s apart = 900s total.
Expand All @@ -92,8 +94,6 @@ references:
printf "Waiting for job to finish: " && \
while [ $i -lt $MAX_CHECKS ]; do ((i++)); if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1" ; then status_code=0 && break; else printf "."; fi; sleep $CHECK_SPEEP; done && \
echo "Done waiting. Job status code: $status_code" && \
pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') && \
echo "GKE pod name: $pod_name" && \
kubectl logs -f $pod_name --container=train > /tmp/full_output.txt
if grep -q '<?xml version="1.0" ?>' /tmp/full_output.txt ; then csplit /tmp/full_output.txt '/<?xml version="1.0" ?>/'; else mv /tmp/full_output.txt xx00; fi && \
# First portion is the test logs. Print these to Github Action stdout.
Expand All @@ -106,10 +106,6 @@ references:
name: Statistics
command: |
mv ./xx01 coverage.xml
# TODO: add human readable report
cat coverage.xml
sudo pip install pycobertura
pycobertura show coverage.xml
jobs:

Expand All @@ -119,7 +115,7 @@ jobs:
environment:
- XLA_VER: 1.9
- PYTHON_VER: 3.7
- MAX_CHECKS: 240
- MAX_CHECKS: 1000
- CHECK_SPEEP: 5
steps:
- checkout
Expand Down
15 changes: 5 additions & 10 deletions dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ local tputests = base.BaseTest {
mode: 'postsubmit',
configMaps: [],

timeout: 1200, # 20 minutes, in seconds.
timeout: 6000, # 100 minutes, in seconds.

image: 'pytorchlightning/pytorch_lightning',
imageTag: 'base-xla-py{PYTHON_VERSION}-torch{PYTORCH_VERSION}',
Expand All @@ -34,16 +34,11 @@ local tputests = base.BaseTest {
pip install -e .[test]
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
export PL_RUN_TPU_TESTS=1
cd tests/tests_pytorch
echo $PWD
# TODO (@kaushikb11): Add device stats tests here
coverage run --source pytorch_lightning -m pytest -v --capture=no \
strategies/test_tpu_spawn.py \
profilers/test_xla_profiler.py \
accelerators/test_tpu.py \
models/test_tpu.py \
plugins/environments/test_xla_environment.py \
utilities/test_xla_device_utils.py
coverage run --source=pytorch_lightning -m pytest -vv --durations=0 ./
echo "\n||| Running standalone tests |||\n"
bash run_standalone_tests.sh -b 1
test_exit_code=$?
echo "\n||| END PYTEST LOGS |||\n"
coverage xml
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class SingleTPUPlugin(SingleTPUStrategy):
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
rank_zero_deprecation(
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in."
"The `pl.plugins.training_type.single_tpu.SingleTPUPlugin` is deprecated in v1.6 and will be removed in"
" v1.8. Use `pl.strategies.single_tpu.SingleTPUStrategy` instead."
)
super().__init__(*args, **kwargs)
44 changes: 33 additions & 11 deletions src/pytorch_lightning/strategies/launchers/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import os
import time
from functools import wraps
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING

import torch.multiprocessing as mp
from torch.multiprocessing import ProcessContext

import pytorch_lightning as pl
from pytorch_lightning.strategies.launchers.multiprocessing import _FakeQueue, _MultiProcessingLauncher, _WorkerOutput
Expand All @@ -26,9 +28,10 @@
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
xm, xmp = None, None

if TYPE_CHECKING:
from pytorch_lightning.strategies import Strategy
Expand Down Expand Up @@ -72,7 +75,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
"""
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
xmp.spawn(
_save_spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=len(self._strategy.parallel_devices),
Expand Down Expand Up @@ -103,14 +106,6 @@ def _wrapping_function(
if self._strategy.local_rank == 0:
return_queue.put(move_data_to_device(results, "cpu"))

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self._strategy.barrier("end-process")

# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if self._strategy.local_rank == 0:
time.sleep(2)

def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
rank_zero_debug("Collecting results from rank 0 process.")
checkpoint_callback = trainer.checkpoint_callback
Expand Down Expand Up @@ -138,3 +133,30 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
self.add_to_queue(trainer, extra)

return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)


def _save_spawn(
fn: Callable,
args: Tuple = (),
nprocs: Optional[int] = None,
join: bool = True,
daemon: bool = False,
start_method: str = "spawn",
) -> Optional[ProcessContext]:
"""Wraps the :func:`torch_xla.distributed.xla_multiprocessing.spawn` with added teardown logic for the worker
processes."""

@wraps(fn)
def wrapped(rank: int, *_args: Any) -> None:
fn(rank, *_args)

# Make all processes wait for each other before joining
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
xm.rendezvous("end-process")

# Ensure that the rank 0 process is the one exiting last
# https://github.com/pytorch/xla/issues/2190#issuecomment-641665358
if rank == 0:
time.sleep(1)

return xmp.spawn(wrapped, args=args, nprocs=nprocs, join=join, daemon=daemon, start_method=start_method)
15 changes: 10 additions & 5 deletions src/pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
start_method="fork",
)
self.debug = debug
self._launched = False

@property
def checkpoint_io(self) -> CheckpointIO:
Expand All @@ -90,6 +91,8 @@ def checkpoint_io(self, io: Optional[CheckpointIO]) -> None:

@property
def root_device(self) -> torch.device:
if not self._launched:
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")
return xm.xla_device()

@staticmethod
Expand Down Expand Up @@ -130,7 +133,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)

if self.debug:
os.environ["PT_XLA_DEBUG"] = str(1)
os.environ["PT_XLA_DEBUG"] = "1"

shared_params = find_shared_parameters(self.model)
self.model_to_device()
Expand All @@ -150,8 +153,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:

@property
def is_distributed(self) -> bool:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1
# HOST_WORLD_SIZE is not set outside the xmp.spawn process
return (xenv.HOST_WORLD_SIZE in os.environ) and self.world_size != 1

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnStrategy._validate_dataloader(dataloader)
Expand Down Expand Up @@ -189,8 +192,9 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
raise ValueError(
"Currently, the TPUSpawnStrategy only supports `sum`, `mean`, `avg` for the reduce operation, got:"
f" {reduce_op}"
)

output = xm.mesh_reduce("reduce", output, sum)
Expand All @@ -201,6 +205,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
return output

def _worker_setup(self, process_idx: int):
self._launched = True
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def test_devices_auto_choice_mps():

@pytest.mark.parametrize(
["parallel_devices", "accelerator"],
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], ("tpu"))],
[([torch.device("cpu")], "cuda"), ([torch.device("cuda", i) for i in range(8)], "tpu")],
)
def test_parallel_devices_in_strategy_confilict_with_accelerator(parallel_devices, accelerator):
with pytest.raises(MisconfigurationException, match=r"parallel_devices set through"):
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/accelerators/test_ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def test_strategy_choice_ipu_plugin(tmpdir):


@RunIf(ipu=True)
def test_device_type_when_training_plugin_ipu_passed(tmpdir):
def test_device_type_when_ipu_strategy_passed(tmpdir):
trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8)
assert isinstance(trainer.strategy, IPUStrategy)
assert isinstance(trainer.accelerator, IPUAccelerator)
Expand Down
21 changes: 9 additions & 12 deletions tests/tests_pytorch/accelerators/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytorch_lightning.strategies import DDPStrategy, TPUSpawnStrategy
from pytorch_lightning.utilities import find_shared_parameters
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.utils import pl_multi_process_test


class WeightSharingModule(BoringModel):
Expand All @@ -46,8 +45,7 @@ def forward(self, x):
return x


@RunIf(tpu=True)
@pl_multi_process_test
@RunIf(tpu=True, standalone=True)
def test_resume_training_on_cpu(tmpdir):
"""Checks if training can be resumed from a saved checkpoint on CPU."""
# Train a model on TPU
Expand All @@ -65,11 +63,9 @@ def test_resume_training_on_cpu(tmpdir):
# Verify that training is resumed on CPU
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
trainer.fit(model, ckpt_path=model_path)
assert trainer.state.finished, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_if_test_works_after_train(tmpdir):
"""Ensure that .test() works after .fit()"""

Expand Down Expand Up @@ -293,12 +289,14 @@ def test_xla_checkpoint_plugin_being_default():
assert isinstance(trainer.strategy.checkpoint_io, XLACheckpointIO)


@RunIf(tpu=True)
@patch("pytorch_lightning.strategies.tpu_spawn.xm")
def test_mp_device_dataloader_attribute(_):
@patch("pytorch_lightning.strategies.tpu_spawn.MpDeviceLoader")
@patch("pytorch_lightning.strategies.tpu_spawn.TPUSpawnStrategy.root_device")
def test_mp_device_dataloader_attribute(root_device_mock, mp_loader_mock):
dataset = RandomDataset(32, 64)
dataloader = TPUSpawnStrategy().process_dataloader(DataLoader(dataset))
assert dataloader.dataset == dataset
dataloader = DataLoader(dataset)
processed_dataloader = TPUSpawnStrategy().process_dataloader(dataloader)
mp_loader_mock.assert_called_with(dataloader, root_device_mock)
assert processed_dataloader.dataset == processed_dataloader._loader.dataset


@RunIf(tpu=True)
Expand All @@ -307,8 +305,7 @@ def test_warning_if_tpus_not_used():
Trainer()


@RunIf(tpu=True)
@pl_multi_process_test
@RunIf(tpu=True, standalone=True)
@pytest.mark.parametrize(
["devices", "expected_device_ids"],
[
Expand Down
12 changes: 5 additions & 7 deletions tests/tests_pytorch/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
assert cpu_stats_mock.call_count == expected


@pytest.mark.skipif(True, reason="TODO (@kaushikb11): fix this test, timeout")
@RunIf(tpu=True)
def test_device_stats_monitor_tpu(tmpdir):
"""Test TPU stats are logged using a logger."""
Expand All @@ -106,24 +105,23 @@ def test_device_stats_monitor_tpu(tmpdir):

class DebugLogger(CSVLogger):
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
def log_metrics(self, metrics, step=None) -> None:
fields = ["avg. free memory (MB)", "avg. peak memory (MB)"]
for f in fields:
assert any(f in h for h in metrics)

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
max_epochs=2,
limit_train_batches=5,
accelerator="tpu",
devices=1,
devices=8,
log_every_n_steps=1,
callbacks=[device_stats],
logger=DebugLogger(tmpdir),
enable_checkpointing=False,
enable_progress_bar=False,
)

trainer.fit(model)


Expand All @@ -146,7 +144,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
trainer.fit(model)


def test_prefix_metric_keys(tmpdir):
def test_prefix_metric_keys():
"""Test that metric key names are converted correctly."""
metrics = {"1": 1.0, "2": 2.0, "3": 3.0}
prefix = "foo"
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def pytest_collection_modifyitems(items: List[pytest.Function], config: pytest.C
min_cuda_gpus="PL_RUN_CUDA_TESTS",
slow="PL_RUN_SLOW_TESTS",
ipu="PL_RUN_IPU_TESTS",
tpu="PL_RUN_TPU_TESTS",
)
if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1":
# special case: we don't have a CPU job for standalone tests, so we shouldn't run only cuda tests.
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,12 +360,10 @@ def test_v1_8_0_deprecated_single_device_plugin_class():
SingleDevicePlugin("cpu")


@RunIf(tpu=True)
@RunIf(tpu=True, standalone=True)
def test_v1_8_0_deprecated_single_tpu_plugin_class():
with pytest.deprecated_call(
match=(
"SingleTPUPlugin` is deprecated in v1.6 and will be removed in v1.8." " Use `.*SingleTPUStrategy` instead."
)
match="SingleTPUPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `.*SingleTPUStrategy` instead."
):
SingleTPUPlugin(0)

Expand Down
2 changes: 2 additions & 0 deletions tests/tests_pytorch/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def __new__(
if tpu:
conditions.append(not _TPU_AVAILABLE)
reasons.append("TPU")
# used in conftest.py::pytest_collection_modifyitems
kwargs["tpu"] = True

if ipu:
conditions.append(not _IPU_AVAILABLE)
Expand Down
Loading

0 comments on commit fff62f0

Please sign in to comment.