Skip to content

Commit

Permalink
Fast multivariate multiplication (#18)
Browse files Browse the repository at this point in the history
Use Kronecker substitution to speed up multivariate multiplication.
  • Loading branch information
PoslavskySV authored Nov 26, 2017
1 parent bbfbce4 commit 55ceccb
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 4 deletions.
2 changes: 1 addition & 1 deletion rings/src/main/java/cc/redberry/rings/Rings.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public static UnivariateRing<UnivariatePolynomial<BigInteger>> UnivariateRingZp(
* @param factory factory
*/
public static <Term extends DegreeVector<Term>, Poly extends AMultivariatePolynomial<Term, Poly>>
PolynomialRing<Poly> MultivariateRing(Poly factory) {
MultivariateRing<Poly> MultivariateRing(Poly factory) {
return new MultivariateRing<>(factory);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1189,4 +1189,28 @@ public Set<Characteristics> characteristics() {
public final String toString() {
return toString(WithVariables.defaultVars(nVariables));
}

static long[] KroneckerMap(int[] degrees) {
long[] result = new long[degrees.length];
result[0] = 1L;
for (int i = 1; i < degrees.length; i++) {
result[i] = 1L;
double check = 1;
for (int j = 0; j < i; j++) {
long b = 2L * degrees[j] + 1;
result[i] *= b;
check *= b;
}

if (check > Long.MAX_VALUE) {
// long overflow -> can't use Kronecker's trick
return null;
}
}
return result;
}

/* shared constant */
/** when to switch to Kronecker's method */
static int KRONECKER_THRESHOLD = 256;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@

import cc.redberry.rings.*;
import cc.redberry.rings.bigint.BigInteger;
import cc.redberry.rings.poly.MachineArithmetic;
import cc.redberry.rings.poly.MultivariateRing;
import cc.redberry.rings.poly.PolynomialMethods;
import cc.redberry.rings.poly.UnivariateRing;
import cc.redberry.rings.poly.univar.UnivariatePolynomial;
import cc.redberry.rings.util.ArraysUtil;
import gnu.trove.iterator.TLongObjectIterator;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TLongObjectHashMap;
import org.apache.commons.math3.random.RandomGenerator;

import java.util.Arrays;
Expand Down Expand Up @@ -1354,6 +1357,20 @@ public MultivariatePolynomial<E> multiplyByBigInteger(BigInteger factor) {
@Override
public MultivariatePolynomial<E> multiply(MultivariatePolynomial<E> oth) {
assertSameCoefficientRingWith(oth);
if (oth.isZero())
return toZero();
if (isZero())
return this;
if (oth.isConstant())
return multiply(oth.cc());

if (size() > KRONECKER_THRESHOLD && oth.size() > KRONECKER_THRESHOLD)
return multiplyKronecker(oth);
else
return multiplyClassic(oth);
}

private MultivariatePolynomial<E> multiplyClassic(MultivariatePolynomial<E> oth) {
MonomialSet<Monomial<E>> newMap = new MonomialSet<>(ordering);
for (Monomial<E> othElement : oth.terms)
for (Monomial<E> thisElement : terms)
Expand All @@ -1362,6 +1379,94 @@ public MultivariatePolynomial<E> multiply(MultivariatePolynomial<E> oth) {
return loadFrom(newMap);
}

private MultivariatePolynomial<E> multiplyKronecker(MultivariatePolynomial<E> oth) {
int[] resultDegrees = new int[nVariables];
int[] thisDegrees = degrees();
int[] othDegrees = oth.degrees();
for (int i = 0; i < resultDegrees.length; i++)
resultDegrees[i] = thisDegrees[i] + othDegrees[i];

long[] map = KroneckerMap(resultDegrees);
if (map == null)
return multiplyClassic(oth);

// check that degrees fit long
double threshold = 0.;
for (int i = 0; i < nVariables; i++)
threshold += 1.0 * resultDegrees[i] * map[i];
threshold *= 2;

if (threshold > Long.MAX_VALUE)
return multiplyClassic(oth);

return fromKronecker(multiplySparseUnivariate(ring, toKronecker(map), oth.toKronecker(map)), map);
}

/**
* Convert to Kronecker's representation
*/
private TLongObjectHashMap<CfHolder<E>> toKronecker(long[] kroneckerMap) {
TLongObjectHashMap<CfHolder<E>> result = new TLongObjectHashMap<>(size());
for (Monomial<E> term : this) {
long exponent = term.exponents[0];
for (int i = 1; i < term.exponents.length; i++)
exponent += term.exponents[i] * kroneckerMap[i];
assert !result.contains(exponent);
result.put(exponent, new CfHolder<>(term.coefficient));
}
return result;
}

private static <E> TLongObjectHashMap<CfHolder<E>> multiplySparseUnivariate(Ring<E> ring,
TLongObjectHashMap<CfHolder<E>> a,
TLongObjectHashMap<CfHolder<E>> b) {
TLongObjectHashMap<CfHolder<E>> result = new TLongObjectHashMap<>(a.size() + b.size());
TLongObjectIterator<CfHolder<E>> ait = a.iterator();
while (ait.hasNext()) {
ait.advance();
TLongObjectIterator<CfHolder<E>> bit = b.iterator();
while (bit.hasNext()) {
bit.advance();

long deg = ait.key() + bit.key();
E val = ring.multiply(ait.value().coefficient, bit.value().coefficient);

CfHolder<E> r = result.putIfAbsent(deg, new CfHolder<>(val));
if (r != null)
r.coefficient = ring.add(r.coefficient, val);
}
}
return result;
}

private MultivariatePolynomial<E> fromKronecker(TLongObjectHashMap<CfHolder<E>> p,
long[] kroneckerMap) {
terms.clear();
TLongObjectIterator<CfHolder<E>> it = p.iterator();
while (it.hasNext()) {
it.advance();
if (ring.isZero(it.value().coefficient))
continue;
long exponent = it.key();
int[] exponents = new int[nVariables];
for (int i = 0; i < nVariables; i++) {
long div = exponent / kroneckerMap[nVariables - i - 1];
exponent = exponent - (div * kroneckerMap[nVariables - i - 1]);
exponents[nVariables - i - 1] = MachineArithmetic.safeToInt(div);
}
terms.add(new Monomial<>(exponents, it.value().coefficient));
}
return this;
}

static final class CfHolder<E> {
E coefficient;

CfHolder(E coefficient) {
this.coefficient = coefficient;
}
}

@Override
public MultivariatePolynomial<E> square() {
return multiply(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import cc.redberry.rings.poly.univar.UnivariatePolynomialZ64;
import cc.redberry.rings.poly.univar.UnivariatePolynomialZp64;
import cc.redberry.rings.util.ArraysUtil;
import gnu.trove.iterator.TLongObjectIterator;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TLongObjectHashMap;
import org.apache.commons.math3.random.RandomGenerator;

import java.util.Arrays;
Expand Down Expand Up @@ -1237,6 +1239,14 @@ public MultivariatePolynomialZp64 multiply(MultivariatePolynomialZp64 oth) {
return this;
if (oth.isConstant())
return multiply(oth.cc());

if (size() > KRONECKER_THRESHOLD && oth.size() > KRONECKER_THRESHOLD)
return multiplyKronecker(oth);
else
return multiplyClassic(oth);
}

private MultivariatePolynomialZp64 multiplyClassic(MultivariatePolynomialZp64 oth) {
MonomialSet<MonomialZp64> newMap = new MonomialSet<>(ordering);
for (MonomialZp64 othElement : oth.terms)
for (MonomialZp64 thisElement : terms)
Expand All @@ -1245,6 +1255,94 @@ public MultivariatePolynomialZp64 multiply(MultivariatePolynomialZp64 oth) {
return loadFrom(newMap);
}

private MultivariatePolynomialZp64 multiplyKronecker(MultivariatePolynomialZp64 oth) {
int[] resultDegrees = new int[nVariables];
int[] thisDegrees = degrees();
int[] othDegrees = oth.degrees();
for (int i = 0; i < resultDegrees.length; i++)
resultDegrees[i] = thisDegrees[i] + othDegrees[i];

long[] map = KroneckerMap(resultDegrees);
if (map == null)
return multiplyClassic(oth);

// check that degrees fit long
double threshold = 0.;
for (int i = 0; i < nVariables; i++)
threshold += 1.0 * resultDegrees[i] * map[i];
threshold *= 2;

if (threshold > Long.MAX_VALUE)
return multiplyClassic(oth);

return fromKronecker(multiplySparseUnivariate(ring, toKronecker(map), oth.toKronecker(map)), map);
}

/**
* Convert to Kronecker's representation
*/
private TLongObjectHashMap<CfHolder> toKronecker(long[] kroneckerMap) {
TLongObjectHashMap<CfHolder> result = new TLongObjectHashMap<>(size());
for (MonomialZp64 term : this) {
long exponent = term.exponents[0];
for (int i = 1; i < term.exponents.length; i++)
exponent += term.exponents[i] * kroneckerMap[i];
assert !result.contains(exponent);
result.put(exponent, new CfHolder(term.coefficient));
}
return result;
}

private static TLongObjectHashMap<CfHolder> multiplySparseUnivariate(IntegersZp64 ring,
TLongObjectHashMap<CfHolder> a,
TLongObjectHashMap<CfHolder> b) {
TLongObjectHashMap<CfHolder> result = new TLongObjectHashMap<>(a.size() + b.size());
TLongObjectIterator<CfHolder> ait = a.iterator();
while (ait.hasNext()) {
ait.advance();
TLongObjectIterator<CfHolder> bit = b.iterator();
while (bit.hasNext()) {
bit.advance();

long deg = ait.key() + bit.key();
long val = ring.multiply(ait.value().coefficient, bit.value().coefficient);

CfHolder r = result.putIfAbsent(deg, new CfHolder(val));
if (r != null)
r.coefficient = ring.add(r.coefficient, val);
}
}
return result;
}

private MultivariatePolynomialZp64 fromKronecker(TLongObjectHashMap<CfHolder> p,
long[] kroneckerMap) {
terms.clear();
TLongObjectIterator<CfHolder> it = p.iterator();
while (it.hasNext()) {
it.advance();
if (it.value().coefficient == 0)
continue;
long exponent = it.key();
int[] exponents = new int[nVariables];
for (int i = 0; i < nVariables; i++) {
long div = exponent / kroneckerMap[nVariables - i - 1];
exponent = exponent - (div * kroneckerMap[nVariables - i - 1]);
exponents[nVariables - i - 1] = MachineArithmetic.safeToInt(div);
}
terms.add(new MonomialZp64(exponents, it.value().coefficient));
}
return this;
}

static final class CfHolder {
long coefficient = 0;

CfHolder(long coefficient) {
this.coefficient = coefficient;
}
}

@Override
public MultivariatePolynomialZp64 square() {
return multiply(this);
Expand Down
21 changes: 18 additions & 3 deletions rings/src/main/java/cc/redberry/rings/util/ArraysUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,8 @@ public static int max(int[] array) {
return a;
}

public static int max(int[] array, int from , int to) {
int a = Integer.MIN_VALUE;;
public static int max(int[] array, int from, int to) {
int a = Integer.MIN_VALUE; ;
for (int i = from; i < to; i++)
a = Math.max(a, array[i]);
return a;
Expand All @@ -445,7 +445,7 @@ public static int min(int[] array) {
return a;
}

public static int min(int[] array, int from , int to) {
public static int min(int[] array, int from, int to) {
int a = Integer.MAX_VALUE;
for (int i = from; i < to; i++)
a = Math.min(a, array[i]);
Expand Down Expand Up @@ -538,6 +538,21 @@ public static double multiplyToDouble(final int[] array, int from, int to) {
return s;
}

public static double multiplyToDouble(final int[] array) {
return multiplyToDouble(array, 0, array.length);
}

public static double sumToDouble(final int[] array, int from, int to) {
double s = 0.;
for (int i = from; i < to; ++i)
s += array[i];
return s;
}

public static double sumToDouble(final int[] array) {
return sumToDouble(array, 0, array.length);
}

public static int or(final long[] array) {
return or(array, 0, array.length);
}
Expand Down
Loading

0 comments on commit 55ceccb

Please sign in to comment.