Skip to content

Commit

Permalink
Add support for type statements.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Nov 21, 2024
1 parent da2ab03 commit d8a0ec5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "wanga"
version = "0.2.3"
version = "0.2.4"
description = "A library for interacting with Large Language Models."
authors = [{ name = "Artur Chakhvadze", email = "norpadon@gmail.com" }]
license = { text = "MIT" }
Expand Down
34 changes: 33 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import collections
import collections.abc
import inspect
import platform
import typing
from dataclasses import dataclass
from datetime import datetime, timedelta
from textwrap import dedent

from attrs import frozen
from pydantic import BaseModel

from wanga.schema.extractor import DEFAULT_SCHEMA_EXTRACTOR
from wanga.schema.jsonschema import JsonSchemaFlavor
from wanga.schema.normalize import normalize_annotation, unpack_optional
from wanga.schema.normalize import normalise_aliases, normalize_annotation, unpack_optional
from wanga.schema.schema import (
CallableSchema,
LiteralNode,
Expand Down Expand Up @@ -359,6 +361,36 @@ class Hoo(BaseModel):
assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(Hoo) == hoo_schema


def test_type_statement():
_, minor, _ = platform.python_version_tuple()
if int(minor) < 12:
return

expr = r"""
type A = int | float
assert normalise_aliases(A) == int | float # type: ignore
def foo() -> A: # type: ignore
pass
expected = UnionNode([PrimitiveNode(primitive_type=int), PrimitiveNode(primitive_type=float)])
assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(foo).return_schema == expected
def bar() -> list[A]: # type: ignore
pass
expected = SequenceNode(
sequence_type=list,
item_schema=UnionNode([PrimitiveNode(primitive_type=int), PrimitiveNode(primitive_type=float)]),
)
assert DEFAULT_SCHEMA_EXTRACTOR.extract_schema(bar).return_schema == expected
"""

expr = dedent(expr)
exec(expr)


def test_json_schema():
@frozen
class Inner:
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions wanga/schema/normalize.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import collections
import collections.abc
import platform
import typing # noqa: F401
from types import NoneType, UnionType
from typing import Annotated, Literal, Union, get_args, get_origin

if int(platform.python_version_tuple()[1]) >= 12:
from typing import TypeAliasType # type: ignore
else:

class TypeAliasType:
pass


from .utils import TypeAnnotation

__all__ = [
Expand All @@ -19,6 +28,12 @@ def _fold_or(annotations: collections.abc.Sequence[TypeAnnotation]) -> type[Unio
return result


def normalise_aliases(annotation: TypeAnnotation) -> TypeAnnotation:
if isinstance(annotation, TypeAliasType):
return annotation.__value__ # type: ignore
return annotation


def unpack_optional(annotation: TypeAnnotation) -> type[UnionType] | None:
r"""Unpack Optional[T] to its inner type T.
Expand Down Expand Up @@ -92,6 +107,7 @@ def normalize_literals(annotation: TypeAnnotation) -> TypeAnnotation:


def _normalize_annotation_rec(annotation: TypeAnnotation, concretize: bool = False) -> TypeAnnotation:
annotation = normalise_aliases(annotation)
origin = get_origin(annotation)
args = get_args(annotation)
if origin is None:
Expand Down

0 comments on commit d8a0ec5

Please sign in to comment.