Skip to content

Commit

Permalink
Merge pull request #210 from pytti-tools/fix_image_encode
Browse files Browse the repository at this point in the history
fixed bug with init images that mainly caused issues for non-VQGAN image models
  • Loading branch information
dmarx authored Jun 23, 2022
2 parents 0143fab + 052b593 commit 8a10acf
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,16 @@ def train(
# losses_accumulator, losses_raw_accumulator = Counter(), Counter()
losses, losses_raw = [], [] # just... don't care
total_loss = 0
if self.embedder is not None:
for mb_i in range(gradient_accumulation_steps):
# logger.debug(mb_i)

for mb_i in range(gradient_accumulation_steps):
# logger.debug(mb_i)
# logger.debug(self.image_rep.shape)
t = 1
interp_losses = [0]
prompt_losses = {}
if self.embedder is not None:
image_embeds, offsets, sizes = self.embedder(self.image_rep, input=z)

t = 1
interp_losses = [0]
if i < interp_steps:
t = i / interp_steps
interp_losses = [
Expand All @@ -317,33 +320,33 @@ def train(
for prompt in prompts
}

losses, losses_raw = zip(
*map(unpack_dict, [prompt_losses, aug_losses, image_losses])
# *map(unpack_dict, [prompt_losses])
)
# logger.debug(losses)
losses = list(losses)
# logger.debug(losses)
# losses = Counter(losses)
# logger.debug(losses)
losses_raw = list(losses_raw)
# losses_raw = Counter(losses_raw)
# losses_accumulator += losses
# losses_raw_accumulator += losses_raw

for v in prompt_losses.values():
v[0].mul_(t)

total_loss_mb = sum(map(lambda x: sum(x.values()), losses)) + sum(
interp_losses
)
losses, losses_raw = zip(
*map(unpack_dict, [prompt_losses, aug_losses, image_losses])
# *map(unpack_dict, [prompt_losses])
)
# logger.debug(losses)
losses = list(losses)
# logger.debug(losses)
# losses = Counter(losses)
# logger.debug(losses)
losses_raw = list(losses_raw)
# losses_raw = Counter(losses_raw)
# losses_accumulator += losses
# losses_raw_accumulator += losses_raw

for v in prompt_losses.values():
v[0].mul_(t)

total_loss_mb = sum(map(lambda x: sum(x.values()), losses)) + sum(
interp_losses
)

total_loss_mb /= gradient_accumulation_steps
total_loss_mb /= gradient_accumulation_steps

# total_loss_mb.backward()
total_loss_mb.backward(retain_graph=True)
# total_loss += total_loss_mb # this is causing it to break
# total_loss = total_loss_mb
# total_loss_mb.backward()
total_loss_mb.backward(retain_graph=True)
# total_loss += total_loss_mb # this is causing it to break
# total_loss = total_loss_mb

# losses = [{k:v} for k,v in losses_accumulator.items()]
# losses_raw = [{k:v} for k,v in losses_raw_accumulator.items()]
Expand Down

0 comments on commit 8a10acf

Please sign in to comment.