Skip to content

Commit

Permalink
Merge pull request #6 from MIERUNE/fix/alpha
Browse files Browse the repository at this point in the history
いくつかの修正・テスト
  • Loading branch information
Kanahiro authored Dec 10, 2023
2 parents e87ffb8 + 4c9b6b3 commit 3eb8ff7
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 16 deletions.
44 changes: 34 additions & 10 deletions csmap/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,40 @@ def parse_args():
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("input_dem_path", type=str)
parser.add_argument("output_path", type=str)
parser.add_argument("--chunk_size", type=int, default=1024)
parser.add_argument("--max_workers", type=int, default=1)
parser.add_argument("--gf_size", type=int, default=12)
parser.add_argument("--gf_sigma", type=int, default=3)
parser.add_argument("--curvature_size", type=int, default=1)
parser.add_argument("--height_scale", type=float, nargs=2, default=[0.0, 1000.0])
parser.add_argument("--slope_scale", type=float, nargs=2, default=[0.0, 1.5])
parser.add_argument("--curvature_scale", type=float, nargs=2, default=[-0.1, 0.1])
parser.add_argument("input_dem_path", type=str, help="input DEM path")
parser.add_argument("output_path", type=str, help="output path")
parser.add_argument(
"--chunk_size", type=int, default=1024, help="chunk size as pixel"
)
parser.add_argument(
"--max_workers", type=int, default=1, help="max workers for multiprocessing"
)
parser.add_argument("--gf_size", type=int, default=12, help="gaussian filter size")
parser.add_argument("--gf_sigma", type=int, default=3, help="gaussian filter sigma")
parser.add_argument(
"--curvature_size", type=int, default=1, help="curvature filter size"
)
parser.add_argument(
"--height_scale",
type=float,
nargs=2,
default=[0.0, 1000.0],
help="height scale, min max",
)
parser.add_argument(
"--slope_scale",
type=float,
nargs=2,
default=[0.0, 1.5],
help="slope scale, min max",
)
parser.add_argument(
"--curvature_scale",
type=float,
nargs=2,
default=[-0.1, 0.1],
help="curvature scale, min max",
)

args = parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions csmap/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@ def blend(
+ curvature_blue * blend_params["curvature_blue"]
+ curvature_ryb * blend_params["curvature_ryb"]
)
_blend = _blend.astype(np.uint8) # force uint8
_blend[3, :, :] = 255 # alpha
return _blend
14 changes: 8 additions & 6 deletions csmap/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ def _process_chunk(
csmap_chunk = csmap(chunk, params)
csmap_chunk_margin_removed = csmap_chunk[
:,
params.gf_size // 2 : -(params.gf_size // 2),
params.gf_size // 2 : -(params.gf_size // 2),
(params.gf_size + params.gf_sigma)
// 2 : -((params.gf_size + params.gf_sigma) // 2),
(params.gf_size + params.gf_sigma)
// 2 : -((params.gf_size + params.gf_sigma) // 2),
] # shape = (4, chunk_size - margin, chunk_size - margin)

if lock is None:
Expand All @@ -94,7 +96,7 @@ def process(
with rasterio.open(input_dem_path) as dem:
margin = params.gf_size + params.gf_sigma # ガウシアンフィルタのサイズ+シグマ
# チャンクごとの処理結果には「淵=margin」が生じるのでこの部分を除外する必要がある
margin_to_removed = 2 * (margin // 2) # 整数値に切り捨てた値*両端
margin_to_removed = margin // 2 # 整数値に切り捨てた値*両端

# マージンを考慮したtransform
transform = Affine(
Expand All @@ -110,8 +112,8 @@ def process(
)

# 生成されるCS立体図のサイズ
out_width = dem.shape[1] - margin_to_removed - 2
out_height = dem.shape[0] - margin_to_removed - 2
out_width = dem.shape[1] - margin_to_removed * 2 - 2
out_height = dem.shape[0] - margin_to_removed * 2 - 2

with rasterio.open(
output_path,
Expand All @@ -126,7 +128,7 @@ def process(
compress="LZW",
) as dst:
# chunkごとに処理
chunk_csmap_size = chunk_size - margin_to_removed - 2
chunk_csmap_size = chunk_size - margin_to_removed * 2 - 2

# 並列処理しない場合とする場合で処理を分ける
if max_workers == 1:
Expand Down
Binary file added tests/fixture/csmap.tif
Binary file not shown.
Binary file added tests/fixture/dem.tif
Binary file not shown.
Binary file added tests/fixture/process.tif
Binary file not shown.
104 changes: 104 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os

import rasterio

from csmap.process import process, csmap, CsmapParams


def test_csmap():
"""リグレッションがないか確認する"""
dem_path = os.path.join(os.path.dirname(__file__), "fixture", "dem.tif")
dem = rasterio.open(dem_path).read(1)

_csmap = csmap(
dem,
params=CsmapParams(
gf_size=12,
gf_sigma=3,
curvature_size=1,
height_scale=[0, 1000],
slope_scale=[0, 1.5],
curvature_scale=[-0.1, 0.1],
),
)

csmap_fixture_path = os.path.join(os.path.dirname(__file__), "fixture", "csmap.tif")
csmap_fixture = rasterio.open(csmap_fixture_path).read([1, 2, 3, 4])

assert _csmap.shape == csmap_fixture.shape
assert _csmap.dtype == csmap_fixture.dtype

# compare all pixels
assert (_csmap == csmap_fixture).all()


def test_process_by_chunk():
"""チャンクごとに処理した結果と一度に処理した結果が一致することをテスト"""
dem_path = os.path.join(os.path.dirname(__file__), "fixture", "dem.tif")

csmap_params = CsmapParams(
gf_size=12,
gf_sigma=3,
curvature_size=1,
height_scale=[0, 1000],
slope_scale=[0, 1.5],
curvature_scale=[-0.1, 0.1],
)

csmap_by_chunk_path = os.path.join(os.path.dirname(__file__), "test_chunk.tif")
process(
input_dem_path=dem_path,
output_path=csmap_by_chunk_path,
chunk_size=256,
params=csmap_params,
max_workers=2,
)
csmap_by_chunk = rasterio.open(csmap_by_chunk_path).read([1, 2, 3, 4])

# チャンク分割なしに処理した結果と比較
csmap_fixture_path = os.path.join(
os.path.dirname(__file__), "fixture", "process.tif"
)
csmap_fixture = rasterio.open(csmap_fixture_path).read([1, 2, 3, 4])

assert csmap_by_chunk.shape == csmap_fixture.shape
assert csmap_by_chunk.dtype == csmap_fixture.dtype

# compare all pixels
assert (csmap_by_chunk == csmap_fixture).all()


def test_process_by_worker():
"""並列処理をしても結果に影響がないことをテスト"""
dem_path = os.path.join(os.path.dirname(__file__), "fixture", "dem.tif")

csmap_params = CsmapParams(
gf_size=12,
gf_sigma=3,
curvature_size=1,
height_scale=[0, 1000],
slope_scale=[0, 1.5],
curvature_scale=[-0.1, 0.1],
)

csmap_by_worker_path = os.path.join(os.path.dirname(__file__), "test_worker.tif")
process(
input_dem_path=dem_path,
output_path=csmap_by_worker_path,
chunk_size=1024,
params=csmap_params,
max_workers=2,
)
csmap_by_worker = rasterio.open(csmap_by_worker_path).read([1, 2, 3, 4])

# チャンク分割なしに処理した結果と比較
csmap_fixture_path = os.path.join(
os.path.dirname(__file__), "fixture", "process.tif"
)
csmap_fixture = rasterio.open(csmap_fixture_path).read([1, 2, 3, 4])

assert csmap_by_worker.shape == csmap_fixture.shape
assert csmap_by_worker.dtype == csmap_fixture.dtype

# compare all pixels
assert (csmap_by_worker == csmap_fixture).all()

0 comments on commit 3eb8ff7

Please sign in to comment.