Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 12, 2023
1 parent 9d7829c commit d6ff3fa
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions kgcnn/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@

def scatter_reduce_sum(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='sum')


def scatter_reduce_min(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='amin', include_self=False)


def scatter_reduce_max(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='amax', include_self=False)


def scatter_reduce_mean(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='mean', include_self=False)


def scatter_reduce_softmax(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='sum')
2 changes: 1 addition & 1 deletion training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
# The metrics from this script is added to the hyperparameter entry for metrics.
model.compile(**hyper.compile(metrics=metrics))
model.summary()
print(model._jit_compile)
print("Compiled with jit: %s." % model._jit_compile)
# Run keras model-fit and take time for training.
start = time.time()
hist = model.fit(x_train, y_train,
Expand Down

0 comments on commit d6ff3fa

Please sign in to comment.