Skip to content

Commit

Permalink
fix: match dtype of activations during Concept.forward
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 23, 2023
1 parent 71ab816 commit 3a06459
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ matcher = ConceptMatcher(model, tokenizer, concepts=concepts)

match_info = matcher.query("Beijing is a northern city", subject="Beijing")

print(match_info.best_match.name) # located in country: China
print(match_info.betch_match.score) # 0.832
print(match_info.best_match.concept) # located in country: China
print(match_info.best_match.score) # 0.832
```

## Acknowledgements
Expand Down
4 changes: 2 additions & 2 deletions docs/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,5 @@ We can use the ``ConceptMatcher`` class to do this matching.
match_info = matcher.query("Beijing is a northern city", subject="Beijing")
print(match_info.best_match.name) # located in country: China
print(match_info.betch_match.score) # 0.832
print(match_info.best_match.concept) # located in country: China
print(match_info.best_match.score) # 0.832
2 changes: 1 addition & 1 deletion linear_relational/Concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.name = name or f"{self.relation}: {self.object}"

def forward(self, activations: torch.Tensor) -> torch.Tensor:
vector = self.vector.to(activations.device)
vector = self.vector.to(activations.device, dtype=activations.dtype)
if len(activations.shape) == 1:
return vector @ activations
return vector @ activations.T

0 comments on commit 3a06459

Please sign in to comment.