Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement integer modular exponentiation using BigInteger#mod_pow #2006

Merged
merged 1 commit into from
Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Performance:

* Enable lazy translation from the parser AST to the Truffle AST for user code by default. This should improve application startup time (#1992).
* `instance variable ... not initialized` and similar warnings are now optimized to have no peak performance impact if they are not printed (depends on `$VERBOSE`).
* Implement integer modular exponentiation using `BigInteger#mod_pow` (#1999)

# 20.1.0

Expand Down
4 changes: 4 additions & 0 deletions spec/ruby/core/integer/pow_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,9 @@
it "raises a ZeroDivisionError when the given argument is 0" do
-> { 2.pow(5, 0) }.should raise_error(ZeroDivisionError)
end

it "raises a RangeError when the first argument is negative and the second argument is present" do
-> { 2.pow(-5, 1) }.should raise_error(RangeError)
end
end
end
75 changes: 75 additions & 0 deletions src/main/java/org/truffleruby/core/cast/BigIntegerCastNode.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) 2013, 2020 Oracle and/or its affiliates. All rights reserved. This
* code is released under a tri EPL/GPL/LGPL license. You can use it,
* redistribute it and/or modify it under the terms of the:
*
* Eclipse Public License version 2.0, or
* GNU General Public License version 2, or
* GNU Lesser General Public License version 2.1.
*/
package org.truffleruby.core.cast;

import java.math.BigInteger;

import org.truffleruby.Layouts;
import org.truffleruby.RubyContext;
import org.truffleruby.RubyLanguage;
import org.truffleruby.language.RubySourceNode;
import org.truffleruby.language.RubyGuards;
import org.truffleruby.language.RubyNode;
import org.truffleruby.language.control.RaiseException;

import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.CachedContext;
import com.oracle.truffle.api.dsl.GenerateUncached;
import com.oracle.truffle.api.dsl.ImportStatic;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.object.DynamicObject;

/** Casts a value into a BigInteger. */
@GenerateUncached
@ImportStatic(RubyGuards.class)
@NodeChild(value = "value", type = RubyNode.class)
public abstract class BigIntegerCastNode extends RubySourceNode {

public static BigIntegerCastNode create() {
return BigIntegerCastNodeGen.create(null);
}

public static BigIntegerCastNode create(RubyNode value) {
return BigIntegerCastNodeGen.create(value);
}

public abstract BigInteger executeCastBigInteger(Object value);

@Specialization
protected BigInteger doInt(int value) {
return BigInteger.valueOf(value);
}

@Specialization
protected BigInteger doLong(long value) {
return BigInteger.valueOf(value);
}

@Specialization(guards = "isRubyBignum(value)")
protected BigInteger doBignum(DynamicObject value) {
return Layouts.BIGNUM.getValue(value);
}

@Specialization(guards = "!isRubyInteger(value)")
protected BigInteger doBasicObject(Object value,
@CachedContext(RubyLanguage.class) RubyContext context) {
throw new RaiseException(context, notAnInteger(context, value));
}

@TruffleBoundary
private DynamicObject notAnInteger(RubyContext context, Object object) {
return context.getCoreExceptions().typeErrorIsNotA(
object.toString(),
"Integer",
this);
}

}
45 changes: 45 additions & 0 deletions src/main/java/org/truffleruby/core/numeric/IntegerNodes.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
import org.truffleruby.builtins.CoreMethodArrayArgumentsNode;
import org.truffleruby.builtins.CoreModule;
import org.truffleruby.builtins.Primitive;
import org.truffleruby.builtins.PrimitiveNode;
import org.truffleruby.builtins.PrimitiveArrayArgumentsNode;
import org.truffleruby.builtins.YieldingCoreMethodNode;
import org.truffleruby.core.CoreLibrary;
import org.truffleruby.core.cast.BooleanCastNode;
import org.truffleruby.core.cast.BigIntegerCastNode;
import org.truffleruby.core.cast.ToIntNode;
import org.truffleruby.core.cast.ToRubyIntegerNode;
import org.truffleruby.core.numeric.IntegerNodesFactory.AbsNodeFactory;
Expand All @@ -34,6 +36,7 @@
import org.truffleruby.core.string.StringNodes;
import org.truffleruby.core.symbol.CoreSymbols;
import org.truffleruby.language.NotProvided;
import org.truffleruby.language.RubyNode;
import org.truffleruby.language.WarnNode;
import org.truffleruby.language.control.RaiseException;
import org.truffleruby.language.dispatch.CallDispatchHeadNode;
Expand All @@ -42,6 +45,8 @@
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.CreateCast;
import com.oracle.truffle.api.dsl.NodeChild;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
Expand Down Expand Up @@ -1776,6 +1781,46 @@ protected int getLimit() {

}

@Primitive(name = "mod_pow")
@NodeChild(value = "base", type = RubyNode.class)
@NodeChild(value = "exponent", type = RubyNode.class)
@NodeChild(value = "modulo", type = RubyNode.class)
public abstract static class ModPowNode extends PrimitiveNode {
@Child private FixnumOrBignumNode fixnumOrBignum = new FixnumOrBignumNode();

@CreateCast("base")
protected RubyNode baseToBigInteger(RubyNode base) {
return BigIntegerCastNode.create(base);
}

@CreateCast("exponent")
protected RubyNode exponentToBigInteger(RubyNode exponent) {
return BigIntegerCastNode.create(exponent);
}

@CreateCast("modulo")
protected RubyNode moduloToBigInteger(RubyNode modulo) {
return BigIntegerCastNode.create(modulo);
}

@Specialization(guards = "modulo.signum() < 0")
protected Object mod_pow_neg(BigInteger base, BigInteger exponent, BigInteger modulo) {
BigInteger result = base.modPow(exponent, modulo.negate());
return fixnumOrBignum.fixnumOrBignum(result.signum() == 1 ? result.add(modulo) : result);
}

@Specialization(guards = "modulo.signum() > 0")
protected Object mod_pow_pos(BigInteger base, BigInteger exponent, BigInteger modulo) {
BigInteger result = base.modPow(exponent, modulo);
return fixnumOrBignum.fixnumOrBignum(result);
}

@Specialization(guards = "modulo.signum() == 0")
protected Object mod_pow_zero(BigInteger base, BigInteger exponent, BigInteger modulo) {
throw new RaiseException(getContext(), coreExceptions().zeroDivisionError(this));
}
}

@CoreMethod(
names = "downto",
needsBlock = true,
Expand Down
12 changes: 9 additions & 3 deletions src/main/ruby/truffleruby/core/integer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,15 @@ def nobits?(mask)
end

def pow(e, m=undefined)
return self ** e if Primitive.undefined?(m)
raise TypeError, '2nd argument not allowed unless all arguments are integers' unless Primitive.object_kind_of?(m, Integer)
(self ** e) % m
if Primitive.undefined?(m)
self ** e
else
raise TypeError, '2nd argument not allowed unless a 1st argument is integer' unless Primitive.object_kind_of?(e, Integer)
raise TypeError, '2nd argument not allowed unless all arguments are integers' unless Primitive.object_kind_of?(m, Integer)
raise RangeError, '1st argument cannot be negative when 2nd argument specified' if e.negative?

Primitive.mod_pow(self, e, m)
end
end

def times
Expand Down