Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate init_seeds() #4849

Merged
merged 1 commit into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

from utils.downloads import gsutil_getsize
from utils.metrics import box_iou, fitness
from utils.torch_utils import init_torch_seeds

# Settings
torch.set_printoptions(linewidth=320, precision=5, profile='long')
Expand Down Expand Up @@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True):


def init_seeds(seed=0):
# Initialize random number generator (RNG) seeds
# Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
# cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
import torch.backends.cudnn as cudnn
random.seed(seed)
np.random.seed(seed)
init_torch_seeds(seed)
torch.manual_seed(seed)
cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)


def get_latest_run(search_dir='.'):
Expand Down
10 changes: 0 additions & 10 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from pathlib import Path

import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int):
dist.barrier(device_ids=[0])


def init_torch_seeds(seed=0):
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
torch.manual_seed(seed)
if seed == 0: # slower, more reproducible
cudnn.benchmark, cudnn.deterministic = False, True
else: # faster, less reproducible
cudnn.benchmark, cudnn.deterministic = True, False


def date_modified(path=__file__):
# return human-readable file modification date, i.e. '2021-3-26'
t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
Expand Down