Skip to content

Commit

Permalink
Feature extraction - wrapper around schema.apply_udf (#198)
Browse files Browse the repository at this point in the history
* extract feature
* raise invalid data type error

---------

Co-authored-by: felipe207 <felipe@whylabs.ai>
Co-authored-by: Jamie Broomall <88007022+jamie256@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 27, 2023
1 parent 80a7ca7 commit 23497fa
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
4 changes: 2 additions & 2 deletions langkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from typing import Dict, List

from .extract import extract
import importlib.resources as resources


Expand Down Expand Up @@ -67,4 +67,4 @@ def package_version(package: str = __package__) -> str:

__version__ = package_version()

__ALL__ = [__version__, LangKitConfig]
__ALL__ = [__version__, LangKitConfig, extract]
20 changes: 20 additions & 0 deletions langkit/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pandas as pd
from typing import Any, Dict, Optional, Union
from whylogs.experimental.core.udf_schema import udf_schema, UdfSchema


def extract(
data: Union[pd.DataFrame, Dict[str, Any]],
schema: Optional[UdfSchema] = None,
):
if schema is None:
schema = udf_schema()
if isinstance(data, pd.DataFrame):
df_enhanced, _ = schema.apply_udfs(pandas=data)
return df_enhanced
elif isinstance(data, dict):
_, row_enhanced = schema.apply_udfs(row=data)
return row_enhanced
raise ValueError(
f"Extract: data of type {type(data)} is invalid: supported input types are pandas dataframe or dictionary"
)
7 changes: 6 additions & 1 deletion langkit/response_hallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,12 @@ def init(llm: LLMInvocationParams, num_samples=1):
def response_hallucination(text):
series_result = []
for prompt, response in zip(text[_prompt], text[_response]):
result: ConsistencyResult = checker.consistency_check(prompt, response)
if checker is not None:
result: ConsistencyResult = checker.consistency_check(prompt, response)
else:
raise Exception(
"Response Hallucination: you need to call init() before using this function"
)
series_result.append(result.final_score)
return series_result

Expand Down
48 changes: 48 additions & 0 deletions langkit/tests/test_extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import langkit
import pandas as pd
from whylogs.experimental.core.udf_schema import UdfSchema, UdfSpec


def test_extract_pandas():
from langkit import textstat

textstat.init()
df = pd.DataFrame({"prompt": ["I love you", "I hate you"]})
enhanced_df = langkit.extract(data=df)
assert "prompt.flesch_reading_ease" in enhanced_df.columns


def test_extract_row():
from langkit import regexes

regexes.init()
row = {"prompt": "I love you", "response": "address: 123 Main St."}
enhanced_row = langkit.extract(data=row)
assert enhanced_row.get("response.has_patterns") == "mailing address"
assert not enhanced_row.get("prompt.has_patterns")


def test_extract_light_metrics():
from langkit import light_metrics

light_metrics.init()

row = {"prompt": "I love you", "response": "address: 123 Main St."}
enhanced_row = langkit.extract(row)
assert enhanced_row.get("response.has_patterns") == "mailing address"
assert not enhanced_row.get("prompt.has_patterns")
assert "prompt.flesch_reading_ease" in enhanced_row.keys()


def test_extract_with_custom_schema():
schema = UdfSchema(
udf_specs=[
UdfSpec(
column_names=["prompt"],
udfs={"prompt.customfeature": lambda x: x["prompt"]},
)
],
)
row = {"prompt": "I love you", "response": "address: 123 Main St."}
enhanced_row = langkit.extract(row, schema=schema)
assert enhanced_row.get("prompt.customfeature") == "I love you"

0 comments on commit 23497fa

Please sign in to comment.