You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
Lately I'm working with LDMs to generate samples conditioned by a specific conditional vector. My model learned to generate samples in an axial-2D fashion. The conditioning embeddings are based on 2D axial slices as well. However, during inference I need to work with entire volumes, so I need to find something to do a proper aggregation of 2D generated slices. Following the tutorial here, I noticed that SliceInferer might come in handy. However, as the documentation suggests, you can just tweak a little bit the network's method you will use for inference, and not the external conditions. What's happening in the SliceInferer is, as the name suggests, passing the input slice-by-slice to the considered network, yet the condition is not processed in the same way.
Describe the solution you'd like
It would be nice to adjust the code to allow models' conditioning. An idea can be to add a "forced" parameter in the SliceInferer (e.g., "condition") that may call another SliceWindowInferer dedicated to the condition.
Additional context
Given the number of axial slices Z and a bottleneck output that matches of my network that matches the dimensionality D of the conditional vector, the main error behind using SliceInferer is that the input model expects a [1, D] but what it got is a [Z,D] of course.
Thanks for helping me out!
The text was updated successfully, but these errors were encountered:
Is your feature request related to a problem? Please describe.
Lately I'm working with LDMs to generate samples conditioned by a specific conditional vector. My model learned to generate samples in an axial-2D fashion. The conditioning embeddings are based on 2D axial slices as well. However, during inference I need to work with entire volumes, so I need to find something to do a proper aggregation of 2D generated slices. Following the tutorial here, I noticed that SliceInferer might come in handy. However, as the documentation suggests, you can just tweak a little bit the network's method you will use for inference, and not the external conditions. What's happening in the SliceInferer is, as the name suggests, passing the input slice-by-slice to the considered network, yet the condition is not processed in the same way.
Describe the solution you'd like
It would be nice to adjust the code to allow models' conditioning. An idea can be to add a "forced" parameter in the SliceInferer (e.g., "condition") that may call another SliceWindowInferer dedicated to the condition.
Additional context
Given the number of axial slices Z and a bottleneck output that matches of my network that matches the dimensionality D of the conditional vector, the main error behind using SliceInferer is that the input model expects a [1, D] but what it got is a [Z,D] of course.
Thanks for helping me out!
The text was updated successfully, but these errors were encountered: