Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trace over Python #354

Open
mofeing opened this issue Dec 9, 2024 · 10 comments
Open

Trace over Python #354

mofeing opened this issue Dec 9, 2024 · 10 comments

Comments

@mofeing
Copy link
Collaborator

mofeing commented Dec 9, 2024

There are times where I need to use some Python package but Reactant is unable to trace over it. So I need to find a way to translate it to Julia before and then trace over it.
For example, for my VQE simulation, the circuit is defined in Qiskit so I need to translate the circuit to Yao and then trace over Yao.

Problem is that this translation is not always posible because...

  1. there might be libraries in Python which do not have an equivalent in Julia
  2. even if there exists an equivalent, the translation might not be perfect
  3. coding the alternative yourself can be time and energy consuming

This could be solved if we add a way to trace over PythonCall. I think we can do this by implementing the ConcreteRArray and TracedRArray types as Python classes, such that they register any operation being done onto them by emitting the MLIR like we do (unlike JAX, which constructs a computational graph first). This MLIR emission can be performed by calling back to Reactant in Julia. And the interception of calls can be done with NumPy's dispatch mechanism.

It doesn't need to be perfect, but just having something would alleviate the interaction between Python and Reactant.

Pd.: My intention is not to reimplement JAX in Python by calling Reactant jajaja. I don't want to start a war against them, but just to solve some integration problems.

@wsmoses
Copy link
Member

wsmoses commented Dec 9, 2024

Can we just have trace of pythoncall call Jax and import it with an hlo_call?

@mofeing
Copy link
Collaborator Author

mofeing commented Dec 9, 2024

Can we just have trace of pythoncall call Jax and import it with an hlo_call?

I'm not very familiar with hlo_call yet, but it could be a possibility. The only problem I see is that users have to set up their Python code to call Jax right? Or can we get in charge of setting up the Jax part?

@wsmoses
Copy link
Member

wsmoses commented Dec 9, 2024

In principle we can automate this if the python code is itself jax compatible

@mofeing
Copy link
Collaborator Author

mofeing commented Dec 9, 2024

So in theory, idiomatic usage of NumPy should be sufficient to make it work with Jax. In practice, I fear that some things like indexing can give us problems as stated in https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates

@wsmoses
Copy link
Member

wsmoses commented Dec 11, 2024

so in principle once #364 lands, we should be able to overload all uses of PythonCall.pycall to do whatever we want. In this case, we want to basically do what @sefffal starting discussing on github with:

using Reactant
using Reactant: Ops
using Enzyme
using PythonCall
jax = pyimport("jax")
numpy = pyimport("numpy")

function PythonCall.pycall(f::Py, args::Reactant.TracedRArray...; kwargs...)
    lowered = jax.jit(f).lower(args...)
    inputs = map(args) do arg
        numpy.array(size(arg), dtype=numpy.float32)
    end
    return Ops.hlo_call(
        pyconvert(String, lowered.as_text()),
        inputs...
    )
end

f = @compile jax.numpy.sum(
    Reactant.to_rarray(Float32[1, 2, 3]),
)

@wsmoses
Copy link
Member

wsmoses commented Dec 11, 2024

of course here we should also parse the tracedrarray eltype, transpose shapes, and pass the inputs vars to the jax lower call, and the args vars to the hlo_call

@wsmoses
Copy link
Member

wsmoses commented Dec 11, 2024

alternative/worst case [and perhaps just useful regardless], we can do something like the following

# Reactant.jl
function Ops.python_call end

# ReactantPythonCallExt.jl

function Reactant.Ops.python_call(python_string, args...)
   ...
end

@mofeing
Copy link
Collaborator Author

mofeing commented Dec 13, 2024

I've seen @sefffal's examples on the Discourse and I'm impressed how well it has worked!

This is the first step to trace over Python and having compiled versions of functions without dependencies, which is super cool. But in my case, I need sth more. Like consider the following 2 cases:

  1. what if we want to pass a Julia object to Python...
  2. ...and what if we want to get a Python object back from it

On the first point, we can linearize. But then the user Python code will be a function whose arguments will be arrays. And for the second point, it's similar but with delinearization only that @compile doesn't know to which create_result method to call.

We can do several things here like if there is already a conversion from the Python object to a Julia object (with pyconvert or with other user-defined mechanism), then we convert the result to Julia and call create_result on it to get the conversion.

I would like to give an example but it doesn't entirely work. In Tenet, we have a way to convert Qiskit parametric circuits in Python to Tenet tensor networks in Julia, and what I want is that the compiled function...

  1. receives the parameters of the quantum circuit as arguments
  2. constructs the quantum circuit in Python
  3. convert the quantum circuit in Python to a tensor network in Julia
  4. do some operations with the tensor network and return a float

all these points should be compiled by Reactant. the current problem is that Qiskit only allows numpy.ndarrays as parameters because it integrates well with sympy and ... whatever. i think we can find this pattern in more cases. the way i make it work right now is to convert the quantum circuit to Yao manually and use Yao instead, but the Yao has less features, is giving us some problems and a lot of people of my field directly use qiskit and python.

@wsmoses
Copy link
Member

wsmoses commented Dec 20, 2024

@sefffal for functions purely calling arrays, the nice version of this is now implemented here: #407

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants