Skip to content

Commit

Permalink
feat(models): add to_dict & to_json helper methods (#1305)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-bot committed Apr 9, 2024
1 parent 69cdfc3 commit 40a881d
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 14 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ We recommend that you always instantiate a client (e.g., with `client = OpenAI()

## Using types

Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev), which provide helper methods for things like:
Nested request parameters are [TypedDicts](https://docs.python.org/3/library/typing.html#typing.TypedDict). Responses are [Pydantic models](https://docs.pydantic.dev) which also provide helper methods for things like:

- Serializing back into JSON, `model.model_dump_json(indent=2, exclude_unset=True)`
- Converting to a dictionary, `model.model_dump(exclude_unset=True)`
- Serializing back into JSON, `model.to_json()`
- Converting to a dictionary, `model.to_dict()`

Typed requests and responses provide autocomplete and documentation within your editor. If you would like to see type errors in VS Code to help catch bugs earlier, set `python.analysis.typeCheckingMode` to `basic`.

Expand Down Expand Up @@ -594,7 +594,7 @@ completion = client.chat.completions.create(
},
],
)
print(completion.model_dump_json(indent=2))
print(completion.to_json())
```

In addition to the options provided in the base `OpenAI` client, the following options are provided:
Expand Down
4 changes: 2 additions & 2 deletions examples/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
],
)
print(completion.model_dump_json(indent=2))
print(completion.to_json())


deployment_client = AzureOpenAI(
Expand All @@ -40,4 +40,4 @@
},
],
)
print(completion.model_dump_json(indent=2))
print(completion.to_json())
2 changes: 1 addition & 1 deletion examples/azure_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@
},
],
)
print(completion.model_dump_json(indent=2))
print(completion.to_json())
8 changes: 4 additions & 4 deletions examples/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ def sync_main() -> None:

# You can manually control iteration over the response
first = next(response)
print(f"got response data: {first.model_dump_json(indent=2)}")
print(f"got response data: {first.to_json()}")

# Or you could automatically iterate through all of data.
# Note that the for loop will not exit until *all* of the data has been processed.
for data in response:
print(data.model_dump_json())
print(data.to_json())


async def async_main() -> None:
Expand All @@ -43,12 +43,12 @@ async def async_main() -> None:
# You can manually control iteration over the response.
# In Python 3.10+ you can also use the `await anext(response)` builtin instead
first = await response.__anext__()
print(f"got response data: {first.model_dump_json(indent=2)}")
print(f"got response data: {first.to_json()}")

# Or you could automatically iterate through all of data.
# Note that the for loop will not exit until *all* of the data has been processed.
async for data in response:
print(data.model_dump_json())
print(data.to_json())


sync_main()
Expand Down
73 changes: 73 additions & 0 deletions src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,79 @@ def model_fields_set(self) -> set[str]:
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
extra: Any = pydantic.Extra.allow # type: ignore

def to_dict(
self,
*,
mode: Literal["json", "python"] = "python",
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> dict[str, object]:
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
mode:
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
exclude_none: Whether to exclude fields that have a value of `None` from the output.
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
"""
return self.model_dump(
mode=mode,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)

def to_json(
self,
*,
indent: int | None = 2,
use_api_names: bool = True,
exclude_unset: bool = True,
exclude_defaults: bool = False,
exclude_none: bool = False,
warnings: bool = True,
) -> str:
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
By default, fields that were not set by the API will not be included,
and keys will match the API response, *not* the property names from the model.
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
Args:
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
exclude_unset: Whether to exclude fields that have not been explicitly set.
exclude_defaults: Whether to exclude fields that have the default value.
exclude_none: Whether to exclude fields that have a value of `None`.
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
"""
return self.model_dump_json(
indent=indent,
by_alias=use_api_names,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
warnings=warnings,
)

@override
def __str__(self) -> str:
# mypy complains about an invalid self arg
Expand Down
10 changes: 7 additions & 3 deletions src/openai/lib/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,9 +678,11 @@ def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_ac
df_train = df.sample(n=n_train, random_state=42)
df_valid = df.drop(df_train.index)
df_train[["prompt", "completion"]].to_json( # type: ignore
fnames[0], lines=True, orient="records", force_ascii=False
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)
df_valid[["prompt", "completion"]].to_json(
fnames[1], lines=True, orient="records", force_ascii=False, indent=None
)
df_valid[["prompt", "completion"]].to_json(fnames[1], lines=True, orient="records", force_ascii=False)

n_classes, pos_class = get_classification_hyperparams(df)
additional_params += " --compute_classification_metrics"
Expand All @@ -690,7 +692,9 @@ def write_out_file(df: pd.DataFrame, fname: str, any_remediations: bool, auto_ac
additional_params += f" --classification_n_classes {n_classes}"
else:
assert len(fnames) == 1
df[["prompt", "completion"]].to_json(fnames[0], lines=True, orient="records", force_ascii=False)
df[["prompt", "completion"]].to_json(
fnames[0], lines=True, orient="records", force_ascii=False, indent=None
)

# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames))
Expand Down
64 changes: 64 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,42 @@ class Model(BaseModel):
assert "resource_id" in m.model_fields_set


def test_to_dict() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)

m = Model(FOO="hello")
assert m.to_dict() == {"FOO": "hello"}
assert m.to_dict(use_api_names=False) == {"foo": "hello"}

m2 = Model()
assert m2.to_dict() == {}
assert m2.to_dict(exclude_unset=False) == {"FOO": None}
assert m2.to_dict(exclude_unset=False, exclude_none=True) == {}
assert m2.to_dict(exclude_unset=False, exclude_defaults=True) == {}

m3 = Model(FOO=None)
assert m3.to_dict() == {"FOO": None}
assert m3.to_dict(exclude_none=True) == {}
assert m3.to_dict(exclude_defaults=True) == {}

if PYDANTIC_V2:

class Model2(BaseModel):
created_at: datetime

time_str = "2024-03-21T11:39:01.275859"
m4 = Model2.construct(created_at=time_str)
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
assert m4.to_dict(mode="json") == {"created_at": time_str}
else:
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
m.to_dict(mode="json")

with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_dict(warnings=False)


def test_forwards_compat_model_dump_method() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down Expand Up @@ -532,6 +568,34 @@ class Model(BaseModel):
m.model_dump(warnings=False)


def test_to_json() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)

m = Model(FOO="hello")
assert json.loads(m.to_json()) == {"FOO": "hello"}
assert json.loads(m.to_json(use_api_names=False)) == {"foo": "hello"}

if PYDANTIC_V2:
assert m.to_json(indent=None) == '{"FOO":"hello"}'
else:
assert m.to_json(indent=None) == '{"FOO": "hello"}'

m2 = Model()
assert json.loads(m2.to_json()) == {}
assert json.loads(m2.to_json(exclude_unset=False)) == {"FOO": None}
assert json.loads(m2.to_json(exclude_unset=False, exclude_none=True)) == {}
assert json.loads(m2.to_json(exclude_unset=False, exclude_defaults=True)) == {}

m3 = Model(FOO=None)
assert json.loads(m3.to_json()) == {"FOO": None}
assert json.loads(m3.to_json(exclude_none=True)) == {}

if not PYDANTIC_V2:
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
m.to_json(warnings=False)


def test_forwards_compat_model_dump_json_method() -> None:
class Model(BaseModel):
foo: Optional[str] = Field(alias="FOO", default=None)
Expand Down

0 comments on commit 40a881d

Please sign in to comment.