Skip to content

Commit

Permalink
refactored preprocessor/merged mt feature extraction in encoder/fixed…
Browse files Browse the repository at this point in the history
… some bugs
  • Loading branch information
LeungTsang committed Oct 4, 2024
1 parent abf5e9a commit 44683fd
Show file tree
Hide file tree
Showing 17 changed files with 33 additions and 34 deletions.
6 changes: 3 additions & 3 deletions pangaea/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def freeze(self) -> None:
for param in self.parameters():
param.requires_grad = False

def native_forward(self, img: dict[str, torch.Tensor]) -> list[torch.Tensor]:
def simple_forward(self, img: dict[str, torch.Tensor]) -> list[torch.Tensor]:
"""Compute the forward pass of the encoder.
Args:
Expand All @@ -135,10 +135,10 @@ def native_forward(self, img: dict[str, torch.Tensor]) -> list[torch.Tensor]:
def forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
b, c, t, h, w = image[list(image.keys())[0]].shape
if self.multi_temporal:
return self.native_forward(image)
return self.simple_forward(image)
else:
if t == 1:
return self.native_forward(image)
return self.simple_forward(image)
else:
return self.naive_multi_temporal_forward(image)

Expand Down
4 changes: 2 additions & 2 deletions pangaea/encoders/croma_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
dim=self.embed_dim, depth=self.encoder_depth, in_channels=self.s2_channels
)

def native_forward(self, image):
def simple_forward(self, image):

image = self.squeeze_temporal_dimension(image)

Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(
in_channels=self.s1_channels,
)

def native_forward(self, image):
def simple_forward(self, image):
# output = []
image = self.squeeze_temporal_dimension(image)

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/dofa_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __init__(
]
)

def native_forward(self, image):
def simple_forward(self, image):
# embed patches
image = self.squeeze_temporal_dimension(image)

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/gfmswin_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def load_encoder_weights(self, logger: Logger) -> None:
self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)
x = self.patch_embed(image["optical"])
if self.ape:
Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/prithvi_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _init_weights(self, m):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def native_forward(self, image):
def simple_forward(self, image):
# embed patches
x = self.patch_embed(image["optical"])

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/remoteclip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def load_encoder_weights(self, logger: Logger) -> None:
self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)

x = self.conv1(image["optical"]) # shape = [*, width, grid, grid]
Expand Down
4 changes: 3 additions & 1 deletion pangaea/encoders/satlasnet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,10 @@ def load_encoder_weights(self, logger: Logger) -> None:

self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image):
def simple_forward(self, image):
# Define forward pass
if not self.multi_temporal:
image = self.squeeze_temporal_dimension(image)

x = self.backbone(image["optical"])

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/scalemae_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def load_encoder_weights(self, logger: Logger) -> None:
self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)

B, _, h, w = image["optical"].shape
Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/spectralgpt_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def load_encoder_weights(self, logger: Logger) -> None:
self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
def simple_forward(self, image: dict[str, torch.Tensor]) -> list[torch.Tensor]:
# input image of shape B C H W
x = image["optical"]#.unsqueeze(-3) # B C H W -> B C 1 H W

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/ssl4eo_data2vec_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def _init_weights(self, m):
def no_weight_decay(self):
return {"pos_embed", "cls_token"}

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)
x = self.patch_embed(image["optical"])
batch_size, seq_len, _ = x.size()
Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/ssl4eo_dino_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def prepare_tokens(self, x):

return self.pos_drop(x)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)

x = self.prepare_tokens(image["optical"])
Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/ssl4eo_mae_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(
self.multi_temporal = False
self.output_dim = embed_dim

def native_forward(self, image):
def simple_forward(self, image):
# embed patches
image = self.squeeze_temporal_dimension(image)

Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/ssl4eo_moco_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def load_encoder_weights(self, logger: Logger) -> None:
self.load_state_dict(pretrained_encoder, strict=False)
self.parameters_warning(missing, incompatible_shape, logger)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)

x = self.patch_embed(image["optical"])
Expand Down
2 changes: 1 addition & 1 deletion pangaea/encoders/unet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
self.in_conv = InConv(self.in_channels, self.topology[0], DoubleConv)
self.encoder = UNet_Encoder(self.topology)

def native_forward(self, image):
def simple_forward(self, image):
image = self.squeeze_temporal_dimension(image)

feat = self.in_conv(image["optical"])
Expand Down
13 changes: 7 additions & 6 deletions pangaea/engine/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
from pathlib import Path
import math
import wandb


import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -46,6 +48,7 @@ def __init__(
sliding_inference_batch: int = None,
use_wandb: bool = False,
) -> None:
self.rank = int(os.environ["RANK"])
self.val_loader = val_loader
self.logger = logging.getLogger()
self.exp_dir = exp_dir
Expand All @@ -57,13 +60,10 @@ def __init__(
self.ignore_index = self.val_loader.dataset.ignore_index
self.num_classes = len(self.classes)
self.max_name_len = max([len(name) for name in self.classes])
self.use_wandb = use_wandb

self.use_wandb = use_wandb

if use_wandb:
import wandb

self.wandb = wandb

def evaluate(
self,
Expand Down Expand Up @@ -183,6 +183,7 @@ def evaluate(self, model, model_name='model', model_ckpt_path=None):
)

for batch_idx, data in enumerate(tqdm(self.val_loader, desc=tag)):

image, target = data["image"], data["target"]
image = {k: v.to(self.device) for k, v in image.items()}
target = target.to(self.device)
Expand Down Expand Up @@ -291,7 +292,7 @@ def format_metric(name, values, mean_value):
self.logger.info(macc_str)

if self.use_wandb:
self.wandb.log(
wandb.log(
{
f"{self.split}_mIoU": metrics["mIoU"],
f"{self.split}_mF1": metrics["mF1"],
Expand Down Expand Up @@ -396,4 +397,4 @@ def log_metrics(self, metrics):
self.logger.info(header+mse+rmse)

if self.use_wandb:
self.wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]})
wandb.log({f"{self.split}_MSE": metrics["MSE"], f"{self.split}_RMSE": metrics["RMSE"]})
9 changes: 2 additions & 7 deletions pangaea/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from pangaea.utils.logger import RunningAverageMeter, sec_to_hm

import wandb

class Trainer:
def __init__(
Expand Down Expand Up @@ -88,11 +88,6 @@ def __init__(

self.start_epoch = 0

if self.use_wandb:
import wandb

self.wandb = wandb

def train(self) -> None:
"""Train the model for n_epochs then evaluate the model and save the best model."""
# end_time = time.time()
Expand Down Expand Up @@ -162,7 +157,7 @@ def train_one_epoch(self, epoch: int) -> None:
self.lr_scheduler.step()

if self.use_wandb and self.rank == 0:
self.wandb.log(
wandb.log(
{
"train_loss": loss.item(),
"learning_rate": self.optimizer.param_groups[0]["lr"],
Expand Down
9 changes: 5 additions & 4 deletions pangaea/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pprint
import random
import time
import wandb

import hydra
import torch
Expand Down Expand Up @@ -56,6 +57,8 @@ def main(cfg: DictConfig) -> None:
fix_seed(cfg.seed)
# distributed training variables
rank = int(os.environ["RANK"])
cfg.task.trainer.use_wandb = cfg.task.trainer.use_wandb and rank == 0
cfg.task.evaluator.use_wandb = cfg.task.evaluator.use_wandb and rank == 0
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)

Expand All @@ -72,8 +75,7 @@ def main(cfg: DictConfig) -> None:
config_log_dir = exp_dir / "configs"
config_log_dir.mkdir(exist_ok=True)
# init wandb
if cfg.task.trainer.use_wandb and rank == 0:
import wandb
if cfg.task.trainer.use_wandb:

wandb_cfg = OmegaConf.to_container(cfg, resolve=True)
wandb.init(
Expand All @@ -91,8 +93,7 @@ def main(cfg: DictConfig) -> None:
# load training config
cfg_path = exp_dir / "configs" / "config.yaml"
cfg = OmegaConf.load(cfg_path)
if cfg.task.trainer.use_wandb and rank == 0:
import wandb
if cfg.task.trainer.use_wandb:

wandb_cfg = OmegaConf.to_container(cfg, resolve=True)
wandb.init(
Expand Down

0 comments on commit 44683fd

Please sign in to comment.