From 056bca0f6356bf76c2ae6d50c60d79b0d730c129 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 25 Feb 2021 17:58:03 -0800 Subject: [PATCH] autodidax: jit, multi-output, pytrees, DeviceArrays --- docs/autodidax.ipynb | 2627 ++++++++++++++++++++++++++++++++---------- docs/autodidax.md | 1124 +++++++++++++++--- docs/autodidax.py | 1032 +++++++++++++---- 3 files changed, 3820 insertions(+), 963 deletions(-) diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index 758b108760ce..7294a3466d9a 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -19,19 +19,35 @@ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License.\n", + "\n", "---" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "# TODO remove me\n", + "import pdb, sys, traceback\n", + "def info(type, value, tb):\n", + " traceback.print_exception(type, value, tb)\n", + " pdb.pm()\n", + "sys.excepthook = info" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Autodidax: JAX core from scratch\n", "\n", - "Ever want to learn how JAX works, but the implementation seemed too\n", - "impenetrable? Well, you're in luck! By reading this tutorial, you'll learn\n", - "every big idea in JAX's core system. You'll even get clued into our weird\n", - "jargon!" + "Ever want to learn how JAX works, but the implementation seemed impenetrable?\n", + "Well, you're in luck! By reading this tutorial, you'll learn every big idea in\n", + "JAX's core system. You'll even get clued into our weird jargon!" ] }, { @@ -44,7 +60,7 @@ "\n", "```python\n", "def f(x):\n", - " y = sin(x) * 2\n", + " y = sin(x) * 2.\n", " z = - y + x\n", " return z\n", "```\n", @@ -54,14 +70,13 @@ "atomic units of processing rather than compositions.\n", "\n", "\"Transform\" means \"interpret differently.\" Instead of standard interpretation\n", - "where we apply primitive functions to numerical inputs to produce numerical\n", + "where we apply primitive operations to numerical inputs to produce numerical\n", "outputs, we want to override primitive application and let different values\n", "flow through our program. For example, we might want to replace the\n", "application of every primitive with an application of [its JVP\n", "rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html),\n", - "and let primal-tangent pairs flow through our program. Moreover, we want to\n", - "apply a composition of multiple transformations, leading to stacks of\n", - "interpreters." + "and let primal-tangent pairs flow through our program. Moreover, we want to be\n", + "able to comopse multiple transformations, leading to stacks of interpreters." ] }, { @@ -80,7 +95,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "from typing import NamedTuple\n", @@ -95,14 +113,22 @@ "cos_p = Primitive(\"cos\")\n", "reduce_sum_p = Primitive(\"reduce_sum\")\n", "greater_p = Primitive(\"greater\")\n", - "\n", - "def add(x, y): return bind(add_p, x, y)\n", - "def mul(x, y): return bind(mul_p, x, y)\n", - "def neg(x): return bind(neg_p, x)\n", - "def sin(x): return bind(sin_p, x)\n", - "def cos(x): return bind(cos_p, x)\n", - "def reduce_sum(x, axis=None): return bind(reduce_sum_p, x, axis=axis)\n", - "def greater(x, y): return bind(greater_p, x, y)" + "transpose_p = Primitive(\"transpose\")\n", + "broadcast_p = Primitive(\"broadcast\")\n", + "\n", + "def add(x, y): return bind1(add_p, x, y)\n", + "def mul(x, y): return bind1(mul_p, x, y)\n", + "def neg(x): return bind1(neg_p, x)\n", + "def sin(x): return bind1(sin_p, x)\n", + "def cos(x): return bind1(cos_p, x)\n", + "def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis)\n", + "def greater(x, y): return bind1(greater_p, x, y)\n", + "def transpose(x, perm): return bind1(transpose_p, perm=perm)\n", + "def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes)\n", + "\n", + "def bind1(prim, *args, **params):\n", + " out, = bind(prim, *args, **params)\n", + " return out" ] }, { @@ -136,19 +162,49 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "from contextlib import contextmanager\n", - "from typing import Type, List, Optional, Any\n", - "\n", + "from typing import Type, List, Optional, Any" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class MainTrace(NamedTuple):\n", " level: int\n", " trace_type: Type['Trace']\n", - " global_data: Optional[Any]\n", - "\n", + " global_data: Optional[Any]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "trace_stack: List[MainTrace] = []\n", - "\n", + "dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "@contextmanager\n", "def new_main(trace_type: Type['Trace'], global_data=None):\n", " level = len(trace_stack)\n", @@ -165,14 +221,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When we're about to apply a transformed function, we'll push another\n", - "interpreter onto the stack using `new_main`. Then, as we apply primitives in\n", - "the function, we can think of the `bind` first being interpreted by the trace\n", - "at the top of the stack (i.e. with the highest level). If that first\n", - "interpreter itself binds other primitives in its interpretation rule for the\n", - "primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`,\n", - "then those `bind` calls will be handled by the interpreter at the next level\n", - "down.\n", + "When we're about to apply a transformation, we'll push another interpreter\n", + "onto the stack using `new_main`. Then, as we apply primitives in the function,\n", + "we can think of the `bind` first being interpreted by the trace at the top of\n", + "the stack (i.e. with the highest level). If that first interpreter itself\n", + "binds other primitives in its interpretation rule for the primitive, like how\n", + "the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind`\n", + "calls will be handled by the interpreter at the next level down.\n", "\n", "What goes at the bottom of the interpreter stack? At the bottom, we know all\n", "the transformation interpreters are finished, and we just want to do standard\n", @@ -187,7 +242,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "class Trace:\n", @@ -227,12 +284,23 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "import numpy as np\n", - "from typing import Tuple\n", - "\n", + "from typing import Tuple" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class Tracer:\n", " _trace: Trace\n", "\n", @@ -258,8 +326,17 @@ " try:\n", " return getattr(self.aval, name)\n", " except AttributeError:\n", - " raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")\n", - "\n", + " raise AttributeError(f\"{self.__class__.__name__} has no attribute {name}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class ShapedArray:\n", " array_abstraction_level = 1\n", " shape: Tuple[int]\n", @@ -291,6 +368,22 @@ " def str_short(self):\n", " return f'{self.dtype.name}[{\",\".join(str(d) for d in self.shape)}]'\n", "\n", + " def __hash__(self):\n", + " return hash((self.shape, self.dtype))\n", + "\n", + " def __eq__(self, other):\n", + " return (type(self) is type(other) and\n", + " self.shape == other.shape and self.dtype == other.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class ConcreteArray(ShapedArray):\n", " array_abstraction_level = 2\n", " val: np.ndarray\n", @@ -306,8 +399,18 @@ "\n", " @staticmethod\n", " def _nonzero(tracer):\n", - " return bool(tracer.aval.val)\n", - "\n", + " return bool(tracer.aval.val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def get_aval(x):\n", " if isinstance(x, Tracer):\n", " return x.aval\n", @@ -324,21 +427,23 @@ "possible arrays with a given shape and dtype. A `ConcreteArray` represents a\n", "singleton set consisting of a single array value.\n", "\n", - "Now that we've set up the trace stack, the Trace/Tracer API for interpreters,\n", - "and abstract values, we can come back to implement `bind`:" + "Now that we've set up the interpreter stack, the Trace/Tracer API for\n", + "interpreters, and abstract values, we can come back to implement `bind`:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "def bind(prim, *args, **params):\n", " top_trace = find_top_trace(args)\n", " tracers = [full_raise(top_trace, arg) for arg in args]\n", - " out = top_trace.process_primitive(prim, tracers, params)\n", - " return full_lower(out)" + " outs = top_trace.process_primitive(prim, tracers, params)\n", + " return [full_lower(out) for out in outs]" ] }, { @@ -346,8 +451,7 @@ "metadata": {}, "source": [ "The main action is that we call `find_top_trace` to figure out which\n", - "interpreter should handle this primitive application as a function of the\n", - "arguments and the active traces on the trace stack. We then call that top\n", + "interpreter should handle this primitive application. We then call that top\n", "trace's `process_primitive` so that the trace can apply its interpretation\n", "rule. The calls to `full_raise` just ensure that the inputs are boxed in the\n", "top trace's `Tracer` instances, and the call to `full_lower` is an optional\n", @@ -357,14 +461,27 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from operator import attrgetter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "from operator import attrgetter\n", - "\n", "def find_top_trace(xs) -> Trace:\n", " top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),\n", " default=trace_stack[0], key=attrgetter('level'))\n", + " if dynamic_trace and dynamic_trace.level > top_main.level:\n", + " top_main = dynamic_trace\n", " return top_main.trace_type(top_main)" ] }, @@ -372,27 +489,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In words, `find_top_trace` returns the highest-level interpreter associated\n", - "with the `Tracer`s on its inputs, and otherwise returns the interpreter at the\n", - "bottom of the stack (which is always an evaluation trace, at least for now).\n", - "This corresponds to JAX transformations mostly working by data dependence\n", - "_except_ for the special bottom-of-the-stack interpreter, which interprets\n", - "everything." + "In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace`\n", + "returns the highest-level interpreter associated with the `Tracer`s on its\n", + "inputs, and otherwise returns the interpreter at the bottom of the stack\n", + "(which is always an evaluation trace, at least for now). This is a deviation\n", + "from the description above, where we always start by running the interpreter\n", + "at the top of the stack and then work our way down, applying every interpreter\n", + "in the stack. Instead, we're only applying an interpreter when the input\n", + "arguments to a primitive bind are boxed in a `Tracer` corresponding to that\n", + "interpreter. This optimization lets us skip irrelevant transformations, but\n", + "bakes in an assumption that transformations mostly follow data dependence\n", + "(except for the special bottom-of-the-stack interpreter, which interprets\n", + "everything).\n", + "\n", + "An alternative would be to have every interpreter in the stack interpret every\n", + "operation. That's worth exploring! JAX is designed around data dependence in\n", + "large part because that's so natural for automatic differentiation, and JAX's\n", + "roots are in autodiff. But it may be over-fit." ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def full_lower(val):\n", + "def full_lower(val: Any):\n", " if isinstance(val, Tracer):\n", " return val.full_lower()\n", " else:\n", - " return val\n", - "\n", - "def full_raise(trace, val) -> Tracer:\n", + " return val" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def full_raise(trace: Trace, val: Any) -> Tracer:\n", " if not isinstance(val, Tracer):\n", " return trace.pure(val)\n", " level = trace.main.level\n", @@ -433,25 +572,67 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "class EvalTrace(Trace):\n", " pure = lift = lambda self, x: x # no boxing in Tracers needed\n", "\n", " def process_primitive(self, primitive, tracers, params):\n", - " return impl_rules[primitive](*tracers, **params)\n", - "\n", - "trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack\n", - "\n", - "impl_rules = {}\n", - "impl_rules[add_p] = np.add\n", - "impl_rules[mul_p] = np.multiply\n", - "impl_rules[neg_p] = np.negative\n", - "impl_rules[sin_p] = np.sin\n", - "impl_rules[cos_p] = np.cos\n", - "impl_rules[reduce_sum_p] = np.sum\n", - "impl_rules[greater_p] = np.greater" + " return impl_rules[primitive](*tracers, **params)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "impl_rules = {}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "impl_rules[add_p] = lambda x, y: [np.add(x, y)]\n", + "impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)]\n", + "impl_rules[neg_p] = lambda x: [np.negative(x)]\n", + "impl_rules[sin_p] = lambda x: [np.sin(x)]\n", + "impl_rules[cos_p] = lambda x: [np.cos(x)]\n", + "impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)]\n", + "impl_rules[greater_p] = lambda x, y: [np.greater(x, y)]\n", + "impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def broadcast_impl(x, *, shape, axes):\n", + " return [np.broadcast_to(np.expand_dims(x, axes), shape)]\n", + "impl_rules[broadcast_p] = broadcast_impl" ] }, { @@ -464,7 +645,23 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [ { "name": "stdout", @@ -475,11 +672,6 @@ } ], "source": [ - "def f(x):\n", - " y = sin(x) * 2\n", - " z = - y + x\n", - " return z\n", - "\n", "print(f(3.0))" ] }, @@ -497,18 +689,29 @@ "source": [ "### Forward-mode autodiff with `jvp`\n", "\n", - "First, a couple of helper functions:" + "First, a few helper functions:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "def zeros_like(val):\n", - " return np.zeros_like(val)\n", - "\n", + " return np.zeros_like(val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def unzip2(pairs):\n", " lst1, lst2 = [], []\n", " for x1, x2 in pairs:\n", @@ -517,6 +720,19 @@ " return lst1, lst2" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "map_ = map\n", + "def map(f, *xs):\n", + " return list(map_(f, *xs))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -528,7 +744,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ "class JVPTracer(Tracer):\n", @@ -539,17 +757,35 @@ "\n", " @property\n", " def aval(self):\n", - " return get_aval(self.primal)\n", - "\n", + " return get_aval(self.primal)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "class JVPTrace(Trace):\n", " pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val))\n", "\n", " def process_primitive(self, primitive, tracers, params):\n", " primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)\n", " jvp_rule = jvp_rules[primitive]\n", - " primal_out, tangent_out = jvp_rule(primals_in, tangents_in, **params)\n", - " return JVPTracer(self, primal_out, tangent_out)\n", - "\n", + " primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params)\n", + " return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "jvp_rules = {}" ] }, @@ -558,56 +794,107 @@ "metadata": {}, "source": [ "Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n", - "minimal amount of context, which is a zero tangent value." + "minimal amount of context, which is a zero tangent value.\n", + "\n", + "Let's add some JVP rules for primitives:" ] }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], "source": [ - "Let's add some JVP rules for primitives:" + "def add_jvp(primals, tangents):\n", + " (x, y), (x_dot, y_dot) = primals, tangents\n", + " return [x + y], [x_dot + y_dot]\n", + "jvp_rules[add_p] = add_jvp" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def add_jvp(primals, tangents):\n", - " (x, y), (x_dot, y_dot) = primals, tangents\n", - " return x + y, x_dot + y_dot\n", - "jvp_rules[add_p] = add_jvp\n", - "\n", "def mul_jvp(primals, tangents):\n", " (x, y), (x_dot, y_dot) = primals, tangents\n", - " return x * y, x_dot * y + x * y_dot\n", - "jvp_rules[mul_p] = mul_jvp\n", - "\n", + " return [x * y], [x_dot * y + x * y_dot]\n", + "jvp_rules[mul_p] = mul_jvp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def sin_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", - " return sin(x), cos(x) * x_dot\n", - "jvp_rules[sin_p] = sin_jvp\n", - "\n", + " return [sin(x)], [cos(x) * x_dot]\n", + "jvp_rules[sin_p] = sin_jvp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def cos_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", - " return cos(x), -sin(x) * x_dot\n", - "jvp_rules[cos_p] = cos_jvp\n", - "\n", + " return [cos(x)], [-sin(x) * x_dot]\n", + "jvp_rules[cos_p] = cos_jvp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def neg_jvp(primals, tangents):\n", " (x,), (x_dot,) = primals, tangents\n", - " return neg(x), neg(x_dot)\n", - "jvp_rules[neg_p] = neg_jvp\n", - "\n", + " return [neg(x)], [neg(x_dot)]\n", + "jvp_rules[neg_p] = neg_jvp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def reduce_sum_jvp(primals, tangents, *, axis):\n", " (x,), (x_dot,) = primals, tangents\n", - " return reduce_sum(x, axis), reduce_sum(x_dot, axis)\n", - "jvp_rules[reduce_sum_p] = reduce_sum_jvp\n", - "\n", + " return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)]\n", + "jvp_rules[reduce_sum_p] = reduce_sum_jvp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ "def greater_jvp(primals, tangents):\n", " (x, y), _ = primals, tangents\n", " out_primal = greater(x, y)\n", - " return out_primal, zeros_like(out_primal)\n", + " return [out_primal], [zeros_like(out_primal)]\n", "jvp_rules[greater_p] = greater_jvp" ] }, @@ -621,10 +908,12 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def jvp(f, primals, tangents):\n", + "def jvp_v1(f, primals, tangents):\n", " with new_main(JVPTrace) as main:\n", " trace = JVPTrace(main)\n", " tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n", @@ -657,7 +946,7 @@ ], "source": [ "x = 3.0\n", - "y, sin_deriv_at_3 = jvp(sin, (x,), (1.0,))\n", + "y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,))\n", "print(sin_deriv_at_3)\n", "print(cos(3.0))" ] @@ -665,7 +954,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [ { "name": "stdout", @@ -678,12 +969,12 @@ ], "source": [ "def f(x):\n", - " y = sin(x) * 2\n", + " y = sin(x) * 2.\n", " z = - y + x\n", " return z\n", "\n", "x, xdot = 3., 1.\n", - "y, ydot = jvp(f, (x,), (xdot,))\n", + "y, ydot = jvp_v1(f, (x,), (xdot,))\n", "print(y)\n", "print(ydot)" ] @@ -691,7 +982,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [ { "name": "stdout", @@ -706,7 +999,7 @@ ], "source": [ "def deriv(f):\n", - " return lambda x: jvp(f, (x,), (1.,))[1]\n", + " return lambda x: jvp_v1(f, (x,), (1.,))[1]\n", "\n", "print(deriv(sin)(3.))\n", "print(deriv(deriv(sin))(3.))\n", @@ -717,7 +1010,10 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, "outputs": [ { "name": "stdout", @@ -743,665 +1039,1800 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Vectorized batching with `vmap`\n", - "\n", - "First, a couple helper functions, one for producing mapped abstract values\n", - "from unmapped ones (by removing an axis), and one for moving batch dimensions\n", - "around:" + "## Pytrees and flattening user functions' inputs and outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A limitation with `jvp_v1` is that it assumes the user function accepts arrays\n", + "as positional arguments and produces a single array as output. What if it\n", + "produced a list as output? Or accepted nested containers as inputs? It would\n", + "be a pain to deal with all the possible containers in inputs and outputs at\n", + "every layer of the stack. Instead, we can wrap the user function so that the\n", + "wrapped version accepts arrays as inputs and returns a flat list of arrays as\n", + "output. The wrapper just needs to unflatten its input, call the user function,\n", + "and flatten the output.\n", + "\n", + "Here's how we'd like to write `jvp`, assuming the user always gives us\n", + "functions that take arrays as inputs and produces a flat list of arrays as\n", + "outputs:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def mapped_aval(batch_dim, aval):\n", - " shape = list(aval.shape)\n", - " del shape[batch_dim]\n", - " return ShapedArray(tuple(shape), aval.dtype)\n", - "\n", - "def move_batch_axis(axis_size, src, dst, x):\n", - " if src is not_mapped:\n", - " target_shape = list(np.shape(x))\n", - " target_shape.insert(dst, axis_size)\n", - " return np.broadcast_to(np.expand_dims(x, dst), target_shape)\n", - " else:\n", - " return np.moveaxis(x, src, dst)" + "def jvp_flat(f, primals, tangents):\n", + " with new_main(JVPTrace) as main:\n", + " trace = JVPTrace(main)\n", + " tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)]\n", + " outs = f(*tracers_in)\n", + " tracers_out = [full_raise(trace, out) for out in outs]\n", + " primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out)\n", + " return primals_out, tangents_out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The `Tracer` for vectorized batching carries a batched value and an optional\n", - "integer indicating which axis (if any) is the batch axis." + "To support user functions that have arbitrary containers in the inputs and\n", + "outputs, here's how we'd write the user-facing `jvp` wrapper:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "from typing import Union\n", - "\n", - "class NotMapped: pass\n", - "not_mapped = NotMapped()\n", - "\n", - "class BatchTracer(Tracer):\n", - " def __init__(self, trace, val, batch_dim: Union[NotMapped, int]):\n", - " self._trace = trace\n", - " self.val = val\n", - " self.batch_dim = batch_dim\n", - "\n", - " @property\n", - " def aval(self):\n", - " if self.batch_dim is not_mapped:\n", - " return get_aval(self.val)\n", - " else:\n", - " return mapped_aval(self.batch_dim, get_aval(self.val))\n", - "\n", - " def full_lower(self):\n", - " if self.batch_dim is not_mapped:\n", - " return full_lower(self.val)\n", - " else:\n", - " return self\n", - "\n", - "class BatchTrace(Trace):\n", - " pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n", - "\n", - " def process_primitive(self, primitive, tracers, params):\n", - " vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n", - " vmap_rule = vmap_rules[primitive]\n", - " val_out, bdim_out = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n", - " return BatchTracer(self, val_out, bdim_out)\n", - "\n", - " @property\n", - " def axis_size(self):\n", - " return self.main.global_data\n", - "\n", - "vmap_rules = {}" + "def jvp(f, primals, tangents):\n", + " primals_flat, in_tree = tree_flatten(primals)\n", + " tangents_flat, in_tree2 = tree_flatten(tangents)\n", + " if in_tree != in_tree2: raise TypeError\n", + " f, out_tree = flatten_fun(f, in_tree)\n", + " primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat)\n", + " primals_out = tree_unflatten(out_tree(), primals_out_flat)\n", + " tangents_out = tree_unflatten(out_tree(), tangents_out_flat)\n", + " return primals_out, tangents_out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Here we've implemented the optional `Tracer.full_lower` method, which lets us\n", - "peel off a batching tracer if it's not needed because it doesn't represent a\n", - "batched value.\n", - "\n", - "For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n", - "box a value in a `BatchTracer` with the minimal amount of context, which in\n", - "this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n", - "use the `MainTrace`'s interpreter-global data field to store the batch axis\n", - "size.\n", - "\n", - "Next we can define batching interpreter rules for each primitive:" + "Notice that we had to plumb the tree structure of the user function output\n", + "back to the caller of `flatten_fun`. That information isn't available until we\n", + "actually run the user function, so `flatten_fun` just returns a reference to a\n", + "mutable cell, represented as a thunk. These side-effects are safe because we\n", + "always run the user function exactly once. (This safe regime is the reason for\n", + "the \"linear\" name in `linear_util.py`, in the sense of [linear\n", + "types](https://en.wikipedia.org/wiki/Substructural_type_system).)\n", + "\n", + "All that remains is to write `tree_flatten`, `tree_unflatten`, and\n", + "`flatten_fun`:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "from functools import partial\n", - "\n", - "def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n", - " (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n", - " if x_bdim != y_bdim:\n", - " y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n", - " return op(x, y), x_bdim\n", - "vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n", - "vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)\n", + "def flatten_fun(f, in_tree):\n", + " store = Store()\n", "\n", - "def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n", - " (x,), (x_bdim,) = vals_in, dims_in\n", - " return op(x), x_bdim\n", - "vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n", - "vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n", - "vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)\n", + " def flat_fun(*args_flat):\n", + " pytree_args = tree_unflatten(in_tree, args_flat)\n", + " out = f(*pytree_args)\n", + " out_flat, out_tree = tree_flatten(out)\n", + " store.set_value(out_tree)\n", + " return out_flat\n", "\n", - "def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n", - " (x,), (x_bdim,) = vals_in, dims_in\n", - " new_axis = axis + (x_bdim <= axis)\n", - " out_bdim = x_bdim - (new_axis < x_bdim)\n", - " return reduce_sum(x, new_axis), out_bdim\n", - "vmap_rules[reduce_sum_p] = reduce_sum_batching_rule" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, we add a transformation API to kick off the trace:" + " return flat_fun, store" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def vmap(f, in_axes, out_axis):\n", - " def batched_f(*args):\n", - " axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n", - " if ax is not None}\n", - " with new_main(BatchTrace, axis_size) as main:\n", - " trace = BatchTrace(main)\n", - " tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n", - " for x, ax in zip(args, in_axes)]\n", - " out = f(*tracers_in)\n", - " tracer_out = full_raise(trace, out)\n", - " val_out, batch_dim_out = tracer_out.val, tracer_out.batch_dim\n", - " return move_batch_axis(axis_size, batch_dim_out, out_axis, val_out)\n", - " return batched_f" + "class Empty: pass\n", + "empty = Empty()" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0. 1. 2.]\n", - "[1. 2. 3.]\n" - ] - } - ], + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], "source": [ - "def add_one_to_a_scalar(scalar):\n", - " assert np.ndim(scalar) == 0\n", - " return 1 + scalar\n", + "class Store:\n", + " val = empty\n", "\n", - "vector_in = np.arange(3.)\n", - "vector_out = vmap(add_one_to_a_scalar, (0,), 0)(vector_in)\n", + " def set_value(self, val):\n", + " assert self.val is empty\n", + " self.val = val\n", "\n", - "print(vector_in)\n", - "print(vector_out)" + " def __call__(self):\n", + " return self.val" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[ 1. , 0. , -0. ],\n", - " [ 0. , 0.54030231, -0. ],\n", - " [ 0. , 0. , -0.41614684]])" - ] - }, - "execution_count": 172, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 2 + }, + "outputs": [], "source": [ - "def jacfwd(f, x):\n", - " pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n", - " vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n", - " return vmap(pushfwd, (0,), 0)(vecs_in)\n", + "import itertools as it\n", + "from typing import Callable, Type, Hashable, Dict, Iterable, Iterator\n", + "\n", + "class NodeType(NamedTuple):\n", + " to_iterable: Callable\n", + " from_iterable: Callable\n", + "\n", + "node_types: Dict[Type, NodeType] = {\n", + " tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)),\n", + " list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)),\n", + " dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))),\n", + " lambda keys, vals: dict(zip(keys, vals))),\n", + "}\n", + "\n", + "class PyTreeDef(NamedTuple):\n", + " node_type: NodeType\n", + " node_metadata: Hashable\n", + " child_treedefs: Tuple['PyTreeDef']\n", + "\n", + "class Leaf: pass\n", + "leaf = Leaf()\n", + "\n", + "def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]:\n", + " children_iter, treedef = _tree_flatten(x)\n", + " return list(children_iter), treedef\n", + "\n", + "def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]:\n", + " node_type = node_types.get(type(x))\n", + " if node_type:\n", + " node_metadata, children = node_type.to_iterable(x)\n", + " children_flat, child_trees = unzip2(map(_tree_flatten, children))\n", + " flattened = it.chain.from_iterable(children_flat)\n", + " return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees))\n", + " else:\n", + " return [x], leaf\n", "\n", - "def f(x):\n", - " return sin(x)\n", + "def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any:\n", + " return _tree_unflatten(treedef, iter(xs))\n", "\n", - "jacfwd(f, np.arange(3.))" + "def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any:\n", + " if treedef is leaf:\n", + " return next(xs)\n", + " else:\n", + " children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs)\n", + " return treedef.node_type.from_iterable(treedef.node_metadata, children)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "That's it for `jvp` and `vmap`! Before moving on, let's highlight a few\n", - "simplifications in what we've seen so far compared to the full JAX\n", - "implementation:\n", - "1. **Fewer, simpler primitives.** More primitives means more interpretation\n", - "rules, and for more complex primitives (like for convolution or advanced\n", - "indexing) each rule is harder to write. But the overarching design is no\n", - "different.\n", - "1. **Transformations expect arrays in, single array out.**\n", - "2. **No symbolic zeros in autodiff.**\n", - "3. **No special call primitives yet.** The core machinery needs to be\n", - " generalized to handle the most flexible kind of higher-order primitive,\n", - " used by `jax.custom_jvp` and `jax.custom_vjp`." + "With this pytree-handling `jvp` impelmentation, we can now handle arbitrary\n", + "input and output containers. That'll come in handy with future transformations\n", + "too!" ] }, { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Part 2: Jaxprs, for `jit` and `vjp`\n", - "\n", - "The next transformations are the horizon are `jit` for just-in-time\n", - "compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small\n", - "wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to\n", - "carry a little bit of extra context, but for both `jit` and `vjp` we need\n", - "much richer context: we need to represent _programs_. That is, we need jaxprs!\n", - "\n", - "Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are\n", - "an explicitly typed, functional, first-order language. We need a program\n", - "representation for `jit` because the purpose of `jit` is to stage computation\n", - "out of Python. For any computation we want to stage out, we need to be able to\n", - "represent it as data, and build it up as we trace a Python function.\n", - "Similarly, `vjp` needs a way to represent the computation for the backward\n", - "pass of reverse-mode autodiff. We use the same jaxpr program representation\n", - "for both needs.\n", + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return {'hi': z, 'there': [x, y]}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "x, xdot = 3., 1.\n", + "y, ydot = jvp(f, (x,), (xdot,))\n", + "print(y)\n", + "print(ydot)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Vectorized batching with `vmap`\n", + "\n", + "First, a couple helper functions, one for producing mapped abstract values\n", + "from unmapped ones (by removing an axis), and one for moving batch dimensions\n", + "around:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def mapped_aval(batch_dim, aval):\n", + " shape = list(aval.shape)\n", + " del shape[batch_dim]\n", + " return ShapedArray(tuple(shape), aval.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def move_batch_axis(axis_size, src, dst, x):\n", + " if src is not_mapped:\n", + " target_shape = list(np.shape(x))\n", + " target_shape.insert(dst, axis_size)\n", + " return broadcast(x, target_shape, [dst])\n", + " elif src == dst:\n", + " return x\n", + " else:\n", + " return moveaxis(x, src, dst)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def moveaxis(x, src: int, dst: int):\n", + " perm = [i for i in range(np.ndim(x)) if i != src]\n", + " perm.insert(dst, src)\n", + " return transpose(x, perm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `Tracer` for vectorized batching carries a batched value and an optional\n", + "integer indicating which axis (if any) is the batch axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from typing import Union" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class NotMapped: pass\n", + "not_mapped = NotMapped()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "BatchAxis = Union[NotMapped, int]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class BatchTracer(Tracer):\n", + " def __init__(self, trace, val, batch_dim: BatchAxis):\n", + " self._trace = trace\n", + " self.val = val\n", + " self.batch_dim = batch_dim\n", + "\n", + " @property\n", + " def aval(self):\n", + " if self.batch_dim is not_mapped:\n", + " return get_aval(self.val)\n", + " else:\n", + " return mapped_aval(self.batch_dim, get_aval(self.val))\n", + "\n", + " def full_lower(self):\n", + " if self.batch_dim is not_mapped:\n", + " return full_lower(self.val)\n", + " else:\n", + " return self" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class BatchTrace(Trace):\n", + " pure = lift = lambda self, val: BatchTracer(self, val, not_mapped)\n", + "\n", + " def process_primitive(self, primitive, tracers, params):\n", + " vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers)\n", + " vmap_rule = vmap_rules[primitive]\n", + " val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params)\n", + " return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)]\n", + "\n", + " @property\n", + " def axis_size(self):\n", + " return self.main.global_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vmap_rules = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we've implemented the optional `Tracer.full_lower` method, which lets us\n", + "peel off a batching tracer if it's not needed because it doesn't represent a\n", + "batched value.\n", + "\n", + "For `BatchTrace`, analogous to `JVPTrace`, the methods `pure` and `lift` just\n", + "box a value in a `BatchTracer` with the minimal amount of context, which in\n", + "this case is a `batch_dim` taking the sentinel value `not_mapped`. Notice we\n", + "use the `MainTrace`'s interpreter-global data field to store the batch axis\n", + "size.\n", + "\n", + "Next we can define batching interpreter rules for each primitive:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from functools import partial" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in):\n", + " (x, y), (x_bdim, y_bdim) = vals_in, dims_in\n", + " if x_bdim != y_bdim:\n", + " y = move_batch_axis(axis_size, y_bdim, x_bdim, y)\n", + " return [op(x, y)], [x_bdim]\n", + "vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add)\n", + "vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in):\n", + " (x,), (x_bdim,) = vals_in, dims_in\n", + " return [op(x)], [x_bdim]\n", + "vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin)\n", + "vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos)\n", + "vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis):\n", + " (x,), (x_bdim,) = vals_in, dims_in\n", + " new_axis = axis + (x_bdim <= axis)\n", + " out_bdim = x_bdim - (new_axis < x_bdim)\n", + " return [reduce_sum(x, new_axis)], [out_bdim]\n", + "vmap_rules[reduce_sum_p] = reduce_sum_batching_rule" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we add a transformation API to kick off the trace:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def vmap_flat(f, in_axes, *args):\n", + " axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes)\n", + " if ax is not not_mapped}\n", + " with new_main(BatchTrace, axis_size) as main:\n", + " trace = BatchTrace(main)\n", + " tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x\n", + " for x, ax in zip(args, in_axes)]\n", + " outs = f(*tracers_in)\n", + " tracers_out = [full_raise(trace, out) for out in outs]\n", + " vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out)\n", + " outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out)\n", + " for val_out, bdim in zip(vals_out, bdims_out)]\n", + " return outs_transposed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def vmap(f, in_axes):\n", + " def batched_f(*args):\n", + " args_flat, in_tree = tree_flatten(args)\n", + " in_axes_flat, in_tree2 = tree_flatten(in_axes)\n", + " if in_tree != in_tree2: raise TypeError\n", + " f_flat, out_tree = flatten_fun(f, in_tree)\n", + " outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat)\n", + " return tree_unflatten(out_tree(), outs_flat)\n", + " return batched_f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def add_one_to_a_scalar(scalar):\n", + " assert np.ndim(scalar) == 0\n", + " return 1 + scalar\n", + "\n", + "vector_in = np.arange(3.)\n", + "vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in)\n", + "\n", + "print(vector_in)\n", + "print(vector_out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def jacfwd(f, x):\n", + " pushfwd = lambda v: jvp(f, (x,), (v,))[1]\n", + " vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2)\n", + " return vmap(pushfwd, (0,))(vecs_in)\n", + "\n", + "def f(x):\n", + " return sin(x)\n", + "\n", + "jacfwd(f, np.arange(3.))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "That's it for `jvp` and `vmap`! Before moving on, let's highlight a few\n", + "simplifications in what we've seen so far compared to the full JAX\n", + "implementation:\n", + "1. **Fewer, simpler primitives.** More primitives means more interpretation\n", + "rules, and for more complex primitives (like for convolution or advanced\n", + "indexing) each rule is harder to write. But the overarching design is no\n", + "different.\n", + "2. **No pytrees.** Transformations expect arrays in, and either a single array\n", + " out or a flat list of arrays out.\n", + "3. **Missing optimization: no symbolic zeros in autodiff.**\n", + "4. **No special call primitives yet.** The core machinery needs to be\n", + " generalized to handle the most flexible kind of higher-order primitive,\n", + " used by `jax.custom_jvp` and `jax.custom_vjp`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 2: Jaxprs\n", + "\n", + "The next transformations are the horizon are `jit` for just-in-time\n", + "compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small\n", + "wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to\n", + "carry a little bit of extra context, for both `jit` and `vjp` we need much\n", + "richer context: we need to represent _programs_. That is, we need jaxprs!\n", + "\n", + "Jaxprs are JAX's internal intermediate representation of programs. They are\n", + "explicitly typed, functional, first-order, and in ANF form. We need a\n", + "program representation for `jit` because the purpose of `jit` is to stage\n", + "computation out of Python. For any computation we want to stage out, we need\n", + "to be able to represent it as data, and build it up as we trace a Python\n", + "function. Similarly, `vjp` needs a way to represent the computation for the\n", + "backward pass of reverse-mode autodiff. We use the same jaxpr program\n", + "representation for both needs.\n", + "\n", + "(Building a program representation is the most\n", + "[free](https://en.wikipedia.org/wiki/Free_object) kind of\n", + "trace-transformation, and so except for issues around handling native Python\n", + "control flow, any transformation could be implemented by first tracing to a\n", + "jaxpr and then interpreting the jaxpr.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Jaxpr data strutures\n", + "\n", + "The jaxpr term syntax is roughly:\n", + "\n", + "```\n", + "jaxpr ::=\n", + " { lambda , ... .\n", + " let \n", + " ...\n", + " in ( , ... ) }\n", + "\n", + "binder ::= :\n", + "var ::= a | b | c | ...\n", + "atom ::= | \n", + "literal ::= | \n", + "\n", + "eqn ::= , ... = [ ] , ...\n", + "```\n", + "\n", + "The syntax of types is:\n", + "\n", + "```\n", + "jaxpr_type ::= [ , ... ] -> [ , ... ]\n", + "array_type ::= []\n", + "dtype ::= f32 | f64 | i32 | i64\n", + "shape ::= , ...\n", + "```\n", + "\n", + "How do we represent these as Python data structures? We reuse ShapedArrays to\n", + "represent types, and we can represent the term syntax with a few Python\n", + "structs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from typing import Set\n", + "\n", + "class Var:\n", + " aval: ShapedArray\n", + " def __init__(self, aval): self.aval = aval\n", + "\n", + "class Lit:\n", + " val: Any\n", + " aval: ShapedArray\n", + "\n", + " def __init__(self, val):\n", + " self.val = val\n", + " self.aval = raise_to_shaped(get_aval(self.val))\n", + "\n", + "Atom = Union[Var, Lit]\n", + "\n", + "class JaxprEqn(NamedTuple):\n", + " primitive: Primitive\n", + " inputs: List[Atom]\n", + " params: Dict[str, Any]\n", + " out_binders: List[Var]\n", + "\n", + "class Jaxpr(NamedTuple):\n", + " in_binders: List[Var]\n", + " eqns: List[JaxprEqn]\n", + " outs: List[Atom]\n", + "\n", + "def raise_to_shaped(aval):\n", + " return ShapedArray(aval.shape, aval.dtype)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "Type-checking a jaxpr involves checking that there are no unbound variables,\n", + "that variables are only bound once, and that for each equation the type of\n", + "the primitive application matches the type of the output binders." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class JaxprType:\n", + " in_types: List[ShapedArray]\n", + " out_type: List[ShapedArray]\n", + "\n", + " def __init__(self, in_types, out_types):\n", + " self.in_types = in_types\n", + " self.out_types = out_types\n", + "\n", + " def __repr__(self):\n", + " in_types = ', '.join(aval.str_short() for aval in self.in_types)\n", + " out_types = ', '.join(aval.str_short() for aval in self.out_types)\n", + " return f'({in_types}) -> ({out_types})'\n", + "\n", + "def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n", + " env: Set[Var] = set()\n", + "\n", + " for v in jaxpr.in_binders:\n", + " if v in env: raise TypeError\n", + " env.add(v)\n", + "\n", + " for eqn in jaxpr.eqns:\n", + " in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n", + " out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n", + " for out_binder, out_type in zip(eqn.out_binders, out_types):\n", + " if not types_equal(out_type, out_binder.aval): raise TypeError\n", + " for out_binder in eqn.out_binders:\n", + " if out_binder in env: raise TypeError\n", + " env.add(out_binder)\n", + "\n", + " in_types = [v.aval for v in jaxpr.in_binders]\n", + " out_types = [typecheck_atom(env, x) for x in jaxpr.outs]\n", + " return JaxprType(in_types, out_types)\n", + "\n", + "def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n", + " if isinstance(x, Var):\n", + " if x not in env: raise TypeError(\"unbound variable\")\n", + " return x.aval\n", + " elif isinstance(x, Lit):\n", + " return raise_to_shaped(get_aval(x.val))\n", + " else:\n", + " assert False\n", + "\n", + "def types_equal(a: ShapedArray, b: ShapedArray) -> bool:\n", + " return a.shape == b.shape and a.dtype == b.dtype" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can apply the function represented by a jaxpr to arguments with a simple\n", + "interpreter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]:\n", + " env: Dict[Var, Any] = {}\n", "\n", - "(Building a program representation is the most\n", - "[free](https://en.wikipedia.org/wiki/Free_object) kind of\n", - "trace- transformation, and so except for issues around handling native Python\n", - "control flow, any transformation could be implemented by first tracing to a\n", - "jaxpr and then interpreting the jaxpr.)\n", + " def read(x: Atom) -> Any:\n", + " return env[x] if type(x) is Var else x.val\n", "\n", - "The jaxpr term syntax is roughly:\n", + " def write(v: Var, val: Any) -> None:\n", + " env[v] = val\n", "\n", - "```\n", - "jaxpr ::=\n", - " { lambda , ... .\n", - " let \n", - " ...\n", - " in }\n", + " map(write, jaxpr.in_binders, args)\n", + " for eqn in jaxpr.eqns:\n", + " in_vals = map(read, eqn.inputs)\n", + " outs = bind(eqn.primitive, *in_vals, **eqn.params)\n", + " map(write, eqn.out_binders, outs)\n", + " return map(read, jaxpr.outs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def jaxpr_as_fun(jaxpr: Jaxpr):\n", + " return lambda *args: eval_jaxpr(jaxpr, args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By using `bind` in the interpreter, this interpreter itself is traceable." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Building jaxprs with tracing\n", "\n", - "binder ::= :\n", - "var ::= a | b | c | ...\n", - "atom ::= | \n", - "literal ::= | \n", + "Now that we have jaxprs as a data structure, we need ways to produce these\n", + "from tracing Python code. In general there are two variants of how we trace to\n", + "a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n", + "used by `jit`, which is also used by control flow primitives like `lax.cond`,\n", + "`lax.while_loop`, and `lax.scan`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n", + "class JaxprTracer(Tracer):\n", + " __slots__ = ['aval']\n", + " aval: ShapedArray\n", + "\n", + " def __init__(self, trace, aval):\n", + " self._trace = trace\n", + " self.aval = aval\n", + "\n", + "# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n", + "class JaxprTrace(Trace):\n", + " def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n", + " aval = raise_to_shaped(aval)\n", + " tracer = self.builder.new_tracer(self, aval)\n", + " self.builder.tracer_to_var[id(tracer)] = Var(aval)\n", + " return tracer\n", + "\n", + " def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n", + " tracer = self.builder.const_tracers.get(id(val))\n", + " if tracer is None:\n", + " tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val)))\n", + " self.builder.add_const(tracer, val)\n", + " return tracer\n", + " pure = lift = get_or_make_const_tracer\n", + "\n", + " def process_primitive(self, primitive, tracers, params):\n", + " avals_in = [t.aval for t in tracers]\n", + " avals_out = abstract_eval_rules[primitive](*avals_in, **params)\n", + " out_tracers = [self.builder.new_tracer(self, a) for a in avals_out]\n", + " inputs = [self.builder.getvar(t) for t in tracers]\n", + " outvars = [self.builder.add_var(t) for t in out_tracers]\n", + " self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars))\n", + " return out_tracers\n", + "\n", + " @property\n", + " def builder(self):\n", + " return self.main.global_data\n", + "\n", + "# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n", + "abstract_eval_rules = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that we keep as interpreter-global data a builder object, which keeps\n", + "track of variables, constants, and eqns as we build up the jaxpr." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class JaxprBuilder:\n", + " eqns: List[JaxprEqn]\n", + " tracer_to_var: Dict[int, Var]\n", + " const_tracers: Dict[int, JaxprTracer]\n", + " constvals: Dict[Var, Any]\n", + " tracers: List[JaxprTracer]\n", + "\n", + " def __init__(self):\n", + " self.eqns = []\n", + " self.tracer_to_var = {}\n", + " self.const_tracers = {}\n", + " self.constvals = {}\n", + " self.tracers = []\n", + "\n", + " def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer:\n", + " tracer = JaxprTracer(trace, aval)\n", + " self.tracers.append(tracer)\n", + " return tracer\n", + "\n", + " def add_eqn(self, eqn: JaxprEqn) -> None:\n", + " self.eqns.append(eqn)\n", + "\n", + " def add_var(self, tracer: JaxprTracer) -> Var:\n", + " assert id(tracer) not in self.tracer_to_var\n", + " var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n", + " return var\n", + "\n", + " def getvar(self, tracer: JaxprTracer) -> Var:\n", + " var = self.tracer_to_var.get(id(tracer))\n", + " assert var is not None\n", + " return var\n", + "\n", + " def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n", + " var = self.add_var(tracer)\n", + " self.const_tracers[id(val)] = tracer\n", + " self.constvals[var] = val\n", + " return var\n", + "\n", + " def build(self, in_tracers: List[JaxprTracer], out_tracers: List[JaxprTracer]\n", + " ) -> Tuple[Jaxpr, List[Any]]:\n", + " constvars, constvals = unzip2(self.constvals.items())\n", + " t2v = lambda t: self.tracer_to_var[id(t)]\n", + " in_binders = constvars + [t2v(t) for t in in_tracers]\n", + " out_vars = [t2v(t) for t in out_tracers]\n", + " jaxpr = Jaxpr(in_binders, self.eqns, out_vars)\n", + " typecheck_jaxpr(jaxpr)\n", + " return jaxpr, constvals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The rules we need for `JaxprTrace.process_primitive` are essentially typing\n", + "rules for primitive applications: given the primitive, its parameters, and\n", + "types for the inputs, the rule must produce a type for the output, which is\n", + "then packaged with the output `JaxprTracer`. We can use abstract evaluation\n", + "rules for this same purpose, even though they can be more general (since\n", + "abstract evaluation rules must accept ConcreteArray inputs, and since they\n", + "need only return an upper bound on the set of possible outputs, they can\n", + "produce ConcreteArray outputs as well). We'll reuse these abstract evaluation\n", + "rules for the other jaxpr-producing trace machinery, where the potential extra\n", + "generality is useful." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def broadcast_shapes(*shapes):\n", + " assert len(shapes) > 1\n", + " for sizes in zip(*shapes):\n", + " sizes = [d for d in sizes if d != 1]\n", + " if sizes[:-1] != sizes[1:]:\n", + " raise Exception\n", + " return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))\n", + "\n", + "def broadcasting_binop_abstract_eval_rule(*avals_in):\n", + " out_dtype = np.result_type(*map(np.result_type, avals_in))\n", + " out_shape = broadcast_shapes(*map(np.shape, avals_in))\n", + " return [ShapedArray(out_shape, out_dtype)]\n", + "\n", + "abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule\n", + "abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule\n", + "\n", + "def vectorized_unop_abstract_eval_rule(aval_in):\n", + " return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))]\n", + "\n", + "abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule\n", + "abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule\n", + "abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule\n", + "\n", + "def reduce_sum_abstract_eval_rule(aval_in, *, axis):\n", + " new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]\n", + " return [ShapedArray(tuple(new_shape), aval_in.dtype)]\n", + "abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule\n", + "\n", + "def broadcast_abstract_eval(x, *, shape, axes):\n", + " return [ShapedArray(tuple(shape), np.result_type(x))]\n", + "abstract_eval_rules[broadcast_p] = broadcast_abstract_eval" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To check our implementation of jaxprs, we can add a `make_jaxpr`\n", + "transformation and a pretty-printer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from functools import lru_cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@lru_cache()\n", + "def make_jaxpr_v1(f, *avals_in):\n", + " avals_in, in_tree = tree_flatten(avals_in)\n", + " f, out_tree = flatten_fun(f, in_tree)\n", + "\n", + " builder = JaxprBuilder()\n", + " with new_main(JaxprTrace, builder) as main:\n", + " trace = JaxprTrace(main)\n", + " tracers_in = [trace.new_arg(aval) for aval in avals_in]\n", + " outs = f(*tracers_in)\n", + " tracers_out = [full_raise(trace, out) for out in outs]\n", + " jaxpr, consts = builder.build(tracers_in, tracers_out)\n", + " return jaxpr, consts, out_tree()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "import string\n", + "\n", + "class PPrint:\n", + " lines: List[Tuple[int, str]]\n", + "\n", + " def __init__(self, lines):\n", + " self.lines = lines\n", + "\n", + " def indent(self, indent: int) -> 'PPrint':\n", + " return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n", + "\n", + " def __add__(self, rhs: 'PPrint') -> 'PPrint':\n", + " return PPrint(self.lines + rhs.lines)\n", + "\n", + " def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n", + " if not rhs.lines: return self\n", + " if not self.lines: return rhs\n", + " indent, s = self.lines[-1]\n", + " indented_block = rhs.indent(indent + len(s))\n", + " common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n", + " return PPrint(self.lines[:-1]\n", + " + [(indent, common_line)]\n", + " + indented_block.lines[1:])\n", + "\n", + " def __str__(self) -> str:\n", + " return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n", + "\n", + "def pp(s: Any) -> PPrint:\n", + " return PPrint([(0, line) for line in str(s).splitlines()])\n", + "\n", + "def vcat(ps: List[PPrint]) -> PPrint:\n", + " return sum(ps, pp(''))\n", "\n", - "eqn ::= = [ ] , ...\n", - "```\n", + "def pp_jaxpr(jaxpr: Jaxpr):\n", + " namegen = (''.join(s) for r in it.count(1)\n", + " for s in it.permutations(string.ascii_lowercase, r))\n", + " names = defaultdict(lambda: next(namegen))\n", + " in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n", + " eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n", + " outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val)\n", + " for v in jaxpr.outs)\n", + " return (pp(f'{{ lambda {in_binders} .') +\n", + " ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2))\n", "\n", - "The syntax of types is:\n", + "def var_str(names: Dict[Var, str], v: Var) -> str:\n", + " return f'{names[v]}:{v.aval.str_short()}'\n", "\n", - "```\n", - "jaxpr_type ::= [, ...] -> [, ...]\n", - "array_type ::= []\n", - "dtype ::= f32 | f64 | i32 | i64\n", - "shape ::= , ...\n", - "```\n", + "def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n", + " lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders))\n", + " rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n", + " pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n", + " for x in eqn.inputs)))\n", + " return lhs >> pp(' = ') >> rhs\n", "\n", - "How do we represent these as Python data structures? We reuse ShapedArrays to\n", - "represent types, and we can represent the term syntax with a few Python\n", - "structs:" + "def pp_params(params: Dict[str, Any]) -> PPrint:\n", + " items = sorted(params.items())\n", + " if items:\n", + " return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n", + " else:\n", + " return pp(' ')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.)))\n", + "print(pp_jaxpr(jaxpr))\n", + "print(typecheck_jaxpr(jaxpr))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "But there's a limitation here: because of how `find_top_trace` operates by\n", + "data dependence, `make_jaxpr_v1` can't stage out all the primitive operations\n", + "performed by the Python callable it's given. For example:" ] }, { "cell_type": "code", "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.))\n", + "print(pp_jaxpr(jaxpr))" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "This is precisely the issue that\n", + "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n", + "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n", + "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n", + "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n", + "global defined in Part 1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "from typing import Dict, Set\n", + "@contextmanager\n", + "def new_dynamic(main: MainTrace):\n", + " global dynamic_trace\n", + " prev_dynamic_trace, dynamic_trace = dynamic_trace, main\n", + " try:\n", + " yield\n", + " finally:\n", + " dynamic_trace = prev_dynamic_trace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@lru_cache() # ShapedArrays are hashable\n", + "def make_jaxpr(f, *avals_in):\n", + " avals_in, in_tree = tree_flatten(avals_in)\n", + " f, out_tree = flatten_fun(f, in_tree)\n", "\n", - "class Var:\n", - " aval: ShapedArray\n", - " def __init__(self, aval): self.aval = aval\n", + " builder = JaxprBuilder()\n", + " with new_main(JaxprTrace, builder) as main:\n", + " with new_dynamic(main):\n", + " trace = JaxprTrace(main)\n", + " tracers_in = [trace.new_arg(aval) for aval in avals_in]\n", + " outs = f(*tracers_in)\n", + " tracers_out = [full_raise(trace, out) for out in outs]\n", + " jaxpr, consts = builder.build(tracers_in, tracers_out)\n", + " return jaxpr, consts, out_tree()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.))\n", + "print(pp_jaxpr(jaxpr))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Using `dynamic_trace` this way is conceptually the same as stashing the\n", + "current interpreter stack and starting a new one with the `JaxprTrace` at the\n", + "bottom. That is, no interpreters lower in the stack than the `dynamic_trace`\n", + "are applied (since `JaxprTrace.process_primitive` doesn't call `bind`), though\n", + "if the Python callable being traced to a jaxpr itself uses transformations\n", + "then those can be pushed onto the interpreter stack above the `JaxprTrace`.\n", + "But temporarily stashing the interpreter stack would break up the system\n", + "state. The `dynamic_trace` tag achieves the same goals while keeping the\n", + "system state simpler." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "That's it for jaxprs! With jaxprs in hand, we can implement the remaining\n", + "major JAX features. But before moving on, let's highlight some\n", + "simplifications we've made:\n", + "1. **Single-output primitives and jaxprs.**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Part 3: `jit`, simplified\n", "\n", - "class Lit:\n", + "While `jit` has a transformation-like API in that it accepts a Python callable\n", + "as an argument, under the hood it's really a higher-order primitive rather\n", + "than a transformation. A primitive is _higher-order_ when it's parameterized\n", + "by a function." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### \"Final style\" and \"initial style\"\n", + "\n", + "There are two options for how to handle higher-order primitives. Each requires\n", + "a different approach to tracing and engenders different tradeoffs:\n", + "1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr\n", + " until as late as possible, namely until we're running the final interpreter\n", + " at the bottom of the interpreter stack. That way we can swap a `JaxprTrace`\n", + " in at the bottom of the interpreter stack and thus stage out rather than\n", + " execute all primitive operations. With this approach, transformations in\n", + " the stack get applied as we execute the Python callable as usual. This\n", + " approach can be very tricky to implement, but it's as general as possible\n", + " because it allows higher-order primitives not to raise the abstraction\n", + " level of their arguments and thus allows data-dependent Python control\n", + " flow. We refer to this approach as using a \"final-style higher-order\n", + " primitive\" employing the discharge-at-tracing-time \"final-style\n", + " transformations\" we've used so far.\n", + "2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the\n", + " primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and\n", + " be done with the Python callable entirely. In this case, `make_jaxpr` puts\n", + " its `JaxprTrace` at the top of the interpreter stack, and no\n", + " transformations lower in the stack, which might enter via closed-over\n", + " Tracers, are applied to the Python callable as we trace it.\n", + " (Transformations applied within the Python callable are applied as usual,\n", + " being added to the stack above the JaxprTrace.) Instead, the\n", + " transformations lower in the stack are later applied to the call primitive,\n", + " and the call primitive's rules must then transform the jaxpr itself.\n", + " Because we trace to a jaxpr up-front, this approach can't support\n", + " data-dependent Python control flow, but it is more straightforward to\n", + " implement. We refer to this kind of higher-order primitive as an\n", + " \"initial-style higher-order primitive\", and say that its jaxpr-processing\n", + " transformation rules are \"initial-style transformation rules.\"\n", + "\n", + "The latter approach fits for `jit` because we don't need to support\n", + "data-dependent Python control flow in the user-provided Python callable, as\n", + "the whole purpose of `jit` is to stage computation out of Python to be\n", + "executed by XLA. (In contrast, `custom_jvp` is a higher-order primitive in\n", + "which we want to support data-dependent Python control flow.)\n", + "\n", + "Historically, we started using the \"initial-style\" and \"final-style\"\n", + "terminology after reading the [typed tagless final\n", + "interpreters](http://okmij.org/ftp/tagless-final/index.html) paper, and\n", + "jokingly referring to JAX as an implementation of \"untyped tagful final\n", + "interpreters.\" We don't claim to carry over (or understand) any deep meaning\n", + "behind these terms; we loosely use \"initial style\" to mean \"build an AST and\n", + "then transform it\", and we use \"final style\" to mean \"transform as we trace.\"\n", + "But it's just imprecise yet sticky jargon." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the initial-style approach, here's the user-facing `jit` wrapper:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def jit(f):\n", + " def f_jitted(*args):\n", + " avals_in = [raise_to_shaped(get_aval(x)) for x in args]\n", + " jaxpr, consts, out_tree = make_jaxpr(f, *avals_in)\n", + " outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts))\n", + " return tree_unflatten(out_tree, outs)\n", + " return f_jitted\n", + "\n", + "xla_call_p = Primitive('xla_call')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With any new primitive, we need to give it transformation rules, starting with\n", + "its evaluation rule. When we evaluate an application of the `xla_call`\n", + "primitive, we want to stage out out the computation to XLA. That involves\n", + "translating the jaxpr to an XLA HLO program, transferring the argument values\n", + "to the XLA device, executing the XLA program, and transferring back the\n", + "results. We'll cache the XLA HLO compilation so that for each `jit`ted\n", + "function it only needs to be performed once per argument shape and dtype\n", + "signature.\n", + "\n", + "First, some utilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "class IDHashable:\n", " val: Any\n", - " aval: ShapedArray\n", "\n", " def __init__(self, val):\n", " self.val = val\n", - " self.aval = raise_to_shaped(get_aval(self.val))\n", - "\n", - "Atom = Union[Var, Lit]\n", - "\n", - "class JaxprEqn(NamedTuple):\n", - " primitive: Primitive\n", - " inputs: List[Atom]\n", - " params: Dict[str, Any]\n", - " out_binder: Var\n", - "\n", - "class Jaxpr(NamedTuple):\n", - " in_binders: List[Var]\n", - " eqns: List[JaxprEqn]\n", - " out: Atom\n", "\n", + " def __hash__(self) -> int:\n", + " return id(self.val)\n", "\n", - "def raise_to_shaped(aval):\n", - " return ShapedArray(aval.shape, aval.dtype)" + " def __eq__(self, other):\n", + " return type(other) is IDHashable and id(self.val) == id(other.val)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll define the evaluation rule for `xla_call`:" ] }, { "cell_type": "code", "execution_count": null, - "id": "composite-dinner", + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "from jax.lib import xla_bridge as xb\n", + "from jax.lib import xla_client as xc\n", + "xe = xc._xla\n", + "xops = xc._xla.ops\n", + "\n", + "def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int):\n", + " consts, args = args[:num_consts], args[num_consts:]\n", + " hashable_consts = tuple(map(IDHashable, consts))\n", + " execute = xla_callable(IDHashable(jaxpr), hashable_consts)\n", + " return execute(*args)\n", + "impl_rules[xla_call_p] = xla_call_impl\n", + "\n", + "@lru_cache()\n", + "def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]):\n", + " jaxpr: Jaxpr = hashable_jaxpr.val\n", + " consts = [x.val for x in hashable_consts]\n", + " in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]]\n", + " c = xb.make_computation_builder('xla_call')\n", + " xla_consts = _xla_consts(c, consts)\n", + " xla_params = _xla_params(c, in_avals)\n", + " outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params)\n", + " out = xops.Tuple(c, outs)\n", + " compiled = xb.get_backend(None).compile(c.build(out))\n", + " return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs])\n", + "\n", + "def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]:\n", + " unique_consts = {id(cnst): cnst for cnst in consts}\n", + " xla_consts = {\n", + " id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()}\n", + " return [xla_consts[id(cnst)] for cnst in consts]\n", + "\n", + "def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]:\n", + " return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)]\n", + "\n", + "def _xla_shape(aval: ShapedArray) -> xe.Shape:\n", + " return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape)" + ] + }, + { + "cell_type": "markdown", "metadata": {}, + "source": [ + "The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO\n", + "program using `jaxpr_subcomp`, then returns a callable which executes the\n", + "compiled program:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "class JaxprType:\n", - " in_types: List[ShapedArray]\n", - " out_type: ShapedArray\n", - "\n", - " def __init__(self, in_types, out_type):\n", - " self.in_types = in_types\n", - " self.out_type = out_type\n", + "def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp]\n", + " ) -> xe.XlaOp:\n", + " env: Dict[Var, xe.XlaOp] = {}\n", "\n", - " def __repr__(self):\n", - " in_types = ', '.join(aval.str_short() for aval in self.in_types)\n", - " out_type = self.out_type.str_short()\n", - " return f'({in_types}) -> {out_type}'\n", - "\n", - "\n", - "def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType:\n", - " env: Set[Var] = set()\n", + " def read(x: Atom) -> xe.XlaOp:\n", + " return env[x] if type(x) is Var else xb.constant(c, x.val)\n", "\n", - " for v in jaxpr.in_binders:\n", - " env.add(v)\n", + " def write(v: Var, val: xe.XlaOp) -> None:\n", + " env[v] = val\n", "\n", + " map(write, jaxpr.in_binders, args)\n", " for eqn in jaxpr.eqns:\n", - " in_types = [typecheck_atom(env, x) for x in eqn.inputs]\n", - " out_type = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params)\n", - " if not types_equal(out_type, eqn.out_binder.aval): raise TypeError\n", - " env.add(eqn.out_binder)\n", - "\n", - " out_type = typecheck_atom(env, jaxpr.out)\n", - " return JaxprType([v.aval for v in jaxpr.in_binders], out_type)\n", - "\n", - "def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray:\n", - " if isinstance(x, Var):\n", - " if x not in env: raise TypeError(\"unbound variable\")\n", - " return x.aval\n", - " elif isinstance(x, Lit):\n", - " return raise_to_shaped(get_aval(x.val))\n", - " else:\n", - " assert False\n", - "\n", - "def types_equal(a: ShapedArray, b: ShapedArray) -> bool:\n", - " return a.shape == b.shape and a.dtype == b.dtype" + " in_avals = [x.aval for x in eqn.inputs]\n", + " in_vals = map(read, eqn.inputs)\n", + " rule = xla_translations[eqn.primitive]\n", + " out_vals = rule(c, in_avals, in_vals, **eqn.params)\n", + " map(write, eqn.out_binders, out_vals)\n", + " return map(read, jaxpr.outs)\n", + "\n", + "def execute_compiled(compiled, out_avals, *args):\n", + " input_bufs = [input_handlers[type(x)](x) for x in args]\n", + " out_bufs = compiled.execute(input_bufs)\n", + " return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)]\n", + "\n", + "input_handlers = {\n", + " int: xb.get_backend(None).buffer_from_pyval,\n", + " float: xb.get_backend(None).buffer_from_pyval,\n", + " np.ndarray: xb.get_backend(None).buffer_from_pyval,\n", + "}\n", + "\n", + "def handle_result(aval: ShapedArray, buf):\n", + " del aval # Unused for now.\n", + " return buf.to_py()\n", + "\n", + "xla_translations = {}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's\n", + "a common pattern: the way we process jaxprs is usually with an interpreter.\n", + "And as with any interpreter, we need an interpretation rule for each\n", + "primitive:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def direct_translation(op, c, in_avals, in_vals):\n", + " del c, in_avals\n", + " return [op(*in_vals)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "xla_translations[add_p] = partial(direct_translation, xops.Add)\n", + "xla_translations[mul_p] = partial(direct_translation, xops.Mul)\n", + "xla_translations[neg_p] = partial(direct_translation, xops.Neg)\n", + "xla_translations[sin_p] = partial(direct_translation, xops.Sin)\n", + "xla_translations[cos_p] = partial(direct_translation, xops.Cos)\n", + "xla_translations[greater_p] = partial(direct_translation, xops.Gt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def reduce_sum_translation(c, in_avals, in_vals, *, axis):\n", + " (x_aval,), (x,) = in_avals, in_vals\n", + " zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype))\n", + " subc = xb.make_computation_builder('add')\n", + " shape = _xla_shape(ShapedArray((), x_aval.dtype))\n", + " xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape))\n", + " return [xops.Reduce(c, [x], [zero], subc.build(), [axis])]\n", + "xla_translations[reduce_sum_p] = reduce_sum_translation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def broadcast_translation(c, in_avals, in_vals, *, shape, axes):\n", + " x, = in_vals\n", + " dims_complement = [i for i in range(len(shape)) if i not in axes]\n", + " return [xops.BroadcastInDim(x, shape, dims_complement)]\n", + "xla_translations[broadcast_p] = broadcast_translation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now that we have jaxprs as a data structure, we need ways to produce these\n", - "from tracing Python code. In general there are two variants of how we trace to\n", - "a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one\n", - "used by `jit`, which is also used by control flow primitives like\n", - "`lax.cond`, `lax.while_loop`, and `lax.scan`." + "With that, we can now use `jit` to stage out, compile, and execute programs\n", + "with XLA!" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "# NB: the analogous class in JAX is called 'DynamicJaxprTracer'\n", - "class JaxprTracer(Tracer):\n", - " __slots__ = ['aval']\n", - " aval: ShapedArray\n", - "\n", - " def __init__(self, trace, aval):\n", - " self._trace = trace\n", - " self.aval = aval\n", - "\n", - "# NB: the analogous class in JAX is called 'DynamicJaxprTrace'\n", - "class JaxprTrace(Trace):\n", - " def new_arg(self, aval: ShapedArray) -> JaxprTracer:\n", - " aval = raise_to_shaped(aval)\n", - " tracer = JaxprTracer(self, aval)\n", - " self.builder.tracer_to_var[id(tracer)] = Var(aval)\n", - " return tracer\n", - "\n", - " def get_or_make_const_tracer(self, val: Any) -> JaxprTracer:\n", - " tracer = self.builder.const_tracers.get(id(val))\n", - " if tracer is None:\n", - " tracer = JaxprTracer(self, raise_to_shaped(get_aval(val)))\n", - " self.builder.add_const(tracer, val)\n", - " return tracer\n", - " pure = lift = get_or_make_const_tracer\n", - "\n", - " def process_primitive(self, primitive, tracers, params):\n", - " avals_in = [t.aval for t in tracers]\n", - " aval_out = abstract_eval_rules[primitive](*avals_in, **params)\n", - " out_tracer = JaxprTracer(self, aval_out)\n", - " inputs = [self.builder.getvar(t) for t in tracers]\n", - " outvar = self.builder.add_var(out_tracer)\n", - " self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvar))\n", - " return out_tracer\n", - "\n", - " @property\n", - " def builder(self):\n", - " return self.main.global_data\n", - "\n", - "# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance\n", - "abstract_eval_rules = {}" + "@jit\n", + "def f(x, y):\n", + " print('tracing!')\n", + " return sin(x) * cos(y)" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Notice that we keep as interpreter-global data a builder object, which keeps\n", - "track of variables, constants, and eqns as we build up the jaxpr." + "z = f(3., 4.) # 'tracing!' prints the first time\n", + "print(z)" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "class JaxprBuilder:\n", - " eqns: List[JaxprEqn]\n", - " tracer_to_var: Dict[int, Var]\n", - " const_tracers: Dict[int, JaxprTracer]\n", - " constvals: Dict[Var, Any]\n", - "\n", - " def __init__(self):\n", - " self.eqns = []\n", - " self.tracer_to_var = {}\n", - " self.const_tracers = {}\n", - " self.constvals = {}\n", - "\n", - " def add_eqn(self, eqn: JaxprEqn) -> None:\n", - " self.eqns.append(eqn)\n", - "\n", - " def add_var(self, tracer: JaxprTracer) -> Var:\n", - " var = self.tracer_to_var.get(id(tracer))\n", - " assert var is None\n", - " var = self.tracer_to_var[id(tracer)] = Var(tracer.aval)\n", - " return var\n", + "z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit!\n", + "print(z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "@jit\n", + "def f(x):\n", + " return reduce_sum(x, axis=0)\n", "\n", - " def getvar(self, tracer: JaxprTracer) -> Var:\n", - " var = self.tracer_to_var.get(id(tracer))\n", - " assert var is not None\n", - " return var\n", + "print(f(np.array([1., 2., 3.])))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 1 + }, + "outputs": [], + "source": [ + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z\n", "\n", - " def add_const(self, tracer: JaxprTracer, val: Any) -> Var:\n", - " var = self.add_var(tracer)\n", - " self.const_tracers[id(val)] = tracer\n", - " self.constvals[var] = val\n", - " return var\n", + "def deriv(f):\n", + " return lambda x: jvp(f, (x,), (1.,))[1]\n", "\n", - " def build(self, in_tracers: List[JaxprTracer], out_tracer: JaxprTracer\n", - " ) -> Tuple[Jaxpr, List[Any]]:\n", - " constvars, constvals = unzip2(self.constvals.items())\n", - " t2v = lambda t: self.tracer_to_var[id(t)]\n", - " in_binders = constvars + [t2v(t) for t in in_tracers]\n", - " jaxpr = Jaxpr(in_binders, self.eqns, t2v(out_tracer))\n", - " typecheck_jaxpr(jaxpr)\n", - " return jaxpr, constvals" + "print( deriv(deriv(f))(3.))\n", + "print(jit(deriv(deriv(f)))(3.))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The rules we need for `JaxprTrace.process_primitive` are essentially typing\n", - "rules for primitive applications: given the primitive, its parameters, and\n", - "types for the inputs, the rule must produce a type for the output, which is\n", - "then packaged with the output `JaxprTracer`. We can use abstract evaluation\n", - "rules for this same purpose, even though they can be more general (since\n", - "abstract evaluation rules need to work on ConcreteArray inputs as well). We'll\n", - "reuse these abstract evaluation rules for the other jaxpr-producing trace\n", - "machinery, where the potential extra generality is useful." + "Instead of implementing `jit` to first trace to a jaxpr and then to lower the\n", + "jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step\n", + "and just lowered to HLO while tracing. That is, perhaps we could have instead\n", + "implemented `jit` with a `Trace` and `Tracer` that appended to the XLA HLO\n", + "graph incrementally on each primitive bind. That's correct for now, but won't\n", + "be possible when we introduce compiled SPMD computations because there we must\n", + "know the number of replicas needed before compiling the program." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We haven't yet defined any transformation rules for `xla_call_p` other than\n", + "its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or\n", + "`jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the \"top\n", + "level.\" Let's fix that!" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def broadcast_shapes(*shapes):\n", - " assert len(shapes) > 1\n", - " for sizes in zip(*shapes):\n", - " sizes = [d for d in sizes if d != 1]\n", - " if sizes[:-1] != sizes[1:]:\n", - " raise Exception\n", - " return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes))" + "def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts):\n", + " del num_consts # Unused.\n", + " new_jaxpr, new_consts = jvp_jaxpr(jaxpr)\n", + " outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr,\n", + " num_consts=len(new_consts))\n", + " n = len(outs) // 2\n", + " primals_out, tangents_out = outs[:n], outs[n:]\n", + " return primals_out, tangents_out\n", + "jvp_rules[xla_call_p] = xla_call_jvp_rule" ] }, { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "lines_to_next_cell": 1 + }, "outputs": [], "source": [ - "def broadcasting_binop_abstract_eval_rule(*avals_in):\n", - " out_dtype = np.result_type(*map(np.result_type, avals_in))\n", - " out_shape = broadcast_shapes(*map(np.shape, avals_in))\n", - " return ShapedArray(out_shape, out_dtype)\n", - "\n", - "abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule\n", - "abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule\n", - "\n", - "def vectorized_unop_abstract_eval_rule(aval_in):\n", - " return ShapedArray(np.shape(aval_in), np.result_type(aval_in))\n", - "\n", - "abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule\n", - "abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule\n", - "abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule\n", - "\n", - "def reduce_sum_abstract_eval_rule(aval_in, *, axis):\n", - " new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis]\n", - " return ShapedArray(tuple(new_shape), aval_in.dtype)\n", - "abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule" + "def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]:\n", + " def jvp_traceable(*primals_and_tangents):\n", + " n = len(primals_and_tangents) // 2\n", + " primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:]\n", + " return jvp(jaxpr_as_fun(jaxpr), primals, tangents)\n", + "\n", + " in_avals = [v.aval for v in jaxpr.in_binders]\n", + " new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals)\n", + " return new_jaxpr, new_consts" ] }, { - "cell_type": "markdown", - "metadata": {}, + "cell_type": "code", + "execution_count": null, + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], "source": [ - "To check our implementation, we can add a `make_jaxpr` transformation and\n", - "first pretty-printer:" + "def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts):\n", + " del num_consts # Unused.\n", + " new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in)\n", + " outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr,\n", + " num_consts=len(new_consts))\n", + " return outs, [0] * len(outs)\n", + "vmap_rules[xla_call_p] = xla_call_vmap_rule\n", + "\n", + "def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis]\n", + " ) -> Tuple[Jaxpr, List[Any]]:\n", + " vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in))\n", + " in_avals = [unmapped_aval(axis_size, d, v.aval)\n", + " for v, d in zip(jaxpr.in_binders, bdims_in)]\n", + " new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals)\n", + " return new_jaxpr, new_consts\n", + "\n", + "def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray\n", + " ) -> ShapedArray:\n", + " if batch_dim is not_mapped:\n", + " return aval\n", + " else:\n", + " shape = list(aval.shape)\n", + " shape.insert(batch_dim, axis_size)\n", + " return ShapedArray(tuple(shape), aval.dtype)" ] }, { "cell_type": "code", "execution_count": null, - "id": "defensive-ownership", "metadata": { - "lines_to_next_cell": 1 + "lines_to_end_of_cell_marker": 0, + "lines_to_next_cell": 2 }, "outputs": [], "source": [ - "def make_jaxpr(f, avals_in):\n", - " builder = JaxprBuilder()\n", - " with new_main(JaxprTrace, builder) as main:\n", - " trace = JaxprTrace(main)\n", - " tracers_in = [trace.new_arg(aval) for aval in avals_in]\n", - " out = f(*tracers_in)\n", - " tracer_out = full_raise(trace, out)\n", - " return builder.build(tracers_in, tracer_out)" + "@jit\n", + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z\n", + "\n", + "x, xdot = 3., 1.\n", + "y, ydot = jvp(f, (x,), (xdot,))\n", + "print(y)\n", + "print(ydot)\n", + "\n", + "ys = vmap(f, (0,))(np.arange(3.))\n", + "print(ys)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One piece missing is device memory persistence for arrays. That is, we've\n", + "defined `handle_result` to transfer results back to CPU memory as NumPy\n", + "arrays, but it's often preferrable to avoid transferring results just to\n", + "transfer them back for the next operation. We can do that by introducing a\n", + "`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type\n", + "`numpy.ndarray`s:" ] }, { "cell_type": "code", "execution_count": null, - "id": "adopted-month", "metadata": { "lines_to_next_cell": 1 }, "outputs": [], "source": [ - "from collections import defaultdict\n", - "import itertools as it\n", - "import string\n", - "\n", - "class PPrint:\n", - " lines: List[Tuple[int, str]]\n", - "\n", - " def __init__(self, lines):\n", - " self.lines = lines\n", - "\n", - " def indent(self, indent: int) -> 'PPrint':\n", - " return PPrint([(indent + orig_indent, s) for orig_indent, s in self.lines])\n", - "\n", - " def __add__(self, rhs: 'PPrint') -> 'PPrint':\n", - " return PPrint(self.lines + rhs.lines)\n", - "\n", - " def __rshift__(self, rhs: 'PPrint') -> 'PPrint':\n", - " if not rhs.lines: return self\n", - " if not self.lines: return rhs\n", - " indent, s = self.lines[-1]\n", - " indented_block = rhs.indent(indent + len(s))\n", - " common_line = s + ' ' * rhs.lines[0][0] + rhs.lines[0][1]\n", - " return PPrint(self.lines[:-1]\n", - " + [(indent, common_line)]\n", - " + indented_block.lines[1:])\n", - "\n", - " def __str__(self) -> str:\n", - " return '\\n'.join(' ' * indent + s for indent, s in self.lines)\n", - "\n", - "def pp(s: Any) -> PPrint:\n", - " return PPrint([(0, line) for line in str(s).splitlines()])\n", + "def handle_result(aval: ShapedArray, buf):\n", + " return DeviceArray(aval, buf)\n", "\n", - "def vcat(ps: List[PPrint]) -> PPrint:\n", - " return sum(ps, pp(''))\n", + "class DeviceArray:\n", + " buf: Any\n", + " aval: ShapedArray\n", "\n", - "def pp_jaxpr(jaxpr: Jaxpr):\n", - " namegen = (''.join(s) for r in it.count(1)\n", - " for s in it.permutations(string.ascii_lowercase, r))\n", - " names = defaultdict(lambda: next(namegen))\n", - " in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders)\n", - " eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns])\n", - " out = names[jaxpr.out] if isinstance(jaxpr.out, Var) else str(jaxpr.out.val)\n", - " return (pp(f'{{ lambda {in_binders} .') +\n", - " ((pp('let ') >> eqns) + pp(f'in {out} }}')).indent(2))\n", + " def __init__(self, aval, buf):\n", + " self.aval = aval\n", + " self.buf = buf\n", "\n", - "def var_str(names: Dict[Var, str], v: Var) -> str:\n", - " return f'{names[v]}:{v.aval.str_short()}'\n", + " dtype = property(lambda self: self.aval.dtype)\n", + " shape = property(lambda self: self.aval.shape)\n", + " ndim = property(lambda self: self.aval.ndim)\n", "\n", - "def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint:\n", - " lhs = pp(var_str(names, eqn.out_binder))\n", - " rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >>\n", - " pp(' '.join(names[x] if isinstance(x, Var) else str(x.val)\n", - " for x in eqn.inputs)))\n", - " return lhs >> pp(' = ') >> rhs\n", + " def __array__(self): return self.buf.to_py()\n", + " def __repr__(self): return repr(self.buf.to_py())\n", + " def __str__(self): return str(self.buf.to_py())\n", "\n", - "def pp_params(params: Dict[str, Any]) -> PPrint:\n", - " items = sorted(params.items())\n", - " if items:\n", - " return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ')\n", - " else:\n", - " return pp(' ')" + " _neg = staticmethod(neg)\n", + " _add = staticmethod(add)\n", + " _radd = staticmethod(add)\n", + " _mul = staticmethod(mul)\n", + " _rmul = staticmethod(mul)\n", + " _gt = staticmethod(greater)\n", + "input_handlers[DeviceArray] = lambda x: x.buf" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "lines_to_next_cell": 1 - }, + "metadata": {}, "outputs": [], "source": [ - "jaxpr, consts = make_jaxpr(lambda x: 2. * x, [raise_to_shaped(get_aval(3.))])\n", - "print(pp_jaxpr(jaxpr))\n", - "print(typecheck_jaxpr(jaxpr))" + "@jit\n", + "def f(x):\n", + " y = sin(x) * 2.\n", + " z = - y + x\n", + " return z\n", + "\n", + "x, xdot = 3., 1.\n", + "y, ydot = jvp(f, (x,), (xdot,))\n", + "print(y)\n", + "print(ydot)" ] } ], diff --git a/docs/autodidax.md b/docs/autodidax.md index 01bee7e9dd14..9ca94e269c91 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -28,15 +28,24 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + --- ``` +```{code-cell} ipython3 +# TODO remove me +import pdb, sys, traceback +def info(type, value, tb): + traceback.print_exception(type, value, tb) + pdb.pm() +sys.excepthook = info +``` + # Autodidax: JAX core from scratch -Ever want to learn how JAX works, but the implementation seemed too -impenetrable? Well, you're in luck! By reading this tutorial, you'll learn -every big idea in JAX's core system. You'll even get clued into our weird -jargon! +Ever want to learn how JAX works, but the implementation seemed impenetrable? +Well, you're in luck! By reading this tutorial, you'll learn every big idea in +JAX's core system. You'll even get clued into our weird jargon! +++ @@ -46,7 +55,7 @@ We want to transform functions that look like this: ```python def f(x): - y = sin(x) * 2 + y = sin(x) * 2. z = - y + x return z ``` @@ -56,14 +65,13 @@ infix operators (`mul`, `add`, and `neg`) as primitive operations, meaning atomic units of processing rather than compositions. "Transform" means "interpret differently." Instead of standard interpretation -where we apply primitive functions to numerical inputs to produce numerical +where we apply primitive operations to numerical inputs to produce numerical outputs, we want to override primitive application and let different values flow through our program. For example, we might want to replace the application of every primitive with an application of [its JVP rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), -and let primal-tangent pairs flow through our program. Moreover, we want to -apply a composition of multiple transformations, leading to stacks of -interpreters. +and let primal-tangent pairs flow through our program. Moreover, we want to be +able to comopse multiple transformations, leading to stacks of interpreters. +++ @@ -86,14 +94,22 @@ sin_p = Primitive("sin") cos_p = Primitive("cos") reduce_sum_p = Primitive("reduce_sum") greater_p = Primitive("greater") - -def add(x, y): return bind(add_p, x, y) -def mul(x, y): return bind(mul_p, x, y) -def neg(x): return bind(neg_p, x) -def sin(x): return bind(sin_p, x) -def cos(x): return bind(cos_p, x) -def reduce_sum(x, axis=None): return bind(reduce_sum_p, x, axis=axis) -def greater(x, y): return bind(greater_p, x, y) +transpose_p = Primitive("transpose") +broadcast_p = Primitive("broadcast") + +def add(x, y): return bind1(add_p, x, y) +def mul(x, y): return bind1(mul_p, x, y) +def neg(x): return bind1(neg_p, x) +def sin(x): return bind1(sin_p, x) +def cos(x): return bind1(cos_p, x) +def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis) +def greater(x, y): return bind1(greater_p, x, y) +def transpose(x, perm): return bind1(transpose_p, perm=perm) +def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes) + +def bind1(prim, *args, **params): + out, = bind(prim, *args, **params) + return out ``` We'll set up array data types and infix operator methods in a moment. @@ -122,14 +138,21 @@ more descriptive. ```{code-cell} ipython3 from contextlib import contextmanager from typing import Type, List, Optional, Any +``` +```{code-cell} ipython3 class MainTrace(NamedTuple): level: int trace_type: Type['Trace'] global_data: Optional[Any] +``` +```{code-cell} ipython3 trace_stack: List[MainTrace] = [] +dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 +``` +```{code-cell} ipython3 @contextmanager def new_main(trace_type: Type['Trace'], global_data=None): level = len(trace_stack) @@ -142,14 +165,13 @@ def new_main(trace_type: Type['Trace'], global_data=None): trace_stack.pop() ``` -When we're about to apply a transformed function, we'll push another -interpreter onto the stack using `new_main`. Then, as we apply primitives in -the function, we can think of the `bind` first being interpreted by the trace -at the top of the stack (i.e. with the highest level). If that first -interpreter itself binds other primitives in its interpretation rule for the -primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, -then those `bind` calls will be handled by the interpreter at the next level -down. +When we're about to apply a transformation, we'll push another interpreter +onto the stack using `new_main`. Then, as we apply primitives in the function, +we can think of the `bind` first being interpreted by the trace at the top of +the stack (i.e. with the highest level). If that first interpreter itself +binds other primitives in its interpretation rule for the primitive, like how +the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind` +calls will be handled by the interpreter at the next level down. What goes at the bottom of the interpreter stack? At the bottom, we know all the transformation interpreters are finished, and we just want to do standard @@ -193,7 +215,9 @@ like arrays.) ```{code-cell} ipython3 import numpy as np from typing import Tuple +``` +```{code-cell} ipython3 class Tracer: _trace: Trace @@ -220,7 +244,9 @@ class Tracer: return getattr(self.aval, name) except AttributeError: raise AttributeError(f"{self.__class__.__name__} has no attribute {name}") +``` +```{code-cell} ipython3 class ShapedArray: array_abstraction_level = 1 shape: Tuple[int] @@ -252,6 +278,15 @@ class ShapedArray: def str_short(self): return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]' + def __hash__(self): + return hash((self.shape, self.dtype)) + + def __eq__(self, other): + return (type(self) is type(other) and + self.shape == other.shape and self.dtype == other.dtype) +``` + +```{code-cell} ipython3 class ConcreteArray(ShapedArray): array_abstraction_level = 2 val: np.ndarray @@ -268,7 +303,9 @@ class ConcreteArray(ShapedArray): @staticmethod def _nonzero(tracer): return bool(tracer.aval.val) +``` +```{code-cell} ipython3 def get_aval(x): if isinstance(x, Tracer): return x.aval @@ -281,20 +318,19 @@ different levels of abstraction. A `ShapedArray` represents the set of all possible arrays with a given shape and dtype. A `ConcreteArray` represents a singleton set consisting of a single array value. -Now that we've set up the trace stack, the Trace/Tracer API for interpreters, -and abstract values, we can come back to implement `bind`: +Now that we've set up the interpreter stack, the Trace/Tracer API for +interpreters, and abstract values, we can come back to implement `bind`: ```{code-cell} ipython3 def bind(prim, *args, **params): top_trace = find_top_trace(args) tracers = [full_raise(top_trace, arg) for arg in args] - out = top_trace.process_primitive(prim, tracers, params) - return full_lower(out) + outs = top_trace.process_primitive(prim, tracers, params) + return [full_lower(out) for out in outs] ``` The main action is that we call `find_top_trace` to figure out which -interpreter should handle this primitive application as a function of the -arguments and the active traces on the trace stack. We then call that top +interpreter should handle this primitive application. We then call that top trace's `process_primitive` so that the trace can apply its interpretation rule. The calls to `full_raise` just ensure that the inputs are boxed in the top trace's `Tracer` instances, and the call to `full_lower` is an optional @@ -302,28 +338,45 @@ optimization so that we unbox values out of `Tracer`s as much as possible. ```{code-cell} ipython3 from operator import attrgetter +``` +```{code-cell} ipython3 def find_top_trace(xs) -> Trace: top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), default=trace_stack[0], key=attrgetter('level')) + if dynamic_trace and dynamic_trace.level > top_main.level: + top_main = dynamic_trace return top_main.trace_type(top_main) ``` -In words, `find_top_trace` returns the highest-level interpreter associated -with the `Tracer`s on its inputs, and otherwise returns the interpreter at the -bottom of the stack (which is always an evaluation trace, at least for now). -This corresponds to JAX transformations mostly working by data dependence -_except_ for the special bottom-of-the-stack interpreter, which interprets -everything. +In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace` +returns the highest-level interpreter associated with the `Tracer`s on its +inputs, and otherwise returns the interpreter at the bottom of the stack +(which is always an evaluation trace, at least for now). This is a deviation +from the description above, where we always start by running the interpreter +at the top of the stack and then work our way down, applying every interpreter +in the stack. Instead, we're only applying an interpreter when the input +arguments to a primitive bind are boxed in a `Tracer` corresponding to that +interpreter. This optimization lets us skip irrelevant transformations, but +bakes in an assumption that transformations mostly follow data dependence +(except for the special bottom-of-the-stack interpreter, which interprets +everything). + +An alternative would be to have every interpreter in the stack interpret every +operation. That's worth exploring! JAX is designed around data dependence in +large part because that's so natural for automatic differentiation, and JAX's +roots are in autodiff. But it may be over-fit. ```{code-cell} ipython3 -def full_lower(val): +def full_lower(val: Any): if isinstance(val, Tracer): return val.full_lower() else: return val +``` -def full_raise(trace, val) -> Tracer: +```{code-cell} ipython3 +def full_raise(trace: Trace, val: Any) -> Tracer: if not isinstance(val, Tracer): return trace.pure(val) level = trace.main.level @@ -359,27 +412,43 @@ class EvalTrace(Trace): def process_primitive(self, primitive, tracers, params): return impl_rules[primitive](*tracers, **params) +``` +```{code-cell} ipython3 trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack +``` +```{code-cell} ipython3 impl_rules = {} -impl_rules[add_p] = np.add -impl_rules[mul_p] = np.multiply -impl_rules[neg_p] = np.negative -impl_rules[sin_p] = np.sin -impl_rules[cos_p] = np.cos -impl_rules[reduce_sum_p] = np.sum -impl_rules[greater_p] = np.greater +``` + +```{code-cell} ipython3 +impl_rules[add_p] = lambda x, y: [np.add(x, y)] +impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)] +impl_rules[neg_p] = lambda x: [np.negative(x)] +impl_rules[sin_p] = lambda x: [np.sin(x)] +impl_rules[cos_p] = lambda x: [np.cos(x)] +impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)] +impl_rules[greater_p] = lambda x, y: [np.greater(x, y)] +impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)] +``` + +```{code-cell} ipython3 +def broadcast_impl(x, *, shape, axes): + return [np.broadcast_to(np.expand_dims(x, axes), shape)] +impl_rules[broadcast_p] = broadcast_impl ``` With this interpreter, we can evaluate user functions: ```{code-cell} ipython3 def f(x): - y = sin(x) * 2 + y = sin(x) * 2. z = - y + x return z +``` +```{code-cell} ipython3 print(f(3.0)) ``` @@ -390,12 +459,14 @@ that now we can add some real transformations. ### Forward-mode autodiff with `jvp` -First, a couple of helper functions: +First, a few helper functions: ```{code-cell} ipython3 def zeros_like(val): return np.zeros_like(val) +``` +```{code-cell} ipython3 def unzip2(pairs): lst1, lst2 = [], [] for x1, x2 in pairs: @@ -404,6 +475,12 @@ def unzip2(pairs): return lst1, lst2 ``` +```{code-cell} ipython3 +map_ = map +def map(f, *xs): + return list(map_(f, *xs)) +``` + The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The `Trace` applies JVP rules. @@ -417,68 +494,82 @@ class JVPTracer(Tracer): @property def aval(self): return get_aval(self.primal) +``` +```{code-cell} ipython3 class JVPTrace(Trace): pure = lift = lambda self, val: JVPTracer(self, val, zeros_like(val)) def process_primitive(self, primitive, tracers, params): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) jvp_rule = jvp_rules[primitive] - primal_out, tangent_out = jvp_rule(primals_in, tangents_in, **params) - return JVPTracer(self, primal_out, tangent_out) + primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params) + return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)] +``` +```{code-cell} ipython3 jvp_rules = {} ``` Notice both `lift` and `sublift` package a value into a `JVPTracer` with the minimal amount of context, which is a zero tangent value. -+++ - Let's add some JVP rules for primitives: ```{code-cell} ipython3 def add_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents - return x + y, x_dot + y_dot + return [x + y], [x_dot + y_dot] jvp_rules[add_p] = add_jvp +``` +```{code-cell} ipython3 def mul_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents - return x * y, x_dot * y + x * y_dot + return [x * y], [x_dot * y + x * y_dot] jvp_rules[mul_p] = mul_jvp +``` +```{code-cell} ipython3 def sin_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return sin(x), cos(x) * x_dot + return [sin(x)], [cos(x) * x_dot] jvp_rules[sin_p] = sin_jvp +``` +```{code-cell} ipython3 def cos_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return cos(x), -sin(x) * x_dot + return [cos(x)], [-sin(x) * x_dot] jvp_rules[cos_p] = cos_jvp +``` +```{code-cell} ipython3 def neg_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return neg(x), neg(x_dot) + return [neg(x)], [neg(x_dot)] jvp_rules[neg_p] = neg_jvp +``` +```{code-cell} ipython3 def reduce_sum_jvp(primals, tangents, *, axis): (x,), (x_dot,) = primals, tangents - return reduce_sum(x, axis), reduce_sum(x_dot, axis) + return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)] jvp_rules[reduce_sum_p] = reduce_sum_jvp +``` +```{code-cell} ipython3 def greater_jvp(primals, tangents): (x, y), _ = primals, tangents out_primal = greater(x, y) - return out_primal, zeros_like(out_primal) + return [out_primal], [zeros_like(out_primal)] jvp_rules[greater_p] = greater_jvp ``` Finally, we add a transformation API to kick off the trace: ```{code-cell} ipython3 -def jvp(f, primals, tangents): +def jvp_v1(f, primals, tangents): with new_main(JVPTrace) as main: trace = JVPTrace(main) tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)] @@ -492,26 +583,26 @@ And with that, we can differentiate! ```{code-cell} ipython3 x = 3.0 -y, sin_deriv_at_3 = jvp(sin, (x,), (1.0,)) +y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,)) print(sin_deriv_at_3) print(cos(3.0)) ``` ```{code-cell} ipython3 def f(x): - y = sin(x) * 2 + y = sin(x) * 2. z = - y + x return z x, xdot = 3., 1. -y, ydot = jvp(f, (x,), (xdot,)) +y, ydot = jvp_v1(f, (x,), (xdot,)) print(y) print(ydot) ``` ```{code-cell} ipython3 def deriv(f): - return lambda x: jvp(f, (x,), (1.,))[1] + return lambda x: jvp_v1(f, (x,), (1.,))[1] print(deriv(sin)(3.)) print(deriv(deriv(sin))(3.)) @@ -530,6 +621,157 @@ print(deriv(f)(3.)) print(deriv(f)(-3.)) ``` +## Pytrees and flattening user functions' inputs and outputs + ++++ + +A limitation with `jvp_v1` is that it assumes the user function accepts arrays +as positional arguments and produces a single array as output. What if it +produced a list as output? Or accepted nested containers as inputs? It would +be a pain to deal with all the possible containers in inputs and outputs at +every layer of the stack. Instead, we can wrap the user function so that the +wrapped version accepts arrays as inputs and returns a flat list of arrays as +output. The wrapper just needs to unflatten its input, call the user function, +and flatten the output. + +Here's how we'd like to write `jvp`, assuming the user always gives us +functions that take arrays as inputs and produces a flat list of arrays as +outputs: + +```{code-cell} ipython3 +def jvp_flat(f, primals, tangents): + with new_main(JVPTrace) as main: + trace = JVPTrace(main) + tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out) + return primals_out, tangents_out +``` + +To support user functions that have arbitrary containers in the inputs and +outputs, here's how we'd write the user-facing `jvp` wrapper: + +```{code-cell} ipython3 +def jvp(f, primals, tangents): + primals_flat, in_tree = tree_flatten(primals) + tangents_flat, in_tree2 = tree_flatten(tangents) + if in_tree != in_tree2: raise TypeError + f, out_tree = flatten_fun(f, in_tree) + primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat) + primals_out = tree_unflatten(out_tree(), primals_out_flat) + tangents_out = tree_unflatten(out_tree(), tangents_out_flat) + return primals_out, tangents_out +``` + +Notice that we had to plumb the tree structure of the user function output +back to the caller of `flatten_fun`. That information isn't available until we +actually run the user function, so `flatten_fun` just returns a reference to a +mutable cell, represented as a thunk. These side-effects are safe because we +always run the user function exactly once. (This safe regime is the reason for +the "linear" name in `linear_util.py`, in the sense of [linear +types](https://en.wikipedia.org/wiki/Substructural_type_system).) + +All that remains is to write `tree_flatten`, `tree_unflatten`, and +`flatten_fun`: + +```{code-cell} ipython3 +def flatten_fun(f, in_tree): + store = Store() + + def flat_fun(*args_flat): + pytree_args = tree_unflatten(in_tree, args_flat) + out = f(*pytree_args) + out_flat, out_tree = tree_flatten(out) + store.set_value(out_tree) + return out_flat + + return flat_fun, store +``` + +```{code-cell} ipython3 +class Empty: pass +empty = Empty() +``` + +```{code-cell} ipython3 +class Store: + val = empty + + def set_value(self, val): + assert self.val is empty + self.val = val + + def __call__(self): + return self.val +``` + +```{code-cell} ipython3 +import itertools as it +from typing import Callable, Type, Hashable, Dict, Iterable, Iterator + +class NodeType(NamedTuple): + to_iterable: Callable + from_iterable: Callable + +node_types: Dict[Type, NodeType] = { + tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)), + list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)), + dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))), + lambda keys, vals: dict(zip(keys, vals))), +} + +class PyTreeDef(NamedTuple): + node_type: NodeType + node_metadata: Hashable + child_treedefs: Tuple['PyTreeDef'] + +class Leaf: pass +leaf = Leaf() + +def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]: + children_iter, treedef = _tree_flatten(x) + return list(children_iter), treedef + +def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]: + node_type = node_types.get(type(x)) + if node_type: + node_metadata, children = node_type.to_iterable(x) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees)) + else: + return [x], leaf + +def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any: + return _tree_unflatten(treedef, iter(xs)) + +def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any: + if treedef is leaf: + return next(xs) + else: + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) +``` + +With this pytree-handling `jvp` impelmentation, we can now handle arbitrary +input and output containers. That'll come in handy with future transformations +too! + +```{code-cell} ipython3 +def f(x): + y = sin(x) * 2. + z = - y + x + return {'hi': z, 'there': [x, y]} +``` + +```{code-cell} ipython3 +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot) +``` + ### Vectorized batching with `vmap` First, a couple helper functions, one for producing mapped abstract values @@ -541,14 +783,25 @@ def mapped_aval(batch_dim, aval): shape = list(aval.shape) del shape[batch_dim] return ShapedArray(tuple(shape), aval.dtype) +``` +```{code-cell} ipython3 def move_batch_axis(axis_size, src, dst, x): if src is not_mapped: target_shape = list(np.shape(x)) target_shape.insert(dst, axis_size) - return np.broadcast_to(np.expand_dims(x, dst), target_shape) + return broadcast(x, target_shape, [dst]) + elif src == dst: + return x else: - return np.moveaxis(x, src, dst) + return moveaxis(x, src, dst) +``` + +```{code-cell} ipython3 +def moveaxis(x, src: int, dst: int): + perm = [i for i in range(np.ndim(x)) if i != src] + perm.insert(dst, src) + return transpose(x, perm) ``` The `Tracer` for vectorized batching carries a batched value and an optional @@ -556,12 +809,20 @@ integer indicating which axis (if any) is the batch axis. ```{code-cell} ipython3 from typing import Union +``` +```{code-cell} ipython3 class NotMapped: pass not_mapped = NotMapped() +``` +```{code-cell} ipython3 +BatchAxis = Union[NotMapped, int] +``` + +```{code-cell} ipython3 class BatchTracer(Tracer): - def __init__(self, trace, val, batch_dim: Union[NotMapped, int]): + def __init__(self, trace, val, batch_dim: BatchAxis): self._trace = trace self.val = val self.batch_dim = batch_dim @@ -578,20 +839,24 @@ class BatchTracer(Tracer): return full_lower(self.val) else: return self +``` +```{code-cell} ipython3 class BatchTrace(Trace): pure = lift = lambda self, val: BatchTracer(self, val, not_mapped) def process_primitive(self, primitive, tracers, params): vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers) vmap_rule = vmap_rules[primitive] - val_out, bdim_out = vmap_rule(self.axis_size, vals_in, bdims_in, **params) - return BatchTracer(self, val_out, bdim_out) + val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params) + return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)] @property def axis_size(self): return self.main.global_data +``` +```{code-cell} ipython3 vmap_rules = {} ``` @@ -609,45 +874,67 @@ Next we can define batching interpreter rules for each primitive: ```{code-cell} ipython3 from functools import partial +``` +```{code-cell} ipython3 def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in): (x, y), (x_bdim, y_bdim) = vals_in, dims_in if x_bdim != y_bdim: y = move_batch_axis(axis_size, y_bdim, x_bdim, y) - return op(x, y), x_bdim + return [op(x, y)], [x_bdim] vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add) vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul) +``` +```{code-cell} ipython3 def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in): (x,), (x_bdim,) = vals_in, dims_in - return op(x), x_bdim + return [op(x)], [x_bdim] vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin) vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos) vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg) +``` +```{code-cell} ipython3 def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): (x,), (x_bdim,) = vals_in, dims_in new_axis = axis + (x_bdim <= axis) out_bdim = x_bdim - (new_axis < x_bdim) - return reduce_sum(x, new_axis), out_bdim + return [reduce_sum(x, new_axis)], [out_bdim] vmap_rules[reduce_sum_p] = reduce_sum_batching_rule ``` +- + ++++ + Finally, we add a transformation API to kick off the trace: ```{code-cell} ipython3 -def vmap(f, in_axes, out_axis): +def vmap_flat(f, in_axes, *args): + axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) + if ax is not not_mapped} + with new_main(BatchTrace, axis_size) as main: + trace = BatchTrace(main) + tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x + for x, ax in zip(args, in_axes)] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out) + outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out) + for val_out, bdim in zip(vals_out, bdims_out)] + return outs_transposed +``` + +```{code-cell} ipython3 +def vmap(f, in_axes): def batched_f(*args): - axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) - if ax is not None} - with new_main(BatchTrace, axis_size) as main: - trace = BatchTrace(main) - tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x - for x, ax in zip(args, in_axes)] - out = f(*tracers_in) - tracer_out = full_raise(trace, out) - val_out, batch_dim_out = tracer_out.val, tracer_out.batch_dim - return move_batch_axis(axis_size, batch_dim_out, out_axis, val_out) + args_flat, in_tree = tree_flatten(args) + in_axes_flat, in_tree2 = tree_flatten(in_axes) + if in_tree != in_tree2: raise TypeError + f_flat, out_tree = flatten_fun(f, in_tree) + outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat) + return tree_unflatten(out_tree(), outs_flat) return batched_f ``` @@ -657,7 +944,7 @@ def add_one_to_a_scalar(scalar): return 1 + scalar vector_in = np.arange(3.) -vector_out = vmap(add_one_to_a_scalar, (0,), 0)(vector_in) +vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in) print(vector_in) print(vector_out) @@ -667,7 +954,7 @@ print(vector_out) def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,))[1] vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2) - return vmap(pushfwd, (0,), 0)(vecs_in) + return vmap(pushfwd, (0,))(vecs_in) def f(x): return sin(x) @@ -682,37 +969,42 @@ implementation: rules, and for more complex primitives (like for convolution or advanced indexing) each rule is harder to write. But the overarching design is no different. -1. **Transformations expect arrays in, single array out.** -2. **No symbolic zeros in autodiff.** -3. **No special call primitives yet.** The core machinery needs to be +2. **No pytrees.** Transformations expect arrays in, and either a single array + out or a flat list of arrays out. +3. **Missing optimization: no symbolic zeros in autodiff.** +4. **No special call primitives yet.** The core machinery needs to be generalized to handle the most flexible kind of higher-order primitive, used by `jax.custom_jvp` and `jax.custom_vjp`. +++ -## Part 2: Jaxprs, for `jit` and `vjp` +## Part 2: Jaxprs The next transformations are the horizon are `jit` for just-in-time compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small -wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to -carry a little bit of extra context, but for both `jit` and `vjp` we need -much richer context: we need to represent _programs_. That is, we need jaxprs! - -Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are -an explicitly typed, functional, first-order language. We need a program -representation for `jit` because the purpose of `jit` is to stage computation -out of Python. For any computation we want to stage out, we need to be able to -represent it as data, and build it up as we trace a Python function. -Similarly, `vjp` needs a way to represent the computation for the backward -pass of reverse-mode autodiff. We use the same jaxpr program representation -for both needs. +wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to +carry a little bit of extra context, for both `jit` and `vjp` we need much +richer context: we need to represent _programs_. That is, we need jaxprs! + +Jaxprs are JAX's internal intermediate representation of programs. They are +explicitly typed, functional, first-order, and in ANF form. We need a +program representation for `jit` because the purpose of `jit` is to stage +computation out of Python. For any computation we want to stage out, we need +to be able to represent it as data, and build it up as we trace a Python +function. Similarly, `vjp` needs a way to represent the computation for the +backward pass of reverse-mode autodiff. We use the same jaxpr program +representation for both needs. (Building a program representation is the most [free](https://en.wikipedia.org/wiki/Free_object) kind of -trace- transformation, and so except for issues around handling native Python +trace-transformation, and so except for issues around handling native Python control flow, any transformation could be implemented by first tracing to a jaxpr and then interpreting the jaxpr.) ++++ + +### Jaxpr data strutures + The jaxpr term syntax is roughly: ``` @@ -720,20 +1012,20 @@ jaxpr ::= { lambda , ... . let ... - in } + in ( , ... ) } binder ::= : var ::= a | b | c | ... atom ::= | literal ::= | -eqn ::= = [ ] , ... +eqn ::= , ... = [ ] , ... ``` The syntax of types is: ``` -jaxpr_type ::= [, ...] -> [, ...] +jaxpr_type ::= [ , ... ] -> [ , ... ] array_type ::= [] dtype ::= f32 | f64 | i32 | i64 shape ::= , ... @@ -744,7 +1036,7 @@ represent types, and we can represent the term syntax with a few Python structs: ```{code-cell} ipython3 -from typing import Dict, Set +from typing import Set class Var: aval: ShapedArray @@ -764,47 +1056,54 @@ class JaxprEqn(NamedTuple): primitive: Primitive inputs: List[Atom] params: Dict[str, Any] - out_binder: Var + out_binders: List[Var] class Jaxpr(NamedTuple): in_binders: List[Var] eqns: List[JaxprEqn] - out: Atom - + outs: List[Atom] def raise_to_shaped(aval): return ShapedArray(aval.shape, aval.dtype) ``` +Type-checking a jaxpr involves checking that there are no unbound variables, +that variables are only bound once, and that for each equation the type of +the primitive application matches the type of the output binders. + ```{code-cell} ipython3 class JaxprType: in_types: List[ShapedArray] - out_type: ShapedArray + out_type: List[ShapedArray] - def __init__(self, in_types, out_type): + def __init__(self, in_types, out_types): self.in_types = in_types - self.out_type = out_type + self.out_types = out_types def __repr__(self): in_types = ', '.join(aval.str_short() for aval in self.in_types) - out_type = self.out_type.str_short() - return f'({in_types}) -> {out_type}' - + out_types = ', '.join(aval.str_short() for aval in self.out_types) + return f'({in_types}) -> ({out_types})' def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType: env: Set[Var] = set() for v in jaxpr.in_binders: + if v in env: raise TypeError env.add(v) for eqn in jaxpr.eqns: in_types = [typecheck_atom(env, x) for x in eqn.inputs] - out_type = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) - if not types_equal(out_type, eqn.out_binder.aval): raise TypeError - env.add(eqn.out_binder) + out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) + for out_binder, out_type in zip(eqn.out_binders, out_types): + if not types_equal(out_type, out_binder.aval): raise TypeError + for out_binder in eqn.out_binders: + if out_binder in env: raise TypeError + env.add(out_binder) - out_type = typecheck_atom(env, jaxpr.out) - return JaxprType([v.aval for v in jaxpr.in_binders], out_type) + in_types = [v.aval for v in jaxpr.in_binders] + out_types = [typecheck_atom(env, x) for x in jaxpr.outs] + return JaxprType(in_types, out_types) def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray: if isinstance(x, Var): @@ -819,11 +1118,43 @@ def types_equal(a: ShapedArray, b: ShapedArray) -> bool: return a.shape == b.shape and a.dtype == b.dtype ``` +We can apply the function represented by a jaxpr to arguments with a simple +interpreter. + +```{code-cell} ipython3 +def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: + env: Dict[Var, Any] = {} + + def read(x: Atom) -> Any: + return env[x] if type(x) is Var else x.val + + def write(v: Var, val: Any) -> None: + env[v] = val + + map(write, jaxpr.in_binders, args) + for eqn in jaxpr.eqns: + in_vals = map(read, eqn.inputs) + outs = bind(eqn.primitive, *in_vals, **eqn.params) + map(write, eqn.out_binders, outs) + return map(read, jaxpr.outs) +``` + +```{code-cell} ipython3 +def jaxpr_as_fun(jaxpr: Jaxpr): + return lambda *args: eval_jaxpr(jaxpr, args) +``` + +By using `bind` in the interpreter, this interpreter itself is traceable. + ++++ + +### Building jaxprs with tracing + Now that we have jaxprs as a data structure, we need ways to produce these from tracing Python code. In general there are two variants of how we trace to a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one -used by `jit`, which is also used by control flow primitives like -`lax.cond`, `lax.while_loop`, and `lax.scan`. +used by `jit`, which is also used by control flow primitives like `lax.cond`, +`lax.while_loop`, and `lax.scan`. ```{code-cell} ipython3 # NB: the analogous class in JAX is called 'DynamicJaxprTracer' @@ -839,26 +1170,26 @@ class JaxprTracer(Tracer): class JaxprTrace(Trace): def new_arg(self, aval: ShapedArray) -> JaxprTracer: aval = raise_to_shaped(aval) - tracer = JaxprTracer(self, aval) + tracer = self.builder.new_tracer(self, aval) self.builder.tracer_to_var[id(tracer)] = Var(aval) return tracer def get_or_make_const_tracer(self, val: Any) -> JaxprTracer: tracer = self.builder.const_tracers.get(id(val)) if tracer is None: - tracer = JaxprTracer(self, raise_to_shaped(get_aval(val))) + tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val))) self.builder.add_const(tracer, val) return tracer pure = lift = get_or_make_const_tracer def process_primitive(self, primitive, tracers, params): avals_in = [t.aval for t in tracers] - aval_out = abstract_eval_rules[primitive](*avals_in, **params) - out_tracer = JaxprTracer(self, aval_out) + avals_out = abstract_eval_rules[primitive](*avals_in, **params) + out_tracers = [self.builder.new_tracer(self, a) for a in avals_out] inputs = [self.builder.getvar(t) for t in tracers] - outvar = self.builder.add_var(out_tracer) - self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvar)) - return out_tracer + outvars = [self.builder.add_var(t) for t in out_tracers] + self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars)) + return out_tracers @property def builder(self): @@ -869,7 +1200,7 @@ abstract_eval_rules = {} ``` Notice that we keep as interpreter-global data a builder object, which keeps -track of variables, constants, and eqns as we build up the jaxpr. +track of variables, constants, and eqns as we build up the jaxpr. ```{code-cell} ipython3 class JaxprBuilder: @@ -877,19 +1208,25 @@ class JaxprBuilder: tracer_to_var: Dict[int, Var] const_tracers: Dict[int, JaxprTracer] constvals: Dict[Var, Any] + tracers: List[JaxprTracer] def __init__(self): self.eqns = [] self.tracer_to_var = {} self.const_tracers = {} self.constvals = {} + self.tracers = [] + + def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer: + tracer = JaxprTracer(trace, aval) + self.tracers.append(tracer) + return tracer def add_eqn(self, eqn: JaxprEqn) -> None: self.eqns.append(eqn) def add_var(self, tracer: JaxprTracer) -> Var: - var = self.tracer_to_var.get(id(tracer)) - assert var is None + assert id(tracer) not in self.tracer_to_var var = self.tracer_to_var[id(tracer)] = Var(tracer.aval) return var @@ -904,24 +1241,27 @@ class JaxprBuilder: self.constvals[var] = val return var - def build(self, in_tracers: List[JaxprTracer], out_tracer: JaxprTracer + def build(self, in_tracers: List[JaxprTracer], out_tracers: List[JaxprTracer] ) -> Tuple[Jaxpr, List[Any]]: constvars, constvals = unzip2(self.constvals.items()) t2v = lambda t: self.tracer_to_var[id(t)] in_binders = constvars + [t2v(t) for t in in_tracers] - jaxpr = Jaxpr(in_binders, self.eqns, t2v(out_tracer)) + out_vars = [t2v(t) for t in out_tracers] + jaxpr = Jaxpr(in_binders, self.eqns, out_vars) typecheck_jaxpr(jaxpr) return jaxpr, constvals ``` The rules we need for `JaxprTrace.process_primitive` are essentially typing -rules for primitive applications: given the primitive, its parameters, and +rules for primitive applications: given the primitive, its parameters, and types for the inputs, the rule must produce a type for the output, which is -then packaged with the output `JaxprTracer`. We can use abstract evaluation -rules for this same purpose, even though they can be more general (since -abstract evaluation rules need to work on ConcreteArray inputs as well). We'll -reuse these abstract evaluation rules for the other jaxpr-producing trace -machinery, where the potential extra generality is useful. +then packaged with the output `JaxprTracer`. We can use abstract evaluation +rules for this same purpose, even though they can be more general (since +abstract evaluation rules must accept ConcreteArray inputs, and since they +need only return an upper bound on the set of possible outputs, they can +produce ConcreteArray outputs as well). We'll reuse these abstract evaluation +rules for the other jaxpr-producing trace machinery, where the potential extra +generality is useful. ```{code-cell} ipython3 def broadcast_shapes(*shapes): @@ -931,19 +1271,17 @@ def broadcast_shapes(*shapes): if sizes[:-1] != sizes[1:]: raise Exception return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes)) -``` -```{code-cell} ipython3 def broadcasting_binop_abstract_eval_rule(*avals_in): out_dtype = np.result_type(*map(np.result_type, avals_in)) out_shape = broadcast_shapes(*map(np.shape, avals_in)) - return ShapedArray(out_shape, out_dtype) + return [ShapedArray(out_shape, out_dtype)] abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule def vectorized_unop_abstract_eval_rule(aval_in): - return ShapedArray(np.shape(aval_in), np.result_type(aval_in)) + return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))] abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule @@ -951,27 +1289,41 @@ abstract_eval_rules[neg_p] = vectorized_unop_abstract_eval_rule def reduce_sum_abstract_eval_rule(aval_in, *, axis): new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis] - return ShapedArray(tuple(new_shape), aval_in.dtype) + return [ShapedArray(tuple(new_shape), aval_in.dtype)] abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule + +def broadcast_abstract_eval(x, *, shape, axes): + return [ShapedArray(tuple(shape), np.result_type(x))] +abstract_eval_rules[broadcast_p] = broadcast_abstract_eval ``` -To check our implementation, we can add a `make_jaxpr` transformation and -first pretty-printer: +To check our implementation of jaxprs, we can add a `make_jaxpr` +transformation and a pretty-printer: ```{code-cell} ipython3 -def make_jaxpr(f, avals_in): +from functools import lru_cache +``` + +```{code-cell} ipython3 +@lru_cache() +def make_jaxpr_v1(f, *avals_in): + avals_in, in_tree = tree_flatten(avals_in) + f, out_tree = flatten_fun(f, in_tree) + builder = JaxprBuilder() with new_main(JaxprTrace, builder) as main: trace = JaxprTrace(main) tracers_in = [trace.new_arg(aval) for aval in avals_in] - out = f(*tracers_in) - tracer_out = full_raise(trace, out) - return builder.build(tracers_in, tracer_out) + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + jaxpr, consts = builder.build(tracers_in, tracers_out) + return jaxpr, consts, out_tree() ``` ```{code-cell} ipython3 +:tags: [hide-input] + from collections import defaultdict -import itertools as it import string class PPrint: @@ -1011,15 +1363,16 @@ def pp_jaxpr(jaxpr: Jaxpr): names = defaultdict(lambda: next(namegen)) in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders) eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns]) - out = names[jaxpr.out] if isinstance(jaxpr.out, Var) else str(jaxpr.out.val) + outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val) + for v in jaxpr.outs) return (pp(f'{{ lambda {in_binders} .') + - ((pp('let ') >> eqns) + pp(f'in {out} }}')).indent(2)) + ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2)) def var_str(names: Dict[Var, str], v: Var) -> str: return f'{names[v]}:{v.aval.str_short()}' def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint: - lhs = pp(var_str(names, eqn.out_binder)) + lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders)) rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >> pp(' '.join(names[x] if isinstance(x, Var) else str(x.val) for x in eqn.inputs))) @@ -1034,7 +1387,466 @@ def pp_params(params: Dict[str, Any]) -> PPrint: ``` ```{code-cell} ipython3 -jaxpr, consts = make_jaxpr(lambda x: 2. * x, [raise_to_shaped(get_aval(3.))]) +jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.))) print(pp_jaxpr(jaxpr)) print(typecheck_jaxpr(jaxpr)) ``` + +But there's a limitation here: because of how `find_top_trace` operates by +data dependence, `make_jaxpr_v1` can't stage out all the primitive operations +performed by the Python callable it's given. For example: + +```{code-cell} ipython3 +jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.)) +print(pp_jaxpr(jaxpr)) +``` + +This is precisely the issue that +[omnistaging](https://github.com/google/jax/pull/3370) fixed. +We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always +applied, regardless of whether any inputs to `bind` are boxed in corresponding +`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` +global defined in Part 1: + +```{code-cell} ipython3 +@contextmanager +def new_dynamic(main: MainTrace): + global dynamic_trace + prev_dynamic_trace, dynamic_trace = dynamic_trace, main + try: + yield + finally: + dynamic_trace = prev_dynamic_trace +``` + +```{code-cell} ipython3 +@lru_cache() # ShapedArrays are hashable +def make_jaxpr(f, *avals_in): + avals_in, in_tree = tree_flatten(avals_in) + f, out_tree = flatten_fun(f, in_tree) + + builder = JaxprBuilder() + with new_main(JaxprTrace, builder) as main: + with new_dynamic(main): + trace = JaxprTrace(main) + tracers_in = [trace.new_arg(aval) for aval in avals_in] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + jaxpr, consts = builder.build(tracers_in, tracers_out) + return jaxpr, consts, out_tree() +``` + +```{code-cell} ipython3 +jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.)) +print(pp_jaxpr(jaxpr)) +``` + +Using `dynamic_trace` this way is conceptually the same as stashing the +current interpreter stack and starting a new one with the `JaxprTrace` at the +bottom. That is, no interpreters lower in the stack than the `dynamic_trace` +are applied (since `JaxprTrace.process_primitive` doesn't call `bind`), though +if the Python callable being traced to a jaxpr itself uses transformations +then those can be pushed onto the interpreter stack above the `JaxprTrace`. +But temporarily stashing the interpreter stack would break up the system +state. The `dynamic_trace` tag achieves the same goals while keeping the +system state simpler. + ++++ + +That's it for jaxprs! With jaxprs in hand, we can implement the remaining +major JAX features. But before moving on, let's highlight some +simplifications we've made: +1. **Single-output primitives and jaxprs.** + ++++ + +## Part 3: `jit`, simplified + +While `jit` has a transformation-like API in that it accepts a Python callable +as an argument, under the hood it's really a higher-order primitive rather +than a transformation. A primitive is _higher-order_ when it's parameterized +by a function. + ++++ + +### "Final style" and "initial style" + +There are two options for how to handle higher-order primitives. Each requires +a different approach to tracing and engenders different tradeoffs: +1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr + until as late as possible, namely until we're running the final interpreter + at the bottom of the interpreter stack. That way we can swap a `JaxprTrace` + in at the bottom of the interpreter stack and thus stage out rather than + execute all primitive operations. With this approach, transformations in + the stack get applied as we execute the Python callable as usual. This + approach can be very tricky to implement, but it's as general as possible + because it allows higher-order primitives not to raise the abstraction + level of their arguments and thus allows data-dependent Python control + flow. We refer to this approach as using a "final-style higher-order + primitive" employing the discharge-at-tracing-time "final-style + transformations" we've used so far. +2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the + primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and + be done with the Python callable entirely. In this case, `make_jaxpr` puts + its `JaxprTrace` at the top of the interpreter stack, and no + transformations lower in the stack, which might enter via closed-over + Tracers, are applied to the Python callable as we trace it. + (Transformations applied within the Python callable are applied as usual, + being added to the stack above the JaxprTrace.) Instead, the + transformations lower in the stack are later applied to the call primitive, + and the call primitive's rules must then transform the jaxpr itself. + Because we trace to a jaxpr up-front, this approach can't support + data-dependent Python control flow, but it is more straightforward to + implement. We refer to this kind of higher-order primitive as an + "initial-style higher-order primitive", and say that its jaxpr-processing + transformation rules are "initial-style transformation rules." + +The latter approach fits for `jit` because we don't need to support +data-dependent Python control flow in the user-provided Python callable, as +the whole purpose of `jit` is to stage computation out of Python to be +executed by XLA. (In contrast, `custom_jvp` is a higher-order primitive in +which we want to support data-dependent Python control flow.) + +Historically, we started using the "initial-style" and "final-style" +terminology after reading the [typed tagless final +interpreters](http://okmij.org/ftp/tagless-final/index.html) paper, and +jokingly referring to JAX as an implementation of "untyped tagful final +interpreters." We don't claim to carry over (or understand) any deep meaning +behind these terms; we loosely use "initial style" to mean "build an AST and +then transform it", and we use "final style" to mean "transform as we trace." +But it's just imprecise yet sticky jargon. + ++++ + +With the initial-style approach, here's the user-facing `jit` wrapper: + +```{code-cell} ipython3 +def jit(f): + def f_jitted(*args): + avals_in = [raise_to_shaped(get_aval(x)) for x in args] + jaxpr, consts, out_tree = make_jaxpr(f, *avals_in) + outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts)) + return tree_unflatten(out_tree, outs) + return f_jitted + +xla_call_p = Primitive('xla_call') +``` + +With any new primitive, we need to give it transformation rules, starting with +its evaluation rule. When we evaluate an application of the `xla_call` +primitive, we want to stage out out the computation to XLA. That involves +translating the jaxpr to an XLA HLO program, transferring the argument values +to the XLA device, executing the XLA program, and transferring back the +results. We'll cache the XLA HLO compilation so that for each `jit`ted +function it only needs to be performed once per argument shape and dtype +signature. + +First, some utilities. + +```{code-cell} ipython3 +class IDHashable: + val: Any + + def __init__(self, val): + self.val = val + + def __hash__(self) -> int: + return id(self.val) + + def __eq__(self, other): + return type(other) is IDHashable and id(self.val) == id(other.val) +``` + +Next, we'll define the evaluation rule for `xla_call`: + +```{code-cell} ipython3 +from jax.lib import xla_bridge as xb +from jax.lib import xla_client as xc +xe = xc._xla +xops = xc._xla.ops + +def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): + consts, args = args[:num_consts], args[num_consts:] + hashable_consts = tuple(map(IDHashable, consts)) + execute = xla_callable(IDHashable(jaxpr), hashable_consts) + return execute(*args) +impl_rules[xla_call_p] = xla_call_impl + +@lru_cache() +def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]): + jaxpr: Jaxpr = hashable_jaxpr.val + consts = [x.val for x in hashable_consts] + in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]] + c = xb.make_computation_builder('xla_call') + xla_consts = _xla_consts(c, consts) + xla_params = _xla_params(c, in_avals) + outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params) + out = xops.Tuple(c, outs) + compiled = xb.get_backend(None).compile(c.build(out)) + return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) + +def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]: + unique_consts = {id(cnst): cnst for cnst in consts} + xla_consts = { + id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()} + return [xla_consts[id(cnst)] for cnst in consts] + +def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]: + return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] + +def _xla_shape(aval: ShapedArray) -> xe.Shape: + return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) +``` + +The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO +program using `jaxpr_subcomp`, then returns a callable which executes the +compiled program: + +```{code-cell} ipython3 +def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] + ) -> xe.XlaOp: + env: Dict[Var, xe.XlaOp] = {} + + def read(x: Atom) -> xe.XlaOp: + return env[x] if type(x) is Var else xb.constant(c, x.val) + + def write(v: Var, val: xe.XlaOp) -> None: + env[v] = val + + map(write, jaxpr.in_binders, args) + for eqn in jaxpr.eqns: + in_avals = [x.aval for x in eqn.inputs] + in_vals = map(read, eqn.inputs) + rule = xla_translations[eqn.primitive] + out_vals = rule(c, in_avals, in_vals, **eqn.params) + map(write, eqn.out_binders, out_vals) + return map(read, jaxpr.outs) + +def execute_compiled(compiled, out_avals, *args): + input_bufs = [input_handlers[type(x)](x) for x in args] + out_bufs = compiled.execute(input_bufs) + return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)] + +input_handlers = { + int: xb.get_backend(None).buffer_from_pyval, + float: xb.get_backend(None).buffer_from_pyval, + np.ndarray: xb.get_backend(None).buffer_from_pyval, +} + +def handle_result(aval: ShapedArray, buf): + del aval # Unused for now. + return buf.to_py() + +xla_translations = {} +``` + +Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's +a common pattern: the way we process jaxprs is usually with an interpreter. +And as with any interpreter, we need an interpretation rule for each +primitive: + +```{code-cell} ipython3 +def direct_translation(op, c, in_avals, in_vals): + del c, in_avals + return [op(*in_vals)] +``` + +```{code-cell} ipython3 +xla_translations[add_p] = partial(direct_translation, xops.Add) +xla_translations[mul_p] = partial(direct_translation, xops.Mul) +xla_translations[neg_p] = partial(direct_translation, xops.Neg) +xla_translations[sin_p] = partial(direct_translation, xops.Sin) +xla_translations[cos_p] = partial(direct_translation, xops.Cos) +xla_translations[greater_p] = partial(direct_translation, xops.Gt) +``` + +```{code-cell} ipython3 +def reduce_sum_translation(c, in_avals, in_vals, *, axis): + (x_aval,), (x,) = in_avals, in_vals + zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) + subc = xb.make_computation_builder('add') + shape = _xla_shape(ShapedArray((), x_aval.dtype)) + xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) + return [xops.Reduce(c, [x], [zero], subc.build(), [axis])] +xla_translations[reduce_sum_p] = reduce_sum_translation +``` + +```{code-cell} ipython3 +def broadcast_translation(c, in_avals, in_vals, *, shape, axes): + x, = in_vals + dims_complement = [i for i in range(len(shape)) if i not in axes] + return [xops.BroadcastInDim(x, shape, dims_complement)] +xla_translations[broadcast_p] = broadcast_translation +``` + +With that, we can now use `jit` to stage out, compile, and execute programs +with XLA! + +```{code-cell} ipython3 +@jit +def f(x, y): + print('tracing!') + return sin(x) * cos(y) +``` + +```{code-cell} ipython3 +z = f(3., 4.) # 'tracing!' prints the first time +print(z) +``` + +```{code-cell} ipython3 +z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit! +print(z) +``` + +```{code-cell} ipython3 +@jit +def f(x): + return reduce_sum(x, axis=0) + +print(f(np.array([1., 2., 3.]))) +``` + +```{code-cell} ipython3 +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +def deriv(f): + return lambda x: jvp(f, (x,), (1.,))[1] + +print( deriv(deriv(f))(3.)) +print(jit(deriv(deriv(f)))(3.)) +``` + +Instead of implementing `jit` to first trace to a jaxpr and then to lower the +jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step +and just lowered to HLO while tracing. That is, perhaps we could have instead +implemented `jit` with a `Trace` and `Tracer` that appended to the XLA HLO +graph incrementally on each primitive bind. That's correct for now, but won't +be possible when we introduce compiled SPMD computations because there we must +know the number of replicas needed before compiling the program. + ++++ + +We haven't yet defined any transformation rules for `xla_call_p` other than +its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or +`jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the "top +level." Let's fix that! + +```{code-cell} ipython3 +def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): + del num_consts # Unused. + new_jaxpr, new_consts = jvp_jaxpr(jaxpr) + outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr, + num_consts=len(new_consts)) + n = len(outs) // 2 + primals_out, tangents_out = outs[:n], outs[n:] + return primals_out, tangents_out +jvp_rules[xla_call_p] = xla_call_jvp_rule +``` + +```{code-cell} ipython3 +def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]: + def jvp_traceable(*primals_and_tangents): + n = len(primals_and_tangents) // 2 + primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:] + return jvp(jaxpr_as_fun(jaxpr), primals, tangents) + + in_avals = [v.aval for v in jaxpr.in_binders] + new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals) + return new_jaxpr, new_consts +``` + +```{code-cell} ipython3 +def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): + del num_consts # Unused. + new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in) + outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr, + num_consts=len(new_consts)) + return outs, [0] * len(outs) +vmap_rules[xla_call_p] = xla_call_vmap_rule + +def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis] + ) -> Tuple[Jaxpr, List[Any]]: + vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) + in_avals = [unmapped_aval(axis_size, d, v.aval) + for v, d in zip(jaxpr.in_binders, bdims_in)] + new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals) + return new_jaxpr, new_consts + +def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray + ) -> ShapedArray: + if batch_dim is not_mapped: + return aval + else: + shape = list(aval.shape) + shape.insert(batch_dim, axis_size) + return ShapedArray(tuple(shape), aval.dtype) +``` + +```{code-cell} ipython3 +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot) + +ys = vmap(f, (0,))(np.arange(3.)) +print(ys) +``` + +One piece missing is device memory persistence for arrays. That is, we've +defined `handle_result` to transfer results back to CPU memory as NumPy +arrays, but it's often preferrable to avoid transferring results just to +transfer them back for the next operation. We can do that by introducing a +`DeviceArray` class, which can wrap XLA buffers and otherwise duck-type +`numpy.ndarray`s: + +```{code-cell} ipython3 +def handle_result(aval: ShapedArray, buf): + return DeviceArray(aval, buf) + +class DeviceArray: + buf: Any + aval: ShapedArray + + def __init__(self, aval, buf): + self.aval = aval + self.buf = buf + + dtype = property(lambda self: self.aval.dtype) + shape = property(lambda self: self.aval.shape) + ndim = property(lambda self: self.aval.ndim) + + def __array__(self): return self.buf.to_py() + def __repr__(self): return repr(self.buf.to_py()) + def __str__(self): return str(self.buf.to_py()) + + _neg = staticmethod(neg) + _add = staticmethod(add) + _radd = staticmethod(add) + _mul = staticmethod(mul) + _rmul = staticmethod(mul) + _gt = staticmethod(greater) +input_handlers[DeviceArray] = lambda x: x.buf +``` + +```{code-cell} ipython3 +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot) +``` diff --git a/docs/autodidax.py b/docs/autodidax.py index 678db30c92f4..c8fb3d938dae 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# # jupyter: # jupytext: # formats: ipynb,md:myst,py @@ -25,12 +26,19 @@ # name: python3 # --- +# TODO remove me +import pdb, sys, traceback +def info(type, value, tb): + traceback.print_exception(type, value, tb) + pdb.pm() +sys.excepthook = info + + # # Autodidax: JAX core from scratch # -# Ever want to learn how JAX works, but the implementation seemed too -# impenetrable? Well, you're in luck! By reading this tutorial, you'll learn -# every big idea in JAX's core system. You'll even get clued into our weird -# jargon! +# Ever want to learn how JAX works, but the implementation seemed impenetrable? +# Well, you're in luck! By reading this tutorial, you'll learn every big idea in +# JAX's core system. You'll even get clued into our weird jargon! # ## Part 1: Transformations as interpreters: standard evaluation, `jvp`, and `vmap` # @@ -38,7 +46,7 @@ # # ```python # def f(x): -# y = sin(x) * 2 +# y = sin(x) * 2. # z = - y + x # return z # ``` @@ -48,14 +56,13 @@ # atomic units of processing rather than compositions. # # "Transform" means "interpret differently." Instead of standard interpretation -# where we apply primitive functions to numerical inputs to produce numerical +# where we apply primitive operations to numerical inputs to produce numerical # outputs, we want to override primitive application and let different values # flow through our program. For example, we might want to replace the # application of every primitive with an application of [its JVP # rule](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), -# and let primal-tangent pairs flow through our program. Moreover, we want to -# apply a composition of multiple transformations, leading to stacks of -# interpreters. +# and let primal-tangent pairs flow through our program. Moreover, we want to be +# able to comopse multiple transformations, leading to stacks of interpreters. # ### JAX core machinery # @@ -77,16 +84,22 @@ class Primitive(NamedTuple): cos_p = Primitive("cos") reduce_sum_p = Primitive("reduce_sum") greater_p = Primitive("greater") - -def add(x, y): return bind(add_p, x, y) -def mul(x, y): return bind(mul_p, x, y) -def neg(x): return bind(neg_p, x) -def sin(x): return bind(sin_p, x) -def cos(x): return bind(cos_p, x) -def reduce_sum(x, axis=None): return bind(reduce_sum_p, x, axis=axis) -def greater(x, y): return bind(greater_p, x, y) - - +transpose_p = Primitive("transpose") +broadcast_p = Primitive("broadcast") + +def add(x, y): return bind1(add_p, x, y) +def mul(x, y): return bind1(mul_p, x, y) +def neg(x): return bind1(neg_p, x) +def sin(x): return bind1(sin_p, x) +def cos(x): return bind1(cos_p, x) +def reduce_sum(x, axis=None): return bind1(reduce_sum_p, x, axis=axis) +def greater(x, y): return bind1(greater_p, x, y) +def transpose(x, perm): return bind1(transpose_p, perm=perm) +def broadcast(x, shape, axes): return bind1(broadcast_p, x, shape=shape, axes=axes) + +def bind1(prim, *args, **params): + out, = bind(prim, *args, **params) + return out # - # We'll set up array data types and infix operator methods in a moment. @@ -112,7 +125,6 @@ def greater(x, y): return bind(greater_p, x, y) # needs. We call each element a `MainTrace`, though maybe "Interpreter" would be # more descriptive. -# + from contextlib import contextmanager from typing import Type, List, Optional, Any @@ -122,6 +134,7 @@ class MainTrace(NamedTuple): global_data: Optional[Any] trace_stack: List[MainTrace] = [] +dynamic_trace: Optional[MainTrace] = None # to be employed in Part 3 @contextmanager def new_main(trace_type: Type['Trace'], global_data=None): @@ -134,17 +147,13 @@ def new_main(trace_type: Type['Trace'], global_data=None): finally: trace_stack.pop() - -# - - -# When we're about to apply a transformed function, we'll push another -# interpreter onto the stack using `new_main`. Then, as we apply primitives in -# the function, we can think of the `bind` first being interpreted by the trace -# at the top of the stack (i.e. with the highest level). If that first -# interpreter itself binds other primitives in its interpretation rule for the -# primitive, like how the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, -# then those `bind` calls will be handled by the interpreter at the next level -# down. +# When we're about to apply a transformation, we'll push another interpreter +# onto the stack using `new_main`. Then, as we apply primitives in the function, +# we can think of the `bind` first being interpreted by the trace at the top of +# the stack (i.e. with the highest level). If that first interpreter itself +# binds other primitives in its interpretation rule for the primitive, like how +# the JVP rule of `sin_p` might bind `cos_p` and `mul_p`, then those `bind` +# calls will be handled by the interpreter at the next level down. # # What goes at the bottom of the interpreter stack? At the bottom, we know all # the transformation interpreters are finished, and we just want to do standard @@ -167,7 +176,6 @@ def lift(self, val): assert False # must override def process_primitive(self, primitive, tracers, params): assert False # must override - # The first two methods are about boxing up values in `Tracer`s, which are the # objects that flow through the Python programs we transform. The last method is # the callback we'll use to interpret primitive application. @@ -184,7 +192,6 @@ def process_primitive(self, primitive, tracers, params): # `Tracer` per transformation, and at least one `AbstractValue` per base type, # like arrays.) -# + import numpy as np from typing import Tuple @@ -246,6 +253,13 @@ def _nonzero(tracer): def str_short(self): return f'{self.dtype.name}[{",".join(str(d) for d in self.shape)}]' + def __hash__(self): + return hash((self.shape, self.dtype)) + + def __eq__(self, other): + return (type(self) is type(other) and + self.shape == other.shape and self.dtype == other.dtype) + class ConcreteArray(ShapedArray): array_abstraction_level = 2 val: np.ndarray @@ -269,58 +283,61 @@ def get_aval(x): else: return ConcreteArray(np.asarray(x)) - -# - - # Notice that we actually have two `AbstractValue`s for arrays, representing # different levels of abstraction. A `ShapedArray` represents the set of all # possible arrays with a given shape and dtype. A `ConcreteArray` represents a # singleton set consisting of a single array value. # -# Now that we've set up the trace stack, the Trace/Tracer API for interpreters, -# and abstract values, we can come back to implement `bind`: +# Now that we've set up the interpreter stack, the Trace/Tracer API for +# interpreters, and abstract values, we can come back to implement `bind`: def bind(prim, *args, **params): top_trace = find_top_trace(args) tracers = [full_raise(top_trace, arg) for arg in args] - out = top_trace.process_primitive(prim, tracers, params) - return full_lower(out) - + outs = top_trace.process_primitive(prim, tracers, params) + return [full_lower(out) for out in outs] # The main action is that we call `find_top_trace` to figure out which -# interpreter should handle this primitive application as a function of the -# arguments and the active traces on the trace stack. We then call that top +# interpreter should handle this primitive application. We then call that top # trace's `process_primitive` so that the trace can apply its interpretation # rule. The calls to `full_raise` just ensure that the inputs are boxed in the # top trace's `Tracer` instances, and the call to `full_lower` is an optional # optimization so that we unbox values out of `Tracer`s as much as possible. -# + from operator import attrgetter def find_top_trace(xs) -> Trace: top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)), default=trace_stack[0], key=attrgetter('level')) + if dynamic_trace and dynamic_trace.level > top_main.level: + top_main = dynamic_trace return top_main.trace_type(top_main) +# In words, ignoring the `dynamic_trace` step until Part 3, `find_top_trace` +# returns the highest-level interpreter associated with the `Tracer`s on its +# inputs, and otherwise returns the interpreter at the bottom of the stack +# (which is always an evaluation trace, at least for now). This is a deviation +# from the description above, where we always start by running the interpreter +# at the top of the stack and then work our way down, applying every interpreter +# in the stack. Instead, we're only applying an interpreter when the input +# arguments to a primitive bind are boxed in a `Tracer` corresponding to that +# interpreter. This optimization lets us skip irrelevant transformations, but +# bakes in an assumption that transformations mostly follow data dependence +# (except for the special bottom-of-the-stack interpreter, which interprets +# everything). +# +# An alternative would be to have every interpreter in the stack interpret every +# operation. That's worth exploring! JAX is designed around data dependence in +# large part because that's so natural for automatic differentiation, and JAX's +# roots are in autodiff. But it may be over-fit. -# - - -# In words, `find_top_trace` returns the highest-level interpreter associated -# with the `Tracer`s on its inputs, and otherwise returns the interpreter at the -# bottom of the stack (which is always an evaluation trace, at least for now). -# This corresponds to JAX transformations mostly working by data dependence -# _except_ for the special bottom-of-the-stack interpreter, which interprets -# everything. - -# + -def full_lower(val): +def full_lower(val: Any): if isinstance(val, Tracer): return val.full_lower() else: return val -def full_raise(trace, val) -> Tracer: +def full_raise(trace: Trace, val: Any) -> Tracer: if not isinstance(val, Tracer): return trace.pure(val) level = trace.main.level @@ -333,9 +350,6 @@ def full_raise(trace, val) -> Tracer: else: # val._trace.level == level raise Exception(f"Different traces at same level: {val._trace}, {trace}.") - -# - - # The logic in `full_raise` serves to box values into `Tracer`s for a particular # `Trace`, calling different methods on the `Trace` based on context: # `Trace.pure` is called on non-`Tracer` constants, and `Trace.lift` is called @@ -350,7 +364,6 @@ def full_raise(trace, val) -> Tracer: # We'll start with the simplest interpreter: the evaluation interpreter that # will sit at the bottom of the interpreter stack. -# + class EvalTrace(Trace): pure = lift = lambda self, x: x # no boxing in Tracers needed @@ -360,38 +373,36 @@ def process_primitive(self, primitive, tracers, params): trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack impl_rules = {} -impl_rules[add_p] = np.add -impl_rules[mul_p] = np.multiply -impl_rules[neg_p] = np.negative -impl_rules[sin_p] = np.sin -impl_rules[cos_p] = np.cos -impl_rules[reduce_sum_p] = np.sum -impl_rules[greater_p] = np.greater +impl_rules[add_p] = lambda x, y: [np.add(x, y)] +impl_rules[mul_p] = lambda x, y: [np.multiply(x, y)] +impl_rules[neg_p] = lambda x: [np.negative(x)] +impl_rules[sin_p] = lambda x: [np.sin(x)] +impl_rules[cos_p] = lambda x: [np.cos(x)] +impl_rules[reduce_sum_p] = lambda x, *, axis: [np.sum(x, axis)] +impl_rules[greater_p] = lambda x, y: [np.greater(x, y)] +impl_rules[transpose_p] = lambda x, *, perm: [np.transpose(x, perm)] -# - +def broadcast_impl(x, *, shape, axes): + return [np.broadcast_to(np.expand_dims(x, axes), shape)] +impl_rules[broadcast_p] = broadcast_impl # With this interpreter, we can evaluate user functions: -# + def f(x): - y = sin(x) * 2 + y = sin(x) * 2. z = - y + x return z print(f(3.0)) - -# - - # Woo! Like going around in a big circle. But the point of this indirection is # that now we can add some real transformations. # ### Forward-mode autodiff with `jvp` # -# First, a couple of helper functions: +# First, a few helper functions: -# + def zeros_like(val): return np.zeros_like(val) @@ -402,13 +413,13 @@ def unzip2(pairs): lst2.append(x2) return lst1, lst2 - -# - +map_ = map +def map(f, *xs): + return list(map_(f, *xs)) # The `Tracer` for forward-mode autodiff carries a primal-tangent pair. The # `Trace` applies JVP rules. -# + class JVPTracer(Tracer): def __init__(self, trace, primal, tangent): self._trace = trace @@ -425,62 +436,55 @@ class JVPTrace(Trace): def process_primitive(self, primitive, tracers, params): primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) jvp_rule = jvp_rules[primitive] - primal_out, tangent_out = jvp_rule(primals_in, tangents_in, **params) - return JVPTracer(self, primal_out, tangent_out) + primal_outs, tangent_outs = jvp_rule(primals_in, tangents_in, **params) + return [JVPTracer(self, x, t) for x, t in zip(primal_outs, tangent_outs)] jvp_rules = {} - -# - - # Notice both `lift` and `sublift` package a value into a `JVPTracer` with the # minimal amount of context, which is a zero tangent value. - +# # Let's add some JVP rules for primitives: -# + def add_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents - return x + y, x_dot + y_dot + return [x + y], [x_dot + y_dot] jvp_rules[add_p] = add_jvp def mul_jvp(primals, tangents): (x, y), (x_dot, y_dot) = primals, tangents - return x * y, x_dot * y + x * y_dot + return [x * y], [x_dot * y + x * y_dot] jvp_rules[mul_p] = mul_jvp def sin_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return sin(x), cos(x) * x_dot + return [sin(x)], [cos(x) * x_dot] jvp_rules[sin_p] = sin_jvp def cos_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return cos(x), -sin(x) * x_dot + return [cos(x)], [-sin(x) * x_dot] jvp_rules[cos_p] = cos_jvp def neg_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents - return neg(x), neg(x_dot) + return [neg(x)], [neg(x_dot)] jvp_rules[neg_p] = neg_jvp def reduce_sum_jvp(primals, tangents, *, axis): (x,), (x_dot,) = primals, tangents - return reduce_sum(x, axis), reduce_sum(x_dot, axis) + return [reduce_sum(x, axis)], [reduce_sum(x_dot, axis)] jvp_rules[reduce_sum_p] = reduce_sum_jvp def greater_jvp(primals, tangents): (x, y), _ = primals, tangents out_primal = greater(x, y) - return out_primal, zeros_like(out_primal) + return [out_primal], [zeros_like(out_primal)] jvp_rules[greater_p] = greater_jvp - -# - - # Finally, we add a transformation API to kick off the trace: -def jvp(f, primals, tangents): +def jvp_v1(f, primals, tangents): with new_main(JVPTrace) as main: trace = JVPTrace(main) tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)] @@ -489,37 +493,34 @@ def jvp(f, primals, tangents): primal_out, tangent_out = tracer_out.primal, tracer_out.tangent return primal_out, tangent_out - # And with that, we can differentiate! x = 3.0 -y, sin_deriv_at_3 = jvp(sin, (x,), (1.0,)) +y, sin_deriv_at_3 = jvp_v1(sin, (x,), (1.0,)) print(sin_deriv_at_3) print(cos(3.0)) # + def f(x): - y = sin(x) * 2 + y = sin(x) * 2. z = - y + x return z x, xdot = 3., 1. -y, ydot = jvp(f, (x,), (xdot,)) +y, ydot = jvp_v1(f, (x,), (xdot,)) print(y) print(ydot) - # + def deriv(f): - return lambda x: jvp(f, (x,), (1.,))[1] + return lambda x: jvp_v1(f, (x,), (1.,))[1] print(deriv(sin)(3.)) print(deriv(deriv(sin))(3.)) print(deriv(deriv(deriv(sin)))(3.)) print(deriv(deriv(deriv(deriv(sin))))(3.)) - # + def f(x): if x > 0.: # Python control flow @@ -529,17 +530,151 @@ def f(x): print(deriv(f)(3.)) print(deriv(f)(-3.)) +# - + +# ## Pytrees and flattening user functions' inputs and outputs + +# A limitation with `jvp_v1` is that it assumes the user function accepts arrays +# as positional arguments and produces a single array as output. What if it +# produced a list as output? Or accepted nested containers as inputs? It would +# be a pain to deal with all the possible containers in inputs and outputs at +# every layer of the stack. Instead, we can wrap the user function so that the +# wrapped version accepts arrays as inputs and returns a flat list of arrays as +# output. The wrapper just needs to unflatten its input, call the user function, +# and flatten the output. +# +# Here's how we'd like to write `jvp`, assuming the user always gives us +# functions that take arrays as inputs and produces a flat list of arrays as +# outputs: +def jvp_flat(f, primals, tangents): + with new_main(JVPTrace) as main: + trace = JVPTrace(main) + tracers_in = [JVPTracer(trace, x, t) for x, t in zip(primals, tangents)] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + primals_out, tangents_out = unzip2((t.primal, t.tangent) for t in tracers_out) + return primals_out, tangents_out +# To support user functions that have arbitrary containers in the inputs and +# outputs, here's how we'd write the user-facing `jvp` wrapper: + +def jvp(f, primals, tangents): + primals_flat, in_tree = tree_flatten(primals) + tangents_flat, in_tree2 = tree_flatten(tangents) + if in_tree != in_tree2: raise TypeError + f, out_tree = flatten_fun(f, in_tree) + primals_out_flat, tangents_out_flat = jvp_flat(f, primals_flat, tangents_flat) + primals_out = tree_unflatten(out_tree(), primals_out_flat) + tangents_out = tree_unflatten(out_tree(), tangents_out_flat) + return primals_out, tangents_out + +# Notice that we had to plumb the tree structure of the user function output +# back to the caller of `flatten_fun`. That information isn't available until we +# actually run the user function, so `flatten_fun` just returns a reference to a +# mutable cell, represented as a thunk. These side-effects are safe because we +# always run the user function exactly once. (This safe regime is the reason for +# the "linear" name in `linear_util.py`, in the sense of [linear +# types](https://en.wikipedia.org/wiki/Substructural_type_system).) +# +# All that remains is to write `tree_flatten`, `tree_unflatten`, and +# `flatten_fun`: + +def flatten_fun(f, in_tree): + store = Store() + + def flat_fun(*args_flat): + pytree_args = tree_unflatten(in_tree, args_flat) + out = f(*pytree_args) + out_flat, out_tree = tree_flatten(out) + store.set_value(out_tree) + return out_flat + + return flat_fun, store + +class Empty: pass +empty = Empty() + +class Store: + val = empty + + def set_value(self, val): + assert self.val is empty + self.val = val + + def __call__(self): + return self.val + +# + +import itertools as it +from typing import Callable, Type, Hashable, Dict, Iterable, Iterator + +class NodeType(NamedTuple): + to_iterable: Callable + from_iterable: Callable + +node_types: Dict[Type, NodeType] = { + tuple: NodeType(lambda t: (None, t), lambda _, xs: tuple(xs)), + list: NodeType( lambda l: (None, l), lambda _, xs: list(xs)), + dict: NodeType(lambda d: map(tuple, unzip2(sorted(d.items()))), + lambda keys, vals: dict(zip(keys, vals))), +} + +class PyTreeDef(NamedTuple): + node_type: NodeType + node_metadata: Hashable + child_treedefs: Tuple['PyTreeDef'] + +class Leaf: pass +leaf = Leaf() + +def tree_flatten(x: Any) -> Tuple[List[Any], PyTreeDef]: + children_iter, treedef = _tree_flatten(x) + return list(children_iter), treedef + +def _tree_flatten(x: Any) -> Tuple[Iterable, PyTreeDef]: + node_type = node_types.get(type(x)) + if node_type: + node_metadata, children = node_type.to_iterable(x) + children_flat, child_trees = unzip2(map(_tree_flatten, children)) + flattened = it.chain.from_iterable(children_flat) + return flattened, PyTreeDef(node_type, node_metadata, tuple(child_trees)) + else: + return [x], leaf + +def tree_unflatten(treedef: PyTreeDef, xs: List[Any]) -> Any: + return _tree_unflatten(treedef, iter(xs)) + +def _tree_unflatten(treedef: PyTreeDef, xs: Iterator) -> Any: + if treedef is leaf: + return next(xs) + else: + children = (_tree_unflatten(t, xs) for t in treedef.child_treedefs) + return treedef.node_type.from_iterable(treedef.node_metadata, children) # - + +# With this pytree-handling `jvp` impelmentation, we can now handle arbitrary +# input and output containers. That'll come in handy with future transformations +# too! + +def f(x): + y = sin(x) * 2. + z = - y + x + return {'hi': z, 'there': [x, y]} + + +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot) + # ### Vectorized batching with `vmap` # # First, a couple helper functions, one for producing mapped abstract values # from unmapped ones (by removing an axis), and one for moving batch dimensions # around: -# + def mapped_aval(batch_dim, aval): shape = list(aval.shape) del shape[batch_dim] @@ -549,24 +684,29 @@ def move_batch_axis(axis_size, src, dst, x): if src is not_mapped: target_shape = list(np.shape(x)) target_shape.insert(dst, axis_size) - return np.broadcast_to(np.expand_dims(x, dst), target_shape) + return broadcast(x, target_shape, [dst]) + elif src == dst: + return x else: - return np.moveaxis(x, src, dst) + return moveaxis(x, src, dst) - -# - +def moveaxis(x, src: int, dst: int): + perm = [i for i in range(np.ndim(x)) if i != src] + perm.insert(dst, src) + return transpose(x, perm) # The `Tracer` for vectorized batching carries a batched value and an optional # integer indicating which axis (if any) is the batch axis. -# + from typing import Union class NotMapped: pass not_mapped = NotMapped() +BatchAxis = Union[NotMapped, int] + class BatchTracer(Tracer): - def __init__(self, trace, val, batch_dim: Union[NotMapped, int]): + def __init__(self, trace, val, batch_dim: BatchAxis): self._trace = trace self.val = val self.batch_dim = batch_dim @@ -590,15 +730,14 @@ class BatchTrace(Trace): def process_primitive(self, primitive, tracers, params): vals_in, bdims_in = unzip2((t.val, t.batch_dim) for t in tracers) vmap_rule = vmap_rules[primitive] - val_out, bdim_out = vmap_rule(self.axis_size, vals_in, bdims_in, **params) - return BatchTracer(self, val_out, bdim_out) + val_outs, bdim_outs = vmap_rule(self.axis_size, vals_in, bdims_in, **params) + return [BatchTracer(self, x, bd) for x, bd in zip(val_outs, bdim_outs)] @property def axis_size(self): return self.main.global_data vmap_rules = {} -# - # Here we've implemented the optional `Tracer.full_lower` method, which lets us # peel off a batching tracer if it's not needed because it doesn't represent a @@ -612,20 +751,19 @@ def axis_size(self): # # Next we can define batching interpreter rules for each primitive: -# + from functools import partial def broadcasting_binop_batching_rule(op, axis_size, vals_in, dims_in): (x, y), (x_bdim, y_bdim) = vals_in, dims_in if x_bdim != y_bdim: y = move_batch_axis(axis_size, y_bdim, x_bdim, y) - return op(x, y), x_bdim + return [op(x, y)], [x_bdim] vmap_rules[add_p] = partial(broadcasting_binop_batching_rule, add) vmap_rules[mul_p] = partial(broadcasting_binop_batching_rule, mul) def vectorized_unop_batching_rule(op, axis_size, vals_in, dims_in): (x,), (x_bdim,) = vals_in, dims_in - return op(x), x_bdim + return [op(x)], [x_bdim] vmap_rules[sin_p] = partial(vectorized_unop_batching_rule, sin) vmap_rules[cos_p] = partial(vectorized_unop_batching_rule, cos) vmap_rules[neg_p] = partial(vectorized_unop_batching_rule, neg) @@ -634,7 +772,7 @@ def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): (x,), (x_bdim,) = vals_in, dims_in new_axis = axis + (x_bdim <= axis) out_bdim = x_bdim - (new_axis < x_bdim) - return reduce_sum(x, new_axis), out_bdim + return [reduce_sum(x, new_axis)], [out_bdim] vmap_rules[reduce_sum_p] = reduce_sum_batching_rule @@ -642,28 +780,37 @@ def reduce_sum_batching_rule(axis_size, vals_in, dims_in, *, axis): # Finally, we add a transformation API to kick off the trace: -def vmap(f, in_axes, out_axis): +def vmap_flat(f, in_axes, *args): + axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) + if ax is not not_mapped} + with new_main(BatchTrace, axis_size) as main: + trace = BatchTrace(main) + tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x + for x, ax in zip(args, in_axes)] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + vals_out, bdims_out = unzip2((t.val, t.batch_dim) for t in tracers_out) + outs_transposed = [move_batch_axis(axis_size, bdim, 0, val_out) + for val_out, bdim in zip(vals_out, bdims_out)] + return outs_transposed + +def vmap(f, in_axes): def batched_f(*args): - axis_size, = {x.shape[ax] for x, ax in zip(args, in_axes) - if ax is not None} - with new_main(BatchTrace, axis_size) as main: - trace = BatchTrace(main) - tracers_in = [BatchTracer(trace, x, ax) if ax is not None else x - for x, ax in zip(args, in_axes)] - out = f(*tracers_in) - tracer_out = full_raise(trace, out) - val_out, batch_dim_out = tracer_out.val, tracer_out.batch_dim - return move_batch_axis(axis_size, batch_dim_out, out_axis, val_out) + args_flat, in_tree = tree_flatten(args) + in_axes_flat, in_tree2 = tree_flatten(in_axes) + if in_tree != in_tree2: raise TypeError + f_flat, out_tree = flatten_fun(f, in_tree) + outs_flat = vmap_flat(f_flat, in_axes_flat, *args_flat) + return tree_unflatten(out_tree(), outs_flat) return batched_f - # + def add_one_to_a_scalar(scalar): assert np.ndim(scalar) == 0 return 1 + scalar vector_in = np.arange(3.) -vector_out = vmap(add_one_to_a_scalar, (0,), 0)(vector_in) +vector_out = vmap(add_one_to_a_scalar, (0,))(vector_in) print(vector_in) print(vector_out) @@ -673,7 +820,7 @@ def add_one_to_a_scalar(scalar): def jacfwd(f, x): pushfwd = lambda v: jvp(f, (x,), (v,))[1] vecs_in = np.eye(np.size(x)).reshape(np.shape(x) * 2) - return vmap(pushfwd, (0,), 0)(vecs_in) + return vmap(pushfwd, (0,))(vecs_in) def f(x): return sin(x) @@ -688,34 +835,38 @@ def f(x): # rules, and for more complex primitives (like for convolution or advanced # indexing) each rule is harder to write. But the overarching design is no # different. -# 1. **Transformations expect arrays in, single array out.** -# 2. **No symbolic zeros in autodiff.** -# 3. **No special call primitives yet.** The core machinery needs to be +# 2. **No pytrees.** Transformations expect arrays in, and either a single array +# out or a flat list of arrays out. +# 3. **Missing optimization: no symbolic zeros in autodiff.** +# 4. **No special call primitives yet.** The core machinery needs to be # generalized to handle the most flexible kind of higher-order primitive, # used by `jax.custom_jvp` and `jax.custom_vjp`. -# ## Part 2: Jaxprs, for `jit` and `vjp` + +# ## Part 2: Jaxprs # # The next transformations are the horizon are `jit` for just-in-time # compilation and `vjp` for reverse-mode autodiff. (`grad` is just a small -# wrapper around `vjp`.) For `jvp` and `vmap` we only needed each `Tracer` to -# carry a little bit of extra context, but for both `jit` and `vjp` we need -# much richer context: we need to represent _programs_. That is, we need jaxprs! +# wrapper around `vjp`.) Whereas `jvp` and `vmap` only needed each `Tracer` to +# carry a little bit of extra context, for both `jit` and `vjp` we need much +# richer context: we need to represent _programs_. That is, we need jaxprs! # -# Jaxprs are JAX's internal intermediate representation of programs. Jaxprs are -# an explicitly typed, functional, first-order language. We need a program -# representation for `jit` because the purpose of `jit` is to stage computation -# out of Python. For any computation we want to stage out, we need to be able to -# represent it as data, and build it up as we trace a Python function. -# Similarly, `vjp` needs a way to represent the computation for the backward -# pass of reverse-mode autodiff. We use the same jaxpr program representation -# for both needs. +# Jaxprs are JAX's internal intermediate representation of programs. They are +# explicitly typed, functional, first-order, and in ANF form. We need a +# program representation for `jit` because the purpose of `jit` is to stage +# computation out of Python. For any computation we want to stage out, we need +# to be able to represent it as data, and build it up as we trace a Python +# function. Similarly, `vjp` needs a way to represent the computation for the +# backward pass of reverse-mode autodiff. We use the same jaxpr program +# representation for both needs. # # (Building a program representation is the most # [free](https://en.wikipedia.org/wiki/Free_object) kind of -# trace- transformation, and so except for issues around handling native Python +# trace-transformation, and so except for issues around handling native Python # control flow, any transformation could be implemented by first tracing to a # jaxpr and then interpreting the jaxpr.) + +# ### Jaxpr data strutures # # The jaxpr term syntax is roughly: # @@ -724,20 +875,20 @@ def f(x): # { lambda , ... . # let # ... -# in } +# in ( , ... ) } # # binder ::= : # var ::= a | b | c | ... # atom ::= | # literal ::= | # -# eqn ::= = [ ] , ... +# eqn ::= , ... = [ ] , ... # ``` # # The syntax of types is: # # ``` -# jaxpr_type ::= [, ...] -> [, ...] +# jaxpr_type ::= [ , ... ] -> [ , ... ] # array_type ::= [] # dtype ::= f32 | f64 | i32 | i64 # shape ::= , ... @@ -748,7 +899,7 @@ def f(x): # structs: # + -from typing import Dict, Set +from typing import Set class Var: aval: ShapedArray @@ -768,47 +919,55 @@ class JaxprEqn(NamedTuple): primitive: Primitive inputs: List[Atom] params: Dict[str, Any] - out_binder: Var + out_binders: List[Var] class Jaxpr(NamedTuple): in_binders: List[Var] eqns: List[JaxprEqn] - out: Atom - + outs: List[Atom] def raise_to_shaped(aval): return ShapedArray(aval.shape, aval.dtype) +# - + +# Type-checking a jaxpr involves checking that there are no unbound variables, +# that variables are only bound once, and that for each equation the type of +# the primitive application matches the type of the output binders. # + class JaxprType: in_types: List[ShapedArray] - out_type: ShapedArray + out_type: List[ShapedArray] - def __init__(self, in_types, out_type): + def __init__(self, in_types, out_types): self.in_types = in_types - self.out_type = out_type + self.out_types = out_types def __repr__(self): in_types = ', '.join(aval.str_short() for aval in self.in_types) - out_type = self.out_type.str_short() - return f'({in_types}) -> {out_type}' - + out_types = ', '.join(aval.str_short() for aval in self.out_types) + return f'({in_types}) -> ({out_types})' def typecheck_jaxpr(jaxpr: Jaxpr) -> JaxprType: env: Set[Var] = set() for v in jaxpr.in_binders: + if v in env: raise TypeError env.add(v) for eqn in jaxpr.eqns: in_types = [typecheck_atom(env, x) for x in eqn.inputs] - out_type = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) - if not types_equal(out_type, eqn.out_binder.aval): raise TypeError - env.add(eqn.out_binder) + out_types = abstract_eval_rules[eqn.primitive](*in_types, **eqn.params) + for out_binder, out_type in zip(eqn.out_binders, out_types): + if not types_equal(out_type, out_binder.aval): raise TypeError + for out_binder in eqn.out_binders: + if out_binder in env: raise TypeError + env.add(out_binder) - out_type = typecheck_atom(env, jaxpr.out) - return JaxprType([v.aval for v in jaxpr.in_binders], out_type) + in_types = [v.aval for v in jaxpr.in_binders] + out_types = [typecheck_atom(env, x) for x in jaxpr.outs] + return JaxprType(in_types, out_types) def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray: if isinstance(x, Var): @@ -821,15 +980,39 @@ def typecheck_atom(env: Set[Var], x: Atom) -> ShapedArray: def types_equal(a: ShapedArray, b: ShapedArray) -> bool: return a.shape == b.shape and a.dtype == b.dtype +# - +# We can apply the function represented by a jaxpr to arguments with a simple +# interpreter. -# - +def eval_jaxpr(jaxpr: Jaxpr, args: List[Any]) -> List[Any]: + env: Dict[Var, Any] = {} + + def read(x: Atom) -> Any: + return env[x] if type(x) is Var else x.val + + def write(v: Var, val: Any) -> None: + env[v] = val + + map(write, jaxpr.in_binders, args) + for eqn in jaxpr.eqns: + in_vals = map(read, eqn.inputs) + outs = bind(eqn.primitive, *in_vals, **eqn.params) + map(write, eqn.out_binders, outs) + return map(read, jaxpr.outs) + +def jaxpr_as_fun(jaxpr: Jaxpr): + return lambda *args: eval_jaxpr(jaxpr, args) +# By using `bind` in the interpreter, this interpreter itself is traceable. + +# ### Building jaxprs with tracing +# # Now that we have jaxprs as a data structure, we need ways to produce these # from tracing Python code. In general there are two variants of how we trace to # a jaxpr; `jit` uses one and `vjp` uses the other. We'll start with the one -# used by `jit`, which is also used by control flow primitives like -# `lax.cond`, `lax.while_loop`, and `lax.scan`. +# used by `jit`, which is also used by control flow primitives like `lax.cond`, +# `lax.while_loop`, and `lax.scan`. # + # NB: the analogous class in JAX is called 'DynamicJaxprTracer' @@ -845,26 +1028,26 @@ def __init__(self, trace, aval): class JaxprTrace(Trace): def new_arg(self, aval: ShapedArray) -> JaxprTracer: aval = raise_to_shaped(aval) - tracer = JaxprTracer(self, aval) + tracer = self.builder.new_tracer(self, aval) self.builder.tracer_to_var[id(tracer)] = Var(aval) return tracer def get_or_make_const_tracer(self, val: Any) -> JaxprTracer: tracer = self.builder.const_tracers.get(id(val)) if tracer is None: - tracer = JaxprTracer(self, raise_to_shaped(get_aval(val))) + tracer = self.builder.new_tracer(self, raise_to_shaped(get_aval(val))) self.builder.add_const(tracer, val) return tracer pure = lift = get_or_make_const_tracer def process_primitive(self, primitive, tracers, params): avals_in = [t.aval for t in tracers] - aval_out = abstract_eval_rules[primitive](*avals_in, **params) - out_tracer = JaxprTracer(self, aval_out) + avals_out = abstract_eval_rules[primitive](*avals_in, **params) + out_tracers = [self.builder.new_tracer(self, a) for a in avals_out] inputs = [self.builder.getvar(t) for t in tracers] - outvar = self.builder.add_var(out_tracer) - self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvar)) - return out_tracer + outvars = [self.builder.add_var(t) for t in out_tracers] + self.builder.add_eqn(JaxprEqn(primitive, inputs, params, outvars)) + return out_tracers @property def builder(self): @@ -872,31 +1055,35 @@ def builder(self): # NB: in JAX, instead of a dict we attach impl rules to the Primitive instance abstract_eval_rules = {} - - # - # Notice that we keep as interpreter-global data a builder object, which keeps -# track of variables, constants, and eqns as we build up the jaxpr. +# track of variables, constants, and eqns as we build up the jaxpr. class JaxprBuilder: eqns: List[JaxprEqn] tracer_to_var: Dict[int, Var] const_tracers: Dict[int, JaxprTracer] constvals: Dict[Var, Any] + tracers: List[JaxprTracer] def __init__(self): self.eqns = [] self.tracer_to_var = {} self.const_tracers = {} self.constvals = {} + self.tracers = [] + + def new_tracer(self, trace: JaxprTrace, aval: ShapedArray) -> JaxprTracer: + tracer = JaxprTracer(trace, aval) + self.tracers.append(tracer) + return tracer def add_eqn(self, eqn: JaxprEqn) -> None: self.eqns.append(eqn) def add_var(self, tracer: JaxprTracer) -> Var: - var = self.tracer_to_var.get(id(tracer)) - assert var is None + assert id(tracer) not in self.tracer_to_var var = self.tracer_to_var[id(tracer)] = Var(tracer.aval) return var @@ -911,25 +1098,28 @@ def add_const(self, tracer: JaxprTracer, val: Any) -> Var: self.constvals[var] = val return var - def build(self, in_tracers: List[JaxprTracer], out_tracer: JaxprTracer + def build(self, in_tracers: List[JaxprTracer], out_tracers: List[JaxprTracer] ) -> Tuple[Jaxpr, List[Any]]: constvars, constvals = unzip2(self.constvals.items()) t2v = lambda t: self.tracer_to_var[id(t)] in_binders = constvars + [t2v(t) for t in in_tracers] - jaxpr = Jaxpr(in_binders, self.eqns, t2v(out_tracer)) + out_vars = [t2v(t) for t in out_tracers] + jaxpr = Jaxpr(in_binders, self.eqns, out_vars) typecheck_jaxpr(jaxpr) return jaxpr, constvals - # The rules we need for `JaxprTrace.process_primitive` are essentially typing -# rules for primitive applications: given the primitive, its parameters, and +# rules for primitive applications: given the primitive, its parameters, and # types for the inputs, the rule must produce a type for the output, which is -# then packaged with the output `JaxprTracer`. We can use abstract evaluation -# rules for this same purpose, even though they can be more general (since -# abstract evaluation rules need to work on ConcreteArray inputs as well). We'll -# reuse these abstract evaluation rules for the other jaxpr-producing trace -# machinery, where the potential extra generality is useful. +# then packaged with the output `JaxprTracer`. We can use abstract evaluation +# rules for this same purpose, even though they can be more general (since +# abstract evaluation rules must accept ConcreteArray inputs, and since they +# need only return an upper bound on the set of possible outputs, they can +# produce ConcreteArray outputs as well). We'll reuse these abstract evaluation +# rules for the other jaxpr-producing trace machinery, where the potential extra +# generality is useful. +# + def broadcast_shapes(*shapes): assert len(shapes) > 1 for sizes in zip(*shapes): @@ -938,18 +1128,16 @@ def broadcast_shapes(*shapes): raise Exception return tuple(next((d for d in sizes if d != 1), 1) for sizes in zip(*shapes)) - -# + def broadcasting_binop_abstract_eval_rule(*avals_in): out_dtype = np.result_type(*map(np.result_type, avals_in)) out_shape = broadcast_shapes(*map(np.shape, avals_in)) - return ShapedArray(out_shape, out_dtype) + return [ShapedArray(out_shape, out_dtype)] abstract_eval_rules[add_p] = broadcasting_binop_abstract_eval_rule abstract_eval_rules[mul_p] = broadcasting_binop_abstract_eval_rule def vectorized_unop_abstract_eval_rule(aval_in): - return ShapedArray(np.shape(aval_in), np.result_type(aval_in)) + return [ShapedArray(np.shape(aval_in), np.result_type(aval_in))] abstract_eval_rules[sin_p] = vectorized_unop_abstract_eval_rule abstract_eval_rules[cos_p] = vectorized_unop_abstract_eval_rule @@ -957,27 +1145,35 @@ def vectorized_unop_abstract_eval_rule(aval_in): def reduce_sum_abstract_eval_rule(aval_in, *, axis): new_shape = [d for i, d in enumerate(aval_in.shape) if i != axis] - return ShapedArray(tuple(new_shape), aval_in.dtype) + return [ShapedArray(tuple(new_shape), aval_in.dtype)] abstract_eval_rules[reduce_sum_p] = reduce_sum_abstract_eval_rule - +def broadcast_abstract_eval(x, *, shape, axes): + return [ShapedArray(tuple(shape), np.result_type(x))] +abstract_eval_rules[broadcast_p] = broadcast_abstract_eval # - -# To check our implementation, we can add a `make_jaxpr` transformation and -# first pretty-printer: +# To check our implementation of jaxprs, we can add a `make_jaxpr` +# transformation and a pretty-printer: + +from functools import lru_cache + +@lru_cache() +def make_jaxpr_v1(f, *avals_in): + avals_in, in_tree = tree_flatten(avals_in) + f, out_tree = flatten_fun(f, in_tree) -def make_jaxpr(f, avals_in): builder = JaxprBuilder() with new_main(JaxprTrace, builder) as main: trace = JaxprTrace(main) tracers_in = [trace.new_arg(aval) for aval in avals_in] - out = f(*tracers_in) - tracer_out = full_raise(trace, out) - return builder.build(tracers_in, tracer_out) + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + jaxpr, consts = builder.build(tracers_in, tracers_out) + return jaxpr, consts, out_tree() -# + +# + tags=["hide-input"] from collections import defaultdict -import itertools as it import string class PPrint: @@ -1017,15 +1213,16 @@ def pp_jaxpr(jaxpr: Jaxpr): names = defaultdict(lambda: next(namegen)) in_binders = ', '.join(var_str(names, x) for x in jaxpr.in_binders) eqns = vcat([pp_eqn(names, e) for e in jaxpr.eqns]) - out = names[jaxpr.out] if isinstance(jaxpr.out, Var) else str(jaxpr.out.val) + outs = ', '.join(names[v] if isinstance(v, Var) else str(v.val) + for v in jaxpr.outs) return (pp(f'{{ lambda {in_binders} .') + - ((pp('let ') >> eqns) + pp(f'in {out} }}')).indent(2)) + ((pp('let ') >> eqns) + pp(f'in ( {outs} ) }}')).indent(2)) def var_str(names: Dict[Var, str], v: Var) -> str: return f'{names[v]}:{v.aval.str_short()}' def pp_eqn(names: Dict[Var, str], eqn: JaxprEqn) -> PPrint: - lhs = pp(var_str(names, eqn.out_binder)) + lhs = pp(' '.join(var_str(names, v) for v in eqn.out_binders)) rhs = (pp(eqn.primitive.name) >> pp_params(eqn.params) >> pp(' '.join(names[x] if isinstance(x, Var) else str(x.val) for x in eqn.inputs))) @@ -1037,10 +1234,427 @@ def pp_params(params: Dict[str, Any]) -> PPrint: return pp(' [ ') >> vcat([pp(f'{k}={v}') for k, v in items]) >> pp(' ] ') else: return pp(' ') - - # - -jaxpr, consts = make_jaxpr(lambda x: 2. * x, [raise_to_shaped(get_aval(3.))]) +jaxpr, consts, _ = make_jaxpr_v1(lambda x: 2. * x, raise_to_shaped(get_aval(3.))) print(pp_jaxpr(jaxpr)) print(typecheck_jaxpr(jaxpr)) + +# But there's a limitation here: because of how `find_top_trace` operates by +# data dependence, `make_jaxpr_v1` can't stage out all the primitive operations +# performed by the Python callable it's given. For example: + +jaxpr, consts, _ = make_jaxpr_v1(lambda: mul(2., 2.)) +print(pp_jaxpr(jaxpr)) + +# This is precisely the issue that +# [omnistaging](https://github.com/google/jax/pull/3370) fixed. +# We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always +# applied, regardless of whether any inputs to `bind` are boxed in corresponding +# `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` +# global defined in Part 1: + +@contextmanager +def new_dynamic(main: MainTrace): + global dynamic_trace + prev_dynamic_trace, dynamic_trace = dynamic_trace, main + try: + yield + finally: + dynamic_trace = prev_dynamic_trace + +@lru_cache() # ShapedArrays are hashable +def make_jaxpr(f, *avals_in): + avals_in, in_tree = tree_flatten(avals_in) + f, out_tree = flatten_fun(f, in_tree) + + builder = JaxprBuilder() + with new_main(JaxprTrace, builder) as main: + with new_dynamic(main): + trace = JaxprTrace(main) + tracers_in = [trace.new_arg(aval) for aval in avals_in] + outs = f(*tracers_in) + tracers_out = [full_raise(trace, out) for out in outs] + jaxpr, consts = builder.build(tracers_in, tracers_out) + return jaxpr, consts, out_tree() + +jaxpr, consts, _ = make_jaxpr(lambda: mul(2., 2.)) +print(pp_jaxpr(jaxpr)) + +# Using `dynamic_trace` this way is conceptually the same as stashing the +# current interpreter stack and starting a new one with the `JaxprTrace` at the +# bottom. That is, no interpreters lower in the stack than the `dynamic_trace` +# are applied (since `JaxprTrace.process_primitive` doesn't call `bind`), though +# if the Python callable being traced to a jaxpr itself uses transformations +# then those can be pushed onto the interpreter stack above the `JaxprTrace`. +# But temporarily stashing the interpreter stack would break up the system +# state. The `dynamic_trace` tag achieves the same goals while keeping the +# system state simpler. + +# That's it for jaxprs! With jaxprs in hand, we can implement the remaining +# major JAX features. But before moving on, let's highlight some +# simplifications we've made: +# 1. **Single-output primitives and jaxprs.** + +# ## Part 3: `jit`, simplified +# +# While `jit` has a transformation-like API in that it accepts a Python callable +# as an argument, under the hood it's really a higher-order primitive rather +# than a transformation. A primitive is _higher-order_ when it's parameterized +# by a function. + +# ### "Final style" and "initial style" +# +# There are two options for how to handle higher-order primitives. Each requires +# a different approach to tracing and engenders different tradeoffs: +# 1. **`bind` takes a Python callable as an argument.** We defer forming a jaxpr +# until as late as possible, namely until we're running the final interpreter +# at the bottom of the interpreter stack. That way we can swap a `JaxprTrace` +# in at the bottom of the interpreter stack and thus stage out rather than +# execute all primitive operations. With this approach, transformations in +# the stack get applied as we execute the Python callable as usual. This +# approach can be very tricky to implement, but it's as general as possible +# because it allows higher-order primitives not to raise the abstraction +# level of their arguments and thus allows data-dependent Python control +# flow. We refer to this approach as using a "final-style higher-order +# primitive" employing the discharge-at-tracing-time "final-style +# transformations" we've used so far. +# 2. **`bind` takes a jaxpr as an argument.** Before we call `bind`, in the +# primitive wrapper we can just use `make_jaxpr` to form a jaxpr up-front and +# be done with the Python callable entirely. In this case, `make_jaxpr` puts +# its `JaxprTrace` at the top of the interpreter stack, and no +# transformations lower in the stack, which might enter via closed-over +# Tracers, are applied to the Python callable as we trace it. +# (Transformations applied within the Python callable are applied as usual, +# being added to the stack above the JaxprTrace.) Instead, the +# transformations lower in the stack are later applied to the call primitive, +# and the call primitive's rules must then transform the jaxpr itself. +# Because we trace to a jaxpr up-front, this approach can't support +# data-dependent Python control flow, but it is more straightforward to +# implement. We refer to this kind of higher-order primitive as an +# "initial-style higher-order primitive", and say that its jaxpr-processing +# transformation rules are "initial-style transformation rules." +# +# The latter approach fits for `jit` because we don't need to support +# data-dependent Python control flow in the user-provided Python callable, as +# the whole purpose of `jit` is to stage computation out of Python to be +# executed by XLA. (In contrast, `custom_jvp` is a higher-order primitive in +# which we want to support data-dependent Python control flow.) +# +# Historically, we started using the "initial-style" and "final-style" +# terminology after reading the [typed tagless final +# interpreters](http://okmij.org/ftp/tagless-final/index.html) paper, and +# jokingly referring to JAX as an implementation of "untyped tagful final +# interpreters." We don't claim to carry over (or understand) any deep meaning +# behind these terms; we loosely use "initial style" to mean "build an AST and +# then transform it", and we use "final style" to mean "transform as we trace." +# But it's just imprecise yet sticky jargon. + +# With the initial-style approach, here's the user-facing `jit` wrapper: + +# + +def jit(f): + def f_jitted(*args): + avals_in = [raise_to_shaped(get_aval(x)) for x in args] + jaxpr, consts, out_tree = make_jaxpr(f, *avals_in) + outs = bind(xla_call_p, *consts, *args, jaxpr=jaxpr, num_consts=len(consts)) + return tree_unflatten(out_tree, outs) + return f_jitted + +xla_call_p = Primitive('xla_call') +# - + +# With any new primitive, we need to give it transformation rules, starting with +# its evaluation rule. When we evaluate an application of the `xla_call` +# primitive, we want to stage out out the computation to XLA. That involves +# translating the jaxpr to an XLA HLO program, transferring the argument values +# to the XLA device, executing the XLA program, and transferring back the +# results. We'll cache the XLA HLO compilation so that for each `jit`ted +# function it only needs to be performed once per argument shape and dtype +# signature. +# +# First, some utilities. + +class IDHashable: + val: Any + + def __init__(self, val): + self.val = val + + def __hash__(self) -> int: + return id(self.val) + + def __eq__(self, other): + return type(other) is IDHashable and id(self.val) == id(other.val) + +# Next, we'll define the evaluation rule for `xla_call`: + +# + +from jax.lib import xla_bridge as xb +from jax.lib import xla_client as xc +xe = xc._xla +xops = xc._xla.ops + +def xla_call_impl(*args, jaxpr: Jaxpr, num_consts: int): + consts, args = args[:num_consts], args[num_consts:] + hashable_consts = tuple(map(IDHashable, consts)) + execute = xla_callable(IDHashable(jaxpr), hashable_consts) + return execute(*args) +impl_rules[xla_call_p] = xla_call_impl + +@lru_cache() +def xla_callable(hashable_jaxpr: IDHashable, hashable_consts: Tuple[IDHashable]): + jaxpr: Jaxpr = hashable_jaxpr.val + consts = [x.val for x in hashable_consts] + in_avals = [v.aval for v in jaxpr.in_binders[len(consts):]] + c = xb.make_computation_builder('xla_call') + xla_consts = _xla_consts(c, consts) + xla_params = _xla_params(c, in_avals) + outs = jaxpr_subcomp(c, jaxpr, xla_consts + xla_params) + out = xops.Tuple(c, outs) + compiled = xb.get_backend(None).compile(c.build(out)) + return partial(execute_compiled, compiled, [v.aval for v in jaxpr.outs]) + +def _xla_consts(c: xe.XlaBuilder, consts: List[Any]) -> List[xe.XlaOp]: + unique_consts = {id(cnst): cnst for cnst in consts} + xla_consts = { + id_: xops.ConstantLiteral(c, cnst) for id_, cnst in unique_consts.items()} + return [xla_consts[id(cnst)] for cnst in consts] + +def _xla_params(c: xe.XlaBuilder, avals_in: List[ShapedArray]) -> List[xe.XlaOp]: + return [xb.parameter(c, i, _xla_shape(a)) for i, a in enumerate(avals_in)] + +def _xla_shape(aval: ShapedArray) -> xe.Shape: + return xc.Shape.array_shape(xc.dtype_to_etype(aval.dtype), aval.shape) +# - + +# The main action is in `xla_callable`, which compiles a jaxpr into an XLA HLO +# program using `jaxpr_subcomp`, then returns a callable which executes the +# compiled program: + +# + +def jaxpr_subcomp(c: xe.XlaBuilder, jaxpr: Jaxpr, args: List[xe.XlaOp] + ) -> xe.XlaOp: + env: Dict[Var, xe.XlaOp] = {} + + def read(x: Atom) -> xe.XlaOp: + return env[x] if type(x) is Var else xb.constant(c, x.val) + + def write(v: Var, val: xe.XlaOp) -> None: + env[v] = val + + map(write, jaxpr.in_binders, args) + for eqn in jaxpr.eqns: + in_avals = [x.aval for x in eqn.inputs] + in_vals = map(read, eqn.inputs) + rule = xla_translations[eqn.primitive] + out_vals = rule(c, in_avals, in_vals, **eqn.params) + map(write, eqn.out_binders, out_vals) + return map(read, jaxpr.outs) + +def execute_compiled(compiled, out_avals, *args): + input_bufs = [input_handlers[type(x)](x) for x in args] + out_bufs = compiled.execute(input_bufs) + return [handle_result(aval, buf) for aval, buf in zip(out_avals, out_bufs)] + +input_handlers = { + int: xb.get_backend(None).buffer_from_pyval, + float: xb.get_backend(None).buffer_from_pyval, + np.ndarray: xb.get_backend(None).buffer_from_pyval, +} + +def handle_result(aval: ShapedArray, buf): + del aval # Unused for now. + return buf.to_py() + +xla_translations = {} +# - + +# Notice that `jaxpr_subcomp` has the structure of a simple interpreter. That's +# a common pattern: the way we process jaxprs is usually with an interpreter. +# And as with any interpreter, we need an interpretation rule for each +# primitive: + +def direct_translation(op, c, in_avals, in_vals): + del c, in_avals + return [op(*in_vals)] + +xla_translations[add_p] = partial(direct_translation, xops.Add) +xla_translations[mul_p] = partial(direct_translation, xops.Mul) +xla_translations[neg_p] = partial(direct_translation, xops.Neg) +xla_translations[sin_p] = partial(direct_translation, xops.Sin) +xla_translations[cos_p] = partial(direct_translation, xops.Cos) +xla_translations[greater_p] = partial(direct_translation, xops.Gt) + +def reduce_sum_translation(c, in_avals, in_vals, *, axis): + (x_aval,), (x,) = in_avals, in_vals + zero = xops.ConstantLiteral(c, np.array(0, x_aval.dtype)) + subc = xb.make_computation_builder('add') + shape = _xla_shape(ShapedArray((), x_aval.dtype)) + xops.Add(xops.Parameter(subc, 0, shape), xops.Parameter(subc, 1, shape)) + return [xops.Reduce(c, [x], [zero], subc.build(), [axis])] +xla_translations[reduce_sum_p] = reduce_sum_translation + +def broadcast_translation(c, in_avals, in_vals, *, shape, axes): + x, = in_vals + dims_complement = [i for i in range(len(shape)) if i not in axes] + return [xops.BroadcastInDim(x, shape, dims_complement)] +xla_translations[broadcast_p] = broadcast_translation + +# With that, we can now use `jit` to stage out, compile, and execute programs +# with XLA! + +@jit +def f(x, y): + print('tracing!') + return sin(x) * cos(y) + +z = f(3., 4.) # 'tracing!' prints the first time +print(z) + +z = f(4., 5.) # 'tracing!' doesn't print, compilation cache hit! +print(z) + +# + +@jit +def f(x): + return reduce_sum(x, axis=0) + +print(f(np.array([1., 2., 3.]))) + +# + +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +def deriv(f): + return lambda x: jvp(f, (x,), (1.,))[1] + +print( deriv(deriv(f))(3.)) +print(jit(deriv(deriv(f)))(3.)) +# - + +# Instead of implementing `jit` to first trace to a jaxpr and then to lower the +# jaxpr to XLA HLO, it might appear that we could have skipped the jaxpr step +# and just lowered to HLO while tracing. That is, perhaps we could have instead +# implemented `jit` with a `Trace` and `Tracer` that appended to the XLA HLO +# graph incrementally on each primitive bind. That's correct for now, but won't +# be possible when we introduce compiled SPMD computations because there we must +# know the number of replicas needed before compiling the program. + +# We haven't yet defined any transformation rules for `xla_call_p` other than +# its evaluation rule. That is, we can't yet do `vmap`-of-`jit` or +# `jvp`-of-`jit` or even `jit`-of`-jit`. Instead `jit` has to be at the "top +# level." Let's fix that! + +def xla_call_jvp_rule(primals, tangents, *, jaxpr, num_consts): + del num_consts # Unused. + new_jaxpr, new_consts = jvp_jaxpr(jaxpr) + outs = bind(xla_call_p, *new_consts, *primals, *tangents, jaxpr=new_jaxpr, + num_consts=len(new_consts)) + n = len(outs) // 2 + primals_out, tangents_out = outs[:n], outs[n:] + return primals_out, tangents_out +jvp_rules[xla_call_p] = xla_call_jvp_rule + +def jvp_jaxpr(jaxpr: Jaxpr) -> Tuple[Jaxpr, List[Any]]: + def jvp_traceable(*primals_and_tangents): + n = len(primals_and_tangents) // 2 + primals, tangents = primals_and_tangents[:n], primals_and_tangents[n:] + return jvp(jaxpr_as_fun(jaxpr), primals, tangents) + + in_avals = [v.aval for v in jaxpr.in_binders] + new_jaxpr, new_consts, _ = make_jaxpr(jvp_traceable, *in_avals, *in_avals) + return new_jaxpr, new_consts + +# + +def xla_call_vmap_rule(axis_size, vals_in, dims_in, *, jaxpr, num_consts): + del num_consts # Unused. + new_jaxpr, new_consts = vmap_jaxpr(jaxpr, axis_size, dims_in) + outs = bind(xla_call_p, *new_consts, *vals_in, jaxpr=new_jaxpr, + num_consts=len(new_consts)) + return outs, [0] * len(outs) +vmap_rules[xla_call_p] = xla_call_vmap_rule + +def vmap_jaxpr(jaxpr: Jaxpr, axis_size: int, bdims_in: List[BatchAxis] + ) -> Tuple[Jaxpr, List[Any]]: + vmap_traceable = vmap(jaxpr_as_fun(jaxpr), tuple(bdims_in)) + in_avals = [unmapped_aval(axis_size, d, v.aval) + for v, d in zip(jaxpr.in_binders, bdims_in)] + new_jaxpr, new_consts, _ = make_jaxpr(vmap_traceable, *in_avals) + return new_jaxpr, new_consts + +def unmapped_aval(axis_size: int, batch_dim: BatchAxis, aval: ShapedArray + ) -> ShapedArray: + if batch_dim is not_mapped: + return aval + else: + shape = list(aval.shape) + shape.insert(batch_dim, axis_size) + return ShapedArray(tuple(shape), aval.dtype) + + +# + +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot) + +ys = vmap(f, (0,))(np.arange(3.)) +print(ys) +# - + + +# One piece missing is device memory persistence for arrays. That is, we've +# defined `handle_result` to transfer results back to CPU memory as NumPy +# arrays, but it's often preferrable to avoid transferring results just to +# transfer them back for the next operation. We can do that by introducing a +# `DeviceArray` class, which can wrap XLA buffers and otherwise duck-type +# `numpy.ndarray`s: + +# + +def handle_result(aval: ShapedArray, buf): + return DeviceArray(aval, buf) + +class DeviceArray: + buf: Any + aval: ShapedArray + + def __init__(self, aval, buf): + self.aval = aval + self.buf = buf + + dtype = property(lambda self: self.aval.dtype) + shape = property(lambda self: self.aval.shape) + ndim = property(lambda self: self.aval.ndim) + + def __array__(self): return self.buf.to_py() + def __repr__(self): return repr(self.buf.to_py()) + def __str__(self): return str(self.buf.to_py()) + + _neg = staticmethod(neg) + _add = staticmethod(add) + _radd = staticmethod(add) + _mul = staticmethod(mul) + _rmul = staticmethod(mul) + _gt = staticmethod(greater) +input_handlers[DeviceArray] = lambda x: x.buf + +# + +@jit +def f(x): + y = sin(x) * 2. + z = - y + x + return z + +x, xdot = 3., 1. +y, ydot = jvp(f, (x,), (xdot,)) +print(y) +print(ydot)