Skip to content

A high cohesion, low coupling, and plug-and-play project framework for PyTorch.

License

Notifications You must be signed in to change notification settings

Lmy0217/PyTorch-Project-Framework

Repository files navigation

PyTorch-Project-Framework

Travis CircleCI License PRs Welcome

A high cohesion, low coupling, and plug-and-play project framework for PyTorch.

Folder Structure

  ├── configs
  |    ├── BaseConfig.py  - the loader of all configuration file
  |    ├── BaseTest.py  - the test class of all configuration file
  |    ├── Env.py  - the loader of environmental configuration file
  |    └── Run.py  - the loader of hyperparameter configuration file
  |
  ├── datasets
  |    ├── functional  - the package of functional methods
  |    ├── BaseDataset.py  - the abstract class of all dataset
  |    ├── BaseTest.py  - the test class of all dataset
  |    └── ...  - any dataset of your project
  |
  ├── models
  |    ├── functional  - the package of functional methods
  |    ├── shallow  - the package of shallow methods
  |    ├── BaseModel.py  - the abstract class of all model
  |    ├── BaseTest.py  - the test class of all model
  |    └── ...  - any model of your project
  |
  ├── res
  |    ├── env  - the folder contains any json file of environmental configuration
  |    ├── datasets  - the folder contains any json file of dataset configuration
  |    ├── models  - the folder contains any json file of model configuration
  |    └── run  - the folder contains any json file of hyperparameter configuration
  |
  ├── test
  |    ├── test_configs.py  - the unittest classes of package configs
  |    ├── test_datasets.py  - the unittest classes of package datasets
  |    ├── test_models.py  - the unittest classes of package models
  |    └── test_utils.py  - the unittest classes of package utils
  |
  ├── utils
  |    ├── common.py  - the common methods
  |    ├── logger.py  - the logger class
  |    ├── summary.py  - the summary class
  |    └── ...  - any utils of your project
  |
  ├── main.py  - the main class of framework
  |
  └── test_component.py  - the global test class

Main Components

Datasets

  • Base dataset

    Base dataset is an abstract class that must be Inherited by any dataset you create, the idea behind this is that there's much shared stuff between all datasets. The base dataset mainly contains:

    • more - add / update unique configuration to dataset
    • load - load dataset
    • split - create trainset and testset
  • Your dataset

    Here's where you implement your dataset. So you should:

    • Create your dataset class and inherit the BaseDataset class
    • Override load method
    • Override other methods if your need special implementation
    • Add your dataset name to datasets/__init__.py
    • Create json file of your dataset's configuration in res/datasets/

Models

  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model mainly contains:

    • check_cfg - filter data set
    • train - train step
    • test - test step
    • load - load previously trained model
    • save - save model
  • Your model

    Here's where you implement your model. So you should:

    • Create your model class and inherit the BaseModel class
    • Override train / test method
    • Override other methods if your need special implementation
    • Add your model name to models/__init__.py
    • Create json file of your model's configuration in res/models/

How to Use

Here's how to use this framework, you should do the following:

  • Dataset

    • In datasets folder create a class that inherit the BaseDataset class

       # YourDataset.py
       class YourDataset(datasets.BaseDataset):
           def __init__(self, cfg, **kwargs):
               super().__init__(cfg, **kwargs)
    • Override load method to load dataset

       # In YourDataset class
       def load(self):
           """
           Here load your dataset
           The parameters in `cfg` are load from json file of your dataset's configuration
           For example:
           - Create 4 random images of size (depth, height, width) as source data 
           - Create 4 random labels as target data
           Return data dictionary and the amount of data
           """
      
           data_count = 4
           source = numpy.random.rand(data_count, self.cfg.depth, self.cfg.height, self.cfg.width)
           target = numpy.random.randint(0, self.cfg.label_count, (data_count, 1))
      
           return {'source': source, 'target': target}, data_count
    • Add your dataset name to datasets/__init__.py

      from .YourDataset import YourDataset
    • Create json file of your dataset's configuration in res/datasets/

      {
          "name": "YourDataset", // same with your dataset class name
          // All dataset parameter your need where create `YourDataset` class
          // For example, the size of images and K-fold cross-validation
          "source": {
              depth: 3,
              height: 128,
              width: 128
          },
          "cross_folds": 2
      }
      
  • Model

    • In models folder create a class that inherit the BaseModel class

       # YourModel.py
       class YourModel(models.BaseModels):
           def __init__(self, cfg, data_cfg, run, **kwargs):
               super().__init__(cfg, data_cfg, run, **kwargs)
      
               # The parameters in `cfg` are load from json file of your model's configuration
               # The parameters in `data_cfg` are load from json file of dataset's configuration
               # The parameters in `run` are load from json file of hyperparameter configuration
      
               # Create model, optimizer, criterion, and etc.
               # For example:
               # - model: Linear
               # - criterion: L1 loss
               # - optimizer: Adam
               self.model = torch.nn.Linear(self.cfg.input_dims, self.cfg.output_dims).to(self.device)
               self.criterion = torch.nn.L1Loss.to(self.device)
               self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.run.lr, betas=(self.run.b1, self.run.b2))
    • Override two methods train and test to write the logic of the training and testing process

      # In YourModel class
      def train(self, epoch_info, sample_dict):
          """
          epoch_info: the epoch information
          sample_dict: the dictionary of train data
      
          Implement the logic of training process
          For example:
              source -> [model] -> predict -> [criterion] (+target) -> loss
          Return loss dictionary
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.train()
          self.optimizer.zero_grad()
          predict = self.model(source)
          loss = self.criterion(predict, target)
          loss.backward()
          self.optimizer.step()
      
          # Others you need to calculate
      
          return {'loss': loss}
      
      def test(self, epoch_info, sample_dict):
          """
          batch_idx: the epoch information
          sample_dict: the dictionary of test data
      
          Implement the logic of testing process
          For example:
              source -> [model] -> predict
          Return dictionary of data which you want saved
          """
          source = sample_dict['source'].to(self.device)
          target = sample_dict['target'].to(self.device)
      
          self.model.eval()
          predict = self.model(source)
      
          # Others you need to calculate
      
          return {'target': target, 'predict': predict}
    • Add your model name to models/__init__.py

      from .YourModel import YourModel
    • Create json file of your model's configuration in res/models/

      {
          "name": "YourModel", // same with your model class name
          // All model parameter your need where create `YourModel` class
          // For example, the dimensions of input and output
          "input_dims": 256,
          "output_dims": 1
      }
      
  • Hyperparameter

    • Create json file of your hyperparameter's configuration in res/run/

      {
          "name": "YourHP",
          // Basic hyperparameter
          "batch_size": 32,
          "epochs": 200,
          "save_step": 10,
          // Hyperparameters your need where create optimizer in `YourModel` class or others
          // For example, learning rate
          "lr": 2e-4
      }
      
  • Run main.py to start training or testing

    • Training with configuration files res/datasets/yourdataset.json, res/models/yourmodel.json, and res/run/yourhp.json on GPU 0

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0

    Every save_step epoch trained model and data that you want to save will be saved in the folder save/[yourmodel]-[yourhp]-[yourdataset]-[index of cross-validation].

    • If you want to test epoch 10

      python3 -m main -d "yourdataset" -m "yourmodel" -r "yourhp" -g 0 -t 10

Contributing

Any kind of enhancement or contribution is welcomed.

License

The code is licensed with the MIT license.

About

A high cohesion, low coupling, and plug-and-play project framework for PyTorch.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages