-
Notifications
You must be signed in to change notification settings - Fork 0
/
ifmnist.py
49 lines (40 loc) · 1.18 KB
/
ifmnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
class if_mnist:
"""mnist classifier interface
"""
def __init__(self):
pass
def name(self) -> str:
"""model name.
"""
pass
def fit(self,
tr_imgs: np.ndarray, tr_labels: np.ndarray,
te_imgs: np.ndarray, te_labels: np.ndarray,
sess_file: str) -> None:
"""fitting a model.
Arguments:
tr_imgs {np.ndarray}
-- training images.
tr_labels {np.ndarray}
-- training labels.
te_imgs {np.ndarray}
-- testing images.
te_labels {np.ndarray}
-- testing labels.
sess_file {str}
-- where to checkpoint the model params.
"""
pass
def infer(self, imgs: np.ndarray, sess_file: str) -> np.ndarray:
"""produce inference on an array of images.
Arguments:
imgs {np.ndarray}
-- images to be inferred.
sess_file {str}
-- where to restore model params checkpoint.
Returns:
np.ndarray
-- an array of pmfs.
"""
pass