Skip to content

Commit

Permalink
Fix targeting a bit (#300)
Browse files Browse the repository at this point in the history
In the GPAM code the case of targeting a bit was not handled correctly
(it should not be class classification). Moreover it seems that
TensorFlow is now doing a type check for multiplication (needed by
metrics such as MeanRank).
  • Loading branch information
wsxrdv authored Sep 30, 2024
1 parent 22f2ffb commit 5504503
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion papers/2024/GPAM/gpam_ecc_cm1.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def create_heads_outputs(x: Tensor, outputs: Dict[str, Dict],
relations = ingoing_relations.get(name, [])

# Get parameters for head creation.
dim = outputs[name]['max_val']
dim = outputs[name]['max_val'] if outputs[name]['max_val'] > 2 else 1
head = _make_head(x, heads, name, relations, dim)
heads[name] = head

Expand Down
2 changes: 1 addition & 1 deletion scaaml/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ def process_record(
max_val = data["max_val"]
if max_val == 2:
# Binary classification.
v = rec[data["ap"]][data["byte"]]
v = tf.cast(rec[data["ap"]][data["byte"]], dtype=tf.float32)
else:
# Multiple classes classification.
v = tf.one_hot(rec[data["ap"]][data["byte"]], max_val)
Expand Down

0 comments on commit 5504503

Please sign in to comment.