forked from phizaz/diffae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
choices.py
executable file
·179 lines (138 loc) · 3.97 KB
/
choices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
from enum import Enum
from torch import nn
class TrainMode(Enum):
# manipulate mode = training the classifier
manipulate = 'manipulate'
# default trainin mode!
diffusion = 'diffusion'
# default latent training mode!
# fitting the a DDPM to a given latent
latent_diffusion = 'latentdiffusion'
def is_manipulate(self):
return self in [
TrainMode.manipulate,
]
def is_diffusion(self):
return self in [
TrainMode.diffusion,
TrainMode.latent_diffusion,
]
def is_autoenc(self):
# the network possibly does autoencoding
return self in [
TrainMode.diffusion,
]
def is_latent_diffusion(self):
return self in [
TrainMode.latent_diffusion,
]
def use_latent_net(self):
return self.is_latent_diffusion()
def require_dataset_infer(self):
"""
whether training in this mode requires the latent variables to be available?
"""
# this will precalculate all the latents before hand
# and the dataset will be all the predicted latents
return self in [
TrainMode.latent_diffusion,
TrainMode.manipulate,
]
class ManipulateMode(Enum):
"""
how to train the classifier to manipulate
"""
# train on whole celeba attr dataset
celebahq_all = 'celebahq_all'
# celeba with D2C's crop
d2c_fewshot = 'd2cfewshot'
d2c_fewshot_allneg = 'd2cfewshotallneg'
def is_celeba_attr(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
ManipulateMode.celebahq_all,
]
def is_single_class(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot(self):
return self in [
ManipulateMode.d2c_fewshot,
ManipulateMode.d2c_fewshot_allneg,
]
def is_fewshot_allneg(self):
return self in [
ManipulateMode.d2c_fewshot_allneg,
]
class ModelType(Enum):
"""
Kinds of the backbone models
"""
# unconditional ddpm
ddpm = 'ddpm'
# autoencoding ddpm cannot do unconditional generation
autoencoder = 'autoencoder'
def has_autoenc(self):
return self in [
ModelType.autoencoder,
]
def can_sample(self):
return self in [ModelType.ddpm]
class ModelName(Enum):
"""
List of all supported model classes
"""
beatgans_ddpm = 'beatgans_ddpm'
beatgans_autoenc = 'beatgans_autoenc'
class ModelMeanType(Enum):
"""
Which type of output the model predicts.
"""
eps = 'eps' # the model predicts epsilon
class ModelVarType(Enum):
"""
What is used as the model's output variance.
The LEARNED_RANGE option has been added to allow the model to predict
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
"""
# posterior beta_t
fixed_small = 'fixed_small'
# beta_t
fixed_large = 'fixed_large'
class LossType(Enum):
mse = 'mse' # use raw MSE loss (and KL when learning variances)
l1 = 'l1'
class GenerativeType(Enum):
"""
How's a sample generated
"""
ddpm = 'ddpm'
ddim = 'ddim'
class OptimizerType(Enum):
adam = 'adam'
adamw = 'adamw'
class Activation(Enum):
none = 'none'
relu = 'relu'
lrelu = 'lrelu'
silu = 'silu'
tanh = 'tanh'
def get_act(self):
if self == Activation.none:
return nn.Identity()
elif self == Activation.relu:
return nn.ReLU()
elif self == Activation.lrelu:
return nn.LeakyReLU(negative_slope=0.2)
elif self == Activation.silu:
return nn.SiLU()
elif self == Activation.tanh:
return nn.Tanh()
else:
raise NotImplementedError()
class ManipulateLossType(Enum):
bce = 'bce'
mse = 'mse'