Skip to content

Commit

Permalink
update version
Browse files Browse the repository at this point in the history
  • Loading branch information
cnstark committed May 12, 2022
1 parent f3e029f commit bfeb105
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 39 deletions.
21 changes: 5 additions & 16 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

---

Easytorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。
EasyTorch是一个基于PyTorch的开源神经网络训练框架,封装了PyTorch项目中常用的功能,帮助用户快速构建深度学习项目。

## 功能亮点

Expand All @@ -35,9 +35,7 @@ python >= 3.6 (推荐 >= 3.7)

### PyTorch及CUDA

[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7)

[CUDA](https://developer.nvidia.com/zh-cn/cuda-toolkit) >= 9.2 (推荐 >= 11.0)
[pytorch](https://pytorch.org/) >= 1.4(推荐 >= 1.7),如需使用CUDA,安装PyTorch时选择对应CUDA版本编译的包。

注意:如需使用安培(Ampere)架构GPU,PyTorch版本需 >= 1.7且CUDA版本 >= 11.0。

Expand All @@ -51,23 +49,14 @@ pip install -r requirements.txt

* [线性回归](examples/linear_regression)
* [MNIST手写数字识别](examples/mnist)
* [ImageNet图像分类](examples/imagenet)

## 开始使用

### 新建项目文件夹并初始化git仓库

```shell
mkdir my_deeplearning_project
cd my_deeplearning_project
git init
```

### 添加easytorch子仓库
### 安装EasyTorch

```shell
git submodule add https://github.com/cnstark/easytorch.git
git add .
git commit -m "init by easytorch"
pip install easy-torch
```

### 复制EasyTorch模板至工作目录
Expand Down
14 changes: 8 additions & 6 deletions examples/template/runners/runner_template.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Dict

import torch
from torch import nn
from torch.utils.data import Dataset
Expand All @@ -8,10 +10,10 @@


class RunnerTemplate(Runner):
def __init__(self, cfg: dict):
def __init__(self, cfg: Dict):
super().__init__(cfg)

def init_training(self, cfg: dict):
def init_training(self, cfg: Dict):
super().init_training(cfg)

# init loss
Expand All @@ -22,22 +24,22 @@ def init_training(self, cfg: dict):
# register meters by calling:
# self.register_epoch_meter('train_loss', 'train', '{:.2f}')

def init_validation(self, cfg: dict):
def init_validation(self, cfg: Dict):
super().init_validation(cfg)

# self.register_epoch_meter('val_acc', 'val', '{:.2f}%')

@staticmethod
def define_model(cfg: dict) -> nn.Module:
def define_model(cfg: Dict) -> nn.Module:
return MODEL_DICT[cfg['MODEL']['NAME']](**cfg['MODEL'].get('PARAM', {}))

@staticmethod
def build_train_dataset(cfg: dict) -> Dataset:
def build_train_dataset(cfg: Dict) -> Dataset:
# return your train Dataset
pass

@staticmethod
def build_val_dataset(cfg: dict):
def build_val_dataset(cfg: Dict):
# return your val Dataset
pass

Expand Down
17 changes: 0 additions & 17 deletions examples/template/train.py

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
torch>=1.7
torchvision
easydict
tensorboard
tqdm

0 comments on commit bfeb105

Please sign in to comment.