Skip to content

Commit

Permalink
Merge branch 'main' into file-application-support
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Nov 28, 2024
2 parents 5cbb732 + 2e3df48 commit 5230783
Show file tree
Hide file tree
Showing 108 changed files with 5,226 additions and 3,507 deletions.
18 changes: 0 additions & 18 deletions .github/workflows/black.yaml

This file was deleted.

12 changes: 12 additions & 0 deletions .github/workflows/ruff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Ruff Format
on: [push]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: astral-sh/ruff-action@v1
with:
args: "check --output-format=github"
changed-files: "true"
33 changes: 16 additions & 17 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Before running this file run the following command in the same folder as this file:
# aws s3 cp --recursive s3://autolabel-benchmarking data

import ast
import json
import os
import pickle as pkl
from argparse import ArgumentParser
from rich.console import Console
from typing import List
import json
import pickle as pkl
import ast
from sklearn.metrics import f1_score
import torch

import pylcs
import os
import torch
from sklearn.metrics import f1_score

from autolabel import LabelingAgent, AutolabelConfig, AutolabelDataset
from autolabel.tasks import TaskFactory
from autolabel import AutolabelConfig, AutolabelDataset, LabelingAgent
from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult
from autolabel.tasks import TaskFactory


class F1Metric(BaseMetric):
def compute(
llm_labels: List[LLMAnnotation], gt_labels: List[str]
llm_labels: List[LLMAnnotation], gt_labels: List[str],

Check failure on line 23 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (N805)

benchmark/benchmark.py:23:9: N805 First argument of a method should be named `self`

Check failure on line 23 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

benchmark/benchmark.py:23:21: FA100 Add `from __future__ import annotations` to simplify `typing.List`

Check failure on line 23 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

benchmark/benchmark.py:23:53: FA100 Add `from __future__ import annotations` to simplify `typing.List`
) -> List[MetricResult]:

Check failure on line 24 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

benchmark/benchmark.py:24:10: FA100 Add `from __future__ import annotations` to simplify `typing.List`
def construct_binary_preds(curr_input: List[str], positive_tokens: List[str]):

Check failure on line 25 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ANN202)

benchmark/benchmark.py:25:13: ANN202 Missing return type annotation for private function `construct_binary_preds`

Check failure on line 25 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

benchmark/benchmark.py:25:48: FA100 Add `from __future__ import annotations` to simplify `typing.List`

Check failure on line 25 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (FA100)

benchmark/benchmark.py:25:76: FA100 Add `from __future__ import annotations` to simplify `typing.List`
curr_token_index = 0
Expand All @@ -44,22 +43,22 @@ def construct_binary_preds(curr_input: List[str], positive_tokens: List[str]):
try:
curr_pred = " ".join(ast.literal_eval(llm_labels[i].label)).split(" ")
predictions.extend(construct_binary_preds(curr_input, curr_pred))
except Exception as e:
except Exception:

Check failure on line 46 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (BLE001)

benchmark/benchmark.py:46:20: BLE001 Do not catch blind exception: `Exception`
curr_pred = llm_labels[i].label.split(" ")
predictions.extend(construct_binary_preds(curr_input, curr_pred))
# print(e, llm_labels[i].label)

Check failure on line 49 in benchmark/benchmark.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ERA001)

benchmark/benchmark.py:49:17: ERA001 Found commented-out code
# predictions.extend([0 for _ in range(len(curr_input))])
try:
curr_gt = " ".join(ast.literal_eval(gt_labels[i])).split(" ")
except Exception as e:
except Exception:
curr_gt = gt_labels[i].split(" ")
gt.extend(construct_binary_preds(curr_input, curr_gt))
return [MetricResult(name="F1", value=f1_score(gt, predictions))]


