forked from alexlee-gk/lpips-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lpips_tf.py
90 lines (73 loc) · 3.2 KB
/
lpips_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import sys
import tensorflow as tf
from six.moves import urllib
_URL = 'http://rail.eecs.berkeley.edu/models/lpips'
def _download(url, output_dir):
"""Downloads the `url` file into `output_dir`.
Modified from https://github.com/tensorflow/models/blob/master/research/slim/datasets/dataset_utils.py
"""
filename = url.split('/')[-1]
filepath = os.path.join(output_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
def lpips(input0, input1, model='net-lin', net='alex', version=0.1):
"""
Learned Perceptual Image Patch Similarity (LPIPS) metric.
Args:
input0: An image tensor of shape `[..., height, width, channels]`,
with values in [0, 1].
input1: An image tensor of shape `[..., height, width, channels]`,
with values in [0, 1].
Returns:
The Learned Perceptual Image Patch Similarity (LPIPS) distance.
Reference:
Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang.
The Unreasonable Effectiveness of Deep Features as a Perceptual Metric.
In CVPR, 2018.
"""
# flatten the leading dimensions
batch_shape = tf.shape(input0)[:-3]
input0 = tf.reshape(input0, tf.concat([[-1], tf.shape(input0)[-3:]], axis=0))
input1 = tf.reshape(input1, tf.concat([[-1], tf.shape(input1)[-3:]], axis=0))
# NHWC to NCHW
input0 = tf.transpose(input0, [0, 3, 1, 2])
input1 = tf.transpose(input1, [0, 3, 1, 2])
# normalize to [-1, 1]
input0 = input0 * 2.0 - 1.0
input1 = input1 * 2.0 - 1.0
input0_name, input1_name = '0:0', '1:0'
default_graph = tf.get_default_graph()
producer_version = default_graph.graph_def_versions.producer
cache_dir = os.path.expanduser('~/.lpips')
os.makedirs(cache_dir, exist_ok=True)
# files to try. try a specific producer version, but fallback to the version-less version (latest).
pb_fnames = [
'%s_%s_v%s_%d.pb' % (model, net, version, producer_version),
'%s_%s_v%s.pb' % (model, net, version),
]
for pb_fname in pb_fnames:
if not os.path.isfile(os.path.join(cache_dir, pb_fname)):
try:
_download(os.path.join(_URL, pb_fname), cache_dir)
except urllib.error.HTTPError:
pass
if os.path.isfile(os.path.join(cache_dir, pb_fname)):
break
with open(os.path.join(cache_dir, pb_fname), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def,
input_map={input0_name: input0, input1_name: input1})
distance, = default_graph.get_operations()[-1].outputs
if distance.shape.ndims == 4:
distance = tf.squeeze(distance, axis=[-3, -2, -1])
# reshape the leading dimensions
distance = tf.reshape(distance, batch_shape)
return distance