Skip to content

Commit

Permalink
Don't crash building detection when a building touches multiple regio…
Browse files Browse the repository at this point in the history
…ns in a dedup stage.

PiperOrigin-RevId: 703184935
  • Loading branch information
jzxu authored and copybara-github committed Dec 5, 2024
1 parent 3f080ec commit 6414c61
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 55 deletions.
100 changes: 53 additions & 47 deletions src/skai/detect_buildings.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,58 +639,64 @@ def process(self, example: Example) -> Iterator[Example]:
)


def _overlaps_rows(sparse_tensor: SparseTensor, start_row: int,
end_row: int) -> bool:
start = [start_row, 0]
size = [end_row - start_row, sparse_tensor.dense_shape[1]]
sparse_slice = tf.sparse.slice(sparse_tensor, start=start, size=size)
return bool(len(sparse_slice.indices))


def _overlaps_columns(sparse_tensor: SparseTensor, start_col: int,
end_col: int) -> bool:
start = [0, start_col]
size = [sparse_tensor.dense_shape[0], end_col - start_col]
sparse_slice = tf.sparse.slice(sparse_tensor, start=start, size=size)
return bool(len(sparse_slice.indices))


def _get_regions_overlapped(mask: SparseTensor, margin_size: int) -> List[bool]:
def _get_regions_overlapped(mask: SparseTensor, margin_size: int) -> list[int]:
"""Computes which tile regions a building mask overlaps.
Args:
mask: Building mask as a SparseTensor.
margin_size: Size of the margin in pixels.
Returns:
A list O of 9 booleans, where O[k] represents whether the building mask
overlaps that region.
A list of 9 integers, where the kth element is the number of mask pixels in
that region. 0 pixels means no overlap.
"""
tile_height = mask.dense_shape[0].numpy()
tile_width = mask.dense_shape[1].numpy()

