diff --git a/hubconf.py b/hubconf.py index 760436c0..0d638be5 100644 --- a/hubconf.py +++ b/hubconf.py @@ -6,9 +6,9 @@ from midas.midas_net import MidasNet from midas.midas_net_custom import MidasNet_small -def DPT_BEit_L_512(pretrained=True, **kwargs): +def DPT_BEiT_L_512(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() - MiDaS DPT_BEit_L_512 model for monocular depth estimation + MiDaS DPT_BEiT_L_512 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ @@ -29,9 +29,9 @@ def DPT_BEit_L_512(pretrained=True, **kwargs): return model -def DPT_BEit_L_384(pretrained=True, **kwargs): +def DPT_BEiT_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() - MiDaS DPT_BEit_L_384 model for monocular depth estimation + MiDaS DPT_BEiT_L_384 model for monocular depth estimation pretrained (bool): load pretrained weights into model """ @@ -52,6 +52,29 @@ def DPT_BEit_L_384(pretrained=True, **kwargs): return model +def DPT_BEiT_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_BEiT_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="beitb16_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + def DPT_SwinV2_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_SwinV2_L_384 model for monocular depth estimation @@ -75,6 +98,29 @@ def DPT_SwinV2_L_384(pretrained=True, **kwargs): return model +def DPT_SwinV2_B_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_SwinV2_B_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swin2b24_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + def DPT_SwinV2_T_256(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_SwinV2_T_256 model for monocular depth estimation @@ -98,6 +144,29 @@ def DPT_SwinV2_T_256(pretrained=True, **kwargs): return model +def DPT_Swin_L_384(pretrained=True, **kwargs): + """ # This docstring shows up in hub.help() + MiDaS DPT_Swin_L_384 model for monocular depth estimation + pretrained (bool): load pretrained weights into model + """ + + model = DPTDepthModel( + path=None, + backbone="swinl12_384", + non_negative=True, + ) + + if pretrained: + checkpoint = ( + "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt" + ) + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True + ) + model.load_state_dict(state_dict) + + return model + def DPT_Next_ViT_L_384(pretrained=True, **kwargs): """ # This docstring shows up in hub.help() MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation @@ -131,6 +200,8 @@ def DPT_LeViT_224(pretrained=True, **kwargs): path=None, backbone="levit_384", non_negative=True, + head_features_1=64, + head_features_2=8, ) if pretrained: @@ -289,4 +360,76 @@ def transforms(): ] ) + transforms.beit512_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 512, + 512, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin384_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 384, + 384, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.swin256_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 256, + 256, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + + transforms.levit_transform = Compose( + [ + lambda img: {"image": img / 255.0}, + Resize( + 224, + 224, + resize_target=None, + keep_aspect_ratio=False, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + PrepareForNet(), + lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0), + ] + ) + return transforms