Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cached in FeatureSimilarityModel #39

Merged
merged 1 commit into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ author_email = nlp-parser@baidu.com
# Project version, versions only above than 1.0 will assumed as a released version.
# When modifying project version to above than 1.0, here's the rules should be followed.
# http://wiki.baidu.com/pages/viewpage.action?pageId=469686381
version = 0.1.5
version = 0.1.6
# A brief introduction about the project, ANY NON-ENGLISH CHARACTER IS NOT SUPPORTED!
description = baidu TrustAI
# A longer version of introduction abouth the project, you can also include readme, change log, etc. .md or rst file is recommended.
Expand Down
2 changes: 1 addition & 1 deletion trustai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
"""TrustAI"""

__version__ = "0.1.5"
__version__ = "0.1.6"
11 changes: 9 additions & 2 deletions trustai/interpretation/example_level/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,18 @@ def dot_similarity(inputs_a, inputs_b):
return paddle.sum(inputs_a * inputs_b, axis=1)


def cos_similarity(inputs_a, inputs_b):
def cos_similarity(inputs_a, inputs_b, step=500000):
"""
calaculate cosine similarity between the two inputs.
"""
return F.cosine_similarity(inputs_a, inputs_b.unsqueeze(0))
# Processing to avoid paddle bug
start, end = 0, step
res = []
while start < inputs_a.shape[0]:
res.append(F.cosine_similarity(inputs_a[start:end], inputs_b.unsqueeze(0)))
start = end
end = end + step
return paddle.concat(res, axis=0)


def euc_similarity(inputs_a, inputs_b):
Expand Down
18 changes: 15 additions & 3 deletions trustai/interpretation/example_level/method/feature_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
feature-based similarity method.
cosine, cot and euc.
"""
import os
import sys
import functools
import warnings

Expand All @@ -25,6 +27,7 @@ def __init__(
device=None,
classifier_layer_name="classifier",
predict_fn=None,
cached_train_feature=None,
):
"""
Initialization.
Expand All @@ -38,9 +41,18 @@ def __init__(
ExampleBaseInterpreter.__init__(self, paddle_model, device, predict_fn, classifier_layer_name)
self.paddle_model = paddle_model
self.classifier_layer_name = classifier_layer_name
self.train_feature, _ = self.extract_feature_from_dataloader(train_dataloader)


if cached_train_feature is not None and os.path.isfile(cached_train_feature):
self.train_feature = paddle.load(cached_train_feature)
else:
self.train_feature, _ = self.extract_feature_from_dataloader(train_dataloader)
if cached_train_feature is not None:
try:
paddle.save(self.train_feature, cached_train_feature)
except IOError:
import sys
sys.stderr.write("save cached_train_feature fail")

def interpret(self, data, sample_num=3, sim_fn="cos"):
"""
Select most similar and dissimilar examples for a given data using the `sim_fn` metric.
Expand Down Expand Up @@ -87,7 +99,7 @@ def extract_feature_from_dataloader(self, dataloader):
"""
print("Extracting feature from given dataloader, it will take some time...")
features, preds = [], []

for batch in dataloader:
feature, pred = self.extract_feature(self.paddle_model, batch)
features.append(feature)
Expand Down