-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
400 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
question_id,question_text,language,option_a,option_a_correctness,option_b,option_b_correctness,option_c,option_c_correctness,correct_answer | ||
1528,"How many people worldwide have their basic needs met when it comes to food, water, toilets, electricity, schooling and healthcare?",en-US,Around 20%,Very Wrong,Around 50%,Wrong,Around 80%,Correct,Around 80% | ||
13,What share of the world’s population lives in megacities (cities with at least 10 million people)?,en-US,Around 8%,Correct,Around 28%,Wrong,Around 48%,Very Wrong,Around 8% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
custom_wrappers: | ||
model_config_wrapper: | ||
class: ./model_config_wrapper.ModelConfigWrapper | ||
config_cls: ./model_config_wrapper_config.ModelConfigWrapperConfig | ||
|
||
custom_variation_generators: | ||
model_config_generator: | ||
class: ./model_config_variation_generator.ModelConfigVariationGenerator | ||
config_cls: ./model_config_variation_generator_config.ModelConfigVariationGeneratorConfig | ||
|
||
# custom_reader: | ||
# question_reader: | ||
# class: ./question_reader.QuestionReader | ||
# config_cls: ./question_reader_config.QuestionReaderConfig | ||
|
||
custom_function: model_compare.model_compare | ||
|
||
dataset: | ||
source_type: dataset | ||
file_path: "data/questions.csv" | ||
reader: csv_reader | ||
reader_config: | ||
expected_result_column: correct_answer | ||
|
||
description: Configuration fo question answering with expected results. | ||
evaluators: | ||
# - evaluator_type: individual | ||
# matching_technique: includes | ||
# metric_calculators: | ||
# - method: AVERAGE | ||
# name: string_expected_result | ||
- evaluator_type: individual | ||
metric_calculators: | ||
- method: AVERAGE # TODO: figure out what this means and what are available options. | ||
name: openai_prompt_based_evaluator # TODO: not sure which openAI model is used. | ||
prompt: |- | ||
You are assessing a submitted answer on a given task based on a criterion. Here is the data: | ||
- Task: Given a multiple choice question, pick the correct answer. | ||
- Does the answer correct? | ||
[Input]: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Correctness of Options: | ||
{option_a}: {option_a_correctness} | ||
{option_b}: {option_b_correctness} | ||
{option_c}: {option_c_correctness} | ||
[Result]: {raw_output} | ||
Answer the question by selecting one of the following options: | ||
A It doesn't mention what is the answer to the multiple choice question at all. | ||
B It mentions the answer to the multiple choice question, but it's not in the options list. (Not in option list) | ||
C It mentions the answer to the multiple choice question, but the answer is far away from the correct answer. (Very Wrong) | ||
D It mentions the answer to the multiple choice question, but the answer is not correct, though close to the correct one. (Wrong) | ||
E It mentions the answer to the multiple choice question, and the answer is correct. (Correct) | ||
display_name: correctness | ||
choices: ["A", "B", "C", "D", "E"] | ||
description: Does the answer correct? | ||
scale_description: "0-4" | ||
choice_scores: | ||
A: 0 | ||
B: 1 | ||
C: 2 | ||
D: 3 | ||
E: 4 | ||
|
||
|
||
variations: | ||
- name: model_config | ||
generator_name: model_config_generator | ||
generator_config: | ||
models: | ||
- model_name: gpt-3.5-turbo | ||
params: | ||
temperature: 0 | ||
- model_name: gpt-3.5-turbo | ||
params: | ||
temperature: 1 | ||
- name: prompt_template | ||
variations: | ||
- instantiated_value: | | ||
Answer following multiple choices question: | ||
Question: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Answer: | ||
value: | | ||
Answer following multiple choices question: | ||
Question: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Answer: | ||
value_type: str | ||
variation_id: instruct_question | ||
- instantiated_value: | | ||
Question: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Answer: | ||
value: | | ||
Question: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Answer: | ||
value_type: str | ||
variation_id: simple | ||
human_rating_configs: | ||
- name: correctness | ||
instructions: Rate whether the answer clearly state what the correct answer is | ||
scale: [1, 5] | ||
|
||
- name: coherence | ||
instructions: Rate whether the answer and explanation are coherent | ||
scale: [1, 5] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pandas as pd | ||
from lib.pilot.helpers import read_ai_eval_spreadsheet, get_questions | ||
|
||
|
||
correctness_map = {1: "Correct", 2: "Wrong", 3: "Very Wrong"} | ||
|
||
|
||
def main(): | ||
sheet = read_ai_eval_spreadsheet() | ||
questions = get_questions(sheet) | ||
|
||
output_list = [] | ||
|
||
for q, opts in questions: | ||
output_item = { | ||
"question_id": q.question_id, | ||
"question_text": q.published_version_of_question, | ||
"language": q.language, | ||
} | ||
|
||
for opt in opts: | ||
letter = opt.letter.lower() | ||
output_item[f"option_{letter}"] = opt.question_option | ||
output_item[f"option_{letter}_correctness"] = correctness_map[ | ||
opt.correctness_of_answer_option | ||
] | ||
if opt.correctness_of_answer_option == 1: | ||
output_item["correct_answer"] = opt.question_option | ||
|
||
output_list.append(output_item) | ||
|
||
output_df = pd.DataFrame.from_records(output_list) | ||
output_df.to_csv("data/questions.csv", index=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from yival.common.model_utils import llm_completion | ||
from yival.logger.token_logger import TokenLogger | ||
from yival.schemas.experiment_config import MultimodalOutput | ||
from yival.schemas.model_configs import Request | ||
from yival.states.experiment_state import ExperimentState | ||
from yival.wrappers.string_wrapper import StringWrapper | ||
from model_config_wrapper import ModelConfigWrapper | ||
|
||
|
||
default_model_config = dict(model_name="gpt-3.5-turbo", params={"temperature": 0.5}) | ||
|
||
|
||
def model_compare( | ||
question_id: str, | ||
question_text: str, | ||
language: str, | ||
option_a: str, | ||
option_a_correctness: str, | ||
option_b: str, | ||
option_b_correctness: str, | ||
option_c: str, | ||
option_c_correctness: str, | ||
state: ExperimentState, | ||
) -> MultimodalOutput: | ||
logger = TokenLogger() | ||
logger.reset() | ||
|
||
model = ModelConfigWrapper( | ||
default_model_config, name="model_config", state=state | ||
).get_value() | ||
|
||
prompt_template_default = """Answer following multiple choices question: | ||
Question: {question_text} | ||
A. {option_a} | ||
B. {option_b} | ||
C. {option_c} | ||
Answer:""" | ||
# TODO: there might be better way to handle variables in prompt variations. | ||
prompt_template = str(StringWrapper("", name="prompt_template", state=state)) | ||
if prompt_template == "": | ||
prompt_template = prompt_template_default | ||
|
||
prompt = prompt_template.format( | ||
question_text=question_text, | ||
option_a=option_a, | ||
option_b=option_b, | ||
option_c=option_c, | ||
) | ||
response = llm_completion( | ||
Request(model_name=model["model_name"], prompt=prompt, params=model["params"]) | ||
).output | ||
# NOTE: we can use template in StringWrapper. | ||
# str( | ||
# StringWrapper( | ||
# template=""" | ||
# Generate a landing page headline for {tech_startup_business} | ||
# """, | ||
# variables={ | ||
# "tech_startup_business": tech_startup_business, | ||
# }, | ||
# name="task" | ||
# ) | ||
# ) | ||
|
||
res = MultimodalOutput( | ||
text_output=response["choices"][0]["message"]["content"], | ||
) | ||
token_usage = response["usage"]["total_tokens"] | ||
logger.log(token_usage) | ||
return res | ||
|
||
|
||
def main(): | ||
q = "How many people worldwide have their basic needs met when it comes to food, " | ||
"water, toilets, electricity, schooling and healthcare?" | ||
print( | ||
model_compare( | ||
"1", | ||
q, | ||
"en_US", | ||
"Around 20%", | ||
3, | ||
"Around 50%", | ||
2, | ||
"Around 80%", | ||
1, | ||
ExperimentState(), | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from dataclasses import asdict, dataclass, field | ||
from typing import Any, Iterator, List, Optional | ||
|
||
# from yival.schemas.experiment_config import WrapperVariation | ||
# ^ this is not working for dict so I write my own version | ||
|
||
from yival.variation_generators.base_variation_generator import BaseVariationGenerator | ||
|
||
from model_config_variation_generator_config import ModelConfigVariationGeneratorConfig | ||
|
||
|
||
@dataclass | ||
class WrapperVariation: | ||
""" | ||
Represents a variation within a wrapper. | ||
The value can be any type, but typical usages might include strings, | ||
numbers, configuration dictionaries, or even custom class configurations. | ||
""" | ||
|
||
value_type: str # e.g., "string", "int", "float", "ClassA", ... | ||
value: Any # The actual value or parameters to initialize a value | ||
instantiated_value: Any = field(init=False) | ||
variation_id: Optional[str] = None | ||
|
||
def asdict(self): | ||
return asdict(self) | ||
|
||
def __post_init__(self): | ||
self.instantiated_value = self.instantiate() | ||
|
||
def instantiate(self) -> Any: | ||
""" | ||
Returns an instantiated value based on value_type and params. | ||
""" | ||
return self.value | ||
|
||
|
||
class ModelConfigVariationGenerator(BaseVariationGenerator): | ||
def __init__(self, config: ModelConfigVariationGeneratorConfig): | ||
super().__init__(config) | ||
self.config = config | ||
|
||
def generate_variations(self) -> Iterator[List[WrapperVariation]]: | ||
if not self.config.models: | ||
yield [] | ||
else: | ||
variations = [ | ||
WrapperVariation(value_type="dict", value=var) | ||
for var in self.config.models | ||
] | ||
yield variations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dataclasses import dataclass | ||
from yival.schemas.varation_generator_configs import BaseVariationGeneratorConfig | ||
|
||
from typing import Optional, List, Dict, Any | ||
|
||
|
||
@dataclass | ||
class ModelConfigVariationGeneratorConfig(BaseVariationGeneratorConfig): | ||
models: Optional[List[Dict[str, Any]]] = None # List of variations to generate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import Optional, Dict, Any | ||
|
||
from model_config_wrapper_config import ModelConfigWrapperConfig | ||
from yival.wrappers.base_wrapper import BaseWrapper | ||
from yival.experiment.experiment_runner import ExperimentState | ||
|
||
|
||
class ModelConfigWrapper(BaseWrapper): | ||
""" | ||
A wrapper for model configuration. | ||
Configuration is a dictionary contains 2 keys: | ||
- model_name: the name of model, which is a string | ||
- params: the configuration of model, which is a dictionary | ||
""" | ||
|
||
default_config = ModelConfigWrapperConfig() | ||
|
||
def __init__( | ||
self, | ||
value: Dict[str, Any], | ||
name: str, | ||
config: Optional[ModelConfigWrapperConfig] = None, | ||
state: Optional[ExperimentState] = None, | ||
) -> None: | ||
super().__init__(name, config, state) | ||
self._value = value | ||
|
||
def get_value(self) -> Dict[str, Any]: | ||
variation = self.get_variation() | ||
if variation is not None: | ||
return variation | ||
return self._value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from dataclasses import dataclass | ||
|
||
from yival.schemas.wrapper_configs import BaseWrapperConfig | ||
|
||
|
||
@dataclass | ||
class ModelConfigWrapperConfig(BaseWrapperConfig): | ||
""" | ||
Configuration specific to the ModelConfigWrapper. | ||
""" | ||
|
||
pass |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import pandas as pd | ||
import pickle | ||
|
||
from yival.experiment.experiment_runner import Experiment | ||
|
||
|
||
fp = "_0.pkl" | ||
|
||
data: Experiment = pickle.load(open(fp, "rb")) | ||
|
||
# data.group_experiment_results | ||
# result = data.group_experiment_results[0] | ||
# rs = result.experiment_results | ||
# rs[1].asdict() | ||
|
||
output_list = [] | ||
|
||
for group_results in data.group_experiment_results: | ||
for result in group_results.experiment_results: | ||
result_dict = dict( | ||
combination=str(result.combination).replace("'", ""), | ||
question=result.input_data.content["question_text"], | ||
raw_output=result.raw_output.text_output, | ||
) | ||
for eval_output in result.evaluator_outputs: | ||
result_dict[eval_output.display_name] = eval_output.result | ||
|
||
output_list.append(result_dict) | ||
|
||
|
||
output_df = pd.DataFrame.from_records(output_list) | ||
|
||
output_df.to_csv("./results.csv", index=False) |
Oops, something went wrong.