Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Aug 10, 2023
1 parent 4d2aa0d commit 5929a07
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions test/nn/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def test_summary_basic(gcn):
| ├─(convs)ModuleList | -- | -- | 1,072 |
| │ └─(0)GCNConv | [100, 32], [2, 20] | [100, 16] | 528 |
| │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+---------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index']) == expected[1:-1]
Expand All @@ -81,6 +84,9 @@ def test_summary_with_sparse_tensor(gcn):
| ├─(convs)ModuleList | -- | -- | 1,072 |
| │ └─(0)GCNConv | [100, 32], [100, 100] | [100, 16] | 528 |
| │ └─(1)GCNConv | [100, 16], [100, 100] | [100, 32] | 544 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+---------------------+-----------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['adj_t']) == expected[1:-1]
Expand All @@ -96,10 +102,15 @@ def test_summary_with_max_depth(gcn):
| ├─(dropout)Dropout | [100, 16] | [100, 16] | -- |
| ├─(act)ReLU | [100, 16] | [100, 16] | -- |
| ├─(convs)ModuleList | -- | -- | 1,072 |
| ├─(norms)ModuleList | -- | -- | -- |
+---------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index'],
max_depth=1) == expected[1:-1]
assert summary(
gcn['model'],
gcn['x'],
gcn['edge_index'],
max_depth=1,
) == expected[1:-1]


@withPackage('tabulate')
Expand All @@ -118,10 +129,17 @@ def test_summary_with_leaf_module(gcn):
| │ └─(1)GCNConv | [100, 16], [2, 20] | [100, 32] | 544 |
| │ │ └─(aggr_module)SumAggregation | [120, 32], [120] | [100, 32] | -- |
| │ │ └─(lin)Linear | [100, 16] | [100, 32] | 512 |
| ├─(norms)ModuleList | -- | -- | -- |
| │ └─(0)Identity | [100, 16] | [100, 16] | -- |
| │ └─(1)Identity | -- | -- | -- |
+-----------------------------------------+--------------------+----------------+----------+
"""
assert summary(gcn['model'], gcn['x'], gcn['edge_index'],
leaf_module=None) == expected[13:-1]
assert summary(
gcn['model'],
gcn['x'],
gcn['edge_index'],
leaf_module=None,
) == expected[13:-1]


@withPackage('tabulate')
Expand Down

0 comments on commit 5929a07

Please sign in to comment.