Qax: A library for simplifying research prototyping #16516
Unanswered
davisyoshida
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
🦆Qax🦆
I've been working on a tool called Qax, which I solicited feedback on a couple weeks ago. I've made a bunch of ergonomics improvements since then, and I think it's ready for people to start using for prototyping research ideas.
The main idea of this tool is to make it easy to implement anything which falls under "thing which represents a tensor but doesn't actually instantiate it." The motivating use cases were LoRA and quantized matrices, but I've run into several more since, and @patrick-kidger has suggested a bunch as well. You can basically think of it as a convenience wrapper around
Tracer
s for use cases which don't need the full power of the tracing system.I have a Twitter thread going over the core idea, but I'll expand on it a bit here so people can give feedback. Please do let me know if you have any suggestions. (I already have some backend changes planned based on conversations with Patrick, but they shouldn't affect the frontend API).
Mini-LoRA
Here's a minimal example: implementing LoRA. I've golfed it in order to show what the essential parts are:
The thought here is that ideas like LoRA can be fully specified by answering three questions:
Slightly-less-mini-LoRA
Here's an ungolfed version which shows a few more features as well as not crashing if
dot_general
s other than left matmuls happen:I didn't have to specify the
shape
anddtype
which define this array's aval here because they're automatically derived from thematerialize()
method. In cases where they can't be derived that way (e.g. a symbolic zero needs to already know its shape/dtype in order to be materialized), they can be passed as keyword args at initialization.The way to use the type defined above is with the
qax.use_implicit_args
transform. It transforms a function which takes JAX types into one that takes an ImplicitArray in any postion/keyword where an ImplicitArray was passed.The upshot of this is that you don't need to modify the underlying model code, and you get handed a function which is compatible with the JIT, grad, and vmap. (I'm not sure how it plays with multi-device stuff yet since I do all my development on two GPUs which aren't the same model...).
The code below shows how to use the above LoRA type with a HuggingFace model. Most of it is just standard, so I've marked the two Qax related changes.
It works!
Symbolic identity matrix
As another demo, here's an identity matrix which only stores a shape but no actual data:
Matmuls which pass the
NotImplemented
check above will take zero FLOPs and just pass their argument through:Nesting ImplicitArrays
ImplicitArray
instances can be arbitrarily nested, so combining this withLoraMatrix
gives a representation off
are necessary to support this:Assorted examples
I already used Qax to implement the quantization method I was analyzing in this repo/arxiv note as well. Being able to test stuff on random HuggingFace models or my own Haiku models without patching model code all the time was a really nice change of pace.
A couple more examples:
TODOs
👋
Thanks taking the time to read this! Please let me know if you have any questions or comments.
Beta Was this translation helpful? Give feedback.
All reactions