You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am working on a project where I need to fit a curve to a large dataset using JAX for GPU-accelerated optimization. Here's a detailed breakdown of the problem:
Problem Statement:
I have a SQLite database containing three columns:
s_id: A unique identifier for each group.
x: The input variable, always positive.
y: The target variable, always positive and differentiable over a continuous range.
The goal is to fit the following curve to the data:
pred_y(t) = [Σ(c1_i * x(t-i))(i=0 to 15)]^(-4) *
[Σ(c2_i * x(t-i))(i=0 to 30)] *
[Σ(c3_i * x(t-i))(i=0 to 45)] *
[Σ(c4_i * x(t-i))(i=0 to 60)] *
[Σ(c5_i * x(t-i))_(i=0 to 75)]
I need to determine the coefficients that minimize the Mean Squared Error (MSE). The total MSE is calculated as the mean of all individual s_id group MSEs.
My dataset is large, containing approximately 1 million rows, so I want to leverage GPU acceleration.
The y values are always positive, and the signal is continuous and differentiable, making LBFGS optimization applicable.
Each s_id group has an uneven number of points, e.g., one group may have 1500 points, another 1900, etc.
My Approach:
I wrote the code to calculate the objective function and optimize the coefficients using the jaxopt.LBFGS solver.
My implementation uses JAX for computations and vectorization to maximize efficiency.
The Code:
Here is the full code I have written:
import jax
import jax.numpy as jnp
from jaxopt import LBFGS
import sqlite3
import pandas as pd
import time
import numpy as np
import cudf
def get_data_from_cudf():
db_path = "/content/data.db"
table_name = "data"
conn = sqlite3.connect(db_path)
df = cudf.from_pandas(pd.read_sql(f"SELECT * FROM {table_name}", conn))
conn.close()
df["s_id"] = df["s_id"].astype('int32')
df["x"] = df["x"].astype('float32')
df["y"] = df["y"].astype('float32')
print('Data Fetched Successfully!')
return df
# Helper function to compute rolling sum-product
def rolling_sumproduct(series, window_size, coeffs):
result = jnp.convolve(series, coeffs[::-1], mode='valid')
padded_result = jnp.concatenate((jnp.full(window_size - 1, jnp.nan), result))
return padded_result
# Compute predicted y column
def get_pred_y_col(x, coeff):
power = [-4, 1, 1, 1, 1]
terms = 5
window_sizes = [15 * i + 15 for i in range(terms)]
sum_products = jnp.array([
rolling_sumproduct(x, window_sizes[n - 1], coeff[int(7.5*n*(n-1)):int(7.5*n*(n+1))])
for n in range(1, terms + 1)
])
pred_y = jnp.prod(jnp.array([sp ** p for sp, p in zip(sum_products, power)]), axis=0)
return pred_y
def obj_func(coeff, x, y, symbol_ids, unique_s_ids, masks):
def compute_indiv_mse(coeff, mask, x, y):
masked_x = x[mask]
masked_y = y[mask]
pred_y = get_pred_y_col(masked_x, coeff)
error = (pred_y - masked_y) ** 2
return jnp.nanmean(error)
all_mses = jnp.array([compute_indiv_mse(coeff, mask, x, y) for mask in masks])
return jnp.nanmean(all_mses)
# Main optimization
def optimize_coefficients():
df = get_data_from_cudf()
s_ids = jnp.array(df["s_id"].to_numpy())
x = jnp.array(df["x"].to_numpy())
y = jnp.array(df["y"].fillna(jnp.nan).to_numpy())
unique_s_ids = jnp.array(np.unique(df["s_id"].to_numpy()))
masks = jnp.array([s_ids == uid for uid in unique_s_ids])
# Initialize coefficients
coeff_init = jnp.ones(225) # 15*5*6/2 = 7.5*5*6 = 225
def wrapped_obj_func(coeff):
return obj_func(coeff, x, y, s_ids, unique_s_ids, masks)
solver = LBFGS(fun=wrapped_obj_func, maxiter=1000, tol=1e-6)
solution = solver.run(coeff_init)
optimized_coeff = solution.params
print("Optimization Completed!")
print("Optimized Coefficients:", optimized_coeff)
def apply_pred_y(group):
group_x = jnp.array(group["x"].to_numpy())
pred_y = get_pred_y_col(group_x, optimized_coeff)
return cudf.Series(pred_y, index=group.index)
df["pred_y"] = df.groupby('s_id').apply(apply_pred_y)
return df, optimized_coeff
df = get_data_from_cudf()
print(df)
s = time.time()
df, optimized_coeff = optimize_coefficients()
e = time.time()
print(f"Time Taken: {e-s}")
Problems I Am Facing:
Due to the dynamic sizing of individual s_id groups, JAX throws errors in the objective function. Specifically, errors like:
NonConcreteBooleanIndexError: Array boolean indices must be concrete.
I tried precomputing masks for each s_id to avoid dynamic sizing but still encountered issues related to how JAX handles such operations.
I am unsure how to rewrite the objective function to handle this situation effectively while still leveraging JAX's performance benefits.
Request for Help:
How can I correct the implementation to handle the dynamic sizing issue with jnp in the objective function?
Are there workarounds to use JAX with uneven group sizes or any suggestions for an alternative approach to compute the objective function efficiently on the GPU?
Any insights into why these errors occur and how I can resolve them?
Thank you for your time and help! I am an intermediate programmer, so any detailed explanation or alternative solutions would be greatly appreciated. Please let me know if you need more information.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello everyone,
I am working on a project where I need to fit a curve to a large dataset using JAX for GPU-accelerated optimization. Here's a detailed breakdown of the problem:
Problem Statement:
I have a SQLite database containing three columns:
s_id: A unique identifier for each group.
x: The input variable, always positive.
y: The target variable, always positive and differentiable over a continuous range.
The goal is to fit the following curve to the data:
pred_y(t) = [Σ(c1_i * x(t-i))(i=0 to 15)]^(-4) *
[Σ(c2_i * x(t-i))(i=0 to 30)] *
[Σ(c3_i * x(t-i))(i=0 to 45)] *
[Σ(c4_i * x(t-i))(i=0 to 60)] *
[Σ(c5_i * x(t-i))_(i=0 to 75)]
I need to determine the coefficients that minimize the Mean Squared Error (MSE). The total MSE is calculated as the mean of all individual s_id group MSEs.
My dataset is large, containing approximately 1 million rows, so I want to leverage GPU acceleration.
The y values are always positive, and the signal is continuous and differentiable, making LBFGS optimization applicable.
Each s_id group has an uneven number of points, e.g., one group may have 1500 points, another 1900, etc.
My Approach:
The Code:
Here is the full code I have written:
Problems I Am Facing:
NonConcreteBooleanIndexError: Array boolean indices must be concrete.
Request for Help:
Thank you for your time and help! I am an intermediate programmer, so any detailed explanation or alternative solutions would be greatly appreciated. Please let me know if you need more information.
Beta Was this translation helpful? Give feedback.
All reactions