From 594f4e934d21d91b239472791f7165dccef25c7f Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sat, 7 Dec 2024 10:49:21 -0800 Subject: [PATCH] Adds sketch of improving extract_fields with typeddict This in response to #1252. We should be able to handle typeddict better. This sketches some ideas: 1. field validation should happen in .validate() not the constructor. 2. extract_fields shouldn't need fields if the typeddict is the annotation type. 3. we properly check that typeddict can be a return type. --- examples/debug/typed_dict.ipynb | 356 +++++++++++++++++++++++ hamilton/function_modifiers/expanders.py | 20 +- 2 files changed, 372 insertions(+), 4 deletions(-) create mode 100644 examples/debug/typed_dict.ipynb diff --git a/examples/debug/typed_dict.ipynb b/examples/debug/typed_dict.ipynb new file mode 100644 index 000000000..277b829cf --- /dev/null +++ b/examples/debug/typed_dict.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "# Execute this cell to install dependencies\n", + "%pip install sf-hamilton[visualization]" + ], + "id": "7d9d428b3d95fd85" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# TODO:fix [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/hello_world/typed_dict.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/dagworks-inc/hamilton/blob/main/examples/hello_world/typed_dict.ipynb)\n", + "id": "52a18896ee3cdad6" + }, + { + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-12-07T18:42:00.062097Z", + "start_time": "2024-12-07T18:41:51.024139Z" + } + }, + "cell_type": "code", + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n", + " warnings.warn(\n" + ] + } + ], + "execution_count": 1, + "source": [ + "from typing_extensions import is_typeddict\n", + "%load_ext hamilton.plugins.jupyter_magic" + ], + "id": "initial_id" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:42:00.479296Z", + "start_time": "2024-12-07T18:42:00.075584Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module ef --display\n", + "\n", + "from typing import TypedDict\n", + "from hamilton.function_modifiers import extract_fields\n", + "\n", + "class MyDict(TypedDict):\n", + " foo: str\n", + " bar: int\n", + "\n", + "@extract_fields()\n", + "def some_function()->MyDict:\n", + " return MyDict(foo=\"s\", bar=1)" + ], + "id": "5dbd4e574a6eb1d6", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nfoo\n\nfoo\nstr\n\n\n\nsome_function\n\nsome_function\nMyDict\n\n\n\nsome_function->foo\n\n\n\n\n\nbar\n\nbar\nint\n\n\n\nsome_function->bar\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:42:00.792277Z", + "start_time": "2024-12-07T18:42:00.756575Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import driver\n", + "\n", + "dr = driver.Builder().with_modules(ef).build()\n", + "dr.execute([\"foo\", \"bar\"])" + ], + "id": "b9c896af26148b65", + "outputs": [ + { + "data": { + "text/plain": [ + "{'foo': 's', 'bar': 1}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:42:01.256392Z", + "start_time": "2024-12-07T18:42:00.868513Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module ef2 --display\n", + "\n", + "from typing import TypedDict\n", + "from hamilton.function_modifiers import extract_fields\n", + "\n", + "class MyDict(TypedDict):\n", + " foo: str\n", + " bar: int\n", + "\n", + "@extract_fields({\"foo\": str})\n", + "def some_function()->MyDict:\n", + " return MyDict(foo=\"s\", bar=1)" + ], + "id": "49b1af6cfc233929", + "outputs": [ + { + "data": { + "image/svg+xml": "\n\n\n\n\n\n\n\ncluster__legend\n\nLegend\n\n\n\nfoo\n\nfoo\nstr\n\n\n\nsome_function\n\nsome_function\nMyDict\n\n\n\nsome_function->foo\n\n\n\n\n\nfunction\n\nfunction\n\n\n\n", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:42:01.318133Z", + "start_time": "2024-12-07T18:42:01.283203Z" + } + }, + "cell_type": "code", + "source": [ + "from hamilton import driver\n", + "\n", + "dr = driver.Builder().with_modules(ef2).build()\n", + "dr.execute([\"foo\"])" + ], + "id": "85db6edbbeb40528", + "outputs": [ + { + "data": { + "text/plain": [ + "{'foo': 's'}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:42:03.211760Z", + "start_time": "2024-12-07T18:42:01.458736Z" + } + }, + "cell_type": "code", + "source": [ + "%%cell_to_module ef3 --display\n", + "\n", + "from typing import TypedDict\n", + "from hamilton.function_modifiers import extract_fields\n", + "\n", + "class MyDict(TypedDict):\n", + " foo: str\n", + " bar: int\n", + "\n", + "@extract_fields({\"foo\": int})\n", + "def some_function()->MyDict:\n", + " return MyDict(foo=\"s\", bar=1)" + ], + "id": "4e68a45d22decb1a", + "outputs": [ + { + "ename": "InvalidDecoratorException", + "evalue": "Error {'foo': } do not match TypedDict annotation's fields {'foo': , 'bar': }.", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mInvalidDecoratorException\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[6], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mget_ipython\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_cell_magic\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mcell_to_module\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mef3 --display\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mfrom typing import TypedDict\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mfrom hamilton.function_modifiers import extract_fields\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mclass MyDict(TypedDict):\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m foo: str\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m bar: int\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m@extract_fields(\u001B[39;49m\u001B[38;5;124;43m{\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mfoo\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m: int})\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43mdef some_function()->MyDict:\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m return MyDict(foo=\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43ms\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m, bar=1)\u001B[39;49m\u001B[38;5;130;43;01m\\n\u001B[39;49;00m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2478\u001B[0m, in \u001B[0;36mInteractiveShell.run_cell_magic\u001B[0;34m(self, magic_name, line, cell)\u001B[0m\n\u001B[1;32m 2476\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuiltin_trap:\n\u001B[1;32m 2477\u001B[0m args \u001B[38;5;241m=\u001B[39m (magic_arg_s, cell)\n\u001B[0;32m-> 2478\u001B[0m result \u001B[38;5;241m=\u001B[39m \u001B[43mfn\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 2480\u001B[0m \u001B[38;5;66;03m# The code below prevents the output from being displayed\u001B[39;00m\n\u001B[1;32m 2481\u001B[0m \u001B[38;5;66;03m# when using magics with decodator @output_can_be_silenced\u001B[39;00m\n\u001B[1;32m 2482\u001B[0m \u001B[38;5;66;03m# when the last Python token in the expression is a ';'.\u001B[39;00m\n\u001B[1;32m 2483\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mgetattr\u001B[39m(fn, magic\u001B[38;5;241m.\u001B[39mMAGIC_OUTPUT_CAN_BE_SILENCED, \u001B[38;5;28;01mFalse\u001B[39;00m):\n", + "File \u001B[0;32m~/dagworks/hamilton/hamilton/plugins/jupyter_magic.py:297\u001B[0m, in \u001B[0;36mHamiltonMagics.cell_to_module\u001B[0;34m(self, line, cell)\u001B[0m\n\u001B[1;32m 281\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Turn a notebook cell into a Hamilton module definition. This allows you to define\u001B[39;00m\n\u001B[1;32m 282\u001B[0m \u001B[38;5;124;03mand execute a dataflow from a single cell.\u001B[39;00m\n\u001B[1;32m 283\u001B[0m \n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 292\u001B[0m \u001B[38;5;124;03m```\u001B[39;00m\n\u001B[1;32m 293\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[1;32m 294\u001B[0m \u001B[38;5;66;03m# shell.ex() is equivalent to exec(), but in the user namespace (i.e. notebook context).\u001B[39;00m\n\u001B[1;32m 295\u001B[0m \u001B[38;5;66;03m# This allows imports and functions defined in the magic cell %%cell_to_module to be\u001B[39;00m\n\u001B[1;32m 296\u001B[0m \u001B[38;5;66;03m# directly accessed from the notebook\u001B[39;00m\n\u001B[0;32m--> 297\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mshell\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mex\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcell\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 299\u001B[0m args, unknown_args \u001B[38;5;241m=\u001B[39m parse_known_argstring(\n\u001B[1;32m 300\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcell_to_module, line\n\u001B[1;32m 301\u001B[0m ) \u001B[38;5;66;03m# specify how to parse by passing method\u001B[39;00m\n\u001B[1;32m 302\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mresolve_unknown_args_cell_to_module(unknown_args)\n", + "File \u001B[0;32m~/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2812\u001B[0m, in \u001B[0;36mInteractiveShell.ex\u001B[0;34m(self, cmd)\u001B[0m\n\u001B[1;32m 2810\u001B[0m \u001B[38;5;250m\u001B[39m\u001B[38;5;124;03m\"\"\"Execute a normal python statement in user namespace.\"\"\"\u001B[39;00m\n\u001B[1;32m 2811\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuiltin_trap:\n\u001B[0;32m-> 2812\u001B[0m \u001B[43mexec\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcmd\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43muser_global_ns\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43muser_ns\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m:10\u001B[0m\n", + "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/base.py:60\u001B[0m, in \u001B[0;36mtrack_decorator_usage..replace__call__\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 59\u001B[0m DECORATOR_COUNTER[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcustom_decorator\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m=\u001B[39m DECORATOR_COUNTER[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mcustom_decorator\u001B[39m\u001B[38;5;124m\"\u001B[39m] \u001B[38;5;241m+\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m---> 60\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mcall_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfn\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/base.py:102\u001B[0m, in \u001B[0;36mNodeTransformLifecycle.__call__\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 94\u001B[0m \u001B[38;5;129m@track_decorator_usage\u001B[39m\n\u001B[1;32m 95\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__call__\u001B[39m(\u001B[38;5;28mself\u001B[39m, fn: Callable):\n\u001B[1;32m 96\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"Calls the decorator by adding attributes using the get_lifecycle_name string.\u001B[39;00m\n\u001B[1;32m 97\u001B[0m \u001B[38;5;124;03m These attributes are the pointer to the decorator object itself, and used later in resolve_nodes below.\u001B[39;00m\n\u001B[1;32m 98\u001B[0m \n\u001B[1;32m 99\u001B[0m \u001B[38;5;124;03m :param fn: Function to decorate\u001B[39;00m\n\u001B[1;32m 100\u001B[0m \u001B[38;5;124;03m :return: The function again, with the desired properties.\u001B[39;00m\n\u001B[1;32m 101\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[0;32m--> 102\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mvalidate\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfn\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 103\u001B[0m lifecycle_name \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39mget_lifecycle_name()\n\u001B[1;32m 104\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mhasattr\u001B[39m(fn, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mget_lifecycle_name()):\n", + "File \u001B[0;32m~/dagworks/hamilton/hamilton/function_modifiers/expanders.py:778\u001B[0m, in \u001B[0;36mextract_fields.validate\u001B[0;34m(self, fn)\u001B[0m\n\u001B[1;32m 776\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m k, v \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfields\u001B[38;5;241m.\u001B[39mitems():\n\u001B[1;32m 777\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m typed_dict_fields\u001B[38;5;241m.\u001B[39mget(k, \u001B[38;5;28;01mNone\u001B[39;00m) \u001B[38;5;241m!=\u001B[39m v:\n\u001B[0;32m--> 778\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m base\u001B[38;5;241m.\u001B[39mInvalidDecoratorException(\n\u001B[1;32m 779\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mError \u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfields\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m do not match TypedDict annotation\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124ms fields \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mtyped_dict_fields\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 780\u001B[0m )\n\u001B[1;32m 781\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 782\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m base\u001B[38;5;241m.\u001B[39mInvalidDecoratorException(\n\u001B[1;32m 783\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mFor extracting fields, output type must be a dict or typing.Dict, not: \u001B[39m\u001B[38;5;132;01m{\u001B[39;00moutput_type\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 784\u001B[0m )\n", + "\u001B[0;31mInvalidDecoratorException\u001B[0m: Error {'foo': } do not match TypedDict annotation's fields {'foo': , 'bar': }." + ] + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:33:22.671127Z", + "start_time": "2024-12-07T18:28:13.515984Z" + } + }, + "cell_type": "code", + "source": [ + "from typing import TypedDict\n", + "import typing\n", + "import typing_inspect\n", + "import typing_extensions\n", + "from hamilton.function_modifiers import extract_fields\n", + "\n", + "class MyDict(TypedDict):\n", + " foo: str\n", + " bar: int\n", + "\n", + "# @extract_fields(\n", + "# {\"foo\": str, \"bar\": int}\n", + "# )\n", + "def some_function()->MyDict:\n", + " return MyDict(foo=\"s\", bar=1)\n", + "\n", + "output_type = typing.get_type_hints(some_function).get(\"return\")" + ], + "id": "35bc82aab6e4ce1c", + "outputs": [], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:33:22.674076Z", + "start_time": "2024-12-07T18:28:13.591675Z" + } + }, + "cell_type": "code", + "source": "typing_inspect.is_generic_type(output_type)", + "id": "9258f731ec504725", + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 4 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:33:22.674391Z", + "start_time": "2024-12-07T18:28:13.678468Z" + } + }, + "cell_type": "code", + "source": "typing_extensions.is_typeddict(output_type)", + "id": "74083ed2b1767a0d", + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-12-07T18:33:22.674559Z", + "start_time": "2024-12-07T18:29:47.979165Z" + } + }, + "cell_type": "code", + "source": "typing.get_type_hints(MyDict)", + "id": "59ced85a7e15d623", + "outputs": [ + { + "data": { + "text/plain": [ + "{'foo': str, 'bar': int}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "45819ccf150aada3" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 1b5e726d3..af6a0074b 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -5,6 +5,7 @@ import typing from typing import Any, Callable, Collection, Dict, Tuple, Union +import typing_extensions import typing_inspect from hamilton import node, registry @@ -733,7 +734,7 @@ def _validate_extract_fields(fields: dict): class extract_fields(base.SingleNodeNodeTransformer): """Extracts fields from a dictionary of output.""" - def __init__(self, fields: dict, fill_with: Any = None): + def __init__(self, fields: dict = None, fill_with: Any = None): """Constructor for a modifier that expands a single function into the following nodes: - n functions, each of which take in the original dict and output a specific field @@ -745,7 +746,6 @@ def __init__(self, fields: dict, fill_with: Any = None): field value. """ super(extract_fields, self).__init__() - _validate_extract_fields(fields) self.fields = fields self.fill_with = fill_with @@ -759,13 +759,25 @@ def validate(self, fn: Callable): if typing_inspect.is_generic_type(output_type): base_type = typing_inspect.get_origin(output_type) if base_type == dict or base_type == Dict: - pass + _validate_extract_fields(self.fields) else: raise base.InvalidDecoratorException( f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}" ) elif output_type == dict: - pass + _validate_extract_fields(self.fields) + elif typing_extensions.is_typeddict(output_type): + if self.fields is None: + self.fields = typing.get_type_hints(output_type) + _validate_extract_fields(self.fields) + else: + # check that fields is a subset of TypedDict + typed_dict_fields = typing.get_type_hints(output_type) + for k, v in self.fields.items(): + if typed_dict_fields.get(k, None) != v: + raise base.InvalidDecoratorException( + f"Error {self.fields} did not match a subset of the TypedDict annotation's fields {typed_dict_fields}." + ) else: raise base.InvalidDecoratorException( f"For extracting fields, output type must be a dict or typing.Dict, not: {output_type}"