Skip to content

Commit

Permalink
Specify weights while combining gradients (#271)
Browse files Browse the repository at this point in the history
* Add methods and recipe to use multiple gradients on a single ingredient

* Add gradient weights to Environment.py and Gradient.py

* Add gradient_weights parameter to Ingredient and Agent classes

* Update decay length and gradient weights in test_combined_gradient.json

* Increase strength of gradient in example recipe

* Remove unused import

* Refactor gradient weighting logic

* Add validation for gradient names in Gradient.py

* Add assertion to check for at least two gradients before combining

* Check if gradient key exists in recipe before updating ingredient info

* add validation for gradient information while creating ingredient

* add tests for gradient information validation

* Simplify ingredient gradient update by moving validation to previous step
  • Loading branch information
mogres authored Jul 16, 2024
1 parent 12276f5 commit 83cca5a
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 18 deletions.
9 changes: 2 additions & 7 deletions cellpack/autopack/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,13 +974,8 @@ def create_ingredient(self, recipe, arguments):
ingredient_type = arguments["type"]
ingredient_class = ingredient.get_ingredient_class(ingredient_type)
ingr = ingredient_class(**arguments)
if (
"gradient" in arguments
and arguments["gradient"] != ""
and arguments["gradient"] != "None"
):
ingr.gradient = arguments["gradient"]
# TODO: allow ingrdients to have multiple gradients
if "gradient" in arguments:
ingr = Gradient.update_ingredient_gradient(ingr, arguments)
if "results" in arguments:
ingr.results = arguments["results"]
ingr.initialize_mesh(self.mesh_store)
Expand Down
35 changes: 26 additions & 9 deletions cellpack/autopack/Gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def __init__(self, gradient_data):

self.function = self.defaultFunction # lambda ?

@staticmethod
def update_ingredient_gradient(ingr, arguments):
"""
Update the ingredient gradient
"""
ingr.gradient = arguments["gradient"]
ingr.gradient_weights = None
if "gradient_weights" in arguments:
ingr.gradient_weights = arguments["gradient_weights"]

return ingr

@staticmethod
def scale_between_0_and_1(values):
"""
Expand All @@ -101,7 +113,7 @@ def scale_between_0_and_1(values):
return (values - min_value) / (max_value - min_value)

@staticmethod
def get_combined_gradient_weight(gradient_list):
def get_combined_gradient_weight(gradient_list, gradient_weights=None):
"""
Combine the gradient weights
Expand All @@ -115,11 +127,13 @@ def get_combined_gradient_weight(gradient_list):
numpy.ndarray
the combined gradient weight
"""
assert len(gradient_list) > 1, "Need at least two gradients to combine"

weight_list = numpy.zeros((len(gradient_list), len(gradient_list[0].weight)))
for i in range(len(gradient_list)):
weight_list[i] = Gradient.scale_between_0_and_1(gradient_list[i].weight)

combined_weight = numpy.mean(weight_list, axis=0)
combined_weight = numpy.average(weight_list, axis=0, weights=gradient_weights)
combined_weight = Gradient.scale_between_0_and_1(combined_weight)

return combined_weight
Expand All @@ -143,8 +157,8 @@ def pick_point_from_weight(weight, points):
the index of the picked point
"""
weights_to_use = numpy.take(weight, points)
weights_to_use = Gradient.scale_between_0_and_1(weights_to_use)
weights_to_use[numpy.isnan(weights_to_use)] = 0
weights_to_use = Gradient.scale_between_0_and_1(weights_to_use)

point_probabilities = weights_to_use / numpy.sum(weights_to_use)

Expand Down Expand Up @@ -176,13 +190,16 @@ def pick_point_for_ingredient(ingr, allIngrPts, all_gradients):
if isinstance(ingr.gradient, list):
if len(ingr.gradient) > 1:
if not hasattr(ingr, "combined_weight"):
gradient_list = [
gradient
for gradient_name, gradient in all_gradients.items()
if gradient_name in ingr.gradient
]
gradient_list = []
for gradient_name in ingr.gradient:
if gradient_name not in all_gradients:
raise ValueError(
f"Gradient {gradient_name} not found in gradient list"
)
gradient_list.append(all_gradients[gradient_name])

combined_weight = Gradient.get_combined_gradient_weight(
gradient_list
gradient_list, ingr.gradient_weights
)
ingr.combined_weight = combined_weight

Expand Down
35 changes: 35 additions & 0 deletions cellpack/autopack/ingredient/Ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class Ingredient(Agent):
"distance_function",
"force_random",
"gradient",
"gradient_weights",
"is_attractor",
"max_jitter",
"molarity",
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
distance_function=None,
force_random=False, # avoid any binding
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down Expand Up @@ -231,6 +233,7 @@ def __init__(
distance_function=distance_function,
force_random=force_random,
gradient=gradient,
gradient_weights=gradient_weights,
is_attractor=is_attractor,
overwrite_distance_function=overwrite_distance_function,
packing_mode=packing_mode,
Expand Down Expand Up @@ -407,6 +410,38 @@ def validate_ingredient_info(ingredient_info):
ingredient_info["size_options"]
)

# check if gradient information is entered correctly
if "gradient" in ingredient_info:
if not isinstance(ingredient_info["gradient"], (list, str)):
raise Exception(
(
f"Invalid gradient: {ingredient_info['gradient']} "
f"for ingredient {ingredient_info['name']}"
)
)
if (
ingredient_info["gradient"] == ""
or ingredient_info["gradient"] == "None"
):
raise Exception(
f"Missing gradient for ingredient {ingredient_info['name']}"
)

# if multiple gradients are provided with weights, check if weights are correct
if isinstance(ingredient_info["gradient"], list):
if "gradient_weights" in ingredient_info:
# check if gradient_weights are missing
if not isinstance(ingredient_info["gradient_weights"], list):
raise Exception(
f"Invalid gradient weights for ingredient {ingredient_info['name']}"
)
if len(ingredient_info["gradient"]) != len(
ingredient_info["gradient_weights"]
):
raise Exception(
f"Missing gradient weights for ingredient {ingredient_info['name']}"
)

return ingredient_info

def reset(self):
Expand Down
4 changes: 3 additions & 1 deletion cellpack/autopack/ingredient/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def __init__(
distance_expression=None,
distance_function=None,
force_random=False, # avoid any binding
gradient="",
gradient=None,
gradient_weights=None,
is_attractor=False,
overwrite_distance_function=True, # overWrite
packing_mode="random",
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.distance_expression = distance_expression
self.overwrite_distance_function = overwrite_distance_function
self.gradient = gradient
self.gradient_weights = gradient_weights
self.cb = None
self.radii = None
self.recipe = None # weak ref to recipe
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/ingredient/grow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
cutoff_boundary=1.0,
cutoff_surface=0.5,
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
length=10.0,
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/ingredient/multi_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
distance_function=None,
force_random=False, # avoid any binding
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/ingredient/multi_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
cutoff_boundary=None,
cutoff_surface=None,
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/ingredient/single_cube.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
distance_function=None,
force_random=False, # avoid any binding
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down
3 changes: 2 additions & 1 deletion cellpack/autopack/ingredient/single_cylinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(
distance_expression=None,
distance_function=None,
force_random=False, # avoid any binding
gradient="",
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down
1 change: 1 addition & 0 deletions cellpack/autopack/ingredient/single_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
distance_function=None,
force_random=False, # avoid any binding
gradient=None,
gradient_weights=None,
is_attractor=False,
max_jitter=(1, 1, 1),
molarity=0.0,
Expand Down
4 changes: 4 additions & 0 deletions cellpack/tests/recipes/v2/test_combined_gradient.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
"gradient": [
"X_gradient",
"Y_gradient"
],
"gradient_weights": [
70,
30
]
}
},
Expand Down
38 changes: 38 additions & 0 deletions cellpack/tests/test_ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,44 @@
},
"Missing option 'min' for uniform distribution",
),
(
{
"name": "test",
"type": "single_sphere",
"count": 1,
"gradient": 3,
},
"Invalid gradient: 3 for ingredient test",
),
(
{
"name": "test",
"type": "single_sphere",
"count": 1,
"gradient": "",
},
"Missing gradient for ingredient test",
),
(
{
"name": "test",
"type": "single_sphere",
"count": 1,
"gradient": ["gradient_1", "gradient_2"],
"gradient_weights": 0.5,
},
"Invalid gradient weights for ingredient test",
),
(
{
"name": "test",
"type": "single_sphere",
"count": 1,
"gradient": ["gradient_1", "gradient_2", "gradient_3"],
"gradient_weights": [0.5, 0.5],
},
"Missing gradient weights for ingredient test",
),
],
)
def test_validate_ingredient_info(ingredient_info, output):
Expand Down

0 comments on commit 83cca5a

Please sign in to comment.