Skip to content

Commit

Permalink
Switch mj_solveM2 to use CSR representation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717639495
Change-Id: I59eee0f23606481480ec5acb8a3bba2c14b2ea2f
  • Loading branch information
yuvaltassa authored and copybara-github committed Jan 20, 2025
1 parent baf1c43 commit a5ab7a9
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions src/engine/engine_core_smooth.c
Original file line number Diff line number Diff line change
Expand Up @@ -1668,9 +1668,10 @@ void mj_solveLDs(mjtNum* restrict x, const mjtNum* qLDs, const mjtNum* qLDiagInv
continue;
}

int adr = rowadr[i];
int d = rownnz[i] - 1;
if (d > 0) {
int d;
if ((d = rownnz[i] - 1) > 0) {
int adr = rowadr[i];

// one vector
if (n == 1) {
x[i] -= mju_dotSparse(qLDs+adr, x, d, colind+adr, /*flg_unc1=*/0);
Expand Down Expand Up @@ -1769,41 +1770,53 @@ void mj_solveM_island(const mjModel* m, const mjData* d, mjtNum* restrict x, int
void mj_solveM2(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y,
const mjtNum* sqrtInvD, int n) {
// local copies of key variables
mjtNum* qLD = d->qLD;
int* dof_Madr = m->dof_Madr;
int* dof_parentid = m->dof_parentid;
int nv = m->nv;
int nv = m->nv, nC = m->nC;
const int* rownnz = d->C_rownnz;
const int* rowadr = d->C_rowadr;
const int* colind = d->C_colind;
const int* diagnum = m->dof_simplenum;

// x = y
mju_copy(x, y, n * nv);

// loop over the n input vectors
for (int ivec=0; ivec < n; ivec++) {
int offset = ivec*nv;
// temporary: make local CSR version of qLD
mj_markStack(d);
mjtNum* qLD = mjSTACKALLOC(d, nC, mjtNum);
for (int i=0; i < nC; i++) {
qLD[i] = d->qLD[d->mapM2C[i]];
}

// x <- inv(L') * x; skip simple, exploit sparsity of input vector
for (int i=nv-1; i >= 0; i--) {
mjtNum tmp;
if (!m->dof_simplenum[i] && (tmp = x[i+offset])) {
// init
int Madr_ij = dof_Madr[i]+1;
int j = dof_parentid[i];
// x <- L^-T x
for (int i=nv-1; i > 0; i--) {
// skip diagonal rows
if (diagnum[i]) {
continue;
}

// traverse ancestors backwards
while (j >= 0) {
x[j+offset] -= qLD[Madr_ij++] * tmp; // x(j) -= L(i,j) * x(i)
// prepare row i column address range
int start = rowadr[i];
int end = start + rownnz[i] - 1;

// advance to parent
j = dof_parentid[j];
// process all vectors
for (int offset=0; offset < n*nv; offset+=nv) {
mjtNum x_i;
if ((x_i = x[i+offset])) {
for (int adr=start; adr < end; adr++) {
x[offset + colind[adr]] -= qLD[adr] * x_i;
}
}
}
}

// x <- sqrt(inv(D)) * x
for (int i=0; i < nv; i++) {
x[i+offset] *= sqrtInvD[i]; // x(i) /= sqrt(L(i,i))
// x <- D^-1/2 x
for (int i=0; i < nv; i++) {
mjtNum invD_i = sqrtInvD[i];
for (int offset=0; offset < n*nv; offset+=nv) {
x[i+offset] *= invD_i;
}
}

mj_freeStack(d);
}


Expand Down

0 comments on commit a5ab7a9

Please sign in to comment.