diff --git a/pyproject.toml b/pyproject.toml index 88431e2..245fdac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wanga" -version = "0.2.3" +version = "0.2.4" description = "A library for interacting with Large Language Models." authors = [{ name = "Artur Chakhvadze", email = "norpadon@gmail.com" }] license = { text = "MIT" } diff --git a/tests/test_schema.py b/tests/test_schema.py index 4248371..25328c0 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,16 +1,18 @@ import collections import collections.abc import inspect +import platform import typing from dataclasses import dataclass from datetime import datetime, timedelta +from textwrap import dedent from attrs import frozen from pydantic import BaseModel from wanga.schema.extractor import DEFAULT_SCHEMA_EXTRACTOR from wanga.schema.jsonschema import JsonSchemaFlavor -from wanga.schema.normalize import normalize_annotation, unpack_optional +from wanga.schema.normalize import normalise_aliases, normalize_annotation, unpack_optional from wanga.schema.schema import ( CallableSchema, LiteralNode, @@ -359,6 +361,36 @@ class Hoo(BaseModel): assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(Hoo) == hoo_schema +def test_type_statement(): + _, minor, _ = platform.python_version_tuple() + if int(minor) < 12: + return + + expr = r""" + type A = int | float + + assert normalise_aliases(A) == int | float # type: ignore + + def foo() -> A: # type: ignore + pass + + expected = UnionNode([PrimitiveNode(primitive_type=int), PrimitiveNode(primitive_type=float)]) + assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(foo).return_schema == expected + + def bar() -> list[A]: # type: ignore + pass + + expected = SequenceNode( + sequence_type=list, + item_schema=UnionNode([PrimitiveNode(primitive_type=int), PrimitiveNode(primitive_type=float)]), + ) + assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(bar).return_schema == expected + """ + + expr = dedent(expr) + exec(expr) + + def test_json_schema(): @frozen class Inner: diff --git a/uv.lock b/uv.lock index c912db0..abbc1e7 100644 --- a/uv.lock +++ b/uv.lock @@ -1019,7 +1019,7 @@ wheels = [ [[package]] name = "wanga" -version = "0.1.0" +version = "0.2.3" source = { editable = "." } dependencies = [ { name = "attrs" }, diff --git a/wanga/schema/normalize.py b/wanga/schema/normalize.py index 55b41c1..ed28b35 100644 --- a/wanga/schema/normalize.py +++ b/wanga/schema/normalize.py @@ -1,9 +1,18 @@ import collections import collections.abc +import platform import typing # noqa: F401 from types import NoneType, UnionType from typing import Annotated, Literal, Union, get_args, get_origin +if int(platform.python_version_tuple()[1]) >= 12: + from typing import TypeAliasType # type: ignore +else: + + class TypeAliasType: + pass + + from .utils import TypeAnnotation __all__ = [ @@ -19,6 +28,12 @@ def _fold_or(annotations: collections.abc.Sequence[TypeAnnotation]) -> type[Unio return result +def normalise_aliases(annotation: TypeAnnotation) -> TypeAnnotation: + if isinstance(annotation, TypeAliasType): + return annotation.__value__ # type: ignore + return annotation + + def unpack_optional(annotation: TypeAnnotation) -> type[UnionType] | None: r"""Unpack Optional[T] to its inner type T. @@ -92,6 +107,7 @@ def normalize_literals(annotation: TypeAnnotation) -> TypeAnnotation: def _normalize_annotation_rec(annotation: TypeAnnotation, concretize: bool = False) -> TypeAnnotation: + annotation = normalise_aliases(annotation) origin = get_origin(annotation) args = get_args(annotation) if origin is None: