-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
258 additions
and
531 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
name: CI | ||
|
||
on: | ||
push: | ||
branches: [main] | ||
pull_request: | ||
branches: [main] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout code | ||
uses: actions/checkout@v3 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.12' | ||
|
||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -e . | ||
- name: Run tests | ||
run: | | ||
python run_tests.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
*.pyc | ||
dist | ||
tfrecord.egg-info | ||
/test_* | ||
/*.proto | ||
/*.sh | ||
/.pypirc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import unittest | ||
import sys | ||
|
||
if __name__ == '__main__': | ||
loader = unittest.TestLoader() | ||
tests = loader.discover('tests') | ||
testRunner = unittest.TextTestRunner() | ||
result = testRunner.run(tests) | ||
# Exit with a non-zero status code if tests failed | ||
if not result.wasSuccessful(): | ||
sys.exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from tfrecord.reader import example_loader, tfrecord_iterator | ||
from tfrecord.writer import TFRecordWriter | ||
|
||
|
||
class TestReadWrite(unittest.TestCase): | ||
|
||
def write_tfrecord(self, filename, records): | ||
writer = TFRecordWriter(filename) | ||
for datum in records: | ||
writer.write(datum) | ||
writer.close() | ||
|
||
def read_tfrecord(self, filename): | ||
iterator = tfrecord_iterator(filename) | ||
records = list(iterator) | ||
return records | ||
|
||
def test_write_and_read_integers(self): | ||
datum = {"int_key": (123, "int")} | ||
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | ||
filename = temp_file.name | ||
self.write_tfrecord(filename, [datum]) | ||
|
||
records = self.read_tfrecord(filename) | ||
|
||
self.assertEqual(len(records), 1) | ||
example = list(example_loader(filename, None)) | ||
np.testing.assert_array_equal( | ||
example[0]["int_key"], np.array([123], dtype=np.int64) | ||
) | ||
|
||
os.remove(filename) | ||
|
||
def test_write_and_read_floats(self): | ||
datum = {"float_key": (1.23, "float")} | ||
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | ||
filename = temp_file.name | ||
self.write_tfrecord(filename, [datum]) | ||
|
||
records = self.read_tfrecord(filename) | ||
|
||
self.assertEqual(len(records), 1) | ||
example = list(example_loader(filename, None)) | ||
np.testing.assert_array_equal( | ||
example[0]["float_key"], np.array([1.23], dtype=np.float32) | ||
) | ||
|
||
os.remove(filename) | ||
|
||
def test_write_and_read_string_arrays(self): | ||
datum = {"string_key": ([b"test1", b"test2"], "byte")} | ||
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | ||
filename = temp_file.name | ||
self.write_tfrecord(filename, [datum]) | ||
|
||
records = self.read_tfrecord(filename) | ||
|
||
self.assertEqual(len(records), 1) | ||
example = list(example_loader(filename, None)) | ||
self.assertEqual(example[0]["string_key"], b"test1") | ||
|
||
os.remove(filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import unittest | ||
from unittest.mock import mock_open, patch | ||
|
||
import numpy as np | ||
from tfrecord.reader import ( | ||
example_loader, | ||
sequence_loader, | ||
tfrecord_iterator, | ||
process_feature, | ||
) | ||
|
||
from tfrecord import example_pb2 | ||
|
||
|
||
class TestFeatureProcessing(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.feature_bytes = example_pb2.Feature( | ||
bytes_list=example_pb2.BytesList(value=[b"test"]) | ||
) | ||
self.feature_float = example_pb2.Feature( | ||
float_list=example_pb2.FloatList(value=[1.0]) | ||
) | ||
self.feature_int = example_pb2.Feature( | ||
int64_list=example_pb2.Int64List(value=[1]) | ||
) | ||
|
||
def test_process_feature_bytes(self): | ||
result = process_feature( | ||
self.feature_bytes, "byte", {"byte": "bytes_list"}, "key" | ||
) | ||
self.assertEqual(result, b"test") | ||
|
||
def test_process_feature_float(self): | ||
result = process_feature( | ||
self.feature_float, "float", {"float": "float_list"}, "key" | ||
) | ||
np.testing.assert_array_equal(result, np.array([1.0], dtype=np.float32)) | ||
|
||
def test_process_feature_int(self): | ||
result = process_feature(self.feature_int, "int", {"int": "int64_list"}, "key") | ||
np.testing.assert_array_equal(result, np.array([1], dtype=np.int64)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import unittest | ||
import tempfile | ||
import os | ||
import numpy as np | ||
|
||
from tfrecord.reader import tfrecord_iterator | ||
from tfrecord.writer import TFRecordWriter | ||
|
||
|
||
class TestTFRecordWriter(unittest.TestCase): | ||
|
||
def test_tfrecord_writer_write_example(self): | ||
datum = {"key": (b"value", "byte")} | ||
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | ||
filename = temp_file.name | ||
writer = TFRecordWriter(filename) | ||
writer.write(datum) | ||
writer.close() | ||
|
||
iterator = tfrecord_iterator(filename) | ||
records = list(iterator) | ||
self.assertEqual(records[0], b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value") | ||
os.remove(filename) | ||
|
||
def test_tfrecord_writer_write_sequence_example(self): | ||
datum = {"key": (b"value", "byte")} | ||
sequence_datum = {"seq_key": ([b"seq_value"], "byte")} | ||
with tempfile.NamedTemporaryFile(delete=False) as temp_file: | ||
filename = temp_file.name | ||
writer = TFRecordWriter(filename) | ||
writer.write(datum, sequence_datum) | ||
writer.close() | ||
|
||
iterator = tfrecord_iterator(filename) | ||
records = list(iterator) | ||
self.assertTrue(records[0].tobytes().startswith(b"\n\x12\n\x10\n\x03key\x12\t\n\x07\n\x05value")) | ||
os.remove(filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.