From 7ac65a256dd428990b30b9c0058721d323198db8 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 2 Sep 2022 12:08:31 -0700 Subject: [PATCH] Fix float rounding issues (#736) * Fix rounding bugs * Mypy fixes * Undo some changes * Undo some changes * mypy fixes * Fix tests * Increase size of ETCI2021 fake data * Undo model changes * line length fix --- tests/data/etci2021/data.py | 18 +++++++++++++++--- .../florence_20190302t234651_x-0_y-0.png | Bin 0 -> 91 bytes .../florence_20190302t234651_x-0_y-0_vh.png | Bin 0 -> 91 bytes .../florence_20190302t234651_x-0_y-0_vv.png | Bin 0 -> 91 bytes .../florence_20190302t234651_x-0_y-0.png | Bin 0 -> 91 bytes ...drivernorth_20190302t234651_x-0_y-0_vh.png | Bin 0 -> 91 bytes ...drivernorth_20190302t234651_x-0_y-0_vv.png | Bin 0 -> 91 bytes .../redrivernorth_20190302t234651_x-0_y-0.png | Bin 0 -> 91 bytes .../data/etci2021/test_without_ref_labels.zip | Bin 4068 -> 6039 bytes tests/data/etci2021/train.zip | Bin 4622 -> 6822 bytes .../northal_20190302t234651_x-0_y-0.png | Bin 0 -> 91 bytes .../vh/northal_20190302t234651_x-0_y-0_vh.png | Bin 0 -> 91 bytes .../vv/northal_20190302t234651_x-0_y-0_vv.png | Bin 0 -> 91 bytes .../northal_20190302t234651_x-0_y-0.png | Bin 0 -> 91 bytes tests/data/etci2021/val_with_ref_labels.zip | Bin 4524 -> 6732 bytes tests/datamodules/test_utils.py | 10 +++++----- tests/datasets/test_etci2021.py | 8 ++++---- tests/models/test_changestar.py | 6 +++--- tests/samplers/test_single.py | 4 ++-- torchgeo/datamodules/chesapeake.py | 6 +++--- torchgeo/datamodules/etci2021.py | 2 +- torchgeo/datamodules/utils.py | 6 +++--- torchgeo/datasets/geo.py | 4 ++-- torchgeo/datasets/openbuildings.py | 4 ++-- torchgeo/datasets/so2sat.py | 2 +- torchgeo/datasets/spacenet.py | 2 +- torchgeo/samplers/single.py | 2 +- 27 files changed, 43 insertions(+), 31 deletions(-) create mode 100644 tests/data/etci2021/test/florence_20190302t234651/tiles/flood_label/florence_20190302t234651_x-0_y-0.png create mode 100644 tests/data/etci2021/test/florence_20190302t234651/tiles/vh/florence_20190302t234651_x-0_y-0_vh.png create mode 100644 tests/data/etci2021/test/florence_20190302t234651/tiles/vv/florence_20190302t234651_x-0_y-0_vv.png create mode 100644 tests/data/etci2021/test/florence_20190302t234651/tiles/water_body_label/florence_20190302t234651_x-0_y-0.png create mode 100644 tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vh/redrivernorth_20190302t234651_x-0_y-0_vh.png create mode 100644 tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vv/redrivernorth_20190302t234651_x-0_y-0_vv.png create mode 100644 tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/water_body_label/redrivernorth_20190302t234651_x-0_y-0.png create mode 100644 tests/data/etci2021/train/northal_20190302t234651/tiles/flood_label/northal_20190302t234651_x-0_y-0.png create mode 100644 tests/data/etci2021/train/northal_20190302t234651/tiles/vh/northal_20190302t234651_x-0_y-0_vh.png create mode 100644 tests/data/etci2021/train/northal_20190302t234651/tiles/vv/northal_20190302t234651_x-0_y-0_vv.png create mode 100644 tests/data/etci2021/train/northal_20190302t234651/tiles/water_body_label/northal_20190302t234651_x-0_y-0.png diff --git a/tests/data/etci2021/data.py b/tests/data/etci2021/data.py index 03d14d640b4..b90b7f51b93 100755 --- a/tests/data/etci2021/data.py +++ b/tests/data/etci2021/data.py @@ -15,17 +15,29 @@ { "filename": "train.zip", "directory": "train", - "subdirs": ["nebraska_20170108t002112", "bangladesh_20170314t115609"], + "subdirs": [ + "nebraska_20170108t002112", + "bangladesh_20170314t115609", + "northal_20190302t234651", + ], }, { "filename": "val_with_ref_labels.zip", "directory": "test", - "subdirs": ["florence_20180510t231343", "florence_20180522t231344"], + "subdirs": [ + "florence_20180510t231343", + "florence_20180522t231344", + "florence_20190302t234651", + ], }, { "filename": "test_without_ref_labels.zip", "directory": "test_internal", - "subdirs": ["redrivernorth_20190104t002247", "redrivernorth_20190116t002247"], + "subdirs": [ + "redrivernorth_20190104t002247", + "redrivernorth_20190116t002247", + "redrivernorth_20190302t234651", + ], }, ] diff --git a/tests/data/etci2021/test/florence_20190302t234651/tiles/flood_label/florence_20190302t234651_x-0_y-0.png b/tests/data/etci2021/test/florence_20190302t234651/tiles/flood_label/florence_20190302t234651_x-0_y-0.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test/florence_20190302t234651/tiles/vh/florence_20190302t234651_x-0_y-0_vh.png b/tests/data/etci2021/test/florence_20190302t234651/tiles/vh/florence_20190302t234651_x-0_y-0_vh.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test/florence_20190302t234651/tiles/vv/florence_20190302t234651_x-0_y-0_vv.png b/tests/data/etci2021/test/florence_20190302t234651/tiles/vv/florence_20190302t234651_x-0_y-0_vv.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test/florence_20190302t234651/tiles/water_body_label/florence_20190302t234651_x-0_y-0.png b/tests/data/etci2021/test/florence_20190302t234651/tiles/water_body_label/florence_20190302t234651_x-0_y-0.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vh/redrivernorth_20190302t234651_x-0_y-0_vh.png b/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vh/redrivernorth_20190302t234651_x-0_y-0_vh.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vv/redrivernorth_20190302t234651_x-0_y-0_vv.png b/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/vv/redrivernorth_20190302t234651_x-0_y-0_vv.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/water_body_label/redrivernorth_20190302t234651_x-0_y-0.png b/tests/data/etci2021/test_internal/redrivernorth_20190302t234651/tiles/water_body_label/redrivernorth_20190302t234651_x-0_y-0.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/test_without_ref_labels.zip b/tests/data/etci2021/test_without_ref_labels.zip index 34e062dbc3e99cb4063871a3faa59dd5e11def98..a33afb587295e15a78ceee91b23f8b184f57a7f9 100644 GIT binary patch literal 6039 zcmWIWW@Zs#0D(u_L_@(0D8UD$OHzwV;xqF~Qj78ubMymns@1_zTa=nolvxJUmtRzp z5pQH*XlY<*U{YdWU}R)sPO(XbW>hlC*ubd7$k@cp)DVxGjEM0al1U|*IjO~X%rhm+ zys`}H8(D_ONN2KKS)K??Eb&SCDV6a#iAkx5M8g42V7C_BTkj1tI2wo($+7?#3|P_@ zBJ|@cbPeJwfw(L~uOKh|HP46Q=T9d+N=it0@byapC(|W9g);|PR6=*oFJ;nH^%p-r zq02I#-{=!lWxH6-_w!HF#Sf~Q^^Qe%2 zF^r@<|Dqd-ktiw6zvvd=$iHx9uyj7?>R&^+6G_Rxa7{3O4YK?TcOTSKW!UpCTp7$8 zw8+1VOd`y<`x3x55)cS5yl@23=sgSEdO>{)2mt9Ny{AE}$z?!GP+f(sJckD)u;omw z$p?rr8PS5pYci-U0|B7W;9?}g%}6H0N=Do+2lY)5=0}lbKDeYL){3>{Sph0Jaa#fE zyC8hQz(h_cVJS&r7Gke#VOB0^OeD)f>?JEA_7GVY+_ORWAJaLY4hak_Xw+t=ivO`> z2TaF-ibs4dON2QKNH1t)U?In4*vmLfw?NA{gmbWW4InN8Wv5fbI0w-df)=^>%ty^m zQmjOH7sGs#vQsX3R*;gN4v=L91=&fFjm%)fl422|OLlg8O_p;o{7;nQNXt$K*vSeg z>`eyvXdNP;z)gO9$sM(J>f#{AIfxpS)Y|D5S>}_Jopd?L@dB`b!WT-YwNnFGR$wG2 zgoT*Z0cxT9kSq(a*G`B=H6oReSUaueA}5G3{Esa=!IKuuainFZ25xd(hP`&e7EsXa U#0o4=7#Jjh&>U!?7Y~RB0I3|%;Q#;t literal 4068 zcmWIWW@Zs#0D;%ZvxC75D8UD$OHzwV;xqF~Qj78ubMymns@1_zTa=nolvxJUmtRzp z5pQH*XlY<*U{YdWU}R)sPO(XbW`s>LBGFAHnK`M&cwA;mmU(3v)HkvWkCDz~xw1SF z7{&2P`6-q0If+TBh^XcOCyQGP?ydI*8XOJ8iDX#-3CAl8uk2=Pr--zUGQ~>+1Ag&k>r%086x(KFJ}=Cm2CRl)vcaDKev{jM(TYqmfTA z%OB$MFQO!%PX0wVj|%x0!$`{WFS?N!i4u3JCoX}YTYw|~!j%o){0n!YJ|ux)1QoXY z3)ckm7mZRqW}!r8^kC1w7~Y64qkjHnWD;S<-GTs?u|Oce@WK&9qqiq;>jkwaAONJ7 z^!5d@CYJ#%L3I`0^Z=}RahnV(u^<5C<^w<`VUrPUB3Su_+k8+91Oh18R1_ zz=FnyWH|@J|3ottC-VsiqYb5OI>LN;=O v7{mYAvJ*UM!SXt3*{O`39G3wjpD4FLEA{|yR$u@yFbD!+AJ9Sv4iFCj6@Z!& diff --git a/tests/data/etci2021/train.zip b/tests/data/etci2021/train.zip index b24e36421b26fec41318aeec09acc81671367f31..0b2a11d5d580bf6ac105574e403010bef60e9ccd 100644 GIT binary patch literal 6822 zcmd5>OKTHR6i(BJt;SYLwJWVci&n9nNt)Vje5|WBAO&BTPHb$7At}i?Z9s~HLP1=p zOA(!4pc@fH5nPEN2!bxe7loqGjff(+5bw$4=Fary-g`pZ=qYWcko(>5yuO~3#zyKJ zIs#w|4>yk$A4)%($g7z%rZW98;d^$J-;*#hlWAijnNP*Tn!Znq=mVy%M+dcG$?>l8 z@k}z2GxFyR_pq)FnVJ^X^|0jdHg-6hGgC$y&JJr4Eo_D(1A|dra<+$^6W%1@hVtoAb&mYq*!Pl+ZUp{N-9(#26La?`I zwDbA(tHas0@b_S0e#^yQuNEgejf0ok3Qet{Yac>;o?f24b8oJ1>2~|A0IhJI_em^o zE*2^;YUToTwcE%dO8suJdn;A+R+@vFvBwvVg(w_rjd=6&xw9N~i8QK2CV~g?CXbdg z&uTXwP%1Rl;K-zA2+xOWFzO}Ww8!wM#7l90XtCE)aL(Yla}iN{vPR|4r{DzM4R7WC zM;=sDL>8)uY9~J3<>+Swvyn=9G$evId9)#U*A-aqDa5E-lER&rG5kn z36o9bZ2SwIIgVt_1w!FN%@l_9r+`tVsYV|EnsG;$44G-jS8wQMunY}as7=&<50&~?G52+M?Bb~!AX)Ev919(g(Epo$7;oQ#L}Ga$!)f#uf1}>5GxA4!C?Y%4I2DOnSfx>Dov3xMA(%W z;vXcUH>@&rSbG@y;q^3Na=wbb_JiKQg36m3b}Pp8EfSO~bruZ1qhWBS=GeU((7_Ny ztBPKRg^8MDe?nf(d8Z(v>SBNUI3ZjBytTGM0{#q~-iRXrf14}M)GPzVMJ6`V2>*Br Obdj$);$3?y{Pq`@%Smql literal 4622 zcmWIWW@Zs#0D;%ZvxC75D8UA#ONtUR^YjC7s!@ciNlMI1&q+*4EzXEHGB7kZFg7$P zF*GzaGqA*?T^_DIFEyzsu{b*sqS?^EqQt#Cicar4Aa`6Q6L(h)6hPGzxXhB2XP3>UfGbLWPVm@rEEr9xORy zx=M?91MwC#q?m)RsIeg4Js24b=1a^>3#yAC)eph22l^7Xnt>>A;)r)UMu-sh35}9J zW)wsc?+xUXJBWNhN+K#FDiKk?@L^;UVaDA814aQ52r#^G1ktee7;ZhF78wM9^r!%t z1T#^9H*VuVEi?!K8J7%XLX5*zwE$ZrxD5lf*dPF8*a9GvfMKwr6}OR~790eCjC>4a z5@jTEa>eIjP>T)%K*k#}5@|fR*v0K=R8uR6HWgF|<2Du4;zPJ~AF-xl6hH_&FiQkb zMGgZC8ikmM4-Z@k2CaC;XE~@D2LlTlPZ4W5h6f2b1ysJ`a|fs`1_KKkdzpz01mrYF zsS{A+!JCChCzKK0yh8XNv;6^T6+r+fQ!ul_EXQ5$>l4n%pjHyXP!Hk_MGjVct_HPs z5EjfN-U0%d8P(i}B$$goJA+zP2={6Nd*cKPa!4p+q<2K9U}k7g1qlNS8taL-g_JfI zJ_mrBN-(gX(VCrnZ{kjVurh{F#s;-_U|>Nb69>_5K~95|I|DUNju7t*VD=_#dwdzV W2@~MW3L4;H5Cy^tpcM_AARYis9rVcn diff --git a/tests/data/etci2021/train/northal_20190302t234651/tiles/flood_label/northal_20190302t234651_x-0_y-0.png b/tests/data/etci2021/train/northal_20190302t234651/tiles/flood_label/northal_20190302t234651_x-0_y-0.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/train/northal_20190302t234651/tiles/vh/northal_20190302t234651_x-0_y-0_vh.png b/tests/data/etci2021/train/northal_20190302t234651/tiles/vh/northal_20190302t234651_x-0_y-0_vh.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/train/northal_20190302t234651/tiles/vv/northal_20190302t234651_x-0_y-0_vv.png b/tests/data/etci2021/train/northal_20190302t234651/tiles/vv/northal_20190302t234651_x-0_y-0_vv.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/train/northal_20190302t234651/tiles/water_body_label/northal_20190302t234651_x-0_y-0.png b/tests/data/etci2021/train/northal_20190302t234651/tiles/water_body_label/northal_20190302t234651_x-0_y-0.png new file mode 100644 index 0000000000000000000000000000000000000000..320c3449e5f07848665bfa3dcbaaa8891a85e6eb GIT binary patch literal 91 zcmeAS@N?(olHy`uVBq!ia0vp^4j|0I1SD0tpLGH$B~KT}kcv5P4>Ixsd50E!_s`{L d7XdP1;6t8)0|R3>!;0@95l>e?mvv4FO#l*y6GZ?3 literal 0 HcmV?d00001 diff --git a/tests/data/etci2021/val_with_ref_labels.zip b/tests/data/etci2021/val_with_ref_labels.zip index 6077f058efdc21c951bf46c07d16afc0318b7694..b033d5e084400007c8db063a492dd56c09ab3ada 100644 GIT binary patch literal 6732 zcmd5=O=uHA7~LjmY&Etb)gXcjRz$>f^C$ME;4PpP%t1VC5*t$?G_{FqBO(S5f_e~- zDm52xVsC;6g`T`9f(U}3NFgZnqGD0-;!Kj+o!xI{^KD|9Oi7c&&YSmr^WMu$rV?J? zm_o4sb zE@kJnOkuigD{iQ+p^w$iO9=vXmb^|NAxUvE326(5RIN;2`b=nUHn$r1Ui@0Sw~|Sv zzO4P48}_^#+4FkYH=ca{=#uBazLUXMcW+G=Mk7Bx<(u2D{(f7_1=B~bkCy!-fjgf9 z2VdS;SbTadwEA%CeTC^7(msQv4LZ@Ph;^&rG_nF-qkTgjtr{Jn8of(Bhk+?e7BkV> zO%^)>ziUjzkdG``>}YKz3jk>+S?nOdPpmj-QJF^UkT6+{C;BN_3{n#fjet(r%xOH& zis+KcVdG`VVlZk}z-zQ`$fH#QlEo;GHC)qsoGg~}PFu+$K;U)2~A}u z8uXaPMz!+7i)#U|woVz)L}a5{x#Ghy(5?9>-h}$e%d8$AP^|6^lb^Cg?ln`zt$T|| zM82O+e68+XMxwiNuiuYr0@(m3rdIbRkf1=`<`U3BR^conmw@gywwHOvwR@8ra6JR% z*}MaoxN~oC2uFigTc-@HoA4gV0J;~H-aOry0%+yl+rxtDE#KpA-TMZK$oJEUuhqRf zH{$rZa&HO?>c+i~k)S}{W;0s3_YZOjn|t}~ZQH#|n{XQe%HzEwV&cxd7XvsNyxKZt gV0CY3Gme4Hy`s_$_YRU+R+Qc3?-Z%8MpjUif8M4kGXMYp literal 4524 zcmWIWW@Zs#0D;%ZvxC75D8UM(OHzwV^aF4zk%K8o%gHZF%}Y*=H!?7^FfcVVC^0fN zG&V6NL$8q$M6U@Ry($FlEy>JDEyiPl0g)yEJ(iympOcuBnuEtQbs|kG%b&+=QbNLmuU`r{nJ)1uoH@v%61sDKDU+tEzxeS9U6%R$MxU4}+r?_WpMRPz zp6JlPU&+SD({q1>RUAmj^{42stGSx*-_(3FbM> zEQ%|Qpc{!e6P;=> zN*1{HU;!RqM*Wh7kx7IZcY_L8qyvEf!wW|c4Q*cG)&go^K>$dLGLQ+@f)*FZ%`IZh zipOCVJdA*i4BTdcib)6nxob9%iQiqYLKC-%pavNPfK0p#WD;W{astKYUQjU#0U*=0 z8HqC;T&Uu9GOD2k#2N}JWN{k`YQP~}x|K*nF^U?56_~jnl!0JiK_dqf(Sd<0*&(tM zxOs=qZcwWX1{O5#C(>>V?-6nds3gVb3{U|G0}C1(nTZPmUC1KKX?` None: # Test only train/val set split train_ds, val_ds = dataset_split(ds, val_pct=1 / 2) - assert len(train_ds) == num_samples // 2 - assert len(val_ds) == num_samples // 2 + assert len(train_ds) == round(num_samples / 2) + assert len(val_ds) == round(num_samples / 2) # Test train/val/test set split train_ds, val_ds, test_ds = dataset_split(ds, val_pct=1 / 3, test_pct=1 / 3) - assert len(train_ds) == num_samples // 3 - assert len(val_ds) == num_samples // 3 - assert len(test_ds) == num_samples // 3 + assert len(train_ds) == round(num_samples / 3) + assert len(val_ds) == round(num_samples / 3) + assert len(test_ds) == round(num_samples / 3) diff --git a/tests/datasets/test_etci2021.py b/tests/datasets/test_etci2021.py index b87539805c3..956784db704 100644 --- a/tests/datasets/test_etci2021.py +++ b/tests/datasets/test_etci2021.py @@ -30,19 +30,19 @@ def dataset( metadata = { "train": { "filename": "train.zip", - "md5": "ebbd2e65cd10621bc2e90a230b474b8b", + "md5": "bd55f2116e43a35d5b94a765938be2aa", "directory": "train", "url": os.path.join(data_dir, "train.zip"), }, "val": { "filename": "val_with_ref_labels.zip", - "md5": "efdd1fe6c90f5dfd267c88b86b237c2b", + "md5": "96ed69904043e514c13c14ffd3ec45cd", "directory": "test", "url": os.path.join(data_dir, "val_with_ref_labels.zip"), }, "test": { "filename": "test_without_ref_labels.zip", - "md5": "bf1180143de5705fe95fa8490835d6d1", + "md5": "1b66d85e22c8f5b0794b3542c5ea09ef", "directory": "test_internal", "url": os.path.join(data_dir, "test_without_ref_labels.zip"), }, @@ -67,7 +67,7 @@ def test_getitem(self, dataset: ETCI2021) -> None: assert x["mask"].shape[0] == 1 def test_len(self, dataset: ETCI2021) -> None: - assert len(dataset) == 2 + assert len(dataset) == 3 def test_already_downloaded(self, dataset: ETCI2021) -> None: ETCI2021(root=dataset.root, download=True) diff --git a/tests/models/test_changestar.py b/tests/models/test_changestar.py index 4c7925eb8c1..fed3f7312ba 100644 --- a/tests/models/test_changestar.py +++ b/tests/models/test_changestar.py @@ -18,7 +18,7 @@ IN_CHANNELS = [64, 128] INNNR_CHANNELS = [16, 32, 64] NC = [1, 2, 4] -SF = [4.0, 8.0, 1.0] +SF = [4, 8, 1] class TestChangeStar: @@ -65,7 +65,7 @@ def test_invalid_changestar_farseg_backbone(self) -> None: "inc,innerc,nc,sf", list(itertools.product(IN_CHANNELS, INNNR_CHANNELS, NC, SF)) ) def test_changemixin_output_size( - self, inc: int, innerc: int, nc: int, sf: float + self, inc: int, innerc: int, nc: int, sf: int ) -> None: m = ChangeMixin( in_channels=inc, inner_channels=innerc, num_convs=nc, scale_factor=sf @@ -73,7 +73,7 @@ def test_changemixin_output_size( y = m(torch.rand(3, 2, inc // 2, 32, 32)) assert y[0].shape == y[1].shape - assert y[0].shape == (3, 1, int(32 * sf), int(32 * sf)) + assert y[0].shape == (3, 1, 32 * sf, 32 * sf) @torch.no_grad() # type: ignore[misc] def test_changestar(self) -> None: diff --git a/tests/samplers/test_single.py b/tests/samplers/test_single.py index 2380bb119fd..556de8c6587 100644 --- a/tests/samplers/test_single.py +++ b/tests/samplers/test_single.py @@ -182,8 +182,8 @@ def test_iter(self, sampler: GridGeoSampler) -> None: ) def test_len(self, sampler: GridGeoSampler) -> None: - rows = int((100 - sampler.size[0]) // sampler.stride[0]) + 1 - cols = int((100 - sampler.size[1]) // sampler.stride[1]) + 1 + rows = ((100 - sampler.size[0]) // sampler.stride[0]) + 1 + cols = ((100 - sampler.size[1]) // sampler.stride[1]) + 1 length = rows * cols * 2 assert len(sampler) == length diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 2eff5fb0d04..ecc767a7af1 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -82,7 +82,7 @@ def __init__( self.patch_size = patch_size # This is a rough estimate of how large of a patch we will need to sample in # EPSG:3857 in order to guarantee a large enough patch in the local CRS. - self.original_patch_size = int(patch_size * 2.0) + self.original_patch_size = patch_size * 2 self.batch_size = batch_size self.num_workers = num_workers self.class_set = class_set @@ -151,8 +151,8 @@ def center_crop( def center_crop_inner(sample: Dict[str, Tensor]) -> Dict[str, Tensor]: _, height, width = sample["image"].shape - y1 = (height - size) // 2 - x1 = (width - size) // 2 + y1 = round((height - size) / 2) + x1 = round((width - size) / 2) sample["image"] = sample["image"][:, y1 : y1 + size, x1 : x1 + size] sample["mask"] = sample["mask"][:, y1 : y1 + size, x1 : x1 + size] diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 5620324726c..d47fd315c6c 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -101,7 +101,7 @@ def setup(self, stage: Optional[str] = None) -> None: ) size_train_val = len(train_val_dataset) - size_train = int(0.8 * size_train_val) + size_train = round(0.8 * size_train_val) size_val = size_train_val - size_train self.train_dataset, self.val_dataset = random_split( diff --git a/torchgeo/datamodules/utils.py b/torchgeo/datamodules/utils.py index d38c678b139..d088e493312 100644 --- a/torchgeo/datamodules/utils.py +++ b/torchgeo/datamodules/utils.py @@ -28,11 +28,11 @@ def dataset_split( a list of the subset datasets. Either [train, val] or [train, val, test] """ if test_pct is None: - val_length = int(len(dataset) * val_pct) + val_length = round(len(dataset) * val_pct) train_length = len(dataset) - val_length return random_split(dataset, [train_length, val_length]) else: - val_length = int(len(dataset) * val_pct) - test_length = int(len(dataset) * test_pct) + val_length = round(len(dataset) * val_pct) + test_length = round(len(dataset) * test_pct) train_length = len(dataset) - (val_length + test_length) return random_split(dataset, [train_length, val_length, test_length]) diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py index 0790036865d..40dd3e81b5a 100644 --- a/torchgeo/datasets/geo.py +++ b/torchgeo/datasets/geo.py @@ -449,8 +449,8 @@ def _merge_files(self, filepaths: Sequence[str], query: BoundingBox) -> Tensor: bounds = (query.minx, query.miny, query.maxx, query.maxy) if len(vrt_fhs) == 1: src = vrt_fhs[0] - out_width = int(round((query.maxx - query.minx) / self.res)) - out_height = int(round((query.maxy - query.miny) / self.res)) + out_width = round((query.maxx - query.minx) / self.res) + out_height = round((query.maxy - query.miny) / self.res) out_shape = (src.count, out_height, out_width) dest = src.read( out_shape=out_shape, window=from_bounds(*bounds, src.transform) diff --git a/torchgeo/datasets/openbuildings.py b/torchgeo/datasets/openbuildings.py index 44342426b9b..8ccd760e577 100644 --- a/torchgeo/datasets/openbuildings.py +++ b/torchgeo/datasets/openbuildings.py @@ -324,11 +324,11 @@ def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: ) if shapes: masks = rasterio.features.rasterize( - shapes, out_shape=(int(height), int(width)), transform=transform + shapes, out_shape=(round(height), round(width)), transform=transform ) masks = torch.tensor(masks).unsqueeze(0) else: - masks = torch.zeros(size=(1, int(height), int(width))) + masks = torch.zeros(size=(1, round(height), round(width))) sample = {"mask": masks, "crs": self.crs, "bbox": query} diff --git a/torchgeo/datasets/so2sat.py b/torchgeo/datasets/so2sat.py index 4d9d1065ed0..1b4206cded1 100644 --- a/torchgeo/datasets/so2sat.py +++ b/torchgeo/datasets/so2sat.py @@ -197,7 +197,7 @@ def __init__( raise RuntimeError("Dataset not found or corrupted.") with h5py.File(self.fn, "r") as f: - self.size = int(f["label"].shape[0]) + self.size: int = f["label"].shape[0] def __getitem__(self, index: int) -> Dict[str, Tensor]: """Return an index within the dataset. diff --git a/torchgeo/datasets/spacenet.py b/torchgeo/datasets/spacenet.py index 36c16b01c44..f94bd7edd40 100644 --- a/torchgeo/datasets/spacenet.py +++ b/torchgeo/datasets/spacenet.py @@ -703,7 +703,7 @@ def _load_mask( speed_arr_bin = np.arange(min_speed_bin, max_speed_bin + 1) bin_size_mph = 10.0 speed_cls_arr: "np.typing.NDArray[np.int_]" = np.array( - [int(math.ceil(s / bin_size_mph)) for s in speed_arr_bin] + [math.ceil(s / bin_size_mph) for s in speed_arr_bin] ) try: diff --git a/torchgeo/samplers/single.py b/torchgeo/samplers/single.py index d671dd8b8bc..92930a24382 100644 --- a/torchgeo/samplers/single.py +++ b/torchgeo/samplers/single.py @@ -205,7 +205,7 @@ def __init__( ): self.hits.append(hit) - self.length: int = 0 + self.length = 0 for hit in self.hits: bounds = BoundingBox(*hit.bounds)