Skip to content

Commit

Permalink
Applying ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvaBansal00 committed Nov 27, 2024
1 parent 8cb7f22 commit f8e255b
Show file tree
Hide file tree
Showing 108 changed files with 5,336 additions and 3,595 deletions.
18 changes: 0 additions & 18 deletions .github/workflows/black.yaml

This file was deleted.

12 changes: 12 additions & 0 deletions .github/workflows/ruff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Ruff Format
on: [push]

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: astral-sh/ruff-action@v1
with:
args: "check --output-format=github"
changed-files: "true"
33 changes: 16 additions & 17 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Before running this file run the following command in the same folder as this file:
# aws s3 cp --recursive s3://autolabel-benchmarking data

import ast
import json
import os
import pickle as pkl
from argparse import ArgumentParser
from rich.console import Console
from typing import List
import json
import pickle as pkl
import ast
from sklearn.metrics import f1_score
import torch

import pylcs
import os
import torch
from sklearn.metrics import f1_score

from autolabel import LabelingAgent, AutolabelConfig, AutolabelDataset
from autolabel.tasks import TaskFactory
from autolabel import AutolabelConfig, AutolabelDataset, LabelingAgent
from autolabel.metrics import BaseMetric
from autolabel.schema import LLMAnnotation, MetricResult
from autolabel.tasks import TaskFactory


class F1Metric(BaseMetric):
def compute(
llm_labels: List[LLMAnnotation], gt_labels: List[str]
llm_labels: List[LLMAnnotation], gt_labels: List[str],
) -> List[MetricResult]:
def construct_binary_preds(curr_input: List[str], positive_tokens: List[str]):
curr_token_index = 0
Expand All @@ -44,22 +43,22 @@ def construct_binary_preds(curr_input: List[str], positive_tokens: List[str]):
try:
curr_pred = " ".join(ast.literal_eval(llm_labels[i].label)).split(" ")
predictions.extend(construct_binary_preds(curr_input, curr_pred))
except Exception as e:
except Exception:
curr_pred = llm_labels[i].label.split(" ")
predictions.extend(construct_binary_preds(curr_input, curr_pred))
# print(e, llm_labels[i].label)
# predictions.extend([0 for _ in range(len(curr_input))])
try:
curr_gt = " ".join(ast.literal_eval(gt_labels[i])).split(" ")
except Exception as e:
except Exception:
curr_gt = gt_labels[i].split(" ")
gt.extend(construct_binary_preds(curr_input, curr_gt))
return [MetricResult(name="F1", value=f1_score(gt, predictions))]


class TextSimilarity(BaseMetric):
def compute(
llm_labels: List[LLMAnnotation], gt_labels: List[str]
llm_labels: List[LLMAnnotation], gt_labels: List[str],
) -> List[MetricResult]:
def get_similarity(str_a, str_b):
substring_lengths = pylcs.lcs_string_length(str_a, str_b)
Expand All @@ -70,8 +69,8 @@ def get_similarity(str_a, str_b):
text_similarity.append(get_similarity(llm_labels[i].label, gt_labels[i]))
return [
MetricResult(
name="TextSimilarity", value=sum(text_similarity) / len(text_similarity)
)
name="TextSimilarity", value=sum(text_similarity) / len(text_similarity),
),
]


Expand Down Expand Up @@ -138,7 +137,7 @@ def main():
if dataset in LONG_DATASETS:
few_shot = 0

config = json.load(open(f"configs/{dataset}.json", "r"))
config = json.load(open(f"configs/{dataset}.json"))
config["model"]["name"] = args.model
provider = MODEL_TO_PROVIDER.get(args.model, "vllm")
config["model"]["provider"] = provider
Expand Down Expand Up @@ -174,7 +173,7 @@ def main():
[F1Metric, TextSimilarity] if dataset in NER_DATASETS else []
)
new_ds = agent.run(
ds, max_items=args.max_items, additional_metrics=additional_metrics
ds, max_items=args.max_items, additional_metrics=additional_metrics,
)
new_ds.df.to_csv(f"outputs_{dataset}.csv")
eval_result.append([x.dict() for x in agent.eval_result])
Expand Down
7 changes: 4 additions & 3 deletions benchmark/results.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import json
import os

import pandas as pd

METRICS = {
Expand Down Expand Up @@ -58,14 +59,14 @@ def main():
header_created = False
header = []
for file in eval_files:
d = json.load(open(f"{args.eval_dir}/{file}", "r"))
d = json.load(open(f"{args.eval_dir}/{file}"))
row = []
row.append("-".join(file.split(".")[0].split("/")))
if not header_created:
header.append("model")
for i, dataset in enumerate(DATASETS):
print(dataset)
config = json.load(open(f"configs/{dataset}.json", "r"))
config = json.load(open(f"configs/{dataset}.json"))
metrics_to_add = METRICS[config["task_type"]]
for metric_to_add in metrics_to_add:
for metric in d[i]:
Expand Down
10 changes: 6 additions & 4 deletions examples/banking/example_banking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"import os\n",
"\n",
"# provide your own OpenAI API key here\n",
"os.environ['OPENAI_API_KEY'] = 'sk-XXXXXXXXXXXXXXXXXXXXXXXX'\n"
"os.environ[\"OPENAI_API_KEY\"] = \"sk-XXXXXXXXXXXXXXXXXXXXXXXX\"\n"
]
},
{
Expand All @@ -66,7 +66,7 @@
"source": [
"from autolabel import get_data\n",
"\n",
"get_data('banking')"
"get_data(\"banking\")"
]
},
{
Expand Down Expand Up @@ -123,7 +123,7 @@
"outputs": [],
"source": [
"# load the config\n",
"with open('config_banking.json', 'r') as f:\n",
"with open(\"config_banking.json\") as f:\n",
" config = json.load(f)"
]
},
Expand Down Expand Up @@ -477,6 +477,7 @@
"source": [
"# dry-run -- this tells us how much this will cost and shows an example prompt\n",
"from autolabel import AutolabelDataset\n",
"\n",
"ds = AutolabelDataset(\"data/banking/test.csv\", config=config)\n",
"agent.plan(ds)"
]
Expand Down Expand Up @@ -684,7 +685,7 @@
"outputs": [],
"source": [
"# Start computing confidence scores (using Refuel's LLMs)\n",
"os.environ['REFUEL_API_KEY'] = 'sk-xxxxxxxxxxxx'"
"os.environ[\"REFUEL_API_KEY\"] = \"sk-xxxxxxxxxxxx\""
]
},
{
Expand Down Expand Up @@ -849,6 +850,7 @@
],
"source": [
"from autolabel import AutolabelDataset\n",
"\n",
"ds = AutolabelDataset(\"data/banking/test.csv\", config=config)\n",
"agent.plan(ds)"
]
Expand Down
Loading

0 comments on commit f8e255b

Please sign in to comment.