-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster curve analysis #1192
Comments
I looked into this in some detail. By modifying the following function, we can get the same result as scipy.minimize. def residual(params):
return jnp.sum(jnp.square(y_jax - jnp.exp(-params[0] * x) * jnp.sin(2 * np.pi * params[1] * x) - params[2] And, jaxopt a software library of the optimization for jax is provided. We can use as below import jaxopt
# scipy.optimize.minimize wrapper
%timeit jaxopt.ScipyMinimize(fun=residual).run(init_guess_jax)
# use .params instead of .x
plt.plot(x, function(x, *jaxopt_res.params)) |
Cool. Thanks @to24toro for followup. Could you please compare the performance of jaxopt and scipy optimization (JAX cost function and Numpy cost function) on your platform? I think jax functions need warmup to ignore the compile overhead for the first run. |
You are correct. @jit
def jax_scipy(x):
return jaxopt.ScipyMinimize(fun=residual).run(x)
jax_res = jax_scipy(init_guess_jax) Meanwhile, we can jit jaxopt.BFGS as an example. Of course, I am not sure we can simply compare with ScipyMinimize. Just in case, I show measured time for some approach. res = least_squares(fobj, x0=init_guess)
%timeit least_squares(fobj, x0=init_guess)
%timeit -n 1000 jaxopt.ScipyMinimize(fun=residual).run(init_guess_jax)
@jit
def jax_bfgs(x):
return jaxopt.BFGS(fun=residual).run(x)
# Warm up
jax_res = jax_bfgs(init_guess_jax)
%timeit -n jax_bfgs(init_guess_jax)
|
Closed and merged into #1268 |
Suggested feature
Background
CurveAnalysis is the one of core fitter base classes in Qiskit Experiments. This is currently used for every calibration experiment. We also support ParallelExperiment wrapper that allows us to combine multiple (no-qubit overlapping) experiment instances in the same run, and the calibration experiments can be seamlessly expanded to the device scale on top of this framework. In principle, we can calibrate production QC systems with 1000Q+ with current framework, however, the performance bottleneck must be carefully identified and resolved at this scale.
Current status
Let's go deep inside analysis with a simple example of T1 experiment. Note that the most time consuming operation in the curve analysis is the figure generation. We don't plan to build own (possibly) faster plotter. Indeed, in the era of 1000Q+ device, no hardware engineer will visually investigate plots from every experiment instance. So let's disable the plotter now.
Following (very naive) benchmark code is written in above assumption.
Because
BaseAnalysis.run
invokes another thread, I just testedBaseAnalysis._run_analysis
which is the core function of the analysis so that I can benchmark it with jupyter notebook (we should use another profiler to analyze entire performance including the initialization cost for the experiment data container).As you can see, more than 50% of time is consumed by the scipy least_square solver that minimizes the residual for model parameter search. By default, this solver numerically computes the Jacobian matrix for gradient information.
Proposal
There would be two approaches to speedup this fitting operation. One could implement this minimization solver in (potentially faster) compiled languages such as Rust or Julia. For example, fitting library is written in Rust. However, according to the scipy documentation
Because MINPACK is the library of FORTRAN, I'm not sure how much performance gain we can obtain in return for giving up using scipy (indeed scipy offers rich solver selection).
Apart from this venue, one could focus on the bottleneck of the Jacobian matrix computation. Because we already know the exact fit model, this approach sounds more natural idea. Fortunately, we don't need to write any callback to compute the Jacobian matrix, instead we can just rely on JAX. It offers a scipy minimize wrapper jax.scipy.optimize.minimize (seems like still under development though).
JAX implements python interface and math library which is comparable with numpy. Once we implement the residual function with their math library, we can even JIT compile the minimization routine. As a side-effect, this approach may make JAX required in our software (which might not be applicable to all users), but recently Qiskit released arraylias package that provides alias of JAX/NumPy. With the arraylias package, we can introduce JAX as an optional dependency and the CurveAnalysis solver may fall back to the NumPy residual function when JAX is not available. I believe approach is much more promising.
Test
In this test code, I use a damped sinusoidal function as a fit model. Fit data may have some noise around Y-axis. Because I'm very new to JAX, the code below is just a patchwork of some example codes.
This is the conventional numpy residual function and scipy least_square solver.
Then, try JAX residual function and JAX scipy minimize solver with JIT compile.
Finally compare the results.
As you can see the JAX result is bit off from the actual data points. I guess the difference is due to the return value of the JAX residual function. It errored when I returned vector value, while scipy solver could properly handle it. I believe there is some trick to fix this. On the other hand, it's worth calling attention to the performance improvement; JAX solver only consumed 1.5% of the scipy solver time to return fit parameters!
Note
Because the JIT compile is invoked when the wrapped function is called for the first time, ideally the residual function must be singleton. Otherwise JIT compile is run in every fitter instance and doesn't give significant performance improvement in the parallel environment. In the above example, it was most efficient to JIT compile the minimize function. When I only JIT compile the residual function, I didn't obtain any performance gain -- need more investigation.
I hope the future fit model will look like
The CurveAnalysis can just consume this model object which provides an efficient cost function along with the initial guess generator. This allows us to completely decouple the fit model and fit protocol, and reduce the maintenance overhead of curve analysis subclasses.
The text was updated successfully, but these errors were encountered: