Skip to content

Commit

Permalink
Merge pull request #9 from microsoft/dev/miallama/edge-features
Browse files Browse the repository at this point in the history
Allow using arbitrary edge features in GNNs
  • Loading branch information
Miltos authored Feb 22, 2021
2 parents bc5323d + 8dca26f commit 7a23ce4
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 39 deletions.
2 changes: 1 addition & 1 deletion ptgnn/baseneuralmodel/abstractneuralmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def build_neural_module(self) -> TNeuralModule:

# region Saving/Loading
def save(self, path: Path, model: TNeuralModule) -> None:
os.makedirs(os.path.dirname(str(path)), exist_ok=True)
os.makedirs(os.path.dirname(str(path.absolute())), exist_ok=True)
with gzip.open(path, "wb") as f:
torch.save((self, model), f)

Expand Down
6 changes: 3 additions & 3 deletions ptgnn/implementations/typilus/graph2class.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Graph2Class(
):
def __init__(
self,
gnn_model: GraphNeuralNetworkModel[str, Any],
gnn_model: GraphNeuralNetworkModel,
max_num_classes: int = 100,
try_simplify_unks: bool = True,
):
Expand All @@ -118,7 +118,7 @@ def __init__(
self.__tensorize_samples_with_no_annotation = False
self.__tensorize_keep_original_supernode_idx = False

def __convert(self, typilus_graph: TypilusGraph) -> Tuple[GraphData[str], List[str]]:
def __convert(self, typilus_graph: TypilusGraph) -> Tuple[GraphData[str, None], List[str]]:
def get_adj_list(adjacency_dict):
for from_node_idx, to_node_idxs in adjacency_dict.items():
from_node_idx = int(from_node_idx)
Expand Down Expand Up @@ -149,7 +149,7 @@ def get_adj_list(adjacency_dict):
supernode_annotations.append(enforce_not_None(supernode_data["annotation"]))

return (
GraphData[str](
GraphData[str, None](
node_information=typilus_graph["nodes"],
edges=edges,
reference_nodes={
Expand Down
6 changes: 5 additions & 1 deletion ptgnn/implementations/typilus/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def create_mlp_mp_layers(num_edges: int):
introduce_backwards_edges=True,
add_self_edges=True,
stop_extending_minibatch_after_num_nodes=120000,
edge_dropout_rate=0.0,
),
max_num_classes=100,
)
Expand Down Expand Up @@ -149,9 +150,12 @@ def run(arguments):

initialize_metadata = True
restore_path = arguments.get("--restore-path", None)
if restore_path or (arguments["--aml"] and model_path.exists()):
if restore_path:
initialize_metadata = False
model, nn = Graph2Class.restore_model(Path(restore_path))
elif arguments["--aml"] and model_path.exists():
initialize_metadata = False
model, nn = Graph2Class.restore_model(model_path)
else:
nn = None
model = create_graph2class_gnn_model()
Expand Down
Loading

0 comments on commit 7a23ce4

Please sign in to comment.