forked from deepeshdm/Neural-Style-Transfer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
API.py
53 lines (39 loc) · 1.95 KB
/
API.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
def transfer_style(content_image, style_image, model_path):
"""
:param content_image: path of the content image
:param style_image: path of the style image
:param model_path: path to the downloaded pre-trained model.
The 'model' directory already contains the downloaded pre-trained model,but
you can also download the pre-trained model from the below TF HUB link:
https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2
:return: An image as 3D numpy array.
"""
print("Loading images...")
# Load content and style images
content_image = plt.imread(content_image)
style_image = plt.imread(style_image)
print("Resizing and Normalizing images...")
# Convert to float32 numpy array, add batch dimension, and normalize to range [0, 1]. Example using numpy:
content_image = content_image.astype(np.float32)[np.newaxis, ...] / 255.
style_image = style_image.astype(np.float32)[np.newaxis, ...] / 255.
# Optionally resize the images. It is recommended that the style image is about
# 256 pixels (this size was used when training the style transfer network).
# The content image can be any size.
style_image = tf.image.resize(style_image, (256, 256))
print("Loading pre-trained model...")
# The hub.load() loads any TF Hub model
hub_module = hub.load(model_path)
print("Generating stylized image now...wait a minute")
# Stylize image.
outputs = hub_module(tf.constant(content_image), tf.constant(style_image))
stylized_image = outputs[0]
# reshape the stylized image
stylized_image = np.array(stylized_image)
stylized_image = stylized_image.reshape(
stylized_image.shape[1], stylized_image.shape[2], stylized_image.shape[3])
print("Stylizing completed...")
return stylized_image