-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* move to easytorch * feature: mnist example * fix: mnist * feature: linear regression (temp) * feature: linear regression * feature: linear regression * feature: mnist cpu cfg * fix * feature: MNIST doc & readme * feature: validate * feature: requirements.txt * feature: remove val * feature: README.md * feature: add notes * feature: README * update README * update README * feature: lr * feature: lgtm svg * gitee-mirror * feature: random_test
- Loading branch information
Showing
38 changed files
with
687 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,7 @@ | ||
# easytorch | ||
# EasyTorch | ||
|
||
[![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) | ||
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) | ||
[![gitee mirror](https://github.com/cnstark/easytorch/actions/workflows/git-mirror.yml/badge.svg)](https://gitee.com/cnstark/easytorch) | ||
|
||
[English](README.md) **|** [简体中文](README_CN.md) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# EasyTorch | ||
|
||
[![LICENSE](https://img.shields.io/github/license/cnstark/easytorch.svg)](https://github.com/cnstark/easytorch/blob/master/LICENSE) | ||
[![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/cnstark/easytorch.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/cnstark/easytorch/context:python) | ||
[![gitee mirror](https://github.com/cnstark/easytorch/actions/workflows/git-mirror.yml/badge.svg)](https://gitee.com/cnstark/easytorch) | ||
|
||
[English](README.md) **|** [简体中文](README_CN.md) | ||
|
||
--- | ||
|
||
Easytorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。 | ||
|
||
## 功能亮点 | ||
|
||
* **最小代码量**。封装通用神经网络训练流程,用户仅需实现`Dataset`、`Model`以及训练/推理代码等关键代码,就能完成深度学习项目的构建。 | ||
* **万物基于Config**。通过配置文件控制训练模式与超参,根据配置内容的MD5自动生成唯一的结果存放目录,调整超参不再凌乱。 | ||
* **支持所有设备**。支持CPU、GPU与GPU分布式训练,通过配置参数一键完成设置。 | ||
* **持久化训练日志**。支持`logging`日志系统与`Tensorboard`,并封装为统一接口,用户通过一键调用即可保存自定义的训练日志。 | ||
|
||
## 环境依赖 | ||
|
||
### 操作系统 | ||
|
||
* [Linux](https://pytorch.org/get-started/locally/#linux-prerequisites) | ||
* [Windows](https://pytorch.org/get-started/locally/#windows-prerequisites) | ||
* [MacOS](https://pytorch.org/get-started/locally/#mac-prerequisites) | ||
|
||
推荐使用Ubuntu16.04及更高版本或CentOS7及以更高版本。 | ||
|
||
### Python | ||
|
||
python >= 3.6 (推荐 >= 3.7) | ||
|
||
推荐使用[Anaconda](https://www.anaconda.com/) | ||
|
||
### PyTorch及CUDA | ||
|
||
[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7) | ||
|
||
[CUDA](https://developer.nvidia.com/zh-cn/cuda-toolkit) >= 9.2 (推荐 >= 11.0) | ||
|
||
注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7且CUDA版本 >= 11.0。 | ||
|
||
### 其他依赖 | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## 示例 | ||
|
||
* [线性回归](examples/linear_regression) | ||
* [MNIST手写数字识别](examples/mnist) | ||
|
||
## README 徽章 | ||
|
||
如果你的项目正在使用EasyTorch,可以将EasyTorch徽章 [![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) 添加到你的 README 中: | ||
|
||
``` | ||
[![EasyTorch](https://img.shields.io/badge/Developing%20with-EasyTorch-2077ff.svg)](https://github.com/cnstark/easytorch) | ||
``` |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
checkpoints |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# EasyTorch Example - MNIST Classification | ||
|
||
## Train | ||
|
||
* CPU | ||
|
||
```shell | ||
python train.py -c linear_regression_cpu_cfg.py | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class LinearDataset(Dataset): | ||
def __init__(self, k: float, b: float, num: int): | ||
self.num = num | ||
self.x = torch.unsqueeze(torch.linspace(-1, 1, self.num), dim=1) | ||
self.y = k * self.x + b + torch.rand(self.x.size()) - 0.5 | ||
|
||
def __getitem__(self, index): | ||
return self.x[index], self.y[index] | ||
|
||
def __len__(self): | ||
return self.num |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
from easydict import EasyDict | ||
|
||
from linear_regression_runner import LinearRegressionRunner | ||
|
||
CFG = EasyDict() | ||
|
||
CFG.DESC = 'linear_regression' | ||
CFG.RUNNER = LinearRegressionRunner | ||
CFG.USE_GPU = False | ||
|
||
CFG.MODEL = EasyDict() | ||
CFG.MODEL.NAME = 'linear' | ||
|
||
CFG.TRAIN = EasyDict() | ||
|
||
CFG.TRAIN.NUM_EPOCHS = 10000 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
CFG.TRAIN.CKPT_SAVE_STRATEGY = None | ||
|
||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = 'SGD' | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
'lr': 0.001, | ||
'momentum': 0.1, | ||
} | ||
|
||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 10 | ||
CFG.TRAIN.DATA.K = 10 | ||
CFG.TRAIN.DATA.B = 6 | ||
CFG.TRAIN.DATA.NUM = 100 | ||
CFG.TRAIN.DATA.SHUFFLE = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
from torch import nn | ||
|
||
from easytorch import Runner | ||
|
||
from dataset import LinearDataset | ||
|
||
|
||
class LinearRegressionRunner(Runner): | ||
def init_training(self, cfg): | ||
"""Initialize training. | ||
Including loss, training meters, etc. | ||
Args: | ||
cfg (dict): config | ||
""" | ||
|
||
super().init_training(cfg) | ||
|
||
self.loss = nn.MSELoss() | ||
self.loss = self.to_running_device(self.loss) | ||
|
||
self.register_epoch_meter('train_loss', 'train', '{:.2f}') | ||
|
||
@staticmethod | ||
def define_model(cfg: dict) -> nn.Module: | ||
"""Define model. | ||
Args: | ||
cfg (dict): config | ||
Returns: | ||
model (nn.Module) | ||
""" | ||
|
||
return nn.Linear(1, 1) | ||
|
||
@staticmethod | ||
def build_train_dataset(cfg: dict): | ||
"""Build MNIST train dataset | ||
Args: | ||
cfg (dict): config | ||
Returns: | ||
train dataset (Dataset) | ||
""" | ||
|
||
return LinearDataset( | ||
cfg['TRAIN']['DATA']['K'], | ||
cfg['TRAIN']['DATA']['B'], | ||
cfg['TRAIN']['DATA']['NUM'], | ||
) | ||
|
||
def train_iters(self, epoch, iter_index, data): | ||
"""Training details. | ||
Args: | ||
epoch (int): current epoch. | ||
iter_index (int): current iter. | ||
data (torch.Tensor or tuple): Data provided by DataLoader | ||
Returns: | ||
loss (torch.Tensor) | ||
""" | ||
|
||
x, y = data | ||
x = self.to_running_device(x) | ||
y = self.to_running_device(y) | ||
|
||
output = self.model(x) | ||
loss = self.loss(output, y) | ||
self.update_epoch_meter('train_loss', loss.item()) | ||
return loss | ||
|
||
def on_training_end(self): | ||
"""Print result on training end. | ||
""" | ||
|
||
super().on_training_end() | ||
self.logger.info('Result: k: {}, b: {}'.format(self.model.weight.item(), self.model.bias.item())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import sys | ||
sys.path.append('../..') | ||
from argparse import ArgumentParser | ||
|
||
from easytorch import launch_training | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser(description='Welcome to EasyTorch!') | ||
parser.add_argument('-c', '--cfg', help='training config', required=True) | ||
parser.add_argument('--gpus', help='visible gpus', type=str) | ||
parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true') | ||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
|
||
launch_training(args.cfg, args.gpus, args.tf32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
checkpoints | ||
mnist_data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# EasyTorch Example - MNIST Classification | ||
|
||
## Train | ||
|
||
* CPU | ||
|
||
```shell | ||
python train.py -c config\mnist_cpu_cfg.py | ||
``` | ||
|
||
* GPU (1x) | ||
|
||
```shell | ||
python train.py -c config\mnist_1x_cfg.py --gpus 0 | ||
``` | ||
|
||
## Validate | ||
|
||
* CPU | ||
|
||
```shell | ||
python validate.py -c config\mnist_cpu_cfg.py | ||
``` | ||
|
||
* GPU (1x) | ||
|
||
```shell | ||
python validate.py -c config\mnist_1x_cfg.py --gpus 0 | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
from easydict import EasyDict | ||
|
||
from mnist_runner import MNISTRunner | ||
|
||
CFG = EasyDict() | ||
|
||
CFG.DESC = 'mnist' | ||
CFG.RUNNER = MNISTRunner | ||
CFG.GPU_NUM = 1 | ||
|
||
CFG.MODEL = EasyDict() | ||
CFG.MODEL.NAME = 'conv_net' | ||
|
||
CFG.TRAIN = EasyDict() | ||
|
||
CFG.TRAIN.NUM_EPOCHS = 30 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
CFG.TRAIN.CKPT_SAVE_STRATEGY = None | ||
|
||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = 'SGD' | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
'lr': 0.002, | ||
'momentum': 0.1, | ||
} | ||
|
||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 4 | ||
CFG.TRAIN.DATA.DIR = 'mnist_data' | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
|
||
CFG.VAL = EasyDict() | ||
|
||
CFG.VAL.INTERVAL = 1 | ||
|
||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.DIR = 'mnist_data' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
from easydict import EasyDict | ||
|
||
from mnist_runner import MNISTRunner | ||
|
||
CFG = EasyDict() | ||
|
||
CFG.DESC = 'mnist' | ||
CFG.RUNNER = MNISTRunner | ||
CFG.USE_GPU = False | ||
|
||
CFG.MODEL = EasyDict() | ||
CFG.MODEL.NAME = 'conv_net' | ||
|
||
CFG.TRAIN = EasyDict() | ||
|
||
CFG.TRAIN.NUM_EPOCHS = 30 | ||
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join( | ||
'checkpoints', | ||
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)]) | ||
) | ||
CFG.TRAIN.CKPT_SAVE_STRATEGY = None | ||
|
||
CFG.TRAIN.OPTIM = EasyDict() | ||
CFG.TRAIN.OPTIM.TYPE = 'SGD' | ||
CFG.TRAIN.OPTIM.PARAM = { | ||
'lr': 0.002, | ||
'momentum': 0.1, | ||
} | ||
|
||
CFG.TRAIN.DATA = EasyDict() | ||
CFG.TRAIN.DATA.BATCH_SIZE = 4 | ||
CFG.TRAIN.DATA.DIR = 'mnist_data' | ||
CFG.TRAIN.DATA.SHUFFLE = True | ||
|
||
CFG.VAL = EasyDict() | ||
|
||
CFG.VAL.INTERVAL = 1 | ||
|
||
CFG.VAL.DATA = EasyDict() | ||
CFG.VAL.DATA.DIR = 'mnist_data' |
Oops, something went wrong.