Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

57 use tifffiles #58

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
**/__pycache__/
*.pyc

old_files/
pretrained/
Expand All @@ -15,4 +16,4 @@ data/mados/splits/*
!data/mados/splits/tiny_X.txt

.vscode
.idea
.idea
12 changes: 4 additions & 8 deletions datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
import torch
import pandas as pd
import pathlib
import rasterio
from tifffile import imread
from os.path import join as opj
from .utils import read_tif
import tifffile
from utils.registry import DATASET_REGISTRY

def read_imgs(multi_temporal, temp , fname, data_dir, img_size):
Expand All @@ -22,7 +19,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s1_filepath = data_dir.joinpath(s1_fname)
if s1_filepath.exists():
img_s1 = imread(s1_filepath)
img_s1 = tifffile.imread(s1_filepath)
m = img_s1 == -9999
img_s1 = img_s1.astype('float32')
img_s1 = np.where(m, 0, img_s1)
Expand All @@ -31,7 +28,7 @@ def read_imgs(multi_temporal, temp , fname, data_dir, img_size):

s2_filepath = data_dir.joinpath(s2_fname)
if s2_filepath.exists():
img_s2 = imread(s2_filepath)
img_s2 = tifffile.imread(s2_filepath)
img_s2 = img_s2.astype('float32')
else:
img_s2 = np.zeros((img_size, img_size) + (11,), dtype='float32')
Expand Down Expand Up @@ -77,8 +74,7 @@ def __getitem__(self, index):
fname = str(chip_id)+'_agbm.tif'

imgs_s1, imgs_s2, mask = read_imgs(self.multi_temporal, self.temp, fname, self.dir_features, self.img_size)
with rasterio.open(self.dir_labels.joinpath(fname)) as lbl:
target = lbl.read(1)
target = tifffile.imread(self.dir_labels.joinpath(fname), key=0)
target = np.nan_to_num(target)

imgs_s1 = torch.from_numpy(imgs_s1).float()
Expand Down
1 change: 0 additions & 1 deletion datasets/fivebillionpixels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import time
import torch
import numpy as np
import rasterio
import random
from glob import glob

Expand Down
10 changes: 3 additions & 7 deletions datasets/hlsburnscars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import time
import torch
import numpy as np
import rasterio
import tifffile
from glob import glob

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

Expand Down Expand Up @@ -41,11 +40,8 @@ def __len__(self):
return len(self.image_list)

def __getitem__(self, index):
with rasterio.open(self.image_list[index]) as src:
image = src.read()
with rasterio.open(self.target_list[index]) as src:
target = src.read(1)

image = tifffile.imread(self.image_list[index])
target = tifffile.imread(self.target_list[index], key=0)
image = torch.from_numpy(image)
target = torch.from_numpy(target.astype(np.int64))

Expand Down
96 changes: 48 additions & 48 deletions datasets/mados.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
'''
"""
Adapted from: https://github.com/gkakogeorgiou/mados
'''
"""

import os
import time
Expand All @@ -10,16 +10,10 @@
import zipfile

from glob import glob
import rasterio
import tifffile
import numpy as np

import warnings

warnings.filterwarnings("ignore", category=rasterio.errors.NotGeoreferencedWarning)

import torch
import torchvision.transforms.functional as TF
import torchvision.transforms as T

from .utils import DownloadProgressBar
from utils.registry import DATASET_REGISTRY
Expand All @@ -29,39 +23,45 @@
# MADOS DATASET #
###############################################################


@DATASET_REGISTRY.register()
class MADOS(torch.utils.data.Dataset):
def __init__(self, cfg, split, is_train=True):

self.root_path = cfg['root_path']
self.data_mean = cfg['data_mean']
self.data_std = cfg['data_std']
self.classes = cfg['classes']
self.root_path = cfg["root_path"]
self.data_mean = cfg["data_mean"]
self.data_std = cfg["data_std"]
self.classes = cfg["classes"]
self.class_num = len(self.classes)
self.split = split
self.is_train = is_train

self.ROIs_split = np.genfromtxt(os.path.join(self.root_path, 'splits', f'{split}_X.txt'), dtype='str')
self.ROIs_split = np.genfromtxt(
os.path.join(self.root_path, "splits", f"{split}_X.txt"), dtype="str"
)

self.image_list = []
self.target_list = []

self.tiles = sorted(glob(os.path.join(self.root_path, '*')))
self.tiles = sorted(glob(os.path.join(self.root_path, "*")))

for tile in self.tiles:
splits = [f.split('_cl_')[-1] for f in glob(os.path.join(tile, '10', '*_cl_*'))]
splits = [
f.split("_cl_")[-1] for f in glob(os.path.join(tile, "10", "*_cl_*"))
]

for crop in splits:
crop_name = os.path.basename(tile) + '_' + crop.split('.tif')[0]
crop_name = os.path.basename(tile) + "_" + crop.split(".tif")[0]

if crop_name in self.ROIs_split:
all_bands = glob(os.path.join(tile, '*', '*L2R_rhorc*_' + crop))
all_bands = glob(os.path.join(tile, "*", "*L2R_rhorc*_" + crop))
all_bands = sorted(all_bands, key=self.get_band)
# all_bands = np.array(all_bands)

self.image_list.append(all_bands)

cl_path = os.path.join(tile, '10', os.path.basename(tile) + '_L2R_cl_' + crop)
cl_path = os.path.join(
tile, "10", os.path.basename(tile) + "_L2R_cl_" + crop
)
self.target_list.append(cl_path)

def __len__(self):
Expand All @@ -72,42 +72,39 @@ def getnames(self):

def __getitem__(self, index):

all_bands = self.image_list[index]
band_paths = self.image_list[index]
current_image = []
for c, band in enumerate(all_bands):
upscale_factor = int(os.path.basename(os.path.dirname(band))) // 10
with rasterio.open(band, mode='r') as src:
this_band = src.read(1,
out_shape=(int(src.height * upscale_factor), int(src.width * upscale_factor)),
resampling=rasterio.enums.Resampling.nearest
)
this_band = torch.from_numpy(this_band)
#this_band[torch.isnan(this_band)] = self.data_mean['optical'][c]
current_image.append(this_band)

image = torch.stack(current_image)
for path in band_paths:
upscale_factor = int(os.path.basename(os.path.dirname(path))) // 10

band = tifffile.imread(path)
band_tensor = torch.from_numpy(band)
band_tensor.unsqueeze_(0).unsqueeze_(0)
band_tensor = torch.nn.functional.interpolate(
band_tensor, scale_factor=upscale_factor, mode="nearest"
).squeeze_(0)
current_image.append(band_tensor)

image = torch.cat(current_image)
invalid_mask = torch.isnan(image)
image[invalid_mask] = 0


with rasterio.open(self.target_list[index], mode='r') as src:
target = src.read(1)
target = tifffile.imread(self.target_list[index])
target = torch.from_numpy(target.astype(np.int64))
target = target - 1

output = {
'image': {
'optical': image,
"image": {
"optical": image,
},
'target': target,
'metadata': {}
"target": target,
"metadata": {},
}

return output

@staticmethod
def get_band(path):
return int(path.split('_')[-2])
return int(path.split("_")[-2])

@staticmethod
def download(dataset_config: dict, silent=False):
Expand All @@ -128,15 +125,17 @@ def download(dataset_config: dict, silent=False):
try:
urllib.request.urlretrieve(url, output_path / temp_file_name, pbar)
except urllib.error.HTTPError as e:
print('Error while downloading dataset: The server couldn\'t fulfill the request.')
print('Error code: ', e.code)
print(
"Error while downloading dataset: The server couldn't fulfill the request."
)
print("Error code: ", e.code)
return
except urllib.error.URLError as e:
print('Error while downloading dataset: Failed to reach a server.')
print('Reason: ', e.reason)
print("Error while downloading dataset: Failed to reach a server.")
print("Reason: ", e.reason)
return

with zipfile.ZipFile(output_path / temp_file_name, 'r') as zip_ref:
with zipfile.ZipFile(output_path / temp_file_name, "r") as zip_ref:
print(f"Extracting to {output_path} ...")
# Remove top-level dir in ZIP file for nicer data dir structure
members = []
Expand All @@ -155,4 +154,5 @@ def get_splits(dataset_config):
dataset_train = MADOS(cfg=dataset_config, split="train", is_train=True)
dataset_val = MADOS(cfg=dataset_config, split="val", is_train=False)
dataset_test = MADOS(cfg=dataset_config, split="test", is_train=False)
return dataset_train, dataset_val, dataset_test
return dataset_train, dataset_val, dataset_test

14 changes: 6 additions & 8 deletions datasets/pastis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch
from einops import rearrange
from omegaconf import OmegaConf
Expand Down Expand Up @@ -142,17 +142,15 @@ def __getitem__(self, i):

for modality in self.modalities:
if modality == "aerial":
with rasterio.open(
os.path.join(
path = os.path.join(
self.path,
"DATA_SPOT/PASTIS_SPOT6_RVB_1M00_2019/SPOT6_RVB_1M00_2019_"
+ str(name)
+ ".tif",
)
) as f:
output["aerial"] = split_image(
torch.FloatTensor(f.read()), self.nb_split, part
)
)
output["aerial"] = split_image(
torch.FloatTensor(tifffile.imread(path), self.nb_split, part)
)
elif modality == "s1-median":
modality_name = "s1a"
images = split_image(
Expand Down
14 changes: 5 additions & 9 deletions datasets/sen1floods11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import geopandas
import numpy as np
import pandas as pd
import rasterio
import tifffile
import torch

from .utils import download_bucket_concurrently
Expand Down Expand Up @@ -59,16 +59,12 @@ def _get_date(self, index):
return date_np

def __getitem__(self, index):
with rasterio.open(self.s2_image_list[index]) as src:
s2_image = src.read()
s2_image = tifffile.imread(self.s2_image_list[index])

with rasterio.open(self.s1_image_list[index]) as src:
s1_image = src.read()
# Convert the missing values (clouds etc.)
s1_image = np.nan_to_num(s1_image)
s1_image = tifffile.imread(self.s1_image_list[index])
s1_image = np.nan_to_num(s1_image)

with rasterio.open(self.target_list[index]) as src:
target = src.read(1)
target = tifffile.imread(self.target_list[index], key=0)

timestamp = self._get_date(index)

Expand Down
15 changes: 8 additions & 7 deletions datasets/spacenet7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import json
from glob import glob
import rasterio
import cv2
import tifffile
import numpy as np

import torch
Expand Down Expand Up @@ -132,17 +133,17 @@ def __len__(self) -> int:
def load_planet_mosaic(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'images_masked'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}.tif'
with rasterio.open(str(file), mode='r') as src:
img = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
# 4th band (last oen) is alpha band
img = img[:-1]
img = tifffile.imread(file)
img = cv2.resize(img, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
# 4th band (last one) is alpha band
img = img.transpose(2, 0, 1)[:-1]
return img.astype(np.float32)

def load_building_label(self, aoi_id: str, year: int, month: int) -> np.ndarray:
folder = self.root_path / 'train' / aoi_id / 'labels_raster'
file = folder / f'global_monthly_{year}_{month:02d}_mosaic_{aoi_id}_Buildings.tif'
with rasterio.open(str(file), mode='r') as src:
label = src.read(out_shape=(1024, 1024), resampling=rasterio.enums.Resampling.nearest)
label = tifffile.imread(file)
label = cv2.resize(label, dsize=(self.img_size, self.img_size), interpolation=cv2.INTER_NEAREST)
label = (label > 0).squeeze()
return label.astype(np.int64)

Expand Down
Loading
Loading