Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add doc #27

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

---

Easytorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。
EasyTorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。

## 功能亮点

Expand All @@ -35,9 +35,7 @@ python >= 3.6 (推荐 >= 3.7)

### PyTorch及CUDA

[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7)

[CUDA](https://developer.nvidia.com/zh-cn/cuda-toolkit) >= 9.2 (推荐 >= 11.0)
[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7),如需使用CUDA,安装PyTorch时选择对应CUDA版本编译的包。

注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7且CUDA版本 >= 11.0。

Expand All @@ -51,6 +49,23 @@ pip install -r requirements.txt

* [线性回归](examples/linear_regression)
* [MNIST手写数字识别](examples/mnist)
* [ImageNet图像分类](examples/imagenet)

## 开始使用

### 安装EasyTorch

```shell
pip install easy-torch
```

### 复制EasyTorch模板至工作目录

```shell
cp -r easytorch/examples/template/* .
```

*接下来就可以使用EasyTorch构建你的深度学习项目。*

## README 徽章

Expand Down
Empty file added docs/config.md
Empty file.
Empty file.
39 changes: 39 additions & 0 deletions examples/template/configs/config_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from easydict import EasyDict

from runners.runner_template import RunnerTemplate

CFG = EasyDict()

CFG.DESC = 'config template'
CFG.RUNNER = RunnerTemplate
CFG.GPU_NUM = 1

CFG.MODEL = EasyDict()
CFG.MODEL.NAME = 'model_template'

CFG.TRAIN = EasyDict()

CFG.TRAIN.NUM_EPOCHS = 100
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
'checkpoints',
'_'.join([CFG.MODEL.NAME, str(CFG.TRAIN.NUM_EPOCHS)])
)
CFG.TRAIN.CKPT_SAVE_STRATEGY = None

CFG.TRAIN.OPTIM = EasyDict()
CFG.TRAIN.OPTIM.TYPE = 'SGD'
CFG.TRAIN.OPTIM.PARAM = {
'lr': 0.002,
'momentum': 0.1,
}

CFG.TRAIN.DATA = EasyDict()
CFG.TRAIN.DATA.BATCH_SIZE = 4
CFG.TRAIN.DATA.SHUFFLE = True

CFG.VAL = EasyDict()

CFG.VAL.INTERVAL = 1

CFG.VAL.DATA = EasyDict()
7 changes: 7 additions & 0 deletions examples/template/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .model_template import ModelTemplate


MODEL_DICT = {
'model_template': ModelTemplate
# other models...
}
13 changes: 13 additions & 0 deletions examples/template/models/model_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from torch import nn


class ModelTemplate(nn.Module):
def __init__(self):
super().__init__()
self.op = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.op(x)
Empty file.
64 changes: 64 additions & 0 deletions examples/template/runners/runner_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict

import torch
from torch import nn
from torch.utils.data import Dataset

from easytorch import Runner

from ..models import MODEL_DICT


class RunnerTemplate(Runner):
def __init__(self, cfg: Dict):
super().__init__(cfg)

def init_training(self, cfg: Dict):
super().init_training(cfg)

# init loss
# e.g.
# self.loss = nn.MSELoss()
# self.loss = self.to_running_device(self.loss)

# register meters by calling:
# self.register_epoch_meter('train_loss', 'train', '{:.2f}')

def init_validation(self, cfg: Dict):
super().init_validation(cfg)

# self.register_epoch_meter('val_acc', 'val', '{:.2f}%')

@staticmethod
def define_model(cfg: Dict) -> nn.Module:
return MODEL_DICT[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {}))

@staticmethod
def build_train_dataset(cfg: Dict) -> Dataset:
# return your train Dataset
pass

@staticmethod
def build_val_dataset(cfg: Dict):
# return your val Dataset
pass

def train_iters(self, epoch: int, iter_index: int, data: torch.Tensor or tuple) -> torch.Tensor:
# forward and compute loss
# update meters if necessary
# return loss (will be auto backward and update params) or don't return any thing

# e.g.
# _input, _target = data
# _input = self.to_running_device(_input)
# _target = self.to_running_device(_target)
#
# output = self.model(_input)
# loss = self.loss(output, _target)
# self.update_epoch_meter('train_loss', loss.item())
# return loss
pass

def val_iters(self, iter_index: int, data: torch.Tensor or tuple):
# forward and update meters
pass
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.7
torchvision
easydict
tensorboard
tqdm