Skip to content

Commit

Permalink
refactor: combine georeferencing and digitizing
Browse files Browse the repository at this point in the history
Combine georeferencing and digitizing into one Celery group.
This leads to not needing to ma a request UUID to two Celery task/group
ids.
  • Loading branch information
matthiasschaub committed Dec 19, 2024
1 parent dd34aa5 commit 3aef47f
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 147 deletions.
11 changes: 6 additions & 5 deletions sketch_map_tool/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,17 @@ def merge(fcs: list[FeatureCollection]) -> FeatureCollection:


def zip_(results: list[tuple[str, str, BytesIO]]) -> BytesIO:
"""ZIP the results of the Celery group of `georeference_sketch_map` tasks."""
"""ZIP the raster results of the Celery group of `upload_processing` tasks."""
buffer = BytesIO()
raw = set([r[1].replace("<br />", "\n") for r in results])
attributions = BytesIO("\n".join(raw).encode())
attributions = []
with ZipFile(buffer, "a") as zip_file:
for file_name, _, file in results:
for file_name, attribution, file in results:
stem = Path(file_name).stem
name = Path(stem).with_suffix(".geotiff")
zip_file.writestr(str(name), file.read())
zip_file.writestr("attributions.txt", attributions.read())
attributions.append(attribution.replace("<br />", "\n"))
file = BytesIO("\n".join(set(attributions)).encode())
zip_file.writestr("attributions.txt", file.read())
buffer.seek(0)
return buffer