# Region size is the margin size * 2 because the region includes the margin of
# the current tile AND the margin of the adjacent tile. For example, let's say
# there are two tiles, X and Y, and X is directly on top of Y. This diagram
# explains how X and Y overlap.
#
# +--> ----------------- Top edge of tile Y
# |
# | Y's top margin
# Overlap
# region ================= Where the central regions of X and Y touch
# |
# | X's bottom margin
# |
# +--> ----------------- Bottom edge of tile X
region_size = margin_size * 2
row_bands = [
_overlaps_rows(mask, 0, region_size),
_overlaps_rows(mask, region_size, tile_height - region_size),
_overlaps_rows(mask, tile_height - region_size, tile_height)

row_starts = [
0,
region_size,
tile_height - region_size,
]
column_bands = [
_overlaps_columns(mask, 0, region_size),
_overlaps_columns(mask, region_size, tile_width - region_size),
_overlaps_columns(mask, tile_width - region_size, tile_width)
col_starts = [
0,
region_size,
tile_width - region_size,
]
overlaps = [
row_bands[0] & column_bands[0],
row_bands[0] & column_bands[1],
row_bands[0] & column_bands[2],
row_bands[1] & column_bands[0],
row_bands[1] & column_bands[1],
row_bands[1] & column_bands[2],
row_bands[2] & column_bands[0],
row_bands[2] & column_bands[1],
row_bands[2] & column_bands[2],
row_sizes = [
region_size,
tile_height - 2 * region_size,
region_size,
]
return overlaps
col_sizes = [
region_size,
tile_width - 2 * region_size,
region_size,
]
output = []
for region in range(9):
row = int(region // 3)
col = int(region % 3)
start = (row_starts[row], col_starts[col])
size = (row_sizes[row], col_sizes[col])
output.append(len(tf.sparse.slice(mask, start=start, size=size).indices))
return output


def _encode_sparse_tensor(sparse_tensor: SparseTensor, example: Example,
Expand Down Expand Up @@ -732,7 +738,7 @@ def augment_overlap_region(building: Example) -> Example:
"""Identifies the tile region(s) that the building touches.
For the purposes of parallelizing building deduplication, each tile is divided
into 8 regions as follows:
into 9 regions as follows:
+---------+
|0| 1 |2|
Expand Down Expand Up @@ -822,16 +828,16 @@ def augment_overlap_region(building: Example) -> Example:
augmented = tf.train.Example()
augmented.CopyFrom(building)
for stage in range(4):
feature = f'dedup_stage_{stage}_region'
regions_touched = [r for r in stage_to_regions[stage] if overlaps[r]]
if len(regions_touched) == 1:
augmented.features.feature[feature].float_list.value[:] = region_coords[
regions_touched[0]
]
if not regions_touched:
continue
if len(regions_touched) > 1:
raise ValueError(
f'In stage {stage}, mask touches multiple regions: {regions_touched}'
)
regions_touched.sort(key=lambda r: overlaps[r], reverse=True)
Metrics.counter('skai', 'dedup_mask_touches_multiple_regions').inc()
feature = f'dedup_stage_{stage}_region'
augmented.features.feature[feature].float_list.value[:] = region_coords[
regions_touched[0]
]
return augmented


Expand Down
114 changes: 106 additions & 8 deletions src/skai/detect_buildings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from typing import List, Tuple

from absl.testing import parameterized
import apache_beam as beam
from apache_beam.testing import test_pipeline
from apache_beam.testing import util
Expand Down Expand Up @@ -196,7 +197,7 @@ def _create_fake_tile_example() -> tf.train.Example:
return example


class DetectBuildingsTest(tf.test.TestCase):
class DetectBuildingsTest(tf.test.TestCase, parameterized.TestCase):

def test_building_detection_empty_model_res(self):
"""Tests the Detect Buildings stage outputs zero building instances because the model returns empty detection."""
Expand Down Expand Up @@ -293,7 +294,105 @@ def _check_results(results):

util.assert_that(result, _check_results)

def test_augment_overlap_region(self):
@parameterized.named_parameters(
dict(
testcase_name='region_0',
pixel_coords=[(10, 10)],
stage_0_region=[2.5, 6.5],
stage_1_region=[],
stage_2_region=[],
stage_3_region=[],
),
dict(
testcase_name='region_1',
pixel_coords=[(10, 60)],
stage_0_region=[],
stage_1_region=[],
stage_2_region=[2.5, 7],
stage_3_region=[],
),
dict(
testcase_name='region_2',
pixel_coords=[(10, 360)],
stage_0_region=[2.5, 7.5],
stage_1_region=[],
stage_2_region=[],
stage_3_region=[],
),
dict(
testcase_name='region_3',
pixel_coords=[(60, 10)],
stage_0_region=[],
stage_1_region=[3, 6.5],
stage_2_region=[],
stage_3_region=[],
),
dict(
testcase_name='region_4',
pixel_coords=[(60, 60)],
stage_0_region=[],
stage_1_region=[],
stage_2_region=[],
stage_3_region=[3, 7],
),
dict(
testcase_name='region_5',
pixel_coords=[(60, 360)],
stage_0_region=[],
stage_1_region=[3, 7.5],
stage_2_region=[],
stage_3_region=[],
),
dict(
testcase_name='region_6',
pixel_coords=[(460, 10)],
stage_0_region=[3.5, 6.5],
stage_1_region=[],
stage_2_region=[],
stage_3_region=[],
),
dict(
testcase_name='region_7',
pixel_coords=[(460, 60)],
stage_0_region=[],
stage_1_region=[],
stage_2_region=[3.5, 7],
stage_3_region=[],
),
dict(
testcase_name='region_8',
pixel_coords=[(460, 360)],
stage_0_region=[3.5, 7.5],
stage_1_region=[],
stage_2_region=[],
stage_3_region=[],
),
)
def test_augment_overlap_single_region(
self,
pixel_coords: list[int],
stage_0_region: list[float],
stage_1_region: list[float],
stage_2_region: list[float],
stage_3_region: list[float],
):
margin_size = 25
tile_width = 400
tile_height = 500

building = _create_building_example(3, 7, 0, 0, tile_height, tile_width,
margin_size, 0.5, pixel_coords)
building = detect_buildings.augment_overlap_region(building)
self.assertListEqual(
_get_float_feature(building, 'dedup_stage_0_region'), stage_0_region)
self.assertListEqual(
_get_float_feature(building, 'dedup_stage_1_region'), stage_1_region)
self.assertListEqual(
_get_float_feature(building, 'dedup_stage_2_region'), stage_2_region)
self.assertListEqual(
_get_float_feature(building, 'dedup_stage_3_region'), stage_3_region)

def test_augment_overlap_multi_regions(self):
margin_size = 25
tile_width = 400
tile_height = 500
Expand All @@ -315,7 +414,7 @@ def test_augment_overlap_region(self):

# This building is on the corner of regions 4, 5, 7, and 8, and touches all
# of them.
pixel_coords2 = [(449, 349), (449, 350), (450, 349), (450, 450)]
pixel_coords2 = [(449, 349), (449, 350), (450, 349), (450, 350)]
building2 = _create_building_example(3, 7, 0, 0, tile_height, tile_width,
margin_size, 0.5, pixel_coords2)
building2 = detect_buildings.augment_overlap_region(building2)
Expand Down Expand Up @@ -405,11 +504,10 @@ def test_recursively_copy_directory(self):
dest_dir = self.create_tempdir()
src_subdir = src_dir.mkdir()
# Make some files in the src directory.
[src_dir.create_file(file_path=file_prefix) for file_prefix in ['a', 'b']]
[
src_subdir.create_file(file_path=file_prefix)
for file_prefix in ['c', 'd']
]
for file_prefix in ['a', 'b']:
src_dir.create_file(file_path=file_prefix)
for file_prefix in ['c', 'd']:
src_subdir.create_file(file_path=file_prefix)
detect_buildings._recursively_copy_directory(src_dir, dest_dir)
for src, dest in zip(gfile.walk(src_dir), gfile.walk(dest_dir)):
src_dir_name, src_subdir, src_leaf_files = src
Expand Down

0 comments on commit 6414c61

Please sign in to comment.