Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Usage on microcontroller (ARM Cortex-M4) / LiteRT #139

Open
patrickjedlicka opened this issue Oct 14, 2024 · 1 comment
Open

Usage on microcontroller (ARM Cortex-M4) / LiteRT #139

patrickjedlicka opened this issue Oct 14, 2024 · 1 comment

Comments

@patrickjedlicka
Copy link

Hi, has anyone experience with making predictions on a microcontroller? Is the YDF C++ Library compatible with microcontroller architecture?

For my understanding YDF models not yet compatible with TFlite (nowadays LiteRT). Is this true or has something changed here?

Every help is welcome!

Best regards,

/P

@rstz
Copy link
Collaborator

rstz commented Oct 15, 2024

For GradientBoostedTrees, you can (experimentally) convert to JAX and then convert to LiteRT. I've posted this in another issue (and should probably add it to the documentation)

# Train a Gradient Boosted Trees model
gbt_learner = ydf.GradientBoostedTreesLearner(label='class')
gbt_model = gbt_learner.train(df_fft_general)

# Convert the model to Jax
jax_model = gbt_model.to_jax_function(compatibility="TFL")

# Convert a Jax model to a TensorFlow model.
tf_model = tf.Module()
tf_model.predict = tf.function(
    jax2tf.convert(jax_model.predict, with_gradient=False),
    jit_compile=True,
    autograph=False,
)

# Convert the Tensorflow model to a TFLite model
selected_examples = test_ds[:1].drop(model.label(), axis=1)
input_values = jax_model.encoder(selected_examples)
tf_input_specs = {
    k: tf.TensorSpec.from_tensor(tf.constant(v), name=k)
    for k, v in input_values.items()
}
concrete_predict = tf_model.predict.get_concrete_function(tf_input_specs)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [concrete_predict], tf_model
)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS,  # enable TensorFlow ops.
]
tflite_model = converter.convert()

For Random Forests and Isolation Forests, we haven't had time to implement it, but I don't see any blockers.

The faster and, probably, more elegant approach is to just compile the YDF inference code for your architecture. People have done this for some architectures, e.g. Raspberry Pi, but we can't make any promises about it. For reference, here are the (since deleted) instructions https://github.com/google/yggdrasil-decision-forests/blob/4bdedd31c041706a3d022313f1edaf494dea53c1/documentation/installation.md#compilation-on-and-for-raspberry-pi

Note that we're now using Bazel 5.3.0 (and will be migrating to Bazel 6 or 7 at some point).

In the compilation step, just going for the predict tool should be enough, i.e.,

${BAZEL} build //yggdrasil_decision_forests/cli:predict \
  --config=linux_cpp17 --features=-fully_static_link --host_javabase=@local_jdk//:jdk

If you're successful, please let us know, we'd be happy to include an updated guide in the repo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants