Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first pass at removing deprecated usaged #751

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ async def _unpack_request_config(
config_dicts = []
for config in client_sent_configs:
if isinstance(config, str):
config_dicts.append(model(**_config_from_hash(config)).dict())
config_dicts.append(model(**_config_from_hash(config)).model_dump())
elif isinstance(config, BaseModel):
config_dicts.append(config.dict())
config_dicts.append(config.model_dump())
elif isinstance(config, Mapping):
config_dicts.append(model(**config).dict())
config_dicts.append(model(**config).model_dump())
else:
raise TypeError(f"Expected a string, dict or BaseModel got {type(config)}")
config = merge_configs(*config_dicts)
Expand Down Expand Up @@ -298,7 +298,7 @@ def _unpack_input(validated_model: BaseModel) -> Any:
# This logic should be applied recursively to nested models.
return {
fieldname: _unpack_input(getattr(model, fieldname))
for fieldname in model.__fields__.keys()
for fieldname in model.model_fields.keys()
}

return model
Expand Down Expand Up @@ -330,6 +330,11 @@ def _replace_non_alphanumeric_with_underscores(s: str) -> str:
return re.sub(r"[^a-zA-Z0-9]", "_", s)


def _schema_json(model: Type[BaseModel]) -> str:
"""Return the JSON representation of the model schema."""
return json.dumps(model.model_json_schema(), sort_keys=True, indent=False)


def _resolve_model(
type_: Union[Type, BaseModel], default_name: str, namespace: str
) -> Type[BaseModel]:
Expand All @@ -339,13 +344,13 @@ def _resolve_model(
else:
model = _create_root_model(default_name, type_)

hash_ = model.schema_json()
hash_ = _schema_json(model)

if model.__name__ in _SEEN_NAMES and hash_ not in _MODEL_REGISTRY:
# If the model name has been seen before, but the model itself is different
# generate a new name for the model.
model_to_use = _rename_pydantic_model(model, namespace)
hash_ = model_to_use.schema_json()
hash_ = _schema_json(model_to_use)
else:
model_to_use = model

Expand Down Expand Up @@ -755,7 +760,7 @@ async def _get_config_and_input(
except json.JSONDecodeError:
raise RequestValidationError(errors=["Invalid JSON body"])
try:
body = InvokeRequestShallowValidator.validate(body)
body = InvokeRequestShallowValidator.model_validate(body)

# Merge the config from the path with the config from the body.
user_provided_config = await _unpack_request_config(
Expand Down Expand Up @@ -1407,7 +1412,7 @@ async def input_schema(
self._run_name, user_provided_config, request
)

return self._runnable.get_input_schema(config).schema()
return self._runnable.get_input_schema(config).model_json_schema()

async def output_schema(
self,
Expand All @@ -1434,7 +1439,7 @@ async def output_schema(
config = _update_config_with_defaults(
self._run_name, user_provided_config, request
)
return self._runnable.get_output_schema(config).schema()
return self._runnable.get_output_schema(config).model_json_schema()

async def config_schema(
self,
Expand Down Expand Up @@ -1464,7 +1469,7 @@ async def config_schema(
return (
self._runnable.with_config(config)
.config_schema(include=self._config_keys)
.schema()
.model_json_schema()
)

async def playground(
Expand Down
8 changes: 5 additions & 3 deletions langserve/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,12 @@ async def serve_playground(
if base_url.startswith("/")
else base_url,
LANGSERVE_CONFIG_SCHEMA=json.dumps(
runnable.config_schema(include=config_keys).schema()
runnable.config_schema(include=config_keys).model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.model_json_schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(
output_schema.model_json_schema()
),
LANGSERVE_INPUT_SCHEMA=json.dumps(input_schema.schema()),
LANGSERVE_OUTPUT_SCHEMA=json.dumps(output_schema.schema()),
LANGSERVE_FEEDBACK_ENABLED=json.dumps(
"true" if feedback_enabled else "false"
),
Expand Down
12 changes: 6 additions & 6 deletions langserve/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _log_error_message_once(error_message: str) -> None:
def default(obj) -> Any:
"""Default serialization for well known objects."""
if isinstance(obj, BaseModel):
return obj.dict()
return obj.model_dump()
return super().default(obj)


Expand All @@ -96,7 +96,7 @@ def _decode_lc_objects(value: Any) -> Any:
try:
obj = WellKnownLCObject.model_validate(v)
parsed = obj.root
if set(parsed.dict()) != set(value):
if set(parsed.model_dump()) != set(value):
raise ValueError("Invalid object")
return parsed
except (ValidationError, ValueError, TypeError):
Expand All @@ -121,11 +121,11 @@ def _decode_event_data(value: Any) -> Any:
"""Decode the event data from a JSON object representation."""
if isinstance(value, dict):
try:
obj = CallbackEvent.parse_obj(value)
obj = CallbackEvent.model_validate(value)
return obj.root
except ValidationError:
try:
obj = WellKnownLCObject.parse_obj(value)
obj = WellKnownLCObject.model_validate(value)
return obj.root
except ValidationError:
return {key: _decode_event_data(v) for key, v in value.items()}
Expand Down Expand Up @@ -176,7 +176,7 @@ def loads(self, s: bytes) -> Any:

def _project_top_level(model: BaseModel) -> Dict[str, Any]:
"""Project the top level of the model as dict."""
return {key: getattr(model, key) for key in model.__fields__}
return {key: getattr(model, key) for key in model.model_fields}


def load_events(events: Any) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -207,7 +207,7 @@ def load_events(events: Any) -> List[Dict[str, Any]]:

# Then validate the event
try:
full_event = CallbackEvent.parse_obj(decoded_event_data)
full_event = CallbackEvent.model_validate(decoded_event_data)
except ValidationError as e:
msg = f"Encountered an invalid event: {e}"
if "type" in decoded_event_data:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def test_document_serialization() -> None:
"""Simple test. Exhaustive tests follow below."""
doc = Document(page_content="hello")
d = doc.dict()
d = doc.model_dump()
WellKnownLCObject.model_validate(d)


Expand Down Expand Up @@ -87,7 +87,7 @@ def _get_full_representation(data: Any) -> Any:
elif isinstance(data, list):
return [_get_full_representation(value) for value in data]
elif isinstance(data, BaseModel):
return data.schema()
return data.model_json_schema()
else:
return data

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ async def test_invoke_request_with_runnables() -> None:
assert request.config.tags == ["hello"]
assert request.config.run_name == "run"
assert isinstance(request.config.configurable, BaseModel)
assert request.config.configurable.dict() == {
assert request.config.configurable.model_dump() == {
"template": "goodbye {name}",
}

Expand Down
Loading