From d51bb717e07f25fa5af0f6fe1faadfe1587299e0 Mon Sep 17 00:00:00 2001 From: Anna B <72624798+aberanger@users.noreply.github.com> Date: Wed, 15 May 2024 11:49:30 +0200 Subject: [PATCH] get_dimension_stereotypes on removed community fixed (#82) --- sinr/graph_embeddings.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sinr/graph_embeddings.py b/sinr/graph_embeddings.py index f79db88..f0d0515 100644 --- a/sinr/graph_embeddings.py +++ b/sinr/graph_embeddings.py @@ -1235,7 +1235,7 @@ def get_obj_descriptors(self, obj, topk_dim=5, topk_val=-1): def get_dimension_stereotypes(self, obj, topk=5): """Get the words with the highest values on dimension obj. - :param obj: id of a dimension, or label of a word (then turned into the id of its community) + :param obj: id of a word, or label of a word (then turned into the id of its community) :type obj: int or str :param topk: topk value to consider on the dimension (Default value = 5) :type topk: int @@ -1243,7 +1243,10 @@ def get_dimension_stereotypes(self, obj, topk=5): """ index = self._get_index(obj) - return self.get_dimension_stereotypes_idx(self.get_community_membership(index), topk) + if self.community_membership[index] != -1: + return self.get_dimension_stereotypes_idx(self.get_community_membership(index), topk) + else: + raise DimensionFilteredException("'"+self.vocab[index] + "' (id "+str(index)+') is member of a community which got removed by filtering.') def get_dimension_stereotypes_idx(self, idx, topk=5): """Get the indices of the words with the highest values on dimension obj. @@ -1469,3 +1472,7 @@ def light_model_save(self): f = open(self.name + "_light.pk", 'wb') pk.dump(data,f) f.close() + +class DimensionFilteredException(Exception): + """Exception raised when trying to access a dimension removed by filtering. """ + pass \ No newline at end of file