Skip to content

Commit

Permalink
Fix transforms (#38)
Browse files Browse the repository at this point in the history
* Fix subsets
* Make random split random
  • Loading branch information
Maxim Zhiltsov committed Oct 12, 2020
1 parent 752ad3f commit 245c770
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 9 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-

### Fixed
-
- Default `label-map` parameter value for VOC converter (<https://github.com/openvinotoolkit/datumaro/pull/34>)
- Randomness of random split transform (<https://github.com/openvinotoolkit/datumaro/pull/38>)
- `Transform.subsets()` method (<https://github.com/openvinotoolkit/datumaro/pull/38>)

### Security
-
Expand Down
2 changes: 1 addition & 1 deletion datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ def categories(self):
def subsets(self):
if self._subsets is None:
self._subsets = set(self._extractor.subsets())
return self._subsets
return super().subsets()

def __len__(self):
assert self._length in {None, 'parent'} or isinstance(self._length, int)
Expand Down
17 changes: 10 additions & 7 deletions datumaro/plugins/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,24 +355,27 @@ def __init__(self, extractor, splits, seed=None):

dataset_size = len(extractor)
indices = list(range(dataset_size))

random.seed(seed)
random.shuffle(indices)
parts = []
s = 0
for subset, ratio in splits:
lower_boundary = 0
for split_idx, (subset, ratio) in enumerate(splits):
s += ratio
boundary = int(s * dataset_size)
parts.append((boundary, subset))

upper_boundary = int(s * dataset_size)
if split_idx == len(splits) - 1:
upper_boundary = dataset_size
subset_indices = set(indices[lower_boundary : upper_boundary])
parts.append((subset_indices, subset))
lower_boundary = upper_boundary
self._parts = parts

self._subsets = set(s[0] for s in splits)
self._length = 'parent'

def _find_split(self, index):
for boundary, subset in self._parts:
if index < boundary:
for subset_indices, subset in self._parts:
if index in subset_indices:
return subset
return subset # all the possible remainder goes to the last split

Expand Down

0 comments on commit 245c770

Please sign in to comment.