-
Notifications
You must be signed in to change notification settings - Fork 3
/
get_fda_image.py
138 lines (99 loc) · 4.51 KB
/
get_fda_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import numpy as np
import os
import pandas as pd
from PIL import Image
import scipy.misc
def make_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def extract_ampl_phase(fft_im):
# fft_im: size should be bx3xhxwx2
fft_amp = fft_im[:,:,:,:,0]**2 + fft_im[:,:,:,:,1]**2
fft_amp = torch.sqrt(fft_amp)
fft_pha = torch.atan2( fft_im[:,:,:,:,1], fft_im[:,:,:,:,0] )
return fft_amp, fft_pha
def low_freq_mutate( amp_src, amp_trg, L=0.1 ):
_, _, h, w = amp_src.size()
b = ( np.floor(np.amin((h,w))*L) ).astype(int) # get b
amp_src[:,:,0:b,0:b] = amp_trg[:,:,0:b,0:b] # top left
amp_src[:,:,0:b,w-b:w] = amp_trg[:,:,0:b,w-b:w] # top right
amp_src[:,:,h-b:h,0:b] = amp_trg[:,:,h-b:h,0:b] # bottom left
amp_src[:,:,h-b:h,w-b:w] = amp_trg[:,:,h-b:h,w-b:w] # bottom right
return amp_src
def low_freq_mutate_np( amp_src, amp_trg, L=0.1 ):
a_src = np.fft.fftshift( amp_src, axes=(-2, -1) )
a_trg = np.fft.fftshift( amp_trg, axes=(-2, -1) )
_, h, w = a_src.shape
b = ( np.floor(np.amin((h,w))*L) ).astype(int)
c_h = np.floor(h/2.0).astype(int)
c_w = np.floor(w/2.0).astype(int)
h1 = c_h-b
h2 = c_h+b+1
w1 = c_w-b
w2 = c_w+b+1
a_src[:,h1:h2,w1:w2] = a_trg[:,h1:h2,w1:w2]
a_src = np.fft.ifftshift( a_src, axes=(-2, -1) )
return a_src
def FDA_source_to_target(src_img, trg_img, L=0.1):
# exchange magnitude
# input: src_img, trg_img
# get fft of both source and target
fft_src = torch.rfft( src_img.clone(), signal_ndim=2, onesided=False )
fft_trg = torch.rfft( trg_img.clone(), signal_ndim=2, onesided=False )
# extract amplitude and phase of both ffts
amp_src, pha_src = extract_ampl_phase( fft_src.clone())
amp_trg, pha_trg = extract_ampl_phase( fft_trg.clone())
# replace the low frequency amplitude part of source with that from target
amp_src_ = low_freq_mutate( amp_src.clone(), amp_trg.clone(), L=L )
# recompose fft of source
fft_src_ = torch.zeros( fft_src.size(), dtype=torch.float )
fft_src_[:,:,:,:,0] = torch.cos(pha_src.clone()) * amp_src_.clone()
fft_src_[:,:,:,:,1] = torch.sin(pha_src.clone()) * amp_src_.clone()
# get the recomposed image: source content, target style
_, _, imgH, imgW = src_img.size()
src_in_trg = torch.irfft( fft_src_, signal_ndim=2, onesided=False, signal_sizes=[imgH,imgW] )
return src_in_trg
def FDA_source_to_target_np(src_img, trg_img, L=0.1 ):
# exchange magnitude
# input: src_img, trg_img
src_img_np = src_img #.cpu().numpy()
trg_img_np = trg_img #.cpu().numpy()
# get fft of both source and target
fft_src_np = np.fft.fft2( src_img_np, axes=(-2, -1) )
fft_trg_np = np.fft.fft2( trg_img_np, axes=(-2, -1) )
# extract amplitude and phase of both ffts
amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)
# mutate the amplitude part of source with target
amp_src_ = low_freq_mutate_np( amp_src, amp_trg, L=L )
# mutated fft of source
fft_src_ = amp_src_ * np.exp( 1j * pha_src )
# get the mutated image
src_in_trg = np.fft.ifft2( fft_src_, axes=(-2, -1) )
src_in_trg = np.real(src_in_trg)
return src_in_trg
if __name__ == '__main__':
img_path = './ori_png/'
dir_path = './fda_data/'
for i in range(1, 2):
output_path = os.path.join(dir_path, 's{}'.format(i))
make_dir(output_path)
pair_csv = pd.read_csv('./pair_{}.csv'.format(i),)
for j in range(len(pair_csv)):
src_name = pair_csv.loc[j, 'source_img']
trg_name = pair_csv.loc[j, 'target_img']
print(src_name, trg_name)
src_img = Image.open(os.path.join(img_path, src_name)).convert('RGB')
trg_img = Image.open(os.path.join(img_path, trg_name)).convert('RGB')
print(src_img.size)
print(trg_img.size)
trg_img = trg_img.resize(src_img.size[:2], Image.BICUBIC)
print(trg_img.size)
src_img = np.asarray(src_img, np.float32)
trg_img = np.asarray(trg_img, np.float32)
src_img = src_img.transpose((2, 0, 1)) # (C, H, W)
trg_img = trg_img.transpose((2, 0, 1))
src_in_trg = FDA_source_to_target_np(src_img, trg_img, L=0.01)
src_in_trg = src_in_trg.transpose((1,2,0))
scipy.misc.toimage(src_in_trg, cmin=0.0, cmax=255.0).save(os.path.join(output_path, src_name))