-
Notifications
You must be signed in to change notification settings - Fork 18
/
setting.py
77 lines (61 loc) · 3.17 KB
/
setting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
""" Defines the TraditionalSLSetting, as a variant of the TaskIncremental setting with
only one task.
"""
from dataclasses import dataclass
from typing import ClassVar, List, Optional, Type, TypeVar, Union
from sequoia.utils.utils import constant
# TODO: Re-arrange the 'multiple-inheritance' with domain-incremental and
# task-incremental, this might not be 100% accurate, as the "IID" you get from
# moving down from domain-incremental (+ only one task) might not be exactly the same as
# the one you get form TaskIncremental (+ only one task)
from ..incremental import IncrementalSLSetting
from .results import IIDResults
# TODO: IDEA: Add the pytorch lightning datamodules in the list of
# 'available datasets' for the IID setting, and make sure that it doesn't mess
# up the methods in the parents (train/val loop, dataloader construction, etc.)
# IDEA: Maybe overwrite the 'train/val/test_dataloader' methods on the setting
# and when the chosen dataset is a LightnignDataModule, then just return the
# result from the corresponding method on the LightningDataModule, rather than
# from super().
# from pl_bolts.datamodules import (CIFAR10DataModule, FashionMNISTDataModule,
# ImagenetDataModule, MNISTDataModule)
@dataclass
class TraditionalSLSetting(IncrementalSLSetting):
"""Your 'usual' supervised learning Setting, where the samples are i.i.d.
This Setting is slightly different than the others, in that it can be recovered in
*two* different ways:
- As a variant of Task-Incremental learning, but where there is only one task;
- As a variant of Domain-Incremental learning, but where there is only one task.
"""
Results: ClassVar[Type[Results]] = IIDResults
# Number of tasks.
nb_tasks: int = 5
stationary_context: bool = constant(True)
# increment: Union[int, List[int]] = constant(None)
# A different task size applied only for the first task.
# Desactivated if `increment` is a list.
initial_increment: int = constant(None)
# An optional custom class order, used for NC.
class_order: Optional[List[int]] = constant(None)
# Either number of classes per task, or a list specifying for
# every task the amount of new classes (defaults to the value of
# `increment`).
test_increment: Optional[Union[List[int], int]] = constant(None)
# A different task size applied only for the first test task.
# Desactivated if `test_increment` is a list. Defaults to the
# value of `initial_increment`.
test_initial_increment: Optional[int] = constant(None)
# An optional custom class order for testing, used for NC.
# Defaults to the value of `class_order`.
test_class_order: Optional[List[int]] = constant(None)
@property
def phases(self) -> int:
"""The number of training 'phases', i.e. how many times `method.fit` will be
called.
Defaults to the number of tasks, but may be different, for instance in so-called
Multi-Task Settings, this is set to 1.
"""
return 1 if self.stationary_context else self.nb_tasks
SettingType = TypeVar("SettingType", bound=TraditionalSLSetting)
if __name__ == "__main__":
TraditionalSLSetting.main()