Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests #3

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: CI

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
run-tests:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version:
- '3.8'
- '3.9'
- '3.10'
- '3.11'
- '3.12'

name: Test
runs-on: ${{ matrix.os }}

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: python -m pip install .

- name: Run tests
run: python -m unittest discover
15 changes: 9 additions & 6 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,9 @@ def __init__(
self.norm2 = None
self.conv1 = weight_norm(self.conv1)
self.conv2 = weight_norm(self.conv2)
elif use_norm is None:
self.norm1 = None
self.norm2 = None

self.activation1 = activation_fn[ self.activation_name ]()
self.activation2 = activation_fn[ self.activation_name ]()
Expand Down Expand Up @@ -269,10 +272,10 @@ def __init__(
if dilations is not None and len(dilations) != len(num_channels):
raise ValueError("Length of dilations must match length of num_channels")

allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None]
if use_norm not in allowed_norm_values:
self.allowed_norm_values = ['batch_norm', 'layer_norm', 'weight_norm', None]
if use_norm not in self.allowed_norm_values:
raise ValueError(
f"Argument 'use_norm' must be one of: {allowed_norm_values}"
f"Argument 'use_norm' must be one of: {self.allowed_norm_values}"
)

if activation not in activation_fn.keys():
Expand All @@ -285,10 +288,10 @@ def __init__(
f"Argument 'kernel_initializer' must be one of: {kernel_init_fn.keys()}"
)

allowed_input_shapes = ['NCL', 'NLC']
if input_shape not in allowed_input_shapes:
self.allowed_input_shapes = ['NCL', 'NLC']
if input_shape not in self.allowed_input_shapes:
raise ValueError(
f"Argument 'input_shape' must be one of: {allowed_input_shapes}"
f"Argument 'input_shape' must be one of: {self.allowed_input_shapes}"
)

if dilations is None:
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/unit/__init__.py
Empty file.
118 changes: 118 additions & 0 deletions tests/unit/test_tcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import unittest
import torch
import pytorch_tcn
from pytorch_tcn import TCN


class TestTCN(unittest.TestCase):

def __init__(self, methodName: str = "runTest") -> None:
super().__init__(methodName)

self.num_inputs = 20
self.num_channels = [
32, 64, 64, 128,
32, 64, 64, 128,
]

return

def test_tcn(self, **kwargs):
tcn = TCN(
num_inputs = self.num_inputs,
num_channels = self.num_channels,
**kwargs,
)

time_steps = 196
x = torch.randn( 10, self.num_inputs, time_steps )
y = tcn(x)

self.assertEqual( y.shape, (10, self.num_channels[-1], time_steps) )
return

def test_kernel_size(self):
self.test_tcn( kernel_size = 7 )
return

def test_dilations(self):
# dilations list len != len(num_channels)
with self.assertRaises(ValueError):
self.test_tcn( dilations = [1, 2, 3, 4] )

# dilations list len == len(num_channels)
self.test_tcn( dilations = [1, 2, 3, 4, 1, 2, 3, 4] )
return

def test_dropout(self):
self.test_tcn( dropout = 0.5 )
return

def test_causal(self):
self.test_tcn( causal = True )
return

def test_non_causal(self):
self.test_tcn( causal = False )
return

def test_norms(self):
available_norms = TCN(10,[10]).allowed_norm_values
for norm in available_norms:
print( 'Testing norm:', norm )
self.test_tcn( use_norm = norm )

with self.assertRaises(ValueError):
self.test_tcn( use_norm = 'invalid' )
return

def test_activations(self):
available_activations = pytorch_tcn.tcn.activation_fn.keys()
for activation in available_activations:
self.test_tcn( activation = activation )

with self.assertRaises(ValueError):
self.test_tcn( activation = 'invalid' )
return

def test_kernel_initializers(self):
available_initializers = pytorch_tcn.tcn.kernel_init_fn.keys()
for initializer in available_initializers:
self.test_tcn( kernel_initializer = initializer )

with self.assertRaises(ValueError):
self.test_tcn( kernel_initializer = 'invalid' )
return

def test_skip_connections(self):
self.test_tcn( use_skip_connections = True )
self.test_tcn( use_skip_connections = False )
return

def test_input_shape(self):
self.test_tcn( input_shape = 'NCL' )

# Test NLC
tcn = TCN(
num_inputs = self.num_inputs,
num_channels = self.num_channels,
input_shape = 'NLC'
)

time_steps = 196
x = torch.randn( 10, time_steps, self.num_inputs, )
y = tcn(x)

self.assertEqual( y.shape, (10, time_steps, self.num_channels[-1]) )

with self.assertRaises(ValueError):
self.test_tcn( input_shape = 'invalid' )
return






if __name__ == '__main__':
unittest.main()
Loading