From 3a064595c314d5b9ca1944e9eb4541a041128262 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 23 Nov 2023 18:19:59 +0000 Subject: [PATCH] fix: match dtype of activations during Concept.forward --- README.md | 4 ++-- docs/basic_usage.rst | 4 ++-- linear_relational/Concept.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e5ca7a7..2b42382 100644 --- a/README.md +++ b/README.md @@ -188,8 +188,8 @@ matcher = ConceptMatcher(model, tokenizer, concepts=concepts) match_info = matcher.query("Beijing is a northern city", subject="Beijing") -print(match_info.best_match.name) # located in country: China -print(match_info.betch_match.score) # 0.832 +print(match_info.best_match.concept) # located in country: China +print(match_info.best_match.score) # 0.832 ``` ## Acknowledgements diff --git a/docs/basic_usage.rst b/docs/basic_usage.rst index 2abed9d..35472f1 100644 --- a/docs/basic_usage.rst +++ b/docs/basic_usage.rst @@ -172,5 +172,5 @@ We can use the ``ConceptMatcher`` class to do this matching. match_info = matcher.query("Beijing is a northern city", subject="Beijing") - print(match_info.best_match.name) # located in country: China - print(match_info.betch_match.score) # 0.832 + print(match_info.best_match.concept) # located in country: China + print(match_info.best_match.score) # 0.832 diff --git a/linear_relational/Concept.py b/linear_relational/Concept.py index bcb0a5d..db7feb2 100644 --- a/linear_relational/Concept.py +++ b/linear_relational/Concept.py @@ -34,7 +34,7 @@ def __init__( self.name = name or f"{self.relation}: {self.object}" def forward(self, activations: torch.Tensor) -> torch.Tensor: - vector = self.vector.to(activations.device) + vector = self.vector.to(activations.device, dtype=activations.dtype) if len(activations.shape) == 1: return vector @ activations return vector @ activations.T