Skip to content

Commit

Permalink
Merge pull request #118 from pyiron/port_access_by_attribute
Browse files Browse the repository at this point in the history
Port access by attribute
  • Loading branch information
liamhuber authored Nov 16, 2022
2 parents 23d176f + 536cbe4 commit 12b3d17
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 69 deletions.
63 changes: 63 additions & 0 deletions ironflow/model/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,71 @@

from ryvencore import Node as NodeCore
from ryvencore.Base import Event
from ryvencore.NodePort import NodePort

from ironflow.gui.canvas_widgets import NodeWidget


class PortList(list):
"""
When used to hold a collection of `NodePort` objects, the values of these ports then become accessible by their
labels, as long as those labels do not match an existing method of the builtin list class.
Warning:
This class makes no check that these labels are unique; if multiple items have the same label, the first one
is returned.
Warning:
Accessing port values in this way side-steps ryven functionality when in exec mode or using an executor
(i.e. when `running_with_executor`).
"""

def __init__(self, seq=()):
super().__init__(self, seq=seq)
self._port_finder = PortFinder(self)
self._value_finder = ValueFinder(self)
# This additional mis-direction is necessary so that ports can have the same labels as list class methods

@property
def ports(self):
"""
Allows attribute-like access to ports by their `label_str`
"""
return self._port_finder

@property
def values(self):
"""
Allows attribute-like access to port values by their `label_str`
Calling `port_list.values.some_label` is equivalent to `port_list.ports.some_label.val`
"""
return self._value_finder

@property
def labels(self):
return [item.label_str if isinstance(item, NodePort) else None for item in self]


class PortFinder:
def __init__(self, port_list: PortList):
self._port_list = port_list

def __getattr__(self, key):
for node_port in [
item for item in self._port_list if isinstance(item, NodePort)
]:
if node_port.label_str == key:
return node_port
raise AttributeError(f"No port found with the label {key}")


class ValueFinder(PortFinder):
def __getattr__(self, key):
node_port = super().__getattr__(key)
return node_port.val


