-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathreader.py
216 lines (179 loc) · 7.99 KB
/
reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Reader with GCPS support."""
import warnings
import xml.etree.ElementTree as ET
from typing import List, Optional, Union
import attr
import rasterio
from rasterio._path import _parse_path
from rasterio.control import GroundControlPoint
from rasterio.crs import CRS
from rasterio.dtypes import _gdal_typename
from rasterio.enums import MaskFlags
from rasterio.io import DatasetReader
from rasterio.transform import from_gcps
from rasterio.vrt import WarpedVRT
from rio_tiler.constants import WGS84_CRS
from rio_tiler.errors import NoOverviewWarning
from rio_tiler.io import Reader
from rio_tiler.utils import has_alpha_band
@attr.s
class GCPSReader(Reader):
"""GCPS + Image Reader"""
gcps: Optional[List[GroundControlPoint]] = attr.ib(default=None)
gcps_crs: Optional[CRS] = attr.ib(default=WGS84_CRS)
cutline: Optional[str] = attr.ib(default=None)
dataset: Union[DatasetReader, WarpedVRT] = attr.ib(init=False)
def __attrs_post_init__(self):
"""Define _kwargs, open dataset and get info."""
dataset = self._ctx_stack.enter_context(rasterio.open(self.input))
# External GCPS
if self.gcps:
vrt_xml = vrt_doc(dataset, gcps=self.gcps, gcps_crs=self.gcps_crs)
dataset = self._ctx_stack.enter_context(rasterio.open(vrt_xml))
vrt_options = {}
if dataset.gcps[0]:
vrt_options["src_crs"] = dataset.gcps[1]
vrt_options["src_transform"] = from_gcps(dataset.gcps[0])
if self.cutline:
vrt_options["cutline"] = self.cutline
if vrt_options:
nodata = dataset.nodata
if nodata is not None:
vrt_options.update(
{"nodata": nodata, "add_alpha": False, "src_nodata": nodata}
)
else:
vrt_options["add_alpha"] = True
if has_alpha_band(dataset):
vrt_options.update({"add_alpha": False})
self.dataset = self._ctx_stack.enter_context(
WarpedVRT(dataset, **vrt_options)
)
else:
self.dataset = dataset
self.bounds = tuple(self.dataset.bounds)
self.crs = self.dataset.crs
if self.colormap is None:
self._get_colormap()
if min(
self.dataset.width, self.dataset.height
) > 512 and not self.dataset.overviews(1):
warnings.warn(
"The dataset has no Overviews. rio-tiler performances might be impacted.",
NoOverviewWarning,
)
def vrt_doc( # noqa: C901
src_dataset,
gcps: Optional[List[GroundControlPoint]] = None,
gcps_crs: Optional[CRS] = WGS84_CRS,
):
"""Make a VRT XML document.
Adapted from rasterio.vrt._boundless_vrt_doc function
"""
vrtdataset = ET.Element("VRTDataset")
vrtdataset.attrib["rasterYSize"] = str(src_dataset.height)
vrtdataset.attrib["rasterXSize"] = str(src_dataset.width)
tags = src_dataset.tags()
if tags:
metadata = ET.SubElement(vrtdataset, "Metadata")
for key, value in tags.items():
v = ET.SubElement(metadata, "MDI")
v.attrib["Key"] = key
v.text = str(value)
im_tags = src_dataset.tags(ns="IMAGE_STRUCTURE")
if im_tags:
metadata = ET.SubElement(vrtdataset, "Metadata")
for key, value in im_tags.items():
if key == "LAYOUT" and value == "COG":
continue
v = ET.SubElement(metadata, "MDI")
v.attrib["Key"] = key
v.text = str(value)
srs = ET.SubElement(vrtdataset, "SRS")
srs.text = src_dataset.crs.wkt if src_dataset.crs else ""
geotransform = ET.SubElement(vrtdataset, "GeoTransform")
geotransform.text = ",".join([str(v) for v in src_dataset.transform.to_gdal()])
nodata_value = src_dataset.nodata
if gcps:
gcp_list = ET.SubElement(vrtdataset, "GCPList")
gcp_list.attrib["Projection"] = str(gcps_crs)
for gcp in gcps:
g = ET.SubElement(gcp_list, "GCP")
g.attrib["Id"] = gcp.id
g.attrib["Pixel"] = str(gcp.col)
g.attrib["Line"] = str(gcp.row)
g.attrib["X"] = str(gcp.x)
g.attrib["Y"] = str(gcp.y)
for bidx, ci, block_shape, dtype in zip(
src_dataset.indexes,
src_dataset.colorinterp,
src_dataset.block_shapes,
src_dataset.dtypes,
):
vrtrasterband = ET.SubElement(vrtdataset, "VRTRasterBand")
vrtrasterband.attrib["dataType"] = _gdal_typename(dtype)
vrtrasterband.attrib["band"] = str(bidx)
if nodata_value is not None:
nodata = ET.SubElement(vrtrasterband, "NoDataValue")
nodata.text = str(nodata_value)
colorinterp = ET.SubElement(vrtrasterband, "ColorInterp")
colorinterp.text = ci.name.capitalize()
source = ET.SubElement(vrtrasterband, "SimpleSource")
sourcefilename = ET.SubElement(source, "SourceFilename")
sourcefilename.attrib["relativeToVRT"] = "0"
sourcefilename.text = _parse_path(src_dataset.name).as_vsi()
sourceband = ET.SubElement(source, "SourceBand")
sourceband.text = str(bidx)
sourceproperties = ET.SubElement(source, "SourceProperties")
sourceproperties.attrib["RasterXSize"] = str(src_dataset.width)
sourceproperties.attrib["RasterYSize"] = str(src_dataset.height)
sourceproperties.attrib["dataType"] = _gdal_typename(dtype)
sourceproperties.attrib["BlockYSize"] = str(block_shape[0])
sourceproperties.attrib["BlockXSize"] = str(block_shape[1])
srcrect = ET.SubElement(source, "SrcRect")
srcrect.attrib["xOff"] = "0"
srcrect.attrib["yOff"] = "0"
srcrect.attrib["xSize"] = str(src_dataset.width)
srcrect.attrib["ySize"] = str(src_dataset.height)
dstrect = ET.SubElement(source, "DstRect")
dstrect.attrib["xOff"] = "0"
dstrect.attrib["yOff"] = "0"
dstrect.attrib["xSize"] = str(src_dataset.width)
dstrect.attrib["ySize"] = str(src_dataset.height)
if src_dataset.options is not None:
openoptions = ET.SubElement(source, "OpenOptions")
for ookey, oovalue in src_dataset.options.items():
ooi = ET.SubElement(openoptions, "OOI")
ooi.attrib["key"] = str(ookey)
ooi.text = str(oovalue)
if nodata_value is not None:
nodata = ET.SubElement(source, "NODATA")
nodata.text = str(nodata_value)
if all(MaskFlags.per_dataset in flags for flags in src_dataset.mask_flag_enums):
maskband = ET.SubElement(vrtdataset, "MaskBand")
vrtrasterband = ET.SubElement(maskband, "VRTRasterBand")
vrtrasterband.attrib["dataType"] = "Byte"
source = ET.SubElement(vrtrasterband, "SimpleSource")
sourcefilename = ET.SubElement(source, "SourceFilename")
sourcefilename.attrib["relativeToVRT"] = "0"
sourcefilename.attrib["shared"] = "0"
sourcefilename.text = _parse_path(src_dataset.name).as_vsi()
sourceband = ET.SubElement(source, "SourceBand")
sourceband.text = "mask,1"
sourceproperties = ET.SubElement(source, "SourceProperties")
sourceproperties.attrib["RasterXSize"] = str(src_dataset.width)
sourceproperties.attrib["RasterYSize"] = str(src_dataset.height)
sourceproperties.attrib["dataType"] = "Byte"
sourceproperties.attrib["BlockYSize"] = str(block_shape[0])
sourceproperties.attrib["BlockXSize"] = str(block_shape[1])
srcrect = ET.SubElement(source, "SrcRect")
srcrect.attrib["xOff"] = "0"
srcrect.attrib["yOff"] = "0"
srcrect.attrib["xSize"] = str(src_dataset.width)
srcrect.attrib["ySize"] = str(src_dataset.height)
dstrect = ET.SubElement(source, "DstRect")
dstrect.attrib["xOff"] = "0"
dstrect.attrib["yOff"] = "0"
dstrect.attrib["xSize"] = str(src_dataset.width)
dstrect.attrib["ySize"] = str(src_dataset.height)
return ET.tostring(vrtdataset).decode("ascii")