Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic Data #9472

Merged
merged 38 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
18c4949
synthetic-data: adds simple synthetic data gen
paperMoose Aug 16, 2023
d70f573
synthetic-data: adds fc
paperMoose Aug 18, 2023
dc9870d
synthetic-data: updates
paperMoose Aug 18, 2023
bef7f04
synthetic-data: adds docstring useage examples, improves steerability…
paperMoose Aug 24, 2023
86430a4
synthetic-data: updates typing to use ChatOpenAI
paperMoose Aug 25, 2023
cc79c6f
Merge branch 'master' into synthetic-data
paperMoose Sep 8, 2023
bb9fc21
synthetic-data: linting and typing additions
paperMoose Sep 9, 2023
e8b5e7d
Merge branch 'synthetic-data' of github.com:paperMoose/langchain into…
paperMoose Sep 9, 2023
252e619
Merge branch 'master' into synthetic-data
paperMoose Sep 9, 2023
a3173b3
Merge branch 'master' into synthetic-data
paperMoose Sep 11, 2023
f263fc1
synthetic-data: fix
paperMoose Sep 11, 2023
ea04d8c
Merge branch 'synthetic-data' of github.com:paperMoose/langchain into…
paperMoose Sep 11, 2023
6df0644
Merge branch 'master' into synthetic-data
paperMoose Sep 11, 2023
f684240
Merge branch 'master' into synthetic-data
paperMoose Sep 12, 2023
2425a37
Merge branch 'master' into synthetic-data
paperMoose Sep 15, 2023
8a10cfa
synthetic-data: linting fixes
paperMoose Sep 15, 2023
87ba9cf
Merge branch 'master' into synthetic-data
paperMoose Sep 16, 2023
f170089
Merge branch 'master' into synthetic-data
paperMoose Sep 18, 2023
b42dc3c
Merge branch 'master' into synthetic-data
paperMoose Sep 18, 2023
ee0587d
synthetic-data: fixes linting issues
paperMoose Sep 27, 2023
7c09d42
Merge branch 'synthetic-data' of github.com:paperMoose/langchain into…
paperMoose Sep 27, 2023
b2ffc7b
synthetic-data: fix linting
paperMoose Sep 27, 2023
4643aac
synthetic-data: removes stray print statements
paperMoose Sep 27, 2023
1bc0c90
Merge branch 'master' into synthetic-data
paperMoose Sep 27, 2023
90663ff
synthetic-data: moves things to experimental
paperMoose Sep 27, 2023
7153006
synthetic-data: adds useage notebook
paperMoose Sep 27, 2023
306a19c
Merge branch 'master' into synthetic-data
paperMoose Sep 27, 2023
c4cdb5e
Merge branch 'master' into synthetic-data
paperMoose Sep 28, 2023
0b85d20
Merge branch 'master' into synthetic-data
paperMoose Sep 28, 2023
b9254fb
synthetic-data: fixes issue with import from main file
paperMoose Sep 28, 2023
23efd88
Merge branch 'synthetic-data' of github.com:paperMoose/langchain into…
paperMoose Sep 28, 2023
dcbf9f9
synthetic-data: address wills comments
paperMoose Sep 28, 2023
dba01d9
format
hinthornw Sep 28, 2023
1b50c2b
format
hinthornw Sep 28, 2023
2856237
?
hinthornw Sep 29, 2023
6fdf294
?
hinthornw Sep 29, 2023
4215215
fmt
baskaryan Sep 29, 2023
b7c86b3
Merge branch 'master' into synthetic-data
baskaryan Sep 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
137 changes: 137 additions & 0 deletions libs/langchain/langchain/chains/data_generation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import asyncio
from typing import Any, Dict, List, Optional, Union

from pydantic.class_validators import root_validator
from langchain.pydantic_v1 import BaseModel

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.schema.language_model import BaseLanguageModel


class SyntheticDataGenerator(BaseModel):
"""Generates synthetic data using the given LLM and few-shot template.

Utilizes the provided LLM to produce synthetic data based on the
few-shot prompt template.

Attributes:
template (FewShotPromptTemplate): Template for few-shot prompting.
llm (Optional[BaseLanguageModel]): Large Language Model to use for generation.
llm_chain (Optional[Chain]): LLM chain with the LLM and few-shot template.
example_input_key (str): Key to use for storing example inputs.

Usage Example:
>>> template = FewShotPromptTemplate(...)
>>> llm = BaseLanguageModel(...)
>>> generator = SyntheticDataGenerator(template=template, llm=llm)
>>> results = generator.generate(subject="climate change", runs=5)
"""

template: FewShotPromptTemplate
llm: Optional[BaseLanguageModel] = None
results: list = []
llm_chain: Optional[Chain] = None
example_input_key: str = "example"

class Config:
validate_assignment = True