class TextSimilarity(BaseMetric):
def compute(
llm_labels: List[LLMAnnotation], gt_labels: List[str]
llm_labels: List[LLMAnnotation], gt_labels: List[str],
) -> List[MetricResult]:
def get_similarity(str_a, str_b):
substring_lengths = pylcs.lcs_string_length(str_a, str_b)
Expand All @@ -70,8 +69,8 @@ def get_similarity(str_a, str_b):
text_similarity.append(get_similarity(llm_labels[i].label, gt_labels[i]))
return [
MetricResult(
name="TextSimilarity", value=sum(text_similarity) / len(text_similarity)
)
name="TextSimilarity", value=sum(text_similarity) / len(text_similarity),
),
]


Expand Down Expand Up @@ -138,7 +137,7 @@ def main():
if dataset in LONG_DATASETS:
few_shot = 0

config = json.load(open(f"configs/{dataset}.json", "r"))
config = json.load(open(f"configs/{dataset}.json"))
config["model"]["name"] = args.model
provider = MODEL_TO_PROVIDER.get(args.model, "vllm")
config["model"]["provider"] = provider
Expand Down Expand Up @@ -174,7 +173,7 @@ def main():
[F1Metric, TextSimilarity] if dataset in NER_DATASETS else []
)
new_ds = agent.run(
ds, max_items=args.max_items, additional_metrics=additional_metrics
ds, max_items=args.max_items, additional_metrics=additional_metrics,
)
new_ds.df.to_csv(f"outputs_{dataset}.csv")
eval_result.append([x.dict() for x in agent.eval_result])
Expand Down
7 changes: 4 additions & 3 deletions benchmark/results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import json
import os

import pandas as pd

METRICS = {
Expand Down Expand Up @@ -58,14 +59,14 @@ def main():
header_created = False
header = []
for file in eval_files:
d = json.load(open(f"{args.eval_dir}/{file}", "r"))
d = json.load(open(f"{args.eval_dir}/{file}"))
row = []
row.append("-".join(file.split(".")[0].split("/")))
if not header_created:
header.append("model")
for i, dataset in enumerate(DATASETS):
print(dataset)
config = json.load(open(f"configs/{dataset}.json", "r"))
config = json.load(open(f"configs/{dataset}.json"))
metrics_to_add = METRICS[config["task_type"]]
for metric_to_add in metrics_to_add:
for metric in d[i]:
Expand Down
10 changes: 6 additions & 4 deletions examples/banking/example_banking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"import os\n",
"\n",
"# provide your own OpenAI API key here\n",
"os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXX'\n"
"os.environ[\"OPENAI_API_KEY\"] = \"sk-XXXXXXXXXXXXXXXXXXXXXXXX\"\n"
]
},
{
Expand All @@ -66,7 +66,7 @@
"source": [
"from autolabel import get_data\n",
"\n",
"get_data('banking')"
"get_data(\"banking\")"
]
},
{
Expand Down Expand Up @@ -123,7 +123,7 @@
"outputs": [],
"source": [
"# load the config\n",
"with open('config_banking.json', 'r') as f:\n",
"with open(\"config_banking.json\") as f:\n",
" config = json.load(f)"
]
},
Expand Down Expand Up @@ -477,6 +477,7 @@
"source": [
"# dry-run -- this tells us how much this will cost and shows an example prompt\n",
"from autolabel import AutolabelDataset\n",
"\n",
"ds = AutolabelDataset(\"data/banking/test.csv\", config=config)\n",
"agent.plan(ds)"
]
Expand Down Expand Up @@ -684,7 +685,7 @@
"outputs": [],
"source": [
"# Start computing confidence scores (using Refuel's LLMs)\n",
"os.environ['REFUEL_API_KEY'] = 'sk-xxxxxxxxxxxx'"
"os.environ[\"REFUEL_API_KEY\"] = \"sk-xxxxxxxxxxxx\""
]
},
{
Expand Down Expand Up @@ -849,6 +850,7 @@
],
"source": [
"from autolabel import AutolabelDataset\n",
"\n",
"ds = AutolabelDataset(\"data/banking/test.csv\", config=config)\n",
"agent.plan(ds)"
]
Expand Down
Loading

0 comments on commit 5230783

Please sign in to comment.