Skip to content

Commit

Permalink
ModelSamplingFlux now takes a resolution and adjusts the shift with it.
Browse files Browse the repository at this point in the history
If you want to sample Flux dev exactly how the reference code does use
the same resolution as your image in this node.
  • Loading branch information
comfyanonymous committed Aug 4, 2024
1 parent f7a5107 commit 56f3c66
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch

class LCM(comfy.model_sampling.EPS):
Expand Down Expand Up @@ -174,17 +175,26 @@ class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}

RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, shift):
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()

x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b

sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST

Expand Down

0 comments on commit 56f3c66

Please sign in to comment.