Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seed option #42

Merged
merged 4 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=unicodedata,cv2 # synthtiger
extension-pkg-whitelist=unicodedata,numpy,cv2 # synthtiger

# Specify a score threshold to be exceeded before program exits with error.
fail-under=10.0
Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ $ export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
```

```
usage: synthtiger [-h] [-o DIR] [-c NUM] [-w NUM] [-v] SCRIPT NAME [CONFIG]
usage: synthtiger [-h] [-o DIR] [-c NUM] [-w NUM] [-s NUM] [-v] SCRIPT NAME [CONFIG]

positional arguments:
SCRIPT Script file path.
Expand All @@ -60,8 +60,9 @@ positional arguments:
optional arguments:
-h, --help show this help message and exit
-o DIR, --output DIR Directory path to save data.
-c NUM, --count NUM Number of data. [default: 100]
-c NUM, --count NUM Number of output data. [default: 100]
-w NUM, --worker NUM Number of workers. If 0, It generates data in the main process. [default: 0]
-s NUM, --seed NUM Random seed. [default: None]
-v, --verbose Print error messages while generating data.
```

Expand Down
71 changes: 54 additions & 17 deletions synthtiger/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
MIT license
"""

import itertools
import os
import random
import sys
import traceback
from multiprocessing import Process, Queue

import imgaug
import numpy as np
import yaml

Expand All @@ -31,21 +33,33 @@ def read_config(path):
return config


def generator(path, name, config=None, worker=0, verbose=False):
def generator(path, name, config=None, count=None, worker=0, seed=None, verbose=False):
counter = range(count) if count is not None else itertools.count()
tasks = _task_generator(seed)

if worker > 0:
queue = Queue(maxsize=1024)
for _ in range(worker):
_run(_worker, (path, name, config, queue, verbose))
task_queue = Queue(maxsize=worker)
data_queue = Queue(maxsize=worker)
pre_count = min(worker, count) if count is not None else worker
post_count = count - pre_count if count is not None else None

while True:
data = queue.get()
yield data
for _ in range(worker):
_run(_worker, (path, name, config, task_queue, data_queue, verbose))
for _ in range(pre_count):
task_queue.put(next(tasks))

for idx in counter:
task_idx, data = data_queue.get()
if post_count is None or idx < post_count:
task_queue.put(next(tasks))
yield task_idx, data
else:
template = read_template(path, name, config)

while True:
data = _generate(template, verbose)
yield data
for _ in counter:
task_idx, task_seed = next(tasks)
data = _generate(template, task_seed, verbose)
yield task_idx, data


def _run(func, args):
Expand All @@ -55,22 +69,45 @@ def _run(func, args):
return proc


def _worker(path, name, config, queue, verbose):
random.seed()
np.random.seed()
def _task_generator(seed):
random_generator = random.Random(seed)
task_idx = -1

while True:
task_idx += 1
task_seed = random_generator.getrandbits(128)
yield task_idx, task_seed


def _worker(path, name, config, task_queue, data_queue, verbose):
template = read_template(path, name, config)

while True:
data = _generate(template, verbose)
queue.put(data)
task_idx, task_seed = task_queue.get()
data = _generate(template, task_seed, verbose)
data_queue.put((task_idx, data))


def _generate(template, seed, verbose):
temp_state = random.getstate()
temp_np_state = np.random.get_state()
temp_imgaug_state = imgaug.random.get_global_rng().state

random.seed(seed)
np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state())
imgaug.seed(seed)

def _generate(template, verbose):
while True:
try:
data = template.generate()
except:
if verbose:
print(f"{traceback.format_exc()}")
continue
return data
break

random.setstate(temp_state)
np.random.set_state(temp_np_state)
imgaug.random.get_global_rng().state = temp_imgaug_state

return data
26 changes: 20 additions & 6 deletions synthtiger/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,27 @@
def run(args):
if args.config is not None:
config = synthtiger.read_config(args.config)

pprint.pprint(config)

template = synthtiger.read_template(args.script, args.name, config)
generator = synthtiger.generator(
args.script, args.name, config, worker=args.worker, verbose=args.verbose
args.script,
args.name,
config=config,
count=args.count,
worker=args.worker,
seed=args.seed,
verbose=args.verbose,
)

if args.output is not None:
template.init_save(args.output)

for idx in range(args.count):
data = next(generator)
for idx, (task_idx, data) in enumerate(generator):
if args.output is not None:
template.save(args.output, data, idx)
print(f"Generated {idx + 1} data")
template.save(args.output, data, task_idx)
print(f"Generated {idx + 1} data (task {task_idx})")

if args.output is not None:
template.end_save(args.output)
Expand All @@ -49,7 +55,7 @@ def parse_args():
metavar="NUM",
type=int,
default=100,
help="Number of data. [default: 100]",
help="Number of output data. [default: 100]",
)
parser.add_argument(
"-w",
Expand All @@ -59,6 +65,14 @@ def parse_args():
default=0,
help="Number of workers. If 0, It generates data in the main process. [default: 0]",
)
parser.add_argument(
"-s",
"--seed",
metavar="NUM",
type=int,
default=None,
help="Random seed. [default: None]",
)
parser.add_argument(
"-v",
"--verbose",
Expand Down