-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils_gray.py
67 lines (49 loc) · 1.94 KB
/
utils_gray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range
# for Fourier_Domain_Swapping
def extract_ampl_phase(fft_im):
# fft_im: size should be bx3xhxwx2
fft_amp = fft_im[:, :, :, :, 0] ** 2 + fft_im[:, :, :, :, 1] ** 2
fft_amp = torch.sqrt(fft_amp)
fft_pha = torch.atan2(fft_im[:, :, :, :, 1], fft_im[:, :, :, :, 0])
return fft_amp, fft_pha
def low_freq_mutate_np(amp_src, amp_trg, L=0.1):
a_src = np.fft.fftshift(amp_src, axes=(-2, -1))
a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1))
h, w = a_src.shape
b = (np.floor(np.amin((h, w)) * L)).astype(int)
c_h = np.floor(h / 2.0).astype(int)
c_w = np.floor(w / 2.0).astype(int)
h1 = c_h - b
h2 = c_h + b + 1
w1 = c_w - b
w2 = c_w + b + 1
a_src[h1:h2, w1:w2] = a_trg[h1:h2, w1:w2]
a_src = np.fft.ifftshift(a_src, axes=(-2, -1))
return a_src
def FDA_source_to_target_np(src_img, trg_img, L=0.01):
# exchange magnitude
# input: src_img, trg_img
src_img_np = src_img # .cpu().numpy()
trg_img_np = trg_img # .cpu().numpy()
# get fft of both source and target
fft_src_np = np.fft.fft2(src_img_np, axes=(-2, -1))
fft_trg_np = np.fft.fft2(trg_img_np, axes=(-2, -1))
# extract amplitude and phase of both ffts
amp_src, pha_src = np.abs(fft_src_np), np.angle(fft_src_np)
amp_trg, pha_trg = np.abs(fft_trg_np), np.angle(fft_trg_np)
# mutate the amplitude part of source with target
amp_src_ = low_freq_mutate_np(amp_src, amp_trg, L=L)
# mutated fft of source
fft_src_ = amp_src_ * np.exp(1j * pha_src)
# get the mutated image
src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1))
src_in_trg = np.real(src_in_trg)
return src_in_trg