diff --git a/README.md b/README.md index 893a96c..a99beb6 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,7 @@ -# easytorch +# EasyTorch +[![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) +[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) +[![gitee mirror](https://github.com/cnstark/easytorch/actions/workflows/git-mirror.yml/badge.svg)](https://gitee.com/cnstark/easytorch) + +[English](README.md) **|** [简体中文](README_CN.md) diff --git a/README_CN.md b/README_CN.md new file mode 100644 index 0000000..1abf244 --- /dev/null +++ b/README_CN.md @@ -0,0 +1,61 @@ +# EasyTorch + +[![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) +[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) +[![gitee mirror](https://github.com/cnstark/easytorch/actions/workflows/git-mirror.yml/badge.svg)](https://gitee.com/cnstark/easytorch) + +[English](README.md) **|** [简体中文](README_CN.md) + +--- + +Easytorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。 + +## 功能亮点 + +* **最小代码量**。封装通用神经网络训练流程,用户仅需实现`Dataset`、`Model`以及训练/推理代码等关键代码,就能完成深度学习项目的构建。 +* **万物基于Config**。通过配置文件控制训练模式与超参,根据配置内容的MD5自动生成唯一的结果存放目录,调整超参不再凌乱。 +* **支持所有设备**。支持CPU、GPU与GPU分布式训练,通过配置参数一键完成设置。 +* **持久化训练日志**。支持`logging`日志系统与`Tensorboard`,并封装为统一接口,用户通过一键调用即可保存自定义的训练日志。 + +## 环境依赖 + +### 操作系统 + +* [Linux](https://pytorch.org/get-started/locally/#linux-prerequisites) +* [Windows](https://pytorch.org/get-started/locally/#windows-prerequisites) +* [MacOS](https://pytorch.org/get-started/locally/#mac-prerequisites) + +推荐使用Ubuntu16.04及更高版本或CentOS7及以更高版本。 + +### Python + +python >= 3.6 (推荐 >= 3.7) + +推荐使用[Anaconda](https://www.anaconda.com/) + +### PyTorch及CUDA + +[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7) + +[CUDA](https://developer.nvidia.com/zh-cn/cuda-toolkit) >= 9.2 (推荐 >= 11.0) + +注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7且CUDA版本 >= 11.0。 + +### 其他依赖 + +```shell +pip install -r requirements.txt +``` + +## 示例 + +* [线性回归](examples/linear_regression) +* [MNIST手写数字识别](examples/mnist) + +## README 徽章 + +如果你的项目正在使用EasyTorch,可以将EasyTorch徽章 [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) 添加到你的 README 中: + +``` +[![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) +``` diff --git a/__init__.py b/easytorch/__init__.py similarity index 100% rename from __init__.py rename to easytorch/__init__.py diff --git a/config.py b/easytorch/config.py similarity index 100% rename from config.py rename to easytorch/config.py diff --git a/core/__init__.py b/easytorch/core/__init__.py similarity index 100% rename from core/__init__.py rename to easytorch/core/__init__.py diff --git a/core/checkpoint.py b/easytorch/core/checkpoint.py similarity index 100% rename from core/checkpoint.py rename to easytorch/core/checkpoint.py diff --git a/core/data_loader.py b/easytorch/core/data_loader.py similarity index 100% rename from core/data_loader.py rename to easytorch/core/data_loader.py diff --git a/core/launcher.py b/easytorch/core/launcher.py similarity index 100% rename from core/launcher.py rename to easytorch/core/launcher.py diff --git a/core/meter_pool.py b/easytorch/core/meter_pool.py similarity index 100% rename from core/meter_pool.py rename to easytorch/core/meter_pool.py diff --git a/core/optimizer_builder.py b/easytorch/core/optimizer_builder.py similarity index 100% rename from core/optimizer_builder.py rename to easytorch/core/optimizer_builder.py diff --git a/core/runner.py b/easytorch/core/runner.py similarity index 100% rename from core/runner.py rename to easytorch/core/runner.py diff --git a/easyoptim/__init__.py b/easytorch/easyoptim/__init__.py similarity index 100% rename from easyoptim/__init__.py rename to easytorch/easyoptim/__init__.py diff --git a/easyoptim/easy_lr_scheduler.py b/easytorch/easyoptim/easy_lr_scheduler.py similarity index 100% rename from easyoptim/easy_lr_scheduler.py rename to easytorch/easyoptim/easy_lr_scheduler.py diff --git a/utils/__init__.py b/easytorch/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to easytorch/utils/__init__.py diff --git a/utils/data_prefetcher.py b/easytorch/utils/data_prefetcher.py similarity index 100% rename from utils/data_prefetcher.py rename to easytorch/utils/data_prefetcher.py diff --git a/utils/dist.py b/easytorch/utils/dist.py similarity index 100% rename from utils/dist.py rename to easytorch/utils/dist.py diff --git a/utils/env.py b/easytorch/utils/env.py similarity index 100% rename from utils/env.py rename to easytorch/utils/env.py diff --git a/utils/logging.py b/easytorch/utils/logging.py similarity index 100% rename from utils/logging.py rename to easytorch/utils/logging.py diff --git a/utils/timer.py b/easytorch/utils/timer.py similarity index 100% rename from utils/timer.py rename to easytorch/utils/timer.py diff --git a/examples/linear_regression/.gitignore b/examples/linear_regression/.gitignore new file mode 100644 index 0000000..bfea04a --- /dev/null +++ b/examples/linear_regression/.gitignore @@ -0,0 +1 @@ +checkpoints diff --git a/examples/linear_regression/README.md b/examples/linear_regression/README.md new file mode 100644 index 0000000..614fcef --- /dev/null +++ b/examples/linear_regression/README.md @@ -0,0 +1,9 @@ +# EasyTorch Example - MNIST Classification + +## Train + +* CPU + +```shell +python train.py -c linear_regression_cpu_cfg.py +``` diff --git a/examples/linear_regression/dataset.py b/examples/linear_regression/dataset.py new file mode 100644 index 0000000..71df1c0 --- /dev/null +++ b/examples/linear_regression/dataset.py @@ -0,0 +1,15 @@ +import torch +from torch.utils.data import Dataset + + +class LinearDataset(Dataset): + def __init__(self, k: float, b: float, num: int): + self.num = num + self.x = torch.unsqueeze(torch.linspace(-1, 1, self.num), dim=1) + self.y = k * self.x + b + torch.rand(self.x.size()) - 0.5 + + def __getitem__(self, index): + return self.x[index], self.y[index] + + def __len__(self): + return self.num diff --git a/examples/linear_regression/linear_regression_cpu_cfg.py b/examples/linear_regression/linear_regression_cpu_cfg.py new file mode 100644 index 0000000..72a7b11 --- /dev/null +++ b/examples/linear_regression/linear_regression_cpu_cfg.py @@ -0,0 +1,36 @@ +import os +from easydict import EasyDict + +from linear_regression_runner import LinearRegressionRunner + +CFG = EasyDict() + +CFG.DESC = 'linear_regression' +CFG.RUNNER = LinearRegressionRunner +CFG.USE_GPU = False + +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = 'linear' + +CFG.TRAIN = EasyDict() + +CFG.TRAIN.NUM_EPOCHS = 10000 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +CFG.TRAIN.CKPT_SAVE_STRATEGY = None + +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = 'SGD' +CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.001, + 'momentum': 0.1, +} + +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 10 +CFG.TRAIN.DATA.K = 10 +CFG.TRAIN.DATA.B = 6 +CFG.TRAIN.DATA.NUM = 100 +CFG.TRAIN.DATA.SHUFFLE = True diff --git a/examples/linear_regression/linear_regression_runner.py b/examples/linear_regression/linear_regression_runner.py new file mode 100644 index 0000000..250268f --- /dev/null +++ b/examples/linear_regression/linear_regression_runner.py @@ -0,0 +1,81 @@ +from torch import nn + +from easytorch import Runner + +from dataset import LinearDataset + + +class LinearRegressionRunner(Runner): + def init_training(self, cfg): + """Initialize training. + + Including loss, training meters, etc. + + Args: + cfg (dict): config + """ + + super().init_training(cfg) + + self.loss = nn.MSELoss() + self.loss = self.to_running_device(self.loss) + + self.register_epoch_meter('train_loss', 'train', '{:.2f}') + + @staticmethod + def define_model(cfg: dict) -> nn.Module: + """Define model. + + Args: + cfg (dict): config + + Returns: + model (nn.Module) + """ + + return nn.Linear(1, 1) + + @staticmethod + def build_train_dataset(cfg: dict): + """Build MNIST train dataset + + Args: + cfg (dict): config + + Returns: + train dataset (Dataset) + """ + + return LinearDataset( + cfg['TRAIN']['DATA']['K'], + cfg['TRAIN']['DATA']['B'], + cfg['TRAIN']['DATA']['NUM'], + ) + + def train_iters(self, epoch, iter_index, data): + """Training details. + + Args: + epoch (int): current epoch. + iter_index (int): current iter. + data (torch.Tensor or tuple): Data provided by DataLoader + + Returns: + loss (torch.Tensor) + """ + + x, y = data + x = self.to_running_device(x) + y = self.to_running_device(y) + + output = self.model(x) + loss = self.loss(output, y) + self.update_epoch_meter('train_loss', loss.item()) + return loss + + def on_training_end(self): + """Print result on training end. + """ + + super().on_training_end() + self.logger.info('Result: k: {}, b: {}'.format(self.model.weight.item(), self.model.bias.item())) diff --git a/examples/linear_regression/train.py b/examples/linear_regression/train.py new file mode 100644 index 0000000..2d443e9 --- /dev/null +++ b/examples/linear_regression/train.py @@ -0,0 +1,19 @@ +import sys +sys.path.append('../..') +from argparse import ArgumentParser + +from easytorch import launch_training + + +def parse_args(): + parser = ArgumentParser(description='Welcome to EasyTorch!') + parser.add_argument('-c', '--cfg', help='training config', required=True) + parser.add_argument('--gpus', help='visible gpus', type=str) + parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + launch_training(args.cfg, args.gpus, args.tf32) diff --git a/examples/mnist/.gitignore b/examples/mnist/.gitignore new file mode 100644 index 0000000..3320204 --- /dev/null +++ b/examples/mnist/.gitignore @@ -0,0 +1,2 @@ +checkpoints +mnist_data diff --git a/examples/mnist/README.md b/examples/mnist/README.md new file mode 100644 index 0000000..e34ff9a --- /dev/null +++ b/examples/mnist/README.md @@ -0,0 +1,29 @@ +# EasyTorch Example - MNIST Classification + +## Train + +* CPU + +```shell +python train.py -c config\mnist_cpu_cfg.py +``` + +* GPU (1x) + +```shell +python train.py -c config\mnist_1x_cfg.py --gpus 0 +``` + +## Validate + +* CPU + +```shell +python validate.py -c config\mnist_cpu_cfg.py +``` + +* GPU (1x) + +```shell +python validate.py -c config\mnist_1x_cfg.py --gpus 0 +``` diff --git a/examples/mnist/config/__init__.py b/examples/mnist/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/mnist/config/mnist_1x_cfg.py b/examples/mnist/config/mnist_1x_cfg.py new file mode 100644 index 0000000..b3e3fda --- /dev/null +++ b/examples/mnist/config/mnist_1x_cfg.py @@ -0,0 +1,41 @@ +import os +from easydict import EasyDict + +from mnist_runner import MNISTRunner + +CFG = EasyDict() + +CFG.DESC = 'mnist' +CFG.RUNNER = MNISTRunner +CFG.GPU_NUM = 1 + +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = 'conv_net' + +CFG.TRAIN = EasyDict() + +CFG.TRAIN.NUM_EPOCHS = 30 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +CFG.TRAIN.CKPT_SAVE_STRATEGY = None + +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = 'SGD' +CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.002, + 'momentum': 0.1, +} + +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 4 +CFG.TRAIN.DATA.DIR = 'mnist_data' +CFG.TRAIN.DATA.SHUFFLE = True + +CFG.VAL = EasyDict() + +CFG.VAL.INTERVAL = 1 + +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.DIR = 'mnist_data' diff --git a/examples/mnist/config/mnist_cpu_cfg.py b/examples/mnist/config/mnist_cpu_cfg.py new file mode 100644 index 0000000..aeeafd4 --- /dev/null +++ b/examples/mnist/config/mnist_cpu_cfg.py @@ -0,0 +1,41 @@ +import os +from easydict import EasyDict + +from mnist_runner import MNISTRunner + +CFG = EasyDict() + +CFG.DESC = 'mnist' +CFG.RUNNER = MNISTRunner +CFG.USE_GPU = False + +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = 'conv_net' + +CFG.TRAIN = EasyDict() + +CFG.TRAIN.NUM_EPOCHS = 30 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +CFG.TRAIN.CKPT_SAVE_STRATEGY = None + +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = 'SGD' +CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.002, + 'momentum': 0.1, +} + +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 4 +CFG.TRAIN.DATA.DIR = 'mnist_data' +CFG.TRAIN.DATA.SHUFFLE = True + +CFG.VAL = EasyDict() + +CFG.VAL.INTERVAL = 1 + +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.DIR = 'mnist_data' diff --git a/examples/mnist/config/mnist_lr_cpu_cfg.py b/examples/mnist/config/mnist_lr_cpu_cfg.py new file mode 100644 index 0000000..5c60eeb --- /dev/null +++ b/examples/mnist/config/mnist_lr_cpu_cfg.py @@ -0,0 +1,48 @@ +import os +from easydict import EasyDict + +from mnist_runner import MNISTRunner + +CFG = EasyDict() + +CFG.DESC = 'mnist, lr scheduler' +CFG.RUNNER = MNISTRunner +CFG.USE_GPU = False + +CFG.MODEL = EasyDict() +CFG.MODEL.NAME = 'conv_net' + +CFG.TRAIN = EasyDict() + +CFG.TRAIN.NUM_EPOCHS = 30 +CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( + 'checkpoints', + '_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) +) +CFG.TRAIN.CKPT_SAVE_STRATEGY = None + +CFG.TRAIN.OPTIM = EasyDict() +CFG.TRAIN.OPTIM.TYPE = 'SGD' +CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.002, + 'momentum': 0.1, +} + +CFG.TRAIN.LR_SCHEDULER = EasyDict() +CFG.TRAIN.LR_SCHEDULER.TYPE = 'CosineAnnealingLR' +CFG.TRAIN.LR_SCHEDULER.PARAM = { + 'T_max': CFG.TRAIN.NUM_EPOCHS, + 'eta_min': 1e-6 +} + +CFG.TRAIN.DATA = EasyDict() +CFG.TRAIN.DATA.BATCH_SIZE = 4 +CFG.TRAIN.DATA.DIR = 'mnist_data' +CFG.TRAIN.DATA.SHUFFLE = True + +CFG.VAL = EasyDict() + +CFG.VAL.INTERVAL = 1 + +CFG.VAL.DATA = EasyDict() +CFG.VAL.DATA.DIR = 'mnist_data' diff --git a/examples/mnist/conv_net.py b/examples/mnist/conv_net.py new file mode 100644 index 0000000..c015158 --- /dev/null +++ b/examples/mnist/conv_net.py @@ -0,0 +1,30 @@ +from torch import nn + + +class ConvNet(nn.Module): + def __init__(self): + super().__init__() + self.conv_block = nn.Sequential( + nn.Conv2d(1, 10, kernel_size=5), + nn.MaxPool2d(2), + nn.ReLU(inplace=True), + nn.Conv2d(10, 20, kernel_size=5), + nn.Dropout2d(), + nn.MaxPool2d(2), + nn.ReLU(inplace=True), + ) + + self.fc_block = nn.Sequential( + nn.Linear(320, 50), + nn.ReLU(inplace=True), + nn.Dropout2d(), + nn.Linear(50, 10), + nn.LogSoftmax(dim=1) + ) + + def forward(self, x): + y = self.conv_block(x) + y = y.view(-1, 320) + y = self.fc_block(y) + + return y diff --git a/examples/mnist/mnist_runner.py b/examples/mnist/mnist_runner.py new file mode 100644 index 0000000..53db6a5 --- /dev/null +++ b/examples/mnist/mnist_runner.py @@ -0,0 +1,132 @@ +from torch import nn +import torchvision + +from easytorch import Runner + +from conv_net import ConvNet + + +class MNISTRunner(Runner): + def init_training(self, cfg): + """Initialize training. + + Including loss, training meters, etc. + + Args: + cfg (dict): config + """ + + super().init_training(cfg) + + self.loss = nn.NLLLoss() + self.loss = self.to_running_device(self.loss) + + self.register_epoch_meter('train_loss', 'train', '{:.2f}') + + def init_validation(self, cfg: dict): + """Initialize validation. + + Including validation meters, etc. + + Args: + cfg (dict): config + """ + + super().init_validation(cfg) + + self.register_epoch_meter('val_acc', 'val', '{:.2f}%') + + @staticmethod + def define_model(cfg: dict) -> nn.Module: + """Define model. + + If you have multiple models, insert the name and class into the dict below, + and select it through ```config```. + + Args: + cfg (dict): config + + Returns: + model (nn.Module) + """ + + return { + 'conv_net': ConvNet + }[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {})) + + @staticmethod + def build_train_dataset(cfg: dict): + """Build MNIST train dataset + + Args: + cfg (dict): config + + Returns: + train dataset (Dataset) + """ + + return torchvision.datasets.MNIST( + cfg['TRAIN']['DATA']['DIR'], train=True, download=True, + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.1307,), (0.3081,)) + ]) + ) + + @staticmethod + def build_val_dataset(cfg: dict): + """Build MNIST val dataset + + Args: + cfg (dict): config + + Returns: + train dataset (Dataset) + """ + + return torchvision.datasets.MNIST( + cfg['VAL']['DATA']['DIR'], train=False, download=True, + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.1307,), (0.3081,)) + ]) + ) + + def train_iters(self, epoch, iter_index, data): + """Training details. + + Args: + epoch (int): current epoch. + iter_index (int): current iter. + data (torch.Tensor or tuple): Data provided by DataLoader + + Returns: + loss (torch.Tensor) + """ + + _input, _target = data + _input = self.to_running_device(_input) + _target = self.to_running_device(_target) + + output = self.model(_input) + loss = self.loss(output, _target) + self.update_epoch_meter('train_loss', loss.item()) + return loss + + def val_iters(self, iter_index, data): + """Validation details. + + Args: + iter_index (int): current iter. + data (torch.Tensor or tuple): Data provided by DataLoader + """ + + _input, _target = data + _input = self.to_running_device(_input) + _target = self.to_running_device(_target) + + output = self.model(_input) + pred = output.data.max(1, keepdim=True)[1] + self.update_epoch_meter('val_acc', 100 * pred.eq(_target.data.view_as(pred)).sum()) diff --git a/examples/mnist/train.py b/examples/mnist/train.py new file mode 100644 index 0000000..2d443e9 --- /dev/null +++ b/examples/mnist/train.py @@ -0,0 +1,19 @@ +import sys +sys.path.append('../..') +from argparse import ArgumentParser + +from easytorch import launch_training + + +def parse_args(): + parser = ArgumentParser(description='Welcome to EasyTorch!') + parser.add_argument('-c', '--cfg', help='training config', required=True) + parser.add_argument('--gpus', help='visible gpus', type=str) + parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true') + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + launch_training(args.cfg, args.gpus, args.tf32) diff --git a/examples/mnist/validate.py b/examples/mnist/validate.py new file mode 100644 index 0000000..a79ecdb --- /dev/null +++ b/examples/mnist/validate.py @@ -0,0 +1,28 @@ +import sys +sys.path.append('../..') +from argparse import ArgumentParser + +from easytorch import launch_runner, Runner + + +def parse_args(): + parser = ArgumentParser(description='Welcome to EasyTorch!') + parser.add_argument('-c', '--cfg', help='training config', required=True) + parser.add_argument('--ckpt', help='ckpt path. if it is None, load default ckpt in ckpt save dir', type=str) + parser.add_argument('--gpus', help='visible gpus', type=str) + parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true') + return parser.parse_args() + + +def main(cfg: dict, runner: Runner, ckpt: str = None): + # init logger + runner.init_logger(logger_name='easytorch-inference', log_file_name='validate_result') + + runner.load_model(ckpt_path=ckpt) + + runner.validate(cfg) + + +if __name__ == '__main__': + args = parse_args() + launch_runner(args.cfg, main, (args.ckpt, ), gpus=args.gpus, tf32_mode=args.tf32) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1ebc35b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +easydict +tensorboard +prefetch_generator diff --git a/tests/random_test/.gitignore b/tests/random_test/.gitignore new file mode 100644 index 0000000..bfea04a --- /dev/null +++ b/tests/random_test/.gitignore @@ -0,0 +1 @@ +checkpoints diff --git a/tests/random_test/random_test.py b/tests/random_test/random_test.py new file mode 100644 index 0000000..d7fe176 --- /dev/null +++ b/tests/random_test/random_test.py @@ -0,0 +1,85 @@ +import os +import sys +sys.path.append('../..') +import random +import numpy as np + +from easydict import EasyDict +import torch +from torch import nn +from torch.utils.data import Dataset + +from easytorch import Runner, get_rank, launch_training + + +class FakeDataset(Dataset): + def __init__(self, num: int, min: int, max: int): + self.num = num + self.min = min + self.max = max + + def __getitem__(self, index): + return index, \ + random.randint(self.min, self.max), \ + np.random.randint(self.min, self.max + 1), \ + torch.randint(self.min, self.max + 1, (1,)).item() + + def __len__(self): + return self.num + + +class DDPTestRunner(Runner): + @staticmethod + def define_model(cfg: dict) -> nn.Module: + return nn.Conv2d(3, 3, 3) + + @staticmethod + def build_train_dataset(cfg: dict): + return FakeDataset(cfg['TRAIN']['DATA']['NUM'], cfg['TRAIN']['DATA']['MIN'], cfg['TRAIN']['DATA']['MAX']) + + def train_iters(self, epoch, iter_index, data): + print('rank: {:d}, epoch: {:d}, iter: {:d}, data: {}'.format(get_rank(), epoch, iter_index, data)) + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def build_cfg(): + CFG = EasyDict() + + CFG.DESC = 'ddp test' + CFG.RUNNER = DDPTestRunner + CFG.GPU_NUM = 8 + CFG.SEED = 6 + + CFG.MODEL = EasyDict() + CFG.MODEL.NAME = 'conv' + + CFG.TRAIN = EasyDict() + + CFG.TRAIN.NUM_EPOCHS = 5 + CFG.TRAIN.CKPT_SAVE_DIR = 'checkpoints' + + CFG.TRAIN.CKPT_SAVE_STRATEGY = None + + CFG.TRAIN.OPTIM = EasyDict() + CFG.TRAIN.OPTIM.TYPE = 'SGD' + CFG.TRAIN.OPTIM.PARAM = { + 'lr': 0.002, + 'momentum': 0.1, + } + + CFG.TRAIN.DATA = EasyDict() + CFG.TRAIN.DATA.NUM = 100 + CFG.TRAIN.DATA.MIN = 0 + CFG.TRAIN.DATA.MAX = 100 + CFG.TRAIN.DATA.BATCH_SIZE = 4 + CFG.TRAIN.DATA.NUM_WORKERS = 2 + CFG.TRAIN.DATA.SHUFFLE = True + + return CFG + + +if __name__ == "__main__": + cfg = build_cfg() + + launch_training(cfg, gpus='0,1,2,3,4,5,6,7,8', tf32_mode=False)