class Node(NodeCore):
"""
A parent class for all ironflow nodes. Apart from a small quality-of-life difference where outputs are
Expand Down Expand Up @@ -46,6 +107,8 @@ class Node(NodeCore):

def __init__(self, params):
super().__init__(params)
self.inputs = PortList()
self.outputs = PortList()

self.before_update = Event(self, int)
self.after_update = Event(self, int)
Expand Down
126 changes: 57 additions & 69 deletions ironflow/nodes/pyiron/atomistics_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,13 @@ def place_event(self):
self.update()

def update_event(self, inp=-1):
pr = Project(self.input(0))
pr = Project(self.inputs.values.name)
self.set_output_val(0, pr)

@property
def _project(self):
return self.output(0)

@property
def representations(self) -> dict:
return {
"name": str(self.input(0)),
"name": str(self.inputs.values.name),
# "job_table": self._project.job_table() if self._project is not None else None
# Todo: Figure out how to display this without breaking the gui size; right now it automatically grows
# the gui because the table is so wide.
Expand All @@ -124,7 +120,10 @@ def update_event(self, inp=-1):

@property
def representations(self) -> dict:
return {"plot3d": self.output(0).plot3d(), "print": self.output(0)}
return {
"plot3d": self.outputs.values.structure.plot3d(),
"print": self.outputs.values.structure,
}


class BulkStructure_Node(OutputsOnlyAtoms):
Expand Down Expand Up @@ -183,14 +182,14 @@ def update_event(self, inp=-1):
self.set_output_val(
0,
STRUCTURE_FACTORY.bulk(
self.input(0),
crystalstructure=self.input(1),
a=self.input(2),
c=self.input(3),
covera=self.input(4),
u=self.input(5),
orthorhombic=self.input(6),
cubic=self.input(7),
self.inputs.values.element,
crystalstructure=self.inputs.values.crystal_structure,
a=self.inputs.values.a,
c=self.inputs.values.c,
covera=self.inputs.values.c_over_a,
u=self.inputs.values.u,
orthorhombic=self.inputs.values.orthorhombic,
cubic=self.inputs.values.cubic,
),
)

Expand Down Expand Up @@ -220,7 +219,9 @@ class Repeat_Node(OutputsOnlyAtoms):
]

def update_event(self, inp=-1):
self.set_output_val(0, self.input(0).repeat(self.input(1)))
self.set_output_val(
0, self.inputs.values.structure.repeat(self.inputs.values.all)
)


class ApplyStrain_Node(OutputsOnlyAtoms):
Expand All @@ -243,7 +244,10 @@ class ApplyStrain_Node(OutputsOnlyAtoms):

def update_event(self, inp=-1):
self.set_output_val(
0, self.input(0).apply_strain(float(self.input(1)), return_box=True)
0,
self.inputs.values.structure.apply_strain(
float(self.inputs.values.strain), return_box=True
),
)


Expand Down Expand Up @@ -273,26 +277,10 @@ class Lammps_Node(Node):
]
color = "#5d95de"

@property
def _project(self):
return self.input(2)

@property
def _name(self):
return self.input(3)

@property
def _structure(self):
return self.input(4)

@property
def _potential(self):
return self.input(5)

def _run(self):
job = self._project.create.job.Lammps(self._name)
job.structure = self._structure
job.potential = self._potential
job = self.inputs.values.project.create.job.Lammps(self.inputs.values.name)
job.structure = self.inputs.values.structure
job.potential = self.inputs.values.potential
self._job = job
job.run()
self.set_output_val(1, job)
Expand All @@ -303,24 +291,22 @@ def _remove(self):
name = (
self._job.name
) # Remove based on the run job, not the input name which might have changed...
self._project.remove_job(name)
self.inputs.values.project.remove_job(name)
self.set_output_val(1, None)
except AttributeError:
pass

def _update_potential_choices(self):
potl_input = self.inputs[5]
last_potential = potl_input.val
structure = self.inputs[4].val
available_potentials = list_potentials(structure)
last_potential = self.inputs.values.potential
available_potentials = list_potentials(self.inputs.values.structure)

if len(available_potentials) == 0:
potl_input.val = "No valid potential"
potl_input.dtype.items = ["No valid potential"]
self.inputs.ports.potential.val = "No valid potential"
self.inputs.ports.potential.dtype.items = ["No valid potential"]
else:
if last_potential not in available_potentials:
potl_input.val = available_potentials[0]
potl_input.dtype.items = available_potentials
self.inputs.ports.potential.val = available_potentials[0]
self.inputs.ports.potential.dtype.items = available_potentials

def update_event(self, inp=-1):
if inp == 0:
Expand All @@ -332,7 +318,7 @@ def update_event(self, inp=-1):

@property
def representations(self) -> dict:
return {"job": BeautifulHasGroups(self.output(1))}
return {"job": BeautifulHasGroups(self.outputs.values.job)}


class GenericOutput_Node(Node):
Expand Down Expand Up @@ -367,23 +353,21 @@ class GenericOutput_Node(Node):
def __init__(self, params):
super().__init__(params)

@property
def _job(self):
return self.input(0)

def _update_fields(self):
if isinstance(self._job, AtomisticGenericJob):
self.inputs[1].dtype.items = self._job["output/generic"].list_nodes()
self.inputs[1].val = self.inputs[1].dtype.items[0]
if isinstance(self.inputs.values.job, AtomisticGenericJob):
self.inputs.ports.field.dtype.items = self.inputs.values.job[
"output/generic"
].list_nodes()
self.inputs.ports.field.val = self.inputs.ports.field.dtype.items[0]
else:
self.inputs[1].dtype.items = [self.init_inputs[1].dtype.default]
self.inputs.ports.field.dtype.items = [self.init_inputs[1].dtype.default]
# Note: It would be sensible to use `self.init_outputs[1].dtype.items` above, but this field gets updated
# to `self.inputs[1].dtype.items`, probably because of the mutability of lists.
self.inputs[1].val = self.init_inputs[1].dtype.default
self.inputs.ports.field.val = self.init_inputs[1].dtype.default

def _update_value(self):
if isinstance(self._job, AtomisticGenericJob):
val = self._job[f"output/generic/{self.input(1)}"]
if isinstance(self.inputs.values.job, AtomisticGenericJob):
val = self.inputs.values.job[f"output/generic/{self.inputs.values.field}"]
else:
val = None
self.set_output_val(0, val)
Expand Down Expand Up @@ -421,7 +405,9 @@ class IntRand_Node(Node):
color = "#aabb44"

def update_event(self, inp=-1):
val = np.random.randint(0, high=self.input(0), size=self.input(1))
val = np.random.randint(
0, high=self.inputs.values.high, size=self.inputs.values.length
)
self.set_output_val(0, val)


Expand Down Expand Up @@ -451,9 +437,9 @@ class JobName_Node(Node):
color = "#aabb44"

def update_event(self, inp=-1):
val = self.input(0) + f"{float(self.input(1))}".replace("-", "m").replace(
".", "p"
)
val = self.inputs.values.base + f"{float(self.inputs.values.float)}".replace(
"-", "m"
).replace(".", "p")
self.set_output_val(0, val)


Expand Down Expand Up @@ -488,7 +474,9 @@ def place_event(self):
self.update()

def update_event(self, inp=-1):
val = np.linspace(self.input(0), self.input(1), self.input(2))
val = np.linspace(
self.inputs.values.min, self.inputs.values.max, self.inputs.values.steps
)
self.set_output_val(0, val)


Expand Down Expand Up @@ -516,8 +504,8 @@ class Plot3d_Node(Node):
color = "#5d95de"

def update_event(self, inp=-1):
self.set_output_val(0, self.input(0).plot3d())
self.set_output_val(1, self.input(0))
self.set_output_val(0, self.inputs.values.structure.plot3d())
self.set_output_val(1, self.inputs.values.structure)


class Matplot_Node(Node):
Expand Down Expand Up @@ -548,7 +536,7 @@ def update_event(self, inp=-1):
plt.ioff()
fig = plt.figure()
plt.clf()
plt.plot(self.input(0), self.input(1))
plt.plot(self.inputs.values.x, self.inputs.values.y)
self.set_output_val(0, fig)
plt.ion()

Expand All @@ -575,7 +563,7 @@ class Sin_Node(Node):
color = "#5d95de"

def update_event(self, inp=-1):
self.set_output_val(0, np.sin(self.input(0)))
self.set_output_val(0, np.sin(self.inputs.values.x))


class Result_Node(Node):
Expand All @@ -601,7 +589,7 @@ def view_place_event(self):
self.main_widget().show_val(self.val)

def update_event(self, inp=-1):
self.val = self.input(0)
self.val = self.inputs.data.val
if self.session.gui:
self.main_widget().show_val(self.val)

Expand All @@ -626,8 +614,8 @@ class ForEach_Node(Node):
def update_event(self, inp=-1):
if inp == 0:
self._count += 1
if len(self.input(2)) > self._count:
e = self.input(2)[self._count]
if len(self.inputs.values.elements) > self._count:
e = self.inputs.values.elements[self._count]
self.set_output_val(1, e)
self.exec_output(0)
else:
Expand Down

0 comments on commit 12b3d17

Please sign in to comment.