Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardizing and generalizing object-oriented transformations #628

Merged
merged 16 commits into from
Feb 22, 2024

Conversation

chaoming0625
Copy link
Collaborator

@chaoming0625 chaoming0625 commented Feb 21, 2024

This PR standardizes the customization of object-oriented transformations. The key is using brainpy.math.VariableStack and brainpy.math.eval_shape.

One OO transformation involves two steps. The first step is using brainpy.math.eval_shape to evaluate all Variables used in the target function. The second step is the actual compilation phase, to compile the model on the given target device.

For example, to customize an object-oriented JIT compilation interface, we can use:

import jax
import brainpy.math as bm


def jit(fun):
  stack: bm.VariableStack = None
  jit_fun = None

  def new_fun(vars, *args, **kwargs):
    for k, v in vars.items():
        stack[k].value = v
    ret = fun(*args, **kwargs)
    new_vars = stack.dict_data()
    return ret, new_vars

  def wrapper(*args, **kwargs):
    global stack, jit_fun

    # [first step]: find all the variables
    if stack is None:
      with bm.VariableStack() as stack:
        ret = bm.eval_shape(fun, *args, **kwargs)
        jit_fun = jax.jit(new_fun)
      if not stack.is_first_stack():
        return ret

    # [second step]: jit compilation
    ret, new_vars = jit_fun(stack.dict_data(), *args, **kwargs)
    stack.assign(new_vars)
    return ret

  return wrapper


@chaoming0625 chaoming0625 marked this pull request as ready for review February 22, 2024 05:06
@chaoming0625 chaoming0625 merged commit 4d74816 into master Feb 22, 2024
32 checks passed
@chaoming0625 chaoming0625 deleted the oo-transform branch February 22, 2024 05:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant