-
Notifications
You must be signed in to change notification settings - Fork 6
/
transfer.py
75 lines (51 loc) · 2.08 KB
/
transfer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
A script for transferring
"""
import os
import scipy.misc as misc
import cyclegan.config as config
from cyclegan.model import CycleGAN
from cyclegan.utils.data import load_file
class Transfer:
def __init__(self, args):
self.args = args
self.cycleGAN = CycleGAN(args.ngf, args.ndf, args.num_resnet,
args.lrG, args.lrD, args.beta1, args.beta2,
args.lambdaA, args.lambdaB, args.num_pool)
self.cycleGAN.load(args.model, args.type)
def generate(self, data_src):
if self.args.type is 'BtoA':
generator = self.cycleGAN.generate_B_to_A
else:
generator = self.cycleGAN.generate_A_to_B
generated, _ = generator(data_src, recon=False)
generated = generated[0].cpu().data.numpy().transpose(1, 2, 0)
return generated
def run_dir(self, src_dir, out_dir):
for filename in os.listdir(src_dir):
src_path = os.path.join(src_dir, filename)
out_path = os.path.join(out_dir, filename)
self.run_file(src_path, out_path)
def run_file(self, src_path, out_path):
print('{0} -> {1}'.format(src_path, out_path))
data_src = load_file(src_path, self.args.input_size)
data_out = self.generate(data_src)
misc.imsave(out_path, data_out)
def prepare_output_dir(args):
args.log_dir = os.path.join(args.output_dir, 'log')
args.model_dir = os.path.join(args.output_dir, 'model')
args.test_output_dir = os.path.join(args.output_dir, 'test')
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.test_output_dir, exist_ok=True)
if __name__ == '__main__':
args = config.parse_args()
transfer = Transfer(args)
print('Running {0} Transfer'.format(args.type))
if args.src_dir and args.out_dir:
os.makedirs(args.out_dir, exist_ok=False)
transfer.run_dir(args.src_dir, args.out_dir)
elif args.src and args.out:
transfer.run_file(args.src, args.out)
else:
config.print_help()