Expand Down
123 changes: 71 additions & 52 deletions sketch_map_tool/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
UploadLimitsExceededError,
UUIDNotFoundError,
)
from sketch_map_tool.helpers import extract_errors, merge, to_array, zip_
from sketch_map_tool.helpers import N_, extract_errors, merge, to_array, zip_
from sketch_map_tool.models import Bbox, Layer, PaperFormat, Size
from sketch_map_tool.tasks import (
cleanup_blobs,
digitize_sketches,
georeference_sketch_map,
upload_processing,
)
from sketch_map_tool.validators import (
validate_type,
Expand Down Expand Up @@ -192,11 +191,10 @@ def digitize_results_post(lang="en") -> Response:
bboxes_[uuid] = bbox
layers_[uuid] = layer

tasks_vector = []
tasks_raster = []
tasks = []
for file_id, file_name, uuid in zip(file_ids, file_names, uuids):
tasks_vector.append(
digitize_sketches.signature(
tasks.append(
upload_processing.signature(
(
file_id,
file_name,
Expand All @@ -206,40 +204,24 @@ def digitize_results_post(lang="en") -> Response:
)
)
)
tasks_raster.append(
georeference_sketch_map.signature(
(
file_id,
file_name,
map_frames[uuid],
layers_[uuid],
bboxes_[uuid],
)
)
)
async_result_raster = group(tasks_raster).apply_async()
c = chord(
group(tasks_vector),
chord_ = chord(
group(tasks),
cleanup_blobs.signature(
kwargs={"file_ids": list(set(file_ids))},
immutable=True,
),
).apply_async()
async_result_vector = c.parent
async_group_result = chord_.parent

# group results have to be saved for them to be able to be restored later
async_result_raster.save()
async_result_vector.save()

# Unique id for current request
uuid = str(uuid4())
# Mapping of request id to multiple tasks id's
map_ = {
"raster-results": str(async_result_raster.id),
"vector-results": str(async_result_vector.id),
}
db_client_flask.set_async_result_ids(uuid, map_)
return redirect(url_for("digitize_results_get", lang=lang, uuid=uuid))
async_group_result.save()
return redirect(
url_for(
"digitize_results_get",
lang=lang,
uuid=async_group_result.id,
)
)


@app.get("/digitize/results")
Expand All @@ -253,19 +235,53 @@ def digitize_results_get(lang="en", uuid: str | None = None) -> Response | str:
return render_template("digitize-results.html", lang=lang)


def get_async_result_id(uuid: str, type_: REQUEST_TYPES):
"""Get Celery Async or Group Result UUID for given request UUID.
Try to get Celery UUID for given request from datastore.
If no Celery UUID has been found the request UUID is the same as the Celery UUID.
This function exists only for legacy support which runs out on ...
"""
# TODO: Remove this function after end of legacy support on ...
try:
return db_client_flask.get_async_result_id(uuid, type_)
except UUIDNotFoundError:
return uuid


def get_async_result(uuid) -> AsyncResult | GroupResult:
"""Get Celyer `AsyncResult` or restore `GroupResult` for given Celery UUID.
Due to legacy support it is not possible to check only the request type
(e.g. `sketch-map` or `vector-results`).
In the past every Celery result was of type `AsyncResult`.
Now `/create` results are of type `AsyncResult` and `/digitze` results are
of type `GroupResult`.
"""
# TODO: Remove this function after end of legacy support on ...
group_result = celery_app.GroupResult.restore(uuid)
async_result = celery_app.AsyncResult(uuid)

if group_result is None and async_result is None:
raise UUIDNotFoundError(
N_("There are no tasks for UUID {UUID}"),
{"UUID": uuid},
)
elif group_result is not None:
return group_result
else:
return async_result


@app.get("/api/status/<uuid>/<type_>")
@app.get("/<lang>/api/status/<uuid>/<type_>")
def status(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
validate_uuid(uuid)
validate_type(type_)

id_ = db_client_flask.get_async_result_id(uuid, type_)

# due to legacy support it is not possible to check only `type_`
# (in the past every Celery result was of type `AsyncResult`)
async_result = celery_app.GroupResult.restore(id_)
if async_result is None:
async_result = celery_app.AsyncResult(id_)
id_ = get_async_result_id(uuid, type_)
async_result = get_async_result(id_)

href = ""
info = ""
Expand Down Expand Up @@ -323,18 +339,18 @@ def download(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
validate_uuid(uuid)
validate_type(type_)

id_ = db_client_flask.get_async_result_id(uuid, type_)
id_ = get_async_result_id(uuid, type_)
async_result = get_async_result(id_)

# due to legacy support it is not possible to check only `type_`
# (in the past every Celery result was of type `AsyncResult`)
async_result = celery_app.GroupResult.restore(id_)
if async_result is None:
async_result = celery_app.AsyncResult(id_)
if not async_result.ready() or async_result.failed():
# Abort if result not ready or failed.
# No nice error message here because user should first check /api/status.
if isinstance(async_result, GroupResult):
if not async_result.ready() or all([r.failed() for r in async_result.results]):
abort(500)
else:
if not async_result.ready() or all([r.failed() for r in async_result.results]):
if not async_result.ready() or async_result.failed():
abort(500)

match type_:
case "quality-report":
mimetype = "application/pdf"
Expand All @@ -348,16 +364,19 @@ def download(uuid: str, type_: REQUEST_TYPES, lang="en") -> Response:
mimetype = "application/zip"
download_name = type_ + ".zip"
if isinstance(async_result, GroupResult):
file: BytesIO = zip_(async_result.get(propagate=False))
results = async_result.get(propagate=False)
raster_results = [r[:-1] for r in results]
file: BytesIO = zip_(raster_results)
else:
# support legacy results
file: BytesIO = async_result.get()
case "vector-results":
mimetype = "application/geo+json"
download_name = type_ + ".geojson"
if isinstance(async_result, GroupResult):
result: list = async_result.get(propagate=False)
raw = geojson.dumps(merge(result))
results = async_result.get(propagate=False)
vector_results = [r[-1] for r in results]
raw = geojson.dumps(merge(vector_results))
file: BytesIO = BytesIO(raw.encode("utf-8"))
else:
# support legacy results
Expand Down
61 changes: 30 additions & 31 deletions sketch_map_tool/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,39 +135,14 @@ def generate_quality_report(bbox: Bbox) -> BytesIO | AsyncResult:

# 2. DIGITIZE RESULTS
#
@celery.task()
def georeference_sketch_map(
file_id: int,
file_name: str,
map_frame: NDArray,
layer: Layer,
bbox: Bbox,
) -> AsyncResult | tuple[str, str, BytesIO]:
"""Georeference uploaded Sketch Map.
Returns file name, attribution text and to the map extend clipped and georeferenced
sketch map as GeoTiff.
"""
# r = interim result
r = db_client_celery.select_file(file_id)
r = to_array(r)
r = clip(r, map_frame)
r = georeference(r, bbox)
return file_name, get_attribution(layer), r


@celery.task
def digitize_sketches(
file_id: int,
file_name: str,
map_frame: NDArray,
sketch_map_frame: NDArray,
layer: Layer,
bbox: Bbox,
) -> AsyncResult | FeatureCollection:
# r = interim result
r: BytesIO = db_client_celery.select_file(file_id) # type: ignore
r: NDArray = to_array(r) # type: ignore
r: NDArray = clip(r, map_frame) # type: ignore
) -> FeatureCollection:
if layer == "osm":
yolo_obj = yolo_obj_osm
yolo_cls = yolo_cls_osm
Expand All @@ -177,16 +152,16 @@ def digitize_sketches(
else:
raise ValueError("Unexpected layer: " + layer)

r: NDArray = detect_markings(
r,
markings: list[NDArray] = detect_markings(
sketch_map_frame,
map_frame,
yolo_obj,
yolo_cls,
sam_predictor,
) # type: ignore
)
# m = marking
l = [] # noqa: E741
for m in r:
for m in markings:
m: BytesIO = georeference(m, bbox, bgr=False) # type: ignore
m: FeatureCollection = polygonize(m, layer_name=file_name) # type: ignore
m: FeatureCollection = post_process(m, file_name)
Expand All @@ -198,6 +173,30 @@ def digitize_sketches(
return merge(l)


@celery.task
def upload_processing(
file_id: int,
file_name: str,
map_frame: NDArray,
layer: Layer,
bbox: Bbox,
) -> AsyncResult | tuple[str, str, BytesIO, FeatureCollection]:
"""Georeference and digitize given sketch map."""
sketch_map_uploaded = db_client_celery.select_file(file_id)
sketch_map_frame = clip(to_array(sketch_map_uploaded), map_frame)
sketch_map_frame_georeferenced = georeference(sketch_map_frame, bbox)
sketches = digitize_sketches(
file_id,
file_name,
map_frame,
sketch_map_frame,
layer,
bbox,
)
attribution = get_attribution(layer)
return file_name, attribution, sketch_map_frame_georeferenced, sketches


@celery.task(ignore_result=True)
def cleanup_map_frames():
"""Cleanup map frames stored in the database."""
Expand Down

Large diffs are not rendered by default.

13 changes: 4 additions & 9 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,21 +370,16 @@ def uuid_digitize(
assert url_rest == "/digitize/results"

# Wait for tasks to be finished and retrieve results (vector and raster)
with flask_app.app_context():
id_vector = db_client_flask.get_async_result_id(uuid, "vector-results")
id_raster = db_client_flask.get_async_result_id(uuid, "raster-results")
group_raster = celery_app.GroupResult.restore(id_raster)
group_vector = celery_app.GroupResult.restore(id_vector)
result_raster = group_raster.get(timeout=180)
result_vector = group_vector.get(timeout=180)
result = celery_app.GroupResult.restore(uuid).get(timeout=180)

# Write sketch map to temporary test directory
dir = tmp_path_factory.mktemp(uuid, numbered=False)
path_raster = dir / "raster.zip"
path_vector = dir / "vector.geojson"
with open(path_vector, "w") as file:
file.write(json.dumps(merge(result_vector)))
file.write(json.dumps(merge(r[-1] for r in result)))
with open(path_raster, "wb") as file:
r = zip_(result_raster)
r = zip_([r[:-1] for r in result])
file.write(r.getbuffer())
return uuid

Expand Down
11 changes: 7 additions & 4 deletions tests/integration/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_create_results_post(
@patch("sketch_map_tool.routes.chord")
@vcr_app.use_cassette
def test_digitize_results_post(mock_chord, sketch_map_marked, flask_client):
mock_chord.get.side_effect = Mock() # mock chord/task execution in Celery
# mock chord/task execution in Celery
mock_chord.return_value.apply_async.return_value.parent.id = uuid4()
unique_file_name = str(uuid4())
data = {"file": [(BytesIO(sketch_map_marked), unique_file_name)], "consent": "True"}
response = flask_client.post("/digitize/results", data=data, follow_redirects=True)
Expand All @@ -96,7 +97,8 @@ def test_digitize_results_post(mock_chord, sketch_map_marked, flask_client):
@patch("sketch_map_tool.routes.chord")
@vcr_app.use_cassette
def test_digitize_results_post_no_consent(mock_chord, sketch_map_marked, flask_client):
mock_chord.get.side_effect = Mock() # mock chord/task execution in Celery
# mock chord/task execution in Celery
mock_chord.return_value.apply_async.return_value.parent.id = uuid4()
# do not send consent parameter
# -> consent is a checkbox and only send if selected
unique_file_name = str(uuid4())
Expand All @@ -123,7 +125,8 @@ def test_digitize_results_legacy_2024_04_15(
flask_client,
):
"""Legacy map frames in DB do not have bbox, lon, lat and format set."""
mock_chord.get.side_effect = Mock() # mock chord/task execution in Celery
# mock chord/task execution in Celery
mock_chord.return_value.apply_async.return_value.parent.id = uuid4()
unique_file_name = str(uuid4())
data = {"file": [(BytesIO(sketch_map_marked), unique_file_name)], "consent": "True"}
response = flask_client.post("/digitize/results", data=data, follow_redirects=True)
Expand Down Expand Up @@ -165,7 +168,7 @@ def test_api_status_uuid_digitize(uuid_digitize, type_, flask_client):
@patch("sketch_map_tool.routes.chord")
def test_api_status_uuid_digitize_info(mock_chord, sketch_map_marked, flask_client):
"""Test if custom task status information is return by /status."""
mock_chord.get.side_effect = Mock() # mock chord/task execution in Celery
mock_chord.return_value.apply_async.return_value.parent.id = uuid4()
unique_file_name = str(uuid4())
data = {"file": [(BytesIO(sketch_map_marked), unique_file_name)], "consent": "True"}
response = flask_client.post("/digitize/results", data=data, follow_redirects=True)
Expand Down
Loading

0 comments on commit 3aef47f

Please sign in to comment.