forked from lenskit/lkpy
-
Notifications
You must be signed in to change notification settings - Fork 1
/
conftest.py
81 lines (63 loc) · 2.24 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# This file is part of LensKit.
# Copyright (C) 2018-2023 Boise State University
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT
import logging
import os
import warnings
import torch
from seedbank import initialize, numpy_rng
from hypothesis import settings
from pytest import fixture, skip
from lenskit.parallel import ensure_parallel_init
from lenskit.util.test import ml_100k, ml_ds, ml_ratings # noqa: F401
logging.getLogger("numba").setLevel(logging.INFO)
_log = logging.getLogger("lenskit.tests")
RNG_SEED = 42
if "LK_TEST_FREE_RNG" in os.environ:
warnings.warn("using nondeterministic RNG initialization")
RNG_SEED = None
@fixture
def rng():
if RNG_SEED is None:
return numpy_rng(os.urandom(4))
else:
return numpy_rng(RNG_SEED)
@fixture(autouse=True)
def init_rng(request):
if RNG_SEED is None:
initialize(os.urandom(4))
else:
initialize(RNG_SEED)
@fixture(scope="module", params=["cpu", "cuda"])
def torch_device(request):
"""
Fixture for testing across Torch devices. This fixture is parameterized, so
if you write a test function with a parameter ``torch_device`` as its first
parameter, it will be called once for each available Torch device.
"""
dev = request.param
if dev == "cuda" and not torch.cuda.is_available():
skip("CUDA not available")
if dev == "mps" and not torch.backends.mps.is_available():
skip("MPS not available")
yield dev
@fixture(autouse=True)
def log_test(request):
try:
modname = request.module.__name__ if request.module else "<unknown>"
except Exception:
modname = "<unknown>"
funcname = request.function.__name__ if request.function else "<unknown>"
_log.info("running test %s:%s", modname, funcname)
def pytest_collection_modifyitems(items):
# add 'slow' to all 'eval' tests
for item in items:
evm = item.get_closest_marker("eval")
slm = item.get_closest_marker("slow")
if evm is not None and slm is None:
_log.debug("adding slow mark to %s", item)
item.add_marker("slow")
settings.register_profile("default", deadline=1000)
ensure_parallel_init()