diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 36122ed..7e5746e 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -289,13 +289,16 @@ def train( # losses_accumulator, losses_raw_accumulator = Counter(), Counter() losses, losses_raw = [], [] # just... don't care total_loss = 0 - if self.embedder is not None: - for mb_i in range(gradient_accumulation_steps): - # logger.debug(mb_i) + + for mb_i in range(gradient_accumulation_steps): + # logger.debug(mb_i) + # logger.debug(self.image_rep.shape) + t = 1 + interp_losses = [0] + prompt_losses = {} + if self.embedder is not None: image_embeds, offsets, sizes = self.embedder(self.image_rep, input=z) - t = 1 - interp_losses = [0] if i < interp_steps: t = i / interp_steps interp_losses = [ @@ -317,33 +320,33 @@ def train( for prompt in prompts } - losses, losses_raw = zip( - *map(unpack_dict, [prompt_losses, aug_losses, image_losses]) - # *map(unpack_dict, [prompt_losses]) - ) - # logger.debug(losses) - losses = list(losses) - # logger.debug(losses) - # losses = Counter(losses) - # logger.debug(losses) - losses_raw = list(losses_raw) - # losses_raw = Counter(losses_raw) - # losses_accumulator += losses - # losses_raw_accumulator += losses_raw - - for v in prompt_losses.values(): - v[0].mul_(t) - - total_loss_mb = sum(map(lambda x: sum(x.values()), losses)) + sum( - interp_losses - ) + losses, losses_raw = zip( + *map(unpack_dict, [prompt_losses, aug_losses, image_losses]) + # *map(unpack_dict, [prompt_losses]) + ) + # logger.debug(losses) + losses = list(losses) + # logger.debug(losses) + # losses = Counter(losses) + # logger.debug(losses) + losses_raw = list(losses_raw) + # losses_raw = Counter(losses_raw) + # losses_accumulator += losses + # losses_raw_accumulator += losses_raw + + for v in prompt_losses.values(): + v[0].mul_(t) + + total_loss_mb = sum(map(lambda x: sum(x.values()), losses)) + sum( + interp_losses + ) - total_loss_mb /= gradient_accumulation_steps + total_loss_mb /= gradient_accumulation_steps - # total_loss_mb.backward() - total_loss_mb.backward(retain_graph=True) - # total_loss += total_loss_mb # this is causing it to break - # total_loss = total_loss_mb + # total_loss_mb.backward() + total_loss_mb.backward(retain_graph=True) + # total_loss += total_loss_mb # this is causing it to break + # total_loss = total_loss_mb # losses = [{k:v} for k,v in losses_accumulator.items()] # losses_raw = [{k:v} for k,v in losses_raw_accumulator.items()]