@root_validator(pre=False, skip_on_failure=True)
def set_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
llm_chain = values.get("llm_chain")
llm = values.get("llm")
few_shot_template = values.get("template")

if not llm_chain: # If llm_chain is None or not present
if llm is None or few_shot_template is None:
raise ValueError(
"Both llm and few_shot_template must be provided if llm_chain is "
"not given."
)
values["llm_chain"] = LLMChain(llm=llm, prompt=few_shot_template)

return values

@staticmethod
def _format_dict_to_string(input_dict: Dict) -> str:
formatted_str = ", ".join(
[f"{key}: {value}" for key, value in input_dict.items()]
)
return formatted_str

def _update_examples(self, example: Union[BaseModel, Dict[str, Any], str]) -> None:
"""Prevents duplicates by adding previously generated examples to the few shot
list."""
if self.template and self.template.examples:
if isinstance(example, BaseModel):
formatted_example = self._format_dict_to_string(example.dict())
elif isinstance(example, dict):
formatted_example = self._format_dict_to_string(example)
else:
formatted_example = str(example)
self.template.examples.pop(0)
self.template.examples.append({self.example_input_key: formatted_example})

def generate(self, subject: str, runs: int, *args: Any, **kwargs: Any) -> List[str]:
"""Generate synthetic data using the given subject string.

Args:
subject (str): The subject the synthetic data will be about.
runs (int): Number of times to generate the data.
extra (str): Extra instructions for steerability in data generation.

Returns:
List[str]: List of generated synthetic data.

Usage Example:
>>> results = generator.generate(subject="climate change", runs=5,
extra="Focus on environmental impacts.")
"""
if self.llm_chain is None:
raise ValueError(
"llm_chain is none, either set either llm_chain or llm at generator "
"construction"
)
for _ in range(runs):
result = self.llm_chain.run(subject=subject, *args, **kwargs)
self.results.append(result)
self._update_examples(result)
return self.results

async def agenerate(
self, subject: str, runs: int, extra: str = "", *args: Any, **kwargs: Any
) -> List[str]:
"""Generate synthetic data using the given subject asynchronously.

Note: Since the LLM calls run concurrently,
you may have fewer duplicates by adding specific instructions to
the "extra" keyword argument.

Args:
subject (str): The subject the synthetic data will be about.
runs (int): Number of times to generate the data asynchronously.
extra (str): Extra instructions for steerability in data generation.

Returns:
List[str]: List of generated synthetic data for the given subject.

Usage Example:
>>> results = await generator.agenerate(subject="climate change", runs=5,
extra="Focus on env impacts.")
"""

async def run_chain(
subject: str, extra: str = "", *args: Any, **kwargs: Any
) -> None:
if self.llm_chain is not None:
result = await self.llm_chain.arun(
subject=subject, extra=extra, *args, **kwargs
)
self.results.append(result)

await asyncio.gather(
*(run_chain(subject=subject, extra=extra) for _ in range(runs))
)
return self.results
64 changes: 64 additions & 0 deletions libs/langchain/langchain/chains/data_generation/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, Dict, Optional, Type, Union

from langchain.pydantic_v1 import BaseModel

from langchain import BasePromptTemplate, PromptTemplate
from langchain.chains.data_generation.base import SyntheticDataGenerator
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseLLMOutputParser

OPENAI_TEMPLATE = PromptTemplate(input_variables=["example"], template="{example}")


def create_openai_data_generator(
output_schema: Union[Dict[str, Any], Type[BaseModel]],
llm: ChatOpenAI,
prompt: BasePromptTemplate,
output_parser: Optional[BaseLLMOutputParser] = None,
**kwargs: Any
) -> SyntheticDataGenerator:
"""
Create an instance of SyntheticDataGenerator tailored for OpenAI models.

This function creates an LLM chain designed for structured output based on the
provided schema, language model, and prompt template. The resulting chain is then
used to instantiate and return a SyntheticDataGenerator.

Args:
output_schema (Union[Dict[str, Any], Type[BaseModel]]): Schema for expected
output. This can be either a dictionary representing a valid JsonSchema or a
Pydantic BaseModel class.


llm (ChatOpenAI): OpenAI language model to use.

prompt (BasePromptTemplate): Template to be used for generating prompts.


output_parser (Optional[BaseLLMOutputParser], optional): Parser for
processing model outputs. If none is provided, a default will be inferred
from the function types.


**kwargs: Additional keyword arguments to be passed to
`create_structured_output_chain`.


Returns: SyntheticDataGenerator: An instance of the data generator set up with
the constructed chain.

Usage:
To generate synthetic data with a structured output, first define your desired
output schema. Then, use this function to create a SyntheticDataGenerator
instance. After obtaining the generator, you can utilize its methods to produce
the desired synthetic data.
"""
# Create function calling chain to ensure structured output
chain = create_structured_output_chain(
output_schema, llm, prompt, output_parser=output_parser, **kwargs
)

# Create the SyntheticDataGenerator instance with the created chain
generator = SyntheticDataGenerator(template=prompt, llm_chain=chain)
return generator
13 changes: 13 additions & 0 deletions libs/langchain/langchain/chains/data_generation/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from libs.langchain.langchain.prompts.prompt import PromptTemplate

DEFAULT_INPUT_KEY = "example"
DEFAULT_PROMPT = PromptTemplate(
input_variables=[DEFAULT_INPUT_KEY], template="{example}"
)

SYNTHETIC_FEW_SHOT_PREFIX = (
"This is a test about generating synthetic data about {subject}. Examples below:"
)
SYNTHETIC_FEW_SHOT_SUFFIX = (
"""Now you generate synthetic data about {subject}. Make sure to {extra}:"""
)
2 changes: 1 addition & 1 deletion libs/langchain/langchain/prompts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def jinja2_formatter(template: str, **kwargs: Any) -> str:
def validate_jinja2(template: str, input_variables: List[str]) -> None:
"""
Validate that the input variables are valid for the template.
Issues an warning if missing or extra variables are found.
Issues a warning if missing or extra variables are found.

Args:
template: The template string.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
from langchain.pydantic_v1 import BaseModel

from langchain import FewShotPromptTemplate
from langchain.chains.data_generation.base import SyntheticDataGenerator
from langchain.chains.data_generation.openai import (
OPENAI_TEMPLATE,
create_openai_data_generator,
)
from langchain.chains.data_generation.prompts import (
SYNTHETIC_FEW_SHOT_PREFIX,
SYNTHETIC_FEW_SHOT_SUFFIX,
)
from langchain.chat_models import ChatOpenAI


# Define the desired output schema for individual medical billing record
class MedicalBilling(BaseModel):
patient_id: int
patient_name: str
diagnosis_code: str
procedure_code: str
total_charge: float
insurance_claim_amount: float


examples = [
{
"example": """Patient ID: 123456, Patient Name: John Doe, Diagnosis Code:
J20.9, Procedure Code: 99203, Total Charge: $500, Insurance Claim Amount:
$350"""
},
{
"example": """Patient ID: 789012, Patient Name: Johnson Smith, Diagnosis
Code: M54.5, Procedure Code: 99213, Total Charge: $150, Insurance Claim
Amount: $120"""
},
{
"example": """Patient ID: 345678, Patient Name: Emily Stone, Diagnosis Code:
E11.9, Procedure Code: 99214, Total Charge: $300, Insurance Claim Amount:
$250"""
},
{
"example": """Patient ID: 901234, Patient Name: Robert Miles, Diagnosis Code:
B07.9, Procedure Code: 99204, Total Charge: $200, Insurance Claim Amount:
$160"""
},
{
"example": """Patient ID: 567890, Patient Name: Clara Jensen, Diagnosis Code:
F41.9, Procedure Code: 99205, Total Charge: $450, Insurance Claim Amount:
$310"""
},
{
"example": """Patient ID: 234567, Patient Name: Alan Turing, Diagnosis Code:
G40.909, Procedure Code: 99215, Total Charge: $220, Insurance Claim Amount:
$180"""
},
]

prompt_template = FewShotPromptTemplate(
prefix=SYNTHETIC_FEW_SHOT_PREFIX,
examples=examples,
suffix=SYNTHETIC_FEW_SHOT_SUFFIX,
input_variables=["subject", "extra"],
example_prompt=OPENAI_TEMPLATE,
)


@pytest.fixture(scope="function")
def synthetic_data_generator() -> SyntheticDataGenerator:
return create_openai_data_generator(
output_schema=MedicalBilling,
llm=ChatOpenAI(temperature=1), # replace with your LLM instance
prompt=prompt_template,
)


@pytest.mark.requires("openai")
def test_generate_synthetic(synthetic_data_generator: SyntheticDataGenerator) -> None:
synthetic_results = synthetic_data_generator.generate(
subject="medical_billing",
extra="""the name must be chosen at random. Make it something you wouldn't
normally choose.""",
runs=10,
)
assert len(synthetic_results) == 10
for row in synthetic_results:
assert isinstance(row, MedicalBilling)
print(synthetic_results)


@pytest.mark.requires("openai")
@pytest.mark.asyncio
async def test_agenerate_synthetic(
synthetic_data_generator: SyntheticDataGenerator,
) -> None:
synthetic_results = await synthetic_data_generator.agenerate(
subject="medical_billing",
extra="""the name must be chosen at random. Make it something you wouldn't
normally choose.""",
runs=10,
)
assert len(synthetic_results) == 10
for row in synthetic_results:
assert isinstance(row, MedicalBilling)
print(synthetic_results)
paperMoose marked this conversation as resolved.
Show resolved Hide resolved