Skip to content

Commit

Permalink
feature: template
Browse files Browse the repository at this point in the history
  • Loading branch information
cnstark committed Aug 4, 2021
1 parent b815dce commit b7f4213
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 2 deletions.
22 changes: 20 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,29 @@ pip install -r requirements.txt

## 开始使用

### 初始化git仓库
### 新建项目文件夹并初始化git仓库

```shell
mkdir my_deeplearning_project
cd my_deeplearning_project
git init
```

### 添加easytorch子仓库

###
```shell
git submodule add https://github.com/cnstark/easytorch.git
git add .
git commit -m "init by easytorch"
```

### 复制EasyTorch模板至工作目录

```shell
cp easytorch/examples/template/* .
```

*接下来就可以使用EasyTorch构建你的深度学习项目。*

## README 徽章

Expand Down
Empty file.
39 changes: 39 additions & 0 deletions examples/template/configs/config_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from easydict import EasyDict

from runners.runner_template import RunnerTemplate

CFG = EasyDict()

CFG.DESC = 'config template'
CFG.RUNNER = RunnerTemplate
CFG.GPU_NUM = 1

CFG.MODEL = EasyDict()
CFG.MODEL.NAME = 'model_template'

CFG.TRAIN = EasyDict()

CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
CFG.TRAIN.CKPT_SAVE_STRATEGY = None

CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = 'SGD'
CFG.TRAIN.OPTIM.PARAM = {
'lr': 0.002,
'momentum': 0.1,
}

CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 4
CFG.TRAIN.DATA.SHUFFLE = True

CFG.VAL = EasyDict()

CFG.VAL.INTERVAL = 1

CFG.VAL.DATA = EasyDict()
7 changes: 7 additions & 0 deletions examples/template/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .model_template import ModelTemplate


MODEL_DICT = {
'model_template': ModelTemplate
# other models...
}
13 changes: 13 additions & 0 deletions examples/template/models/model_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import nn


class ModelTemplate(nn.Module):
def __init__(self):
super().__init__()
self.op = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.op(x)
Empty file.
62 changes: 62 additions & 0 deletions examples/template/runners/runner_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
from torch import nn
from torch.utils.data import Dataset

from easytorch import Runner

from ..models import MODEL_DICT


class RunnerTemplate(Runner):
def __init__(self, cfg: dict):
super().__init__(cfg)

def init_training(self, cfg: dict):
super().init_training(cfg)

# init loss
# e.g.
# self.loss = nn.MSELoss()
# self.loss = self.to_running_device(self.loss)

# register meters by calling:
# self.register_epoch_meter('train_loss', 'train', '{:.2f}')

def init_validation(self, cfg: dict):
super().init_validation(cfg)

# self.register_epoch_meter('val_acc', 'val', '{:.2f}%')

@staticmethod
def define_model(cfg: dict) -> nn.Module:
return MODEL_DICT[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {}))

@staticmethod
def build_train_dataset(cfg: dict) -> Dataset:
# return your train Dataset
pass

@staticmethod
def build_val_dataset(cfg: dict):
# return your val Dataset
pass

def train_iters(self, epoch: int, iter_index: int, data: torch.Tensor or tuple) -> torch.Tensor:
# forward and compute loss
# update meters if necessary
# return loss (will be auto backward and update params) or don't return any thing

# e.g.
# _input, _target = data
# _input = self.to_running_device(_input)
# _target = self.to_running_device(_target)
#
# output = self.model(_input)
# loss = self.loss(output, _target)
# self.update_epoch_meter('train_loss', loss.item())
# return loss
pass

def val_iters(self, iter_index: int, data: torch.Tensor or tuple):
# forward and update meters
pass
17 changes: 17 additions & 0 deletions examples/template/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from argparse import ArgumentParser

from easytorch.easytorch import launch_training


def parse_args():
parser = ArgumentParser(description='Welcome to EasyTorch!')
parser.add_argument('-c', '--cfg', help='training config', required=True)
parser.add_argument('--gpus', help='visible gpus', type=str)
parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true')
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()

launch_training(args.cfg, args.gpus, args.tf32)

0 comments on commit b7f4213

Please sign in to comment.