diff --git a/prepare/metrics/rag_context_correctness.py b/prepare/metrics/rag_context_correctness.py index 3a3bbc324..fc8128b33 100644 --- a/prepare/metrics/rag_context_correctness.py +++ b/prepare/metrics/rag_context_correctness.py @@ -7,6 +7,7 @@ ("map", "metrics.rag.map"), ("mrr", "metrics.rag.mrr"), ("mrr", "metrics.rag.context_correctness"), + ("retrieval_at_k", "metrics.rag.retrieval_at_k"), ]: metric = MetricPipeline( main_score="score", @@ -43,6 +44,52 @@ {"mrr": 1.0, "score": 1.0, "score_name": "mrr"}, {"mrr": 0.5, "score": 0.5, "score_name": "mrr"}, ] + retrieval_at_k_instance_targets = [ + { + "match_at_1": 1.0, + "match_at_3": 1.0, + "match_at_5": 1.0, + "match_at_10": 1.0, + "match_at_20": 1.0, + "match_at_40": 1.0, + "precision_at_1": 1.0, + "precision_at_3": 0.67, + "precision_at_5": 0.67, + "precision_at_10": 0.67, + "precision_at_20": 0.67, + "precision_at_40": 0.67, + "recall_at_1": 0.5, + "recall_at_3": 1.0, + "recall_at_5": 1.0, + "recall_at_10": 1.0, + "recall_at_20": 1.0, + "recall_at_40": 1.0, + "score": 1.0, + "score_name": "match_at_1", + }, + { + "match_at_1": 0.0, + "match_at_10": 1.0, + "match_at_20": 1.0, + "match_at_3": 1.0, + "match_at_40": 1.0, + "match_at_5": 1.0, + "precision_at_1": 0.0, + "precision_at_10": 0.5, + "precision_at_20": 0.5, + "precision_at_3": 0.5, + "precision_at_40": 0.5, + "precision_at_5": 0.5, + "recall_at_1": 0.0, + "recall_at_10": 1.0, + "recall_at_20": 1.0, + "recall_at_3": 1.0, + "recall_at_40": 1.0, + "recall_at_5": 1.0, + "score": 0.0, + "score_name": "match_at_1", + }, + ] map_global_target = { "map": 0.67, @@ -62,11 +109,56 @@ "score_ci_low": 0.5, "score_name": "mrr", } + retrieval_at_k_global_target = { + "match_at_1": 0.5, + "match_at_1_ci_high": 1.0, + "match_at_1_ci_low": 0.0, + "match_at_3": 1.0, + "match_at_5": 1.0, + "match_at_10": 1.0, + "match_at_20": 1.0, + "match_at_40": 1.0, + "precision_at_1": 0.5, + "precision_at_1_ci_high": 1.0, + "precision_at_1_ci_low": 0.0, + "precision_at_3": 0.58, + "precision_at_3_ci_high": 0.67, + "precision_at_3_ci_low": 0.5, + "precision_at_5": 0.58, + "precision_at_5_ci_high": 0.67, + "precision_at_5_ci_low": 0.5, + "precision_at_10": 0.58, + "precision_at_10_ci_high": 0.67, + "precision_at_10_ci_low": 0.5, + "precision_at_20": 0.58, + "precision_at_20_ci_high": 0.67, + "precision_at_20_ci_low": 0.5, + "precision_at_40": 0.58, + "precision_at_40_ci_high": 0.67, + "precision_at_40_ci_low": 0.5, + "recall_at_1": 0.25, + "recall_at_1_ci_high": 0.5, + "recall_at_1_ci_low": 0.0, + "recall_at_3": 1.0, + "recall_at_5": 1.0, + "recall_at_10": 1.0, + "recall_at_20": 1.0, + "recall_at_40": 1.0, + "score": 0.5, + "score_ci_high": 1.0, + "score_ci_low": 0.0, + "score_name": "match_at_1", + } for catalog_name, global_target, instance_targets in [ ("metrics.rag.map", map_global_target, map_instance_targets), ("metrics.rag.mrr", mrr_global_target, mrr_instance_targets), ("metrics.rag.context_correctness", mrr_global_target, mrr_instance_targets), + ( + "metrics.rag.retrieval_at_k", + retrieval_at_k_global_target, + retrieval_at_k_instance_targets, + ), ]: # test the evaluate call test_evaluate( diff --git a/src/unitxt/catalog/metrics/rag/retrieval_at_k.json b/src/unitxt/catalog/metrics/rag/retrieval_at_k.json new file mode 100644 index 000000000..2dc1a82a5 --- /dev/null +++ b/src/unitxt/catalog/metrics/rag/retrieval_at_k.json @@ -0,0 +1,18 @@ +{ + "__type__": "metric_pipeline", + "main_score": "score", + "preprocess_steps": [ + { + "__type__": "copy", + "field": "context_ids", + "to_field": "prediction" + }, + { + "__type__": "wrap", + "field": "ground_truths_context_ids", + "inside": "list", + "to_field": "references" + } + ], + "metric": "metrics.retrieval_at_k" +}