I want to propose a PR for an new ops, which could be in the form of a tritionor a CUDA kerne? #20658
Labels
keras-team-review-pending
Pending review by a Keras team member.
type:feature
The user is asking for a new feature.
RWKV is a new-generation RNN model. It has pre-trained versions of different sizes, ranging from 0.3B to 14B. It has performance similar to LLM and the inference advantages of MAMBA.
I want to contribute the RNN part of RWKV to Keras. But I have several questions now. Firstly, the core operator of RWKV, time-mix iteration, is quite fast. Should I wait for the stable version to submit a PR, or should I submit a new op for each minor version?
Secondly, we have implemented the RWKV-6-Keras, and found that if we only use keras' ops operations, the efficiency is relatively low. To achieve high efficiency, we need to implement it based on cuda or triton. Personally, I prefer to provide a triton implementation, and torch will come with the triton library by default. For jax, we only need to install jax-trition additionally to support it.Cuda implementation requires a complete cuda environment, and the jax and torch we usually install with pip cannot directly compile cuda operators. Therefore, the triton implementation seems to be more user-friendly.
The text was updated successfully, but these errors were encountered: