Skip to content

Commit

Permalink
Add RGB band detection to building detection pipeline.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699212554
  • Loading branch information
jzxu authored and copybara-github committed Nov 22, 2024
1 parent d922d0e commit bd8a045
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
23 changes: 15 additions & 8 deletions src/skai/extract_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import rasterio
import rasterio.plot
from skai import extract_tiles_constants
from skai import read_raster
from skai import utils
import tensorflow as tf

Expand Down Expand Up @@ -220,13 +221,16 @@ def __init__(self, gdal_env: dict[str, str]) -> None:
def setup(self) -> None:
self._rasters = {}

def _get_raster(self, image_path: str):
raster = self._rasters.get(image_path)
if not raster:
def _get_raster(
self, image_path: str
) -> tuple[rasterio.io.DatasetReader, tuple[int, int, int]]:
raster, rgb_bands = self._rasters.get(image_path, (None, None))
if raster is None:
with rasterio.Env(**self._gdal_env):
raster = rasterio.open(image_path)
self._rasters[image_path] = raster
return raster
rgb_bands = read_raster.get_rgb_indices(raster)
self._rasters[image_path] = raster, rgb_bands
return raster, rgb_bands

def process(self, tile: Tile) -> Iterable[Example]:
"""Extract a tile from the source image and encode it as an Example.
Expand All @@ -240,15 +244,18 @@ def process(self, tile: Tile) -> Iterable[Example]:
if tile.x < 0 or tile.y < 0:
raise ValueError(f'Tile extents out of bounds: x={tile.x}, y={tile.y}')

raster = self._get_raster(tile.image_path)
raster, rgb_bands = self._get_raster(tile.image_path)
window = rasterio.windows.Window(tile.x, tile.y, tile.width, tile.height)
window_data = raster.read(window=window, boundless=True, fill_value=0)
window_data = raster.read(
indexes=rgb_bands, window=window, boundless=True, fill_value=0
)
if not np.any(window_data):
Metrics.counter('skai', 'empty_tiles').inc()
return
window_data = rasterio.plot.reshape_as_image(window_data)
# Dimensions should be (row, col, channel).
height, width, _ = window_data.shape
height, width, num_channels = window_data.shape
assert num_channels == 3, f'Expected 3 channels, got {num_channels}'

# Pad to size requested by the tile, if needed.
width_pad = tile.width - width
Expand Down
6 changes: 3 additions & 3 deletions src/skai/read_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def parse_json(json_dict: dict[str, Any]):
def detect_raster_info(raster_path: str, gdal_env: dict[str, str]):
with rasterio.Env(**gdal_env):
raster = rasterio.open(raster_path)
rgb_bands = _get_rgb_indices(raster)
rgb_bands = get_rgb_indices(raster)
bit_depth = 8
return RasterInfo(raster_path, rgb_bands, bit_depth)

Expand Down Expand Up @@ -453,7 +453,7 @@ def _generate_raster_points(
yield _RasterPoint(raster_path, longitude, latitude)


def _get_rgb_indices(raster: rasterio.io.DatasetReader) -> tuple[int, int, int]:
def get_rgb_indices(raster: rasterio.io.DatasetReader) -> tuple[int, int, int]:
"""Returns the indices of the RGB channels in the raster."""
colors = {}
for band in range(raster.count):
Expand Down Expand Up @@ -573,7 +573,7 @@ def _init_raster(self, raster_path: str) -> None:
else:
raster_info = self._raster_info[raster_path]
if raster_info.rgb_bands is None:
raster_info.rgb_bands = _get_rgb_indices(raster)
raster_info.rgb_bands = get_rgb_indices(raster)
if raster_info.bit_depth is None:
raster_info.bit_depth = 8 # TODO(jzxu): Try to auto-detect bit depth.
return raster
Expand Down
14 changes: 7 additions & 7 deletions src/skai/read_raster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,15 @@ def test_get_rgb_indices_grb_image(self):
[ColorInterp.green, ColorInterp.red, ColorInterp.blue]
)
dataset = rasterio.open(image_path)
indices = read_raster._get_rgb_indices(dataset)
indices = read_raster.get_rgb_indices(dataset)
self.assertSequenceEqual(indices, [2, 1, 3])

def test_get_rgb_indices_bgr_image(self):
image_path = _create_test_image_tiff_file(
[ColorInterp.blue, ColorInterp.green, ColorInterp.red]
)
dataset = rasterio.open(image_path)
indices = read_raster._get_rgb_indices(dataset)
indices = read_raster.get_rgb_indices(dataset)
self.assertSequenceEqual(indices, [3, 2, 1])

def test_get_rgb_indices_argb_image(self):
Expand All @@ -271,7 +271,7 @@ def test_get_rgb_indices_argb_image(self):
ColorInterp.blue,
])
dataset = rasterio.open(image_path)
indices = read_raster._get_rgb_indices(dataset)
indices = read_raster.get_rgb_indices(dataset)
self.assertSequenceEqual(indices, [2, 3, 4])

def test_get_rgb_indices_missing_red(self):
Expand All @@ -283,7 +283,7 @@ def test_get_rgb_indices_missing_red(self):
with self.assertRaisesRegex(
ValueError, 'Raster does not have a red channel.'
):
read_raster._get_rgb_indices(dataset)
read_raster.get_rgb_indices(dataset)

def test_get_rgb_indices_missing_green(self):
image_path = _create_test_image_tiff_file([
Expand All @@ -294,7 +294,7 @@ def test_get_rgb_indices_missing_green(self):
with self.assertRaisesRegex(
ValueError, 'Raster does not have a green channel.'
):
read_raster._get_rgb_indices(dataset)
read_raster.get_rgb_indices(dataset)

def test_get_rgb_indices_missing_blue(self):
image_path = _create_test_image_tiff_file([
Expand All @@ -305,7 +305,7 @@ def test_get_rgb_indices_missing_blue(self):
with self.assertRaisesRegex(
ValueError, 'Raster does not have a blue channel.'
):
read_raster._get_rgb_indices(dataset)
read_raster.get_rgb_indices(dataset)

def test_get_rgb_indices_band_name_tags(self):
image_path = _create_test_image_tiff_file_with_position_size(
Expand All @@ -329,7 +329,7 @@ def test_get_rgb_indices_band_name_tags(self):
],
)
dataset = rasterio.open(image_path)
self.assertSequenceEqual(read_raster._get_rgb_indices(dataset), (2, 3, 4))
self.assertSequenceEqual(read_raster.get_rgb_indices(dataset), (2, 3, 4))

def test_convert_image_to_uint8(self):
band = np.diag([4095, 2047, 1023, 511]).astype(np.uint16)
Expand Down

0 comments on commit bd8a045

Please sign in to comment.