From 1782dcf023c7c14f095ba67da1f3577c2c756f06 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 7 Dec 2024 18:52:22 -0800 Subject: [PATCH] Adds typeddict tests --- examples/debug/typed_dict.ipynb | 356 --------------------- hamilton/function_modifiers/expanders.py | 2 +- tests/function_modifiers/test_expanders.py | 35 +- 3 files changed, 34 insertions(+), 359 deletions(-) delete mode 100644 examples/debug/typed_dict.ipynb diff --git a/examples/debug/typed_dict.ipynb b/examples/debug/typed_dict.ipynb deleted file mode 100644 index 277b829cf..000000000 --- a/examples/debug/typed_dict.ipynb +++ /dev/null @@ -1,356 +0,0 @@ -{ - "cells": [ - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "# Execute this cell to install dependencies\n", - "%pip install sf-hamilton[visualization]" - ], - "id": "7d9d428b3d95fd85" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "# TODO:fix [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/hello_world/typed_dict.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/hello_world/typed_dict.ipynb)\n", - "id": "52a18896ee3cdad6" - }, - { - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-12-07T18:42:00.062097Z", - "start_time": "2024-12-07T18:41:51.024139Z" - } - }, - "cell_type": "code", - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", - " warnings.warn(\n" - ] - } - ], - "execution_count": 1, - "source": [ - "from typing_extensions import is_typeddict\n", - "%load_ext hamilton.plugins.jupyter_magic" - ], - "id": "initial_id" - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:42:00.479296Z", - "start_time": "2024-12-07T18:42:00.075584Z" - } - }, - "cell_type": "code", - "source": [ - "%%cell_to_module ef --display\n", - "\n", - "from typing import TypedDict\n", - "from hamilton.function_modifiers import extract_fields\n", - "\n", - "class MyDict(TypedDict):\n", - " foo: str\n", - " bar: int\n", - "\n", - "@extract_fields()\n", - "def some_function()->MyDict:\n", - " return MyDict(foo=\"s\", bar=1)" - ], - "id": "5dbd4e574a6eb1d6", - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nfoo\n\nfoo\nstr\n\n\n\nsome_function\n\nsome_function\nMyDict\n\n\n\nsome_function->foo\n\n\n\n\n\nbar\n\nbar\nint\n\n\n\nsome_function->bar\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 2 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:42:00.792277Z", - "start_time": "2024-12-07T18:42:00.756575Z" - } - }, - "cell_type": "code", - "source": [ - "from hamilton import driver\n", - "\n", - "dr = driver.Builder().with_modules(ef).build()\n", - "dr.execute([\"foo\", \"bar\"])" - ], - "id": "b9c896af26148b65", - "outputs": [ - { - "data": { - "text/plain": [ - "{'foo': 's', 'bar': 1}" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 3 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:42:01.256392Z", - "start_time": "2024-12-07T18:42:00.868513Z" - } - }, - "cell_type": "code", - "source": [ - "%%cell_to_module ef2 --display\n", - "\n", - "from typing import TypedDict\n", - "from hamilton.function_modifiers import extract_fields\n", - "\n", - "class MyDict(TypedDict):\n", - " foo: str\n", - " bar: int\n", - "\n", - "@extract_fields({\"foo\": str})\n", - "def some_function()->MyDict:\n", - " return MyDict(foo=\"s\", bar=1)" - ], - "id": "49b1af6cfc233929", - "outputs": [ - { - "data": { - "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nfoo\n\nfoo\nstr\n\n\n\nsome_function\n\nsome_function\nMyDict\n\n\n\nsome_function->foo\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "execution_count": 4 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:42:01.318133Z", - "start_time": "2024-12-07T18:42:01.283203Z" - } - }, - "cell_type": "code", - "source": [ - "from hamilton import driver\n", - "\n", - "dr = driver.Builder().with_modules(ef2).build()\n", - "dr.execute([\"foo\"])" - ], - "id": "85db6edbbeb40528", - "outputs": [ - { - "data": { - "text/plain": [ - "{'foo': 's'}" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 5 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:42:03.211760Z", - "start_time": "2024-12-07T18:42:01.458736Z" - } - }, - "cell_type": "code", - "source": [ - "%%cell_to_module ef3 --display\n", - "\n", - "from typing import TypedDict\n", - "from hamilton.function_modifiers import extract_fields\n", - "\n", - "class MyDict(TypedDict):\n", - " foo: str\n", - " bar: int\n", - "\n", - "@extract_fields({\"foo\": int})\n", - "def some_function()->MyDict:\n", - " return MyDict(foo=\"s\", bar=1)" - ], - "id": "4e68a45d22decb1a", - "outputs": [ - { - "ename": "InvalidDecoratorException", - "evalue": "Error {'foo': } do not match TypedDict annotation's fields {'foo': , 'bar': }.", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mInvalidDecoratorException\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[6], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mget_ipython\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_cell_magic\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mcell_to_module\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mef3 --display\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mfrom typing import TypedDict\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mfrom hamilton.function_modifiers import extract_fields\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mclass MyDict(TypedDict):\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m foo: str\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m bar: int\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m@extract_fields(\u001B[39;49m\u001B[38;5;124;43m{\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mfoo\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m: int})\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mdef some_function()->MyDict:\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m return MyDict(foo=\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43ms\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m, bar=1)\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2478\u001B[0m, in \u001B[0;36mInteractiveShell.run_cell_magic\u001B[0;34m(self, magic_name, line, cell)\u001B[0m\n\u001B[1;32m 2476\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuiltin_trap:\n\u001B[1;32m 2477\u001B[0m args \u001B[38;5;241m=\u001B[39m (magic_arg_s, cell)\n\u001B[0;32m-> 2478\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[43mfn\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2480\u001B[0m \u001B[38;5;66;03m# The code below prevents the output from being displayed\u001B[39;00m\n\u001B[1;32m 2481\u001B[0m \u001B[38;5;66;03m# when using magics with decodator @output_can_be_silenced\u001B[39;00m\n\u001B[1;32m 2482\u001B[0m \u001B[38;5;66;03m# when the last Python token in the expression is a ';'.\u001B[39;00m\n\u001B[1;32m 2483\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mgetattr\u001B[39m(fn, magic\u001B[38;5;241m.\u001B[39mMAGIC_OUTPUT_CAN_BE_SILENCED, \u001B[38;5;28;01mFalse\u001B[39;00m):\n", - "File \u001B[0;32m~/dagworks/hamilton/hamilton/plugins/jupyter_magic.py:297\u001B[0m, in \u001B[0;36mHamiltonMagics.cell_to_module\u001B[0;34m(self, line, cell)\u001B[0m\n\u001B[1;32m 281\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Turn a notebook cell into a Hamilton module definition. This allows you to define\u001B[39;00m\n\u001B[1;32m 282\u001B[0m \u001B[38;5;124;03mand execute a dataflow from a single cell.\u001B[39;00m\n\u001B[1;32m 283\u001B[0m \n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 292\u001B[0m \u001B[38;5;124;03m```\u001B[39;00m\n\u001B[1;32m 293\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 294\u001B[0m \u001B[38;5;66;03m# shell.ex() is equivalent to exec(), but in the user namespace (i.e. notebook context).\u001B[39;00m\n\u001B[1;32m 295\u001B[0m \u001B[38;5;66;03m# This allows imports and functions defined in the magic cell %%cell_to_module to be\u001B[39;00m\n\u001B[1;32m 296\u001B[0m \u001B[38;5;66;03m# directly accessed from the notebook\u001B[39;00m\n\u001B[0;32m--> 297\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshell\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mex\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcell\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 299\u001B[0m args, unknown_args \u001B[38;5;241m=\u001B[39m parse_known_argstring(\n\u001B[1;32m 300\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcell_to_module, line\n\u001B[1;32m 301\u001B[0m ) \u001B[38;5;66;03m# specify how to parse by passing method\u001B[39;00m\n\u001B[1;32m 302\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mresolve_unknown_args_cell_to_module(unknown_args)\n", - "File \u001B[0;32m~/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2812\u001B[0m, in \u001B[0;36mInteractiveShell.ex\u001B[0;34m(self, cmd)\u001B[0m\n\u001B[1;32m 2810\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Execute a normal python statement in user namespace.\"\"\"\u001B[39;00m\n\u001B[1;32m 2811\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuiltin_trap:\n\u001B[0;32m-> 2812\u001B[0m \u001B[43mexec\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcmd\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43muser_global_ns\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43muser_ns\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m:10\u001B[0m\n", - "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/base.py:60\u001B[0m, in \u001B[0;36mtrack_decorator_usage..replace__call__\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 59\u001B[0m DECORATOR_COUNTER[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcustom_decorator\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m DECORATOR_COUNTER[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcustom_decorator\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m---> 60\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mcall_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfn\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/base.py:102\u001B[0m, in \u001B[0;36mNodeTransformLifecycle.__call__\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 94\u001B[0m \u001B[38;5;129m@track_decorator_usage\u001B[39m\n\u001B[1;32m 95\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__call__\u001B[39m(\u001B[38;5;28mself\u001B[39m, fn: Callable):\n\u001B[1;32m 96\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Calls the decorator by adding attributes using the get_lifecycle_name string.\u001B[39;00m\n\u001B[1;32m 97\u001B[0m \u001B[38;5;124;03m These attributes are the pointer to the decorator object itself, and used later in resolve_nodes below.\u001B[39;00m\n\u001B[1;32m 98\u001B[0m \n\u001B[1;32m 99\u001B[0m \u001B[38;5;124;03m :param fn: Function to decorate\u001B[39;00m\n\u001B[1;32m 100\u001B[0m \u001B[38;5;124;03m :return: The function again, with the desired properties.\u001B[39;00m\n\u001B[1;32m 101\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[0;32m--> 102\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvalidate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfn\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 103\u001B[0m lifecycle_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39mget_lifecycle_name()\n\u001B[1;32m 104\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mhasattr\u001B[39m(fn, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mget_lifecycle_name()):\n", - "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/expanders.py:778\u001B[0m, in \u001B[0;36mextract_fields.validate\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 776\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m k, v \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfields\u001B[38;5;241m.\u001B[39mitems():\n\u001B[1;32m 777\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m typed_dict_fields\u001B[38;5;241m.\u001B[39mget(k, \u001B[38;5;28;01mNone\u001B[39;00m) \u001B[38;5;241m!=\u001B[39m v:\n\u001B[0;32m--> 778\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m base\u001B[38;5;241m.\u001B[39mInvalidDecoratorException(\n\u001B[1;32m 779\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mError \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfields\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m do not match TypedDict annotation\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124ms fields \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtyped_dict_fields\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 780\u001B[0m )\n\u001B[1;32m 781\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 782\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m base\u001B[38;5;241m.\u001B[39mInvalidDecoratorException(\n\u001B[1;32m 783\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mFor extracting fields, output type must be a dict or typing.Dict, not: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00moutput_type\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 784\u001B[0m )\n", - "\u001B[0;31mInvalidDecoratorException\u001B[0m: Error {'foo': } do not match TypedDict annotation's fields {'foo': , 'bar': }." - ] - } - ], - "execution_count": 6 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:33:22.671127Z", - "start_time": "2024-12-07T18:28:13.515984Z" - } - }, - "cell_type": "code", - "source": [ - "from typing import TypedDict\n", - "import typing\n", - "import typing_inspect\n", - "import typing_extensions\n", - "from hamilton.function_modifiers import extract_fields\n", - "\n", - "class MyDict(TypedDict):\n", - " foo: str\n", - " bar: int\n", - "\n", - "# @extract_fields(\n", - "# {\"foo\": str, \"bar\": int}\n", - "# )\n", - "def some_function()->MyDict:\n", - " return MyDict(foo=\"s\", bar=1)\n", - "\n", - "output_type = typing.get_type_hints(some_function).get(\"return\")" - ], - "id": "35bc82aab6e4ce1c", - "outputs": [], - "execution_count": 3 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:33:22.674076Z", - "start_time": "2024-12-07T18:28:13.591675Z" - } - }, - "cell_type": "code", - "source": "typing_inspect.is_generic_type(output_type)", - "id": "9258f731ec504725", - "outputs": [ - { - "data": { - "text/plain": [ - "False" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 4 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:33:22.674391Z", - "start_time": "2024-12-07T18:28:13.678468Z" - } - }, - "cell_type": "code", - "source": "typing_extensions.is_typeddict(output_type)", - "id": "74083ed2b1767a0d", - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 5 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-12-07T18:33:22.674559Z", - "start_time": "2024-12-07T18:29:47.979165Z" - } - }, - "cell_type": "code", - "source": "typing.get_type_hints(MyDict)", - "id": "59ced85a7e15d623", - "outputs": [ - { - "data": { - "text/plain": [ - "{'foo': str, 'bar': int}" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 7 - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": "", - "id": "45819ccf150aada3" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index af6a0074b..faf2acaac 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -771,7 +771,7 @@ def validate(self, fn: Callable): self.fields = typing.get_type_hints(output_type) _validate_extract_fields(self.fields) else: - # check that fields is a subset of TypedDict + # check that fields is a subset of TypedDict that is defined typed_dict_fields = typing.get_type_hints(output_type) for k, v in self.fields.items(): if typed_dict_fields.get(k, None) != v: diff --git a/tests/function_modifiers/test_expanders.py b/tests/function_modifiers/test_expanders.py index ed5b3463e..82792ff14 100644 --- a/tests/function_modifiers/test_expanders.py +++ b/tests/function_modifiers/test_expanders.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, TypedDict import numpy as np import pandas as pd @@ -313,6 +313,15 @@ def test_extract_fields_constructor_happy(fields): expanders._validate_extract_fields(fields) +class MyDict(TypedDict): + test: int + test2: str + + +class MyDictBad(TypedDict): + test2: str + + @pytest.mark.parametrize( "return_type", [ @@ -320,6 +329,7 @@ def test_extract_fields_constructor_happy(fields): Dict, Dict[str, str], Dict[str, Any], + MyDict, ], ) def test_extract_fields_validate_happy(return_type): @@ -330,7 +340,10 @@ def return_dict() -> return_type: annotation.validate(return_dict) -@pytest.mark.parametrize("return_type", [(int), (list), (np.ndarray), (pd.DataFrame)]) +@pytest.mark.parametrize( + "return_type", + [(int), (list), (np.ndarray), (pd.DataFrame), (MyDictBad)], +) def test_extract_fields_validate_errors(return_type): def return_dict() -> return_type: return {} @@ -340,6 +353,24 @@ def return_dict() -> return_type: annotation.validate(return_dict) +def test_extract_fields_typeddict_empty_fields(): + def return_dict() -> MyDict: + return {} + + # don't need fields for TypedDict + annotation = function_modifiers.extract_fields() + annotation.validate(return_dict) + + +def test_extract_fields_typeddict_subset(): + def return_dict() -> MyDict: + return {} + + # test that a subset of fields is fine + annotation = function_modifiers.extract_fields({"test2": str}) + annotation.validate(return_dict) + + def test_valid_extract_fields(): """Tests whole extract_fields decorator.""" annotation = function_modifiers.extract_fields(