From 41a20cda0158a55dd5c2dcd4ec486fc238e39988 Mon Sep 17 00:00:00 2001 From: Nils Lehmann <35272119+nilsleh@users.noreply.github.com> Date: Sun, 10 Jul 2022 00:05:10 +0200 Subject: [PATCH] Million-AID dataset (#455) * millionaid * test * separator * remove type ignore * type in test * requested changes * typos and glob pattern * task argument description * add test md5 hash * Remove download logic * Type ignore no longer needed Co-authored-by: Adam J. Stewart --- docs/api/datasets.rst | 5 + docs/api/non_geo_datasets.csv | 1 + tests/data/millionaid/data.py | 52 +++ tests/data/millionaid/test.zip | Bin 0 -> 3244 bytes .../grassland/meadow/P0115918.jpg | Bin 0 -> 1240 bytes .../test/water_area/beach/P0060208.jpg | Bin 0 -> 1238 bytes tests/data/millionaid/train.zip | Bin 0 -> 3265 bytes .../grassland/meadow/P0115918.jpg | Bin 0 -> 1254 bytes .../train/water_area/beach/P0060208.jpg | Bin 0 -> 1231 bytes tests/datasets/test_millionaid.py | 64 +++ torchgeo/datasets/__init__.py | 2 + torchgeo/datasets/millionaid.py | 371 ++++++++++++++++++ 12 files changed, 495 insertions(+) create mode 100644 tests/data/millionaid/data.py create mode 100644 tests/data/millionaid/test.zip create mode 100644 tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg create mode 100644 tests/data/millionaid/test/water_area/beach/P0060208.jpg create mode 100644 tests/data/millionaid/train.zip create mode 100644 tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg create mode 100644 tests/data/millionaid/train/water_area/beach/P0060208.jpg create mode 100644 tests/datasets/test_millionaid.py create mode 100644 torchgeo/datasets/millionaid.py diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 50741377e8d..6c06768bef0 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -224,6 +224,11 @@ LoveDA .. autoclass:: LoveDA +Million-AID +^^^^^^^^^^^ + +.. autoclass:: MillionAID + NASA Marine Debris ^^^^^^^^^^^^^^^^^^ diff --git a/docs/api/non_geo_datasets.csv b/docs/api/non_geo_datasets.csv index 2e37e18e21b..a7bdb09add2 100644 --- a/docs/api/non_geo_datasets.csv +++ b/docs/api/non_geo_datasets.csv @@ -16,6 +16,7 @@ Dataset,Task,Source,# Samples,# Classes,Size (px),Resolution (m),Bands `LandCover.ai`_,S,Aerial,"10,674",5,512x512,0.25--0.5,RGB `LEVIR-CD+`_,CD,Google Earth,985,2,"1,024x1,024",0.5,RGB `LoveDA`_,S,Google Earth,"5,987",7,"1,024x1,024",0.3,RGB +`Million-AID`_,C,Google Earth,1M,51--73,,0.5--153,RGB `NASA Marine Debris`_,OD,PlanetScope,707,1,256x256,3,RGB `OSCD`_,CD,Sentinel-2,24,2,"40--1,180",60,MSI `PatternNet`_,C,Google Earth,"30,400",38,256x256,0.06--5,RGB diff --git a/tests/data/millionaid/data.py b/tests/data/millionaid/data.py new file mode 100644 index 00000000000..03ea05ad5df --- /dev/null +++ b/tests/data/millionaid/data.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import hashlib +import os +import shutil + +import numpy as np +from PIL import Image + +SIZE = 32 + +np.random.seed(0) + +PATHS = { + "train": [ + os.path.join( + "train", "agriculture_land", "grassland", "meadow", "P0115918.jpg" + ), + os.path.join("train", "water_area", "beach", "P0060208.jpg"), + ], + "test": [ + os.path.join("test", "agriculture_land", "grassland", "meadow", "P0115918.jpg"), + os.path.join("test", "water_area", "beach", "P0060208.jpg"), + ], +} + + +def create_file(path: str) -> None: + Z = np.random.rand(SIZE, SIZE, 3) * 255 + img = Image.fromarray(Z.astype("uint8")).convert("RGB") + img.save(path) + + +if __name__ == "__main__": + for split, paths in PATHS.items(): + # remove old data + if os.path.isdir(split): + shutil.rmtree(split) + for path in paths: + os.makedirs(os.path.dirname(path), exist_ok=True) + create_file(path) + + # compress data + shutil.make_archive(split, "zip", ".", split) + + # Compute checksums + with open(split + ".zip", "rb") as f: + md5 = hashlib.md5(f.read()).hexdigest() + print(f"{split}: {md5}") diff --git a/tests/data/millionaid/test.zip b/tests/data/millionaid/test.zip new file mode 100644 index 0000000000000000000000000000000000000000..15a0bbd3f46d7dbabb14545ac9c186d55a68c6de GIT binary patch literal 3244 zcmeHKX;4#H8U=&^je`qYtKf(z$PUu%3Su{jfB^);B18-!tP(I1X-xy;AX*40TOcST z0tzAoFcDCeuM`8&)#N> zu(ZHNTr2W-dY627i15cy#D!38`!ANrZkE7&QT}lBxll9;;du^*^!^HF`z9s~7K%W5 z!cYj5P(h3OiU5%Vv-vSWSHncMY0D%AiaG&>bhFwmxI)Yz+ijSlU4?tUzu+V1tKaOYO79#-Gj%_w?1b0;vOyV%}m(Dk>`5 z3~vc*7hg@xj?3QBHnuJZu#g0Z07QI%2LZ|!Kp%_db

;D{OCs@D{+Kv)*1lFa&&R zj5YyK8&IpQZEWrAY^?8WZ?6|hA&Ru?m57nPQ{8Ps5*Jk_smt8>UfS~=H;A)D<yP3%@5P^~Q_O;m??^n=rkd|Ry(`_V z>EL|?T?Py3(j*446f&X7%691{K`U|TWLeEKq|m{68FiPSp(+HkjAeR}Z4$cN`W)V3 zmvIC=J&!HuFB*Mvi8-FrVuv@w@FwU5w+LpD+!fa9EkeyPo|pGopR0m}yq&@;u`$un z>Cp-_ZD1}cLC8YZ+44sN#g=1>jmKdJbAXaN@?s)R=#6o#tw zuEo}~llebJSb_rjLd;+#G%vD8#oT;q&7;`z;V0dM9?k9N2WG}|ColJSDDH{ao8sh$ zMm{UkIK679W=anZUxvJmZ>Bmwj4T*T%CBI)MRN?D`%GiL(%tHlSzJ29V%cFdC@nNp#SRU1DK?`233tzli4-` zz5rb5^e+;9b+Mn@7v@>r6&79{K#Z8lUwE;^WVl4$aPQ@~8&A6_Gc&D3I}Q{c+>Id$EJGxQ{_b=C>6s(w8oWqkdd6dVH1*5Z;xgAJSPxezTsa%^aAB3>#V&`@tW^z@ zFop>Bcra^van@se$!(ay>2?DR*DMmUrs-u$7Sy3G@Nw<5>Y&i^Ire^Rm(m)IlQdI$1k%&;L_ zW%C(xFBSCK>Y7Fp5c74^o-VK!xSu^Zlp&{k<8niy$uv7fhJZM_rm{ld@y3z;0ejO6 z9x{0KQnBfFD5{5?&h=B)>$JXIk;nknAqsWG>2*JZRN98Q#v5!vl1-1?PmXqwHhg$J z18z_E_SH&honJWWdM=H$#-mJY7n@T{%b~ci6lmoFi+)#y0p@5aLu#Jb zR~msJmcjQ^7jnzw%NhvsyOaq=cQXnqLOa}(OSv-D5K|?5=o{Mz3W@EMh0`-Pm|3BNUvMZ7SRJm9 z!(ZMcW>B-?BeoEVzP9dzFV)kFYj>4y#)b0DezlPrT0sFq`=uPxz9eiv=zr3CeEBZb*dTHRNf4!bR*nT&ro1}kjO*d&D zUl9}hcG8cG>eu$*&tCz@|0Vcz2RBciKR5a5+4Dz}onPko=p?eY5dmyq#P}CjNI<~q IyN$Qs00;qLhyVZp literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/test/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1d04bf8a52319264e9dabfeec47cd963b07d02fe GIT binary patch literal 1240 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L*w~*|8kG$e%teNk;dZC-MR-3t$H~B-Q!C3 zJjW-^+rEhVe0yv6K+!O-;F8PkJ=+pOHU=HGc&x4^uj8G&>f-Ip>ASXsm3_Z`%~L2c z?2*C7Q-}PDj?}E-e8OPK#b&v5{acaM*_nqYUECSFYIW<&N#`E(*mrzkS32Uctw(Ut(=0g-s^m^?ShFO}xs9c#ZwFJyJI)!d z%;H_&3Y8XwIUToVnVA~Lpk)`}Ru{xq9>Q6a`k$e*`h1q&v@0j3TTii^x?NK!PiNiD z9G>rY*w|xJ{oc>=_P%wx)@tAETi3rD_oQ2DZu=QK`^DQ;_kHs(mS($Id8h1}rM^S= zS#<8q3o3=)+RxujI>i>k&@iWXPEmn$VPU0h^YN{Du?|LS8o#xkNIT}47RG2_vAn@E zCUw=Nm~*;$(*w6mFS@HzbRzepZ-eX&o}B@XG4Et065gqFF?~DBb1CvzZdAGl!;N=! z;pZ-tHLW@rF~RNGxx(+Wy3OC6vyQrd?LR})%WX%3p0%5sgkCOtwe70fwx8v@-T&VN E0AVo#p8x;= literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/test/water_area/beach/P0060208.jpg b/tests/data/millionaid/test/water_area/beach/P0060208.jpg new file mode 100644 index 0000000000000000000000000000000000000000..226ff70e36b3403435b268424de45b26eaafd8b6 GIT binary patch literal 1238 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L+;(ZAEvQSD|{;97%WoFqMGoR$1xh`N1 zqm=&1tuy}I)Y?&epS`Kjud8HhllKh$vxnVgcX&!C3#BM-R1Z1FpvuY4Tqu#{>wB=+ zE@k`uebZlFf1Yb?yJS-KB+cxnH{UOPyLRi_%Rx`fo8Dbas(x{yc5#lXirz{s>rHJr zY==KNGE6>kKsi@+>6W)9Yx665t9C{_>ENCG_TDiC%a9<6&s&;z+e~iM{V=Q6?v}Nu z(4*A!N!%^+duIhipHW~cV>(ssrdn(+0`czVoWx5EK@gV3_T zNe)_!Tp@w;R=&9VD|74H>ra2&chWp%crj;wY}ED3o2uV^QQMcXX#2JuTRERa$Gn}Y zd~04}P{GL)IYk1@*DVh|HuBfqGAnmu*{642>(A~zK8;!Q)<(gDCvs#`WC|N@ba9G3 zi}pTQ%9)pHzjqp=cE^DetfyEll$Ap7_Vmw|JhAL>an|%*?|}m+#L7(OZj&CTW7vn<@ECtCvI;IHQbC1g#etK?Syi;lc$y8Q8rQ}b2Us=XEYE7x`Hb#30SpD!2vzXF@q3uVMfFZ!^{+AikNM-lVk5r_TGPv^B=ko_w0 z4I;w0ToQ2fx8&b9z`~3~Kmvk+hrV00aj6F4hY0XN1tC!gn0F8)7`g&TZV40)LBbH; z5CjaeLb=*f`8PfM5RlN&Kk4?8~s+W3LpwJaW(;fKmY*fN9d>M%;)JBW3FGvTu+I?U@*6D$0XXt-o28U zX?Nu3H1BEtKrfEF5T3niY}|K(-qqXi64XIyl!8jBYinz- zGf5H$mc}OMUCWbec_R=7*vbMV0200+O@Njy$k(=6V5c;Ekqd>1uLAsZ5_-xP0`p*E;TE03fU!;K z9!;7LVIvADu5Eh7aL!f?__mgFx;cF2T6%?|-idM~by8uQJDm0lY#hh22;-U~KfUbq zAvkAQqQ)juipV7cPtoIp_X*C~))DhVBSlF$*3qA5II~GP^}FX!K~MU|iel&2i_av+ z$Hk?`Z5;-J@)0-1Y!%)1?=@Bl9elnmo`xuK58sG^JB(QWU{?O({L`^ z?k4$}wR*dcBlNQj*C4h&*x;SMk(hdt$Ia*2YGYBP!q+ACBkdm6MO`6VDxQ`e>)H4o z>D@5vtgVFJ%fNX5&M@c*8$1cccf`1NcCw0yCCYdvyM3At zxu7xQ(A6@Ro^*{BG;EMr(LdQTazP=4HH4J&PO&$O+pgn+aL-!(gO%IevSYN{XZJRl z6~Ea7jCnV?Q81rMFO7N}H5K>l;buI24(L~3r&7!Ftu_jDmYEmwe!q zcwT}4#5$<1(He~IK{OU_au;4Re?S48%{j08FgN|Y?iaRVIAsq_V|_xV;_EFr%nF1F?VXv--}@{SCvAsef~Gft;J4oIpda=cH?(7h;;89mwb-s2ji$;%R zs#KpLVC}}@500ytYCLWtftw2T+pZRTOoi-ra1!*PgG)oUT(fMDXLe$h+$cta`uxh6 z)~WE29OX@C7+M3P?H^uy$=~k?Mny!_6cXD0>EW=4sCLzH$n)CC54~31)~a4!yC6-Z zqdhuVTg(oE1lBec9%yz47iZykx{X8N0c&PLsXZZrO-E}K?Xgd-I6?DMtS3W6-&)Ub zFF`(=u&pBsukxO(lhqhT0ITzvbYqNi%@it4NuENIKki6ctL_sx%SE$`E6ip-gH>3u zfcQ=FH>eWUL>Ef4zn3LJqkrOotVva4U$~)+YFNR;Se4N+tD9w>_ns87-mgLXr0@;n z3@#ahb9OS(Ii~+Y^=Sg*a#>?eir0(pMtLs31N_7!}Um~&#CWF zrGanrIdvy$Y4ZX{k-wtD%Ov#_4C3Se2TceVVasCxM59CWCR%hlc65}$ZG zShB*QtyaxDxhxqx@(4E$OTkT~@EQJkh(kH|VM|VsFfO1B;IZ<-vV-!f!neBvgwzOVj?|e=Uga#cWlOO;Iwd8*e&-ce>fu($ge|V@UFa0 z*Sezdj0pF!e8m^T+DxMjs61_FSG0LDPnEor2n6(}+fT!!i4NA;rIkY!;8q!GT=$pA zeH}T;zH$gtwT$^NnmqhDwnuZvT?$UKIK^Uf3HxLusxMv=0UJO-O|qri1W^-R)Ug@2 zW{e}|0!i8Og#zu?mr&+b9(ZgdA{6dQop(Llc@`hW(jajQ9lW;j19`o(QB?w?Ew929 z8RwCNN_&-r9qEBDg0Nw%!gs?G{Om&M@>FWOJPP+@8?$%);}Cgz;@X@PPR&$a&<)my ze_TOg#BRhuI-ab7HxxLVqA=tK6Z@)g1GQzDy1VxH(#|J476s|0AgVd2p;c*nfk2+7y8h}5W=r6q~BS?a##BOR}%Wt#SqdF7ScbZTXLyiS^v9F z{mNT;ajodzU|r@~SK5Z(e+sO`*AB~e@$13!`(^&|==skw2Uj3jb`Tx1mk|Dwh)6A5 Nu3{o0Zfk^3zXGnKaIpXY literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg b/tests/data/millionaid/train/agriculture_land/grassland/meadow/P0115918.jpg new file mode 100644 index 0000000000000000000000000000000000000000..afc980b49ddbdb1fa90b5b2944bd2c2e52cc9418 GIT binary patch literal 1254 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L(=`2Q>k5-%AK=(R;AxKW^rG!sxRW{zJ&Lm zPB8L@yjkbAWU|iw)cX>>v5P0PPAIdipUTL})w|@$cAp2YqNC4Omt8k@&DEXlyTkZM z^qHiS`!gRI&)LXhHfOozlgs%k-pC*y{Sb6UZL zUIyXiFaG^4yI=ic{pJgSOEq@8`(Ms}eLs5Jx^1_E>^Hmg7u{NN@n+nr4PEKmU9|+m z3c0;>4v2=`yC1N#ONKV|d6o*1^D7Q7b^9MT9{ zuH(q?$%Cz1ZA#v^?Tc^Sx4X4&du^1b>+TnCtGBFKTf5!&#g|W2rJHk2PkPU_yZg*N za^a3ykJ1yB);(}YVvCu(ic?{l5l`za!+@Ba{8ZCSMxWj4_V#pBkDCFrn1q%?5Szrc7{TbRd!~Ipwzg`*ikYdZ zQ#d(tl#Kp;erB`s)SumLcbe3i+mGL0boJ-5tKa@JsB|rsSaEyW+5G3vYsHt%dbRGh z{!P`7SDyWLJ>suZIEn3!yGHKKyKOo*bm!|ZF($B?C%BrJ8%D;~_TBT;TW53h^SeCG z2@Mi6CLFU1e3qxMjnTGvj!C%LNgZx(iMQ`$7(VPyO9?&D(|aIcwLoiE&*K&H#wzb+ z3S$Iix;|;l^FDD<>$Xc?3PaD+oz~Gi4luKIugi9Px$EM~+qtoKf8XA`ugl1P>74Q{ S;p<-4#!cIOy?V9%|C<14Z~>_R literal 0 HcmV?d00001 diff --git a/tests/data/millionaid/train/water_area/beach/P0060208.jpg b/tests/data/millionaid/train/water_area/beach/P0060208.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d742f2c51b455ab2839b083325bbb563d95aee5 GIT binary patch literal 1231 zcmex=^(PF6}rMnOeST|r4lSw=>~TvNxu(8R<c1}I=;VrF4wW9Q)H;sz?% zD!{d!pzFb!U9xX3zTPI5o8roG<0MW4oqZMDikqloVbuf*=gfJ(V&YTRE(2~ znmD<{#3dx9RMpfqG__1j&CD$#L;a)+r*p1l+7wM;@>u&+%kO#6i{3w4UXwgz z*e8L|tf`TK~TKO@Z|N>AU*@ MillionAID: + root = os.path.join("tests", "data", "millionaid") + split, task = request.param + transforms = nn.Identity() + return MillionAID( + root=root, split=split, task=task, transforms=transforms, checksum=True + ) + + def test_getitem(self, dataset: MillionAID) -> None: + x = dataset[0] + assert isinstance(x, dict) + assert isinstance(x["image"], torch.Tensor) + assert isinstance(x["label"], torch.Tensor) + assert x["image"].shape[0] == 3 + assert x["image"].ndim == 3 + + def test_len(self, dataset: MillionAID) -> None: + assert len(dataset) == 2 + + def test_not_found(self, tmp_path: Path) -> None: + with pytest.raises(RuntimeError, match="Dataset not found in"): + MillionAID(str(tmp_path)) + + def test_not_extracted(self, tmp_path: Path) -> None: + url = os.path.join("tests", "data", "millionaid", "train.zip") + shutil.copy(url, tmp_path) + MillionAID(str(tmp_path)) + + def test_corrupted(self, tmp_path: Path) -> None: + with open(os.path.join(tmp_path, "train.zip"), "w") as f: + f.write("bad") + with pytest.raises(RuntimeError, match="Dataset found, but corrupted."): + MillionAID(str(tmp_path), checksum=True) + + def test_plot(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + dataset.plot(x, suptitle="Test") + plt.close() + + def test_plot_prediction(self, dataset: MillionAID) -> None: + x = dataset[0].copy() + x["prediction"] = x["label"].clone() + dataset.plot(x) + plt.close() diff --git a/torchgeo/datasets/__init__.py b/torchgeo/datasets/__init__.py index 13196988787..f3b7a314626 100644 --- a/torchgeo/datasets/__init__.py +++ b/torchgeo/datasets/__init__.py @@ -68,6 +68,7 @@ ) from .levircd import LEVIRCDPlus from .loveda import LoveDA +from .millionaid import MillionAID from .naip import NAIP from .nasa_marine_debris import NASAMarineDebris from .nwpu import VHR10 @@ -163,6 +164,7 @@ "LandCoverAI", "LEVIRCDPlus", "LoveDA", + "MillionAID", "NASAMarineDebris", "OSCD", "PatternNet", diff --git a/torchgeo/datasets/millionaid.py b/torchgeo/datasets/millionaid.py new file mode 100644 index 00000000000..5136907ff4d --- /dev/null +++ b/torchgeo/datasets/millionaid.py @@ -0,0 +1,371 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Million-AID dataset.""" +import glob +import os +from typing import Any, Callable, Dict, List, Optional, cast + +import matplotlib.pyplot as plt +import numpy as np +import torch +from PIL import Image +from torch import Tensor + +from torchgeo.datasets import VisionDataset + +from .utils import check_integrity, extract_archive + + +class MillionAID(VisionDataset): + """Million-AID Dataset. + + The `MillionAID `_ dataset consists + of one million aerial images from Google Earth Engine that offers + either `a multi-class learning task + `_ + with 51 classes or a `multi-label learning task + `_ + with 73 different possible labels. For more details please consult + the accompanying `paper `_. + + Dataset features: + + * RGB aerial images with varying resolutions from 0.5 m to 153 m per pixel + * images within classes can have different pixel dimension + + Dataset format: + + * images are three-channel jpg + + If you use this dataset in your research, please cite the following paper: + + * https://ieeexplore.ieee.org/document/9393553 + + .. versionadded:: 0.3 + """ + + multi_label_categories = [ + "agriculture_land", + "airport_area", + "apartment", + "apron", + "arable_land", + "bare_land", + "baseball_field", + "basketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "commercial_land", + "dam", + "desert", + "detached_house", + "dry_field", + "factory_area", + "forest", + "golf_course", + "grassland", + "greenhouse", + "ground_track_field", + "helipad", + "highway_area", + "ice_land", + "industrial_land", + "intersection", + "island", + "lake", + "leisure_land", + "meadow", + "mine", + "mining_area", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "port_area", + "power_station", + "public_service_land", + "quarry", + "railway", + "railway_area", + "religious_land", + "residential_land", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "special_land", + "sports_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "transportation_land", + "unutilized_land", + "viaduct", + "wastewater_plant", + "water_area", + "wind_turbine", + "woodland", + "works", + ] + + multi_class_categories = [ + "apartment", + "apron", + "bare_land", + "baseball_field", + "bapsketball_court", + "beach", + "bridge", + "cemetery", + "church", + "commercial_area", + "dam", + "desert", + "detached_house", + "dry_field", + "forest", + "golf_course", + "greenhouse", + "ground_track_field", + "helipad", + "ice_land", + "intersection", + "island", + "lake", + "meadow", + "mine", + "mobile_home_park", + "oil_field", + "orchard", + "paddy_field", + "parking_lot", + "pier", + "quarry", + "railway", + "river", + "road", + "rock_land", + "roundabout", + "runway", + "solar_power_plant", + "sparse_shrub_land", + "stadium", + "storage_tank", + "substation", + "swimming_pool", + "tennis_court", + "terraced_field", + "train_station", + "viaduct", + "wastewater_plant", + "wind_turbine", + "works", + ] + + md5s = { + "train": "1b40503cafa9b0601653ca36cd788852", + "test": "51a63ee3eeb1351889eacff349a983d8", + } + + filenames = {"train": "train.zip", "test": "test.zip"} + + tasks = ["multi-class", "multi-label"] + splits = ["train", "test"] + + def __init__( + self, + root: str = "data", + task: str = "multi-class", + split: str = "train", + transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, + checksum: bool = False, + ) -> None: + """Initialize a new MillionAID dataset instance. + + Args: + root: root directory where dataset can be found + task: type of task, either "multi-class" or "multi-label" + split: train or test split + transforms: a function/transform that takes input sample and its target as + entry and returns a transformed version + checksum: if True, check the MD5 of the downloaded files (may be slow) + + Raises: + RuntimeError: if dataset is not found + """ + self.root = root + self.transforms = transforms + self.checksum = checksum + assert task in self.tasks + assert split in self.splits + self.task = task + self.split = split + + self._verify() + + self.files = self._load_files(self.root) + + self.classes = sorted({cls for f in self.files for cls in f["label"]}) + self.class_to_idx: Dict[str, int] = {c: i for i, c in enumerate(self.classes)} + + def __len__(self) -> int: + """Return the number of data points in the dataset. + + Returns: + length of the dataset + """ + return len(self.files) + + def __getitem__(self, index: int) -> Dict[str, Tensor]: + """Return an index within the dataset. + + Args: + index: index to return + + Returns: + data and label at that index + """ + files = self.files[index] + image = self._load_image(files["image"]) + cls_label = [self.class_to_idx[label] for label in files["label"]] + label = torch.tensor(cls_label, dtype=torch.long) + sample = {"image": image, "label": label} + + if self.transforms is not None: + sample = self.transforms(sample) + + return sample + + def _load_files(self, root: str) -> List[Dict[str, Any]]: + """Return the paths of the files in the dataset. + + Args: + root: root directory of dataset + + Returns: + list of dicts containing paths for each pair of image, and list of labels + """ + imgs_no_subcat = list( + glob.glob(os.path.join(root, self.split, "*", "*", "*.jpg")) + ) + + imgs_subcat = list( + glob.glob(os.path.join(root, self.split, "*", "*", "*", "*.jpg")) + ) + + scenes = [p.split(os.sep)[-3] for p in imgs_no_subcat] + [ + p.split(os.sep)[-4] for p in imgs_subcat + ] + + subcategories = ["Missing" for p in imgs_no_subcat] + [ + p.split(os.sep)[-3] for p in imgs_subcat + ] + + classes = [p.split(os.sep)[-2] for p in imgs_no_subcat] + [ + p.split(os.sep)[-2] for p in imgs_subcat + ] + + if self.task == "multi-label": + labels = [ + [sc, sub, c] if sub != "Missing" else [sc, c] + for sc, sub, c in zip(scenes, subcategories, classes) + ] + else: + labels = [[c] for c in classes] + + images = imgs_no_subcat + imgs_subcat + + files = [dict(image=img, label=l) for img, l in zip(images, labels)] + + return files + + def _load_image(self, path: str) -> Tensor: + """Load a single image. + + Args: + path: path to the image + + Returns: + the image + """ + with Image.open(path) as img: + array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) + tensor: Tensor = torch.from_numpy(array) + # Convert from HxWxC to CxHxW + tensor = tensor.permute((2, 0, 1)) + return tensor + + def _verify(self) -> None: + """Checks the integrity of the dataset structure. + + Returns: + True if the dataset directories are found, else False + """ + filepath = os.path.join(self.root, self.split) + if os.path.isdir(filepath): + return + + filepath = os.path.join(self.root, self.split + ".zip") + if os.path.isfile(filepath): + if self.checksum and not check_integrity(filepath, self.md5s[self.split]): + raise RuntimeError("Dataset found, but corrupted.") + extract_archive(filepath) + return + + raise RuntimeError( + f"Dataset not found in `root={self.root}` directory, either " + "specify a different `root` directory or manually download " + "the dataset to this directory." + ) + + def plot( + self, + sample: Dict[str, Tensor], + show_titles: bool = True, + suptitle: Optional[str] = None, + ) -> plt.Figure: + """Plot a sample from the dataset. + + Args: + sample: a sample returned by :meth:`__getitem__` + show_titles: flag indicating whether to show titles above each panel + suptitle: optional string to use as a suptitle + + Returns: + a matplotlib Figure with the rendered sample + + """ + image = np.rollaxis(sample["image"].numpy(), 0, 3) + labels = [self.classes[cast(int, label)] for label in sample["label"]] + + showing_predictions = "prediction" in sample + if showing_predictions: + prediction_labels = [ + self.classes[cast(int, label)] for label in sample["prediction"] + ] + + fig, ax = plt.subplots(figsize=(4, 4)) + ax.imshow(image) + ax.axis("off") + if show_titles: + title = f"Label: {labels}" + if showing_predictions: + title += f"\nPrediction: {prediction_labels}" + ax.set_title(title) + + if suptitle is not None: + plt.suptitle(suptitle) + return fig