-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trace over Python #354
Comments
Can we just have trace of pythoncall call Jax and import it with an hlo_call? |
I'm not very familiar with |
In principle we can automate this if the python code is itself jax compatible |
So in theory, idiomatic usage of NumPy should be sufficient to make it work with Jax. In practice, I fear that some things like indexing can give us problems as stated in https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates |
so in principle once #364 lands, we should be able to overload all uses of using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
jax = pyimport("jax")
numpy = pyimport("numpy")
function PythonCall.pycall(f::Py, args::Reactant.TracedRArray...; kwargs...)
lowered = jax.jit(f).lower(args...)
inputs = map(args) do arg
numpy.array(size(arg), dtype=numpy.float32)
end
return Ops.hlo_call(
pyconvert(String, lowered.as_text()),
inputs...
)
end
f = @compile jax.numpy.sum(
Reactant.to_rarray(Float32[1, 2, 3]),
) |
of course here we should also parse the tracedrarray eltype, transpose shapes, and pass the |
alternative/worst case [and perhaps just useful regardless], we can do something like the following # Reactant.jl
function Ops.python_call end
# ReactantPythonCallExt.jl
function Reactant.Ops.python_call(python_string, args...)
...
end |
I've seen @sefffal's examples on the Discourse and I'm impressed how well it has worked! This is the first step to trace over Python and having compiled versions of functions without dependencies, which is super cool. But in my case, I need sth more. Like consider the following 2 cases:
On the first point, we can linearize. But then the user Python code will be a function whose arguments will be arrays. And for the second point, it's similar but with delinearization only that We can do several things here like if there is already a conversion from the Python object to a Julia object (with I would like to give an example but it doesn't entirely work. In Tenet, we have a way to convert Qiskit parametric circuits in Python to Tenet tensor networks in Julia, and what I want is that the compiled function...
all these points should be compiled by Reactant. the current problem is that Qiskit only allows |
There are times where I need to use some Python package but Reactant is unable to trace over it. So I need to find a way to translate it to Julia before and then trace over it.
For example, for my VQE simulation, the circuit is defined in Qiskit so I need to translate the circuit to Yao and then trace over Yao.
Problem is that this translation is not always posible because...
This could be solved if we add a way to trace over PythonCall. I think we can do this by implementing the
ConcreteRArray
andTracedRArray
types as Python classes, such that they register any operation being done onto them by emitting the MLIR like we do (unlike JAX, which constructs a computational graph first). This MLIR emission can be performed by calling back to Reactant in Julia. And the interception of calls can be done with NumPy's dispatch mechanism.It doesn't need to be perfect, but just having something would alleviate the interaction between Python and Reactant.
Pd.: My intention is not to reimplement JAX in Python by calling Reactant jajaja. I don't want to start a war against them, but just to solve some integration problems.
The text was updated successfully, but these errors were encountered: