Skip to content

Commit

Permalink
feature: release (#32)
Browse files Browse the repository at this point in the history
* feature: add torch>=1.7

* feature: setup.py

* feature: easytrain

* fix: format file
  • Loading branch information
cnstark authored Oct 13, 2021
1 parent 19a2c8c commit 647350e
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
__pycache__
.vscode
.idea

/build
/dist
*.egg-info
1 change: 1 addition & 0 deletions easytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .config import import_config
from .core import *
from .utils import *
from .version import __version__
4 changes: 2 additions & 2 deletions easytorch/core/launcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import random
from typing import Callable

import torch
Expand Down Expand Up @@ -39,7 +38,8 @@ def train(cfg: dict, use_gpu: bool, tf32_mode: bool):
runner.train(cfg)


def train_ddp(local_rank: int, world_size: int, backend: str or Backend, init_method: str, cfg: dict, tf32_mode: bool, node_rank: int = 0):
def train_ddp(local_rank: int, world_size: int, backend: str or Backend, init_method: str, cfg: dict, tf32_mode: bool,
node_rank: int = 0):
"""Start training with DistributedDataParallel
Args:
Expand Down
30 changes: 30 additions & 0 deletions easytorch/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import sys
from argparse import ArgumentParser

from easytorch import launch_training


def parse_args():
parser = ArgumentParser(description='Welcome to EasyTorch!')
parser.add_argument('-c', '--cfg', help='training config', required=True)
parser.add_argument('--node-rank', default=0, type=int, help='node rank for distributed training')
parser.add_argument('--gpus', help='visible gpus', type=str)
parser.add_argument('--tf32', help='enable tf32 on Ampere device', action='store_true')
return parser.parse_args()


def main():
# work dir
path = os.getcwd()
sys.path.append(path)

# parse arguments
args = parse_args()

# train
launch_training(args.cfg, args.gpus, args.tf32, args.node_rank)


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions easytorch/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
__version__ = '1.0.0'
__all__ = ['__version__']
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch>=1.7
easydict
tensorboard
tqdm
54 changes: 54 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
from setuptools import find_packages, setup


def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content


def get_version():
version_file = 'easytorch/version.py'
with open(version_file, 'r', encoding='utf-8') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']


def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires


if __name__ == '__main__':
setup(
name='easytorch',
version=get_version(),
description='Simple and powerful pytorch framework.',
long_description=readme(),
long_description_content_type='text/markdown',
author='Yuhao Wang',
author_email='yuhaow97@gmail.com',
keywords='pytorch, deep learning',
url='https://github.com/cnstark/easytorch',
include_package_data=True,
packages=find_packages(exclude=('tests',)),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Topic :: Utilities'
],
entry_points={
'console_scripts': ['easytrain=easytorch.train:main'],
},
license='Apache License 2.0',
install_requires=get_requirements(),
zip_safe=False
)

0 comments on commit 647350e

Please sign in to comment.