forked from aws/amazon-sagemaker-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
46 lines (37 loc) · 1.44 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Converts MNIST data to TFRecords file format with Example protos."""
import os
import tensorflow as tf
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(data_set, name, directory):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
if images.shape[0] != num_examples:
raise ValueError(
"Images size %d does not match label size %d." % (images.shape[0], num_examples)
)
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(directory, name + ".tfrecords")
print("Writing", filename)
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
"height": _int64_feature(rows),
"width": _int64_feature(cols),
"depth": _int64_feature(depth),
"label": _int64_feature(int(labels[index])),
"image_raw": _bytes_feature(image_raw),
}
)
)
writer.write(example.SerializeToString())
writer.close()