Skip to content

rodrigodzf/jaxpole

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jaxpole

This is an implementation of a differentiable time-varying all-pole filter in JAX based on torchlpc.

Install

pip install jaxpole

or locally from source

pip install -e '.[dev]'

How to use

import jax.numpy as jnp
import jax
from jaxpole.filter import allpole

pole = 0.99 * jnp.exp(1j * jnp.pi / 4)
coeffs = jnp.array([-2 * pole.real, pole.real**2 + pole.imag**2])
x = jax.random.normal(jax.random.PRNGKey(0), (1, 1000)) # (B, T)
A = jnp.tile(coeffs, (1, x.shape[-1], 1)) # (B, T, P)
zi = jnp.zeros((1, A.shape[-1])) # (B, P)

# filter the signal
y = allpole(x, A, zi)