Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AUROC Bug and allow custom Example Selectors #531

Merged
merged 1 commit into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/autolabel/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ def process_labels(
x.successfully_labeled for x in llm_labels
]

self.df[self.generate_label_name("annotation")] = [
pickle.dumps(x) for x in llm_labels
]

# Add row level LLM metrics to the dataframe
if metrics is not None:
for metric in metrics:
Expand Down
17 changes: 10 additions & 7 deletions src/autolabel/labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from autolabel.dataset import AutolabelDataset
from autolabel.data_models import AnnotationModel, TaskRunModel
from autolabel.database import StateManager
from autolabel.few_shot import ExampleSelectorFactory
from autolabel.few_shot import ExampleSelectorFactory, BaseExampleSelector
from autolabel.models import BaseModel, ModelFactory
from autolabel.metrics import BaseMetric
from autolabel.transforms import BaseTransform, TransformFactory
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
self,
config: Union[AutolabelConfig, str, dict],
cache: Optional[bool] = True,
example_selector: Optional[BaseExampleSelector] = None,
) -> None:
self.db = StateManager()
self.generation_cache = SQLAlchemyGenerationCache() if cache else None
Expand All @@ -67,6 +68,7 @@ def __init__(
self.confidence = ConfidenceCalculator(
score_type="logprob_average", llm=self.llm
)
self.example_selector = example_selector

if in_notebook():
import nest_asyncio
Expand Down Expand Up @@ -139,12 +141,13 @@ def run(
f"Explanation column {self.config.explanation_column()} not found in dataset.\nMake sure that explanations were generated using labeler.generate_explanations(seed_file)."
)

self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config,
seed_examples,
dataset.df.keys().tolist(),
cache=self.generation_cache is not None,
)
if self.example_selector is None:
self.example_selector = ExampleSelectorFactory.initialize_selector(
self.config,
seed_examples,
dataset.df.keys().tolist(),
cache=self.generation_cache is not None,
)

num_failures = 0
current_index = self.task_run.current_index
Expand Down
7 changes: 6 additions & 1 deletion src/autolabel/metrics/auroc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List

from sklearn.metrics import roc_auc_score
import numpy as np

from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult, MetricType
Expand All @@ -26,7 +27,11 @@ def compute(
]
confidence = [llm_label.confidence_score for llm_label in filtered_llm_labels]

auroc = roc_auc_score(match, confidence)
if np.unique(match).shape[0] == 1:
# all labels are the same
auroc = 1 if match[0] == 1 else 0
else:
auroc = roc_auc_score(match, confidence)

value = [
MetricResult(
Expand Down
Loading