diff --git a/algebird-core/src/main/scala/com/twitter/algebird/Scan.scala b/algebird-core/src/main/scala/com/twitter/algebird/Scan.scala new file mode 100644 index 000000000..de2d59d33 --- /dev/null +++ b/algebird-core/src/main/scala/com/twitter/algebird/Scan.scala @@ -0,0 +1,322 @@ +package com.twitter.algebird + +import scala.collection.compat._ + +object Scan { + + /** + * Most consumers of Scan don't care about the type of the type State type variable. But for those that do, + * we make an effort to expose it in all of our combinators. + * @tparam I + * @tparam S + * @tparam O + */ + type Aux[-I, S, +O] = Scan[I, O] { type State = S } + + implicit def applicative[I]: Applicative[({ type L[O] = Scan[I, O] })#L] = new ScanApplicative[I] + + def from[I, S, O](initState: S)(presentAndNextStateFn: (I, S) => (O, S)): Aux[I, S, O] = + new Scan[I, O] { + override type State = S + override val initialState = initState + override def presentAndNextState(i: I, s: State): (O, State) = presentAndNextStateFn(i, s) + } + + def fromFunction[I, O](f: I => O): Aux[I, Unit, O] = new Scan[I, O] { + override type State = Unit + override val initialState = () + override def presentAndNextState(i: I, stateBeforeProcessingI: Unit): (O, State) = (f(i), ()) + } + + /** + * Scans take streams of inputs to streams of outputs, but some scans have trivial inputs and just produce a stream of + * outputs. Streams can be thought of as being a hidden state that is queryable for a head element, and another hidden + * state that represents the rest of the stream. + * @param initState The initial state of the scan; think of this as an infinite stream. + * @param destructor This function decomposes a stream into the its head-element and tail-stream. + * @tparam S The hidden state of the stream that we are turning into a Scan. + * @tparam O The type of the elments of the stream that we are turning into a Scan + * @return A Scan whose inputs are irrelevant, and whose outputs are those that we would get from implementing + * a stream using the information provided to this method. + */ + def iterate[S, O](initState: S)(destructor: S => (O, S)): Aux[Any, S, O] = new Scan[Any, O] { + override type State = S + override val initialState = initState + override def presentAndNextState(i: Any, stateBeforeProcessingI: S): (O, S) = + destructor(stateBeforeProcessingI) + } + + /** + * A Scan whose `Nth` output is the number `N` (starting from 0). + */ + val index: Aux[Any, Long, Long] = iterate(0L)(n => (n, n + 1)) + + def identity[A]: Aux[A, Unit, A] = fromFunction[A, A](x => x) + + /** + * + * @param initStateCreator A call-by-name method that allocates new mutable state + * @param presentAndUpdateStateFn A function that both presents the output value, and has the side-effect of updating the mutable state + * @tparam I + * @tparam S + * @tparam O + * @return A Scan that safely encapsulates state while it's doing its thing. + */ + def mutable[I, S, O](initStateCreator: => S)(presentAndUpdateStateFn: (I, S) => O): Aux[I, S, O] = + new Scan[I, O] { + override type State = S + override def initialState = initStateCreator + override def presentAndNextState(i: I, s: S): (O, S) = (presentAndUpdateStateFn(i, s), s) + } + + /** + * The trivial scan that always returns the same value, regardless of input + * @param t + * @tparam T + */ + def const[T](t: T): Aux[Any, Unit, T] = fromFunction(_ => t) + + /** + * + * @param aggregator + * @param initState + * @tparam A + * @tparam B + * @tparam C + * @return A scan which, when given `[a_1, ..., a_n]` outputs `[c_1, ..., c_n]` where + * `c_i = initState + aggregator.prepare(a_1) + ... + aggregator.prepare(a_i)` + */ + def fromAggregator[A, B, C](aggregator: Aggregator[A, B, C], initState: B): Aux[A, B, C] = + from(initState) { (a: A, stateBeforeProcessingI: B) => + // nb: the order of the arguments to semigroup.plus here is what determines the order of the final summation; + // this matters because not all semigroups are commutative + val stateAfterProcessingA = + aggregator.append(stateBeforeProcessingI, a) + (aggregator.present(stateAfterProcessingA), stateAfterProcessingA) + } + + /** + * + * @param monoidAggregator + * @tparam A + * @tparam B + * @tparam C + * @return A scan which, when given `[a_1, ..., a_n]` outputs `[c_1, ..., c_n]` where + * `c_i = monoidAggregator.monoid.zero + aggregator.prepare(a_1) + ... + aggregator.prepare(a_i)` + */ + def fromMonoidAggregator[A, B, C](monoidAggregator: MonoidAggregator[A, B, C]): Aux[A, B, C] = + fromAggregator(monoidAggregator, monoidAggregator.monoid.zero) + +} + +/** + * The Scan trait is an alternative to the `scanLeft` method on iterators/other collections for a range of + * of use-cases where `scanLeft` is awkward to use. At a high level it provides some of the same functionality as + * `scanLeft`, but with a separation of "what is the state of the scan" from + * "what are the elements that I'm scanning over?". In particular, when scanning over an iterator with `N` elements, + * the output is an iterator with `N` elements (in contrast to scanLeft's `N+1`). + * + * If you find yourself writing a `scanLeft` over pairs of elements, where you only use one element of the pair within + * the `scanLeft`, then throw that element away in a `map` immediately after the scanLeft is done, then this + * abstraction is for you. + * + * The canonical method to use a scan is `apply`. + * + * + * @tparam I The type of elements that the computation is scanning over. + * @tparam O The output type of the scan (typically distinct from the hidden `State` of the scan). + */ +sealed abstract class Scan[-I, +O] extends Serializable { + + import Scan.{from, Aux} + + /** + * The computation of any given scan involves keeping track of a hidden state. + */ + type State + + /** + * The state of the scan before any elements have been processed + * @return + */ + def initialState: State + + /** + * + * @param i An element in the stream to process + * @param stateBeforeProcessingI The state of the scan before processing i + * @return The output of the scan corresponding to processing i with state stateBeforeProcessing, + * along with the result of updating stateBeforeProcessing with the information from i. + */ + def presentAndNextState(i: I, stateBeforeProcessingI: State): (O, State) + + /** + * @param iter + * @return If `iter = Iterator(a_1, ..., a_n)`, return:` + * `Iterator(o_1, ..., o_n)` where + * `(o_(i+1), state_(i+1)) = presentAndNextState(a_i, state_i)` + * and `state_0 = initialState` + * + */ + def scanIterator(iter: Iterator[I]): Iterator[O] = new AbstractIterator[O] { + override def hasNext: Boolean = iter.hasNext + var state: State = initialState + override def next: O = { + val thisState = state + val thisA = iter.next + val (thisC, nextState) = presentAndNextState(thisA, thisState) + state = nextState + thisC + } + } + + /** + * @param inputs + * @param bf + * @tparam In The type of the input collection + * @tparam Out The type of the output collection + * @return + * Given inputs as a collection of the form `[a_1, ..., a_n]` the output will be a collection of the form: + * `[o_1, ..., o_n]` where + * `(o_(i+1), state_(i+1)) = presentAndNextState(a_i, state_i)` + * and `state_0 = initialState`. + */ + def apply[In <: TraversableOnce[I], Out]( + inputs: In + )(implicit bf: BuildFrom[In, O, Out]): Out = + bf.fromSpecific(inputs)(scanIterator(inputs.toIterator)) + + // combinators + + /** + * Return a new scan that is the same as this scan, but with a different `initialState`. + * @param newInitialState + * @return + */ + def replaceState(newInitialState: => State): Aux[I, State, O] = + from(newInitialState)(presentAndNextState(_, _)) + + def composePrepare[I1](f: I1 => I): Aux[I1, State, O] = from(initialState) { (i, stateBeforeProcessingI) => + presentAndNextState(f(i), stateBeforeProcessingI) + } + + def andThenPresent[O1](g: O => O1): Aux[I, State, O1] = from(initialState) { (i, stateBeforeProcessingI) => + val (c, stateAfterProcessingA) = presentAndNextState(i, stateBeforeProcessingI) + (g(c), stateAfterProcessingA) + } + + /** + * Return a scan that is semantically identical to + * `this.join(Scan.identity[I1])`, but where we don't pollute the `State` by pairing it + * redundantly with `Unit`. + * @tparam I1 + * @return If this Scan's `apply` method is given inputs `[a_1, ..., a_n]` resulting in outputs + * of the form `[o_1, ..., o_n`, then this results in a Scan whose `apply` method + * returns `[(o_1, a_1), ..., (o_n, a_n)]` when given the same input. + */ + def joinWithInput[I1 <: I]: Aux[I1, State, (O, I1)] = from(initialState) { (i, stateBeforeProcessingI) => + val (o, stateAfterProcessingI) = presentAndNextState(i, stateBeforeProcessingI) + ((o, i), stateAfterProcessingI) + } + + /** + * Return a scan whose output is paired with the state of the scan before each input updates the state. + * @return If this Scan's `apply` method is given inputs [a_1, ..., a_n] resulting in outputs + * of the form `[o_1, ..., o_n]`, where `(o_(i+1), state_(i+1)) = presentAndNextState(a_i, state_i)` + * and `state_0 = initialState`, + * return a scan that whose apply method, when given inputs `[a_1, ..., a_n]` will return + * `[(o_1, state_0), ..., (o_n, state_(n-1))]`. + */ + def joinWithPriorState: Aux[I, State, (State, O)] = from(initialState) { (i, stateBeforeProcessingI) => + val (o, stateAfterProcessingA) = presentAndNextState(i, stateBeforeProcessingI) + ((stateBeforeProcessingI, o), stateAfterProcessingA) + } + + /** + * Return a scan whose output is paired with the state of the scan after each input updates the state. + * @return If this Scan's `apply` method is given inputs `[a_1, ..., a_n]` resulting in outputs + * of the form `[o_1, ..., o_n]`, where `(o_(i+1), state_(i+1)) = presentAndNextState(a_i, state_i)`` + * and state_0 = initialState, + * return a scan that whose apply method, when given inputs `[a_1, ..., a_n]` will return + * `[(o_1, state_1), ..., (o_n, state_n]`. + */ + def joinWithPosteriorState: Aux[I, State, (O, State)] = from(initialState) { (i, stateBeforeProcessingI) => + val (c, stateAfterProcessingA) = presentAndNextState(i, stateBeforeProcessingI) + ((c, stateAfterProcessingA), stateAfterProcessingA) + } + + /** + * For every `foo`, `scan.joinWithIndex(foo) == scan(foo).zipWithIndex`. + * @return + * If this Scan's `apply` method is given inputs `[a_1, ..., a_n]` resulting in outputs + * of the form `[o_1, ..., o_n]`, return a scan that whose apply method, when given the same input, will return + * `[(o_1, 1), ..., (o_n, n)]`. + */ + def joinWithIndex: Aux[I, (State, Long), (O, Long)] = join(Scan.index) + + /** + * Compose two scans pairwise such that, when given pairwise zipped inputs, the resulting scan will output pairwise + * zipped outputs. + * @param scan2 + * @tparam I2 + * @tparam O2 + * @return If this Scan's apply method is given inputs `[a_1, ..., a_n]` resulting in outputs of + * the form `[o_1, ..., o_n]`, and `scan2.apply([b_1, ..., b_n] = [p_1, ..., p_n]` then + * `zip` will return a scan whose apply method, when given input + * `[(a_1, b_1), ..., (a_n, b_n)]` results in the output `[(o_1, p_1), ..., (o_2, p_2)]`. + * In other words: `scan.zip(scan2)(foo.zip(bar)) == scan(foo).zip(scan2(bar))` + */ + def zip[I2, O2](scan2: Scan[I2, O2]): Aux[(I, I2), (State, scan2.State), (O, O2)] = + from((initialState, scan2.initialState)) { (i1i2, stateBeforeProcessingI1I2) => + val (o1, state1AfterProcesingI1) = + presentAndNextState(i1i2._1, stateBeforeProcessingI1I2._1) + val (o2, state2AfterProcesingI2) = + scan2.presentAndNextState(i1i2._2, stateBeforeProcessingI1I2._2) + ((o1, o2), (state1AfterProcesingI1, state2AfterProcesingI2)) + } + + /** + * Given a scan that takes compatible input to this one, pairwise compose the state and outputs of each scan + * on a common input stream. + * @param scan2 + * @tparam I2 + * @tparam O2 + * @return If this Scan's apply method is given inputs [a_1, ..., a_n] resulting in outputs of + * the form `[o_1, ..., o_n]`, and `scan2.apply([a_1, ..., a_n] = [p_1, ..., p_n]` then + * `join` will return a scan whose apply method returns `[(o_1, p_1), ..., (o_2, p_2)]`. + * In other words: `scan.join(scan2)(foo) == scan(foo).zip(scan2(foo))` + */ + def join[I2 <: I, O2](scan2: Scan[I2, O2]): Aux[I2, (State, scan2.State), (O, O2)] = + from((initialState, scan2.initialState)) { (i, stateBeforeProcessingI) => + val (o1, state1AfterProcesingI1) = presentAndNextState(i, stateBeforeProcessingI._1) + val (o2, state2AfterProcesingI2) = scan2.presentAndNextState(i, stateBeforeProcessingI._2) + ((o1, o2), (state1AfterProcesingI1, state2AfterProcesingI2)) + } + + /** + * Takes the output of this scan and feeds as input into scan2. + * @param scan2 + * @tparam P + * @return If this Scan's apply method is given inputs `[a_1, ..., a_n]` resulting in outputs of + * the form `[o_1, ..., o_n]`, and `scan2.apply([o_1, ..., o_n] = [p_1, ..., p_n]` then + * `compose` will return a scan which returns `[p_1, ..., p_n]`. + */ + def compose[P](scan2: Scan[O, P]): Aux[I, (State, scan2.State), P] = + from((initialState, scan2.initialState)) { (i, stateBeforeProcessingI) => + val (o, state1AfterProcesingI) = presentAndNextState(i, stateBeforeProcessingI._1) + val (p, state2AfterProcesingO) = scan2.presentAndNextState(o, stateBeforeProcessingI._2) + (p, (state1AfterProcesingI, state2AfterProcesingO)) + } + +} + +class ScanApplicative[I] extends Applicative[({ type L[O] = Scan[I, O] })#L] { + override def map[T, U](mt: Scan[I, T])(fn: T => U): Scan[I, U] = + mt.andThenPresent(fn) + + override def apply[T](v: T): Scan[I, T] = + Scan.const(v) + + override def join[T, U](mt: Scan[I, T], mu: Scan[I, U]): Scan[I, (T, U)] = + mt.join(mu) +} diff --git a/algebird-test/src/test/scala/com/twitter/algebird/ScanTest.scala b/algebird-test/src/test/scala/com/twitter/algebird/ScanTest.scala new file mode 100644 index 000000000..4538f4f8f --- /dev/null +++ b/algebird-test/src/test/scala/com/twitter/algebird/ScanTest.scala @@ -0,0 +1,169 @@ +package com.twitter.algebird + +import org.scalacheck.Gen +import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks +import org.scalatest.{Matchers, WordSpec} + +import scala.collection.mutable.Queue + +object ScanTest { + // normal people will use Scan not Scan.Aux, so it's good for most of the tests to be using the more common interface. + type StringScan = Scan[Char, String] + + // technically speaking, these aren't exactly the Free scanner, since that would output a giant tree structure from + // the whole scan, but that giant tree structure is pretty close to a String. + + val directFreeScan: Scan.Aux[Char, List[Char], String] = Scan.from(List.empty[Char]) { + (char, previousState) => + val nextState = char :: previousState + (nextState.reverse.mkString, nextState) + } + + val mutableFreeScan: StringScan = Scan.mutable(new Queue[Char]()) { (char, previousState) => + previousState.enqueue(char) + previousState.mkString + } + + val aggregatorFreeScan: StringScan = { + val aggregator = Aggregator.fromMonoid[List[Char]] + + Scan + .fromMonoidAggregator(aggregator) + .composePrepare[Char](c => List(c)) + .andThenPresent(_.mkString) + + } + + val joinWithPosteriorStateFreeScan: StringScan = + directFreeScan + .andThenPresent(_ => ()) + .joinWithPosteriorState + .andThenPresent { case ((), state) => state.reverse.mkString } + + val joinWithPriorStateFreeScan1: StringScan = + directFreeScan + .andThenPresent(_ => ()) + .joinWithPriorState + .joinWithInput + .andThenPresent { case ((state, ()), input) => (input :: state).mkString.reverse } + + val joinWithPriorStateFreeScan2: StringScan = + directFreeScan + .andThenPresent(_ => ()) + .joinWithPriorState + .join(Scan.identity[Char]) + .andThenPresent { case ((state, ()), input) => (input :: state).mkString.reverse } + +} + +class ScanTest extends WordSpec with Matchers with ScalaCheckDrivenPropertyChecks { + import ScanTest._ + + def freeScanLaws(freeScan: StringScan): Unit = + forAll(Gen.listOf(Gen.alphaLowerChar)) { inputList => + val outputList = freeScan(inputList) + + outputList.length should equal(inputList.length) + outputList.zipWithIndex + .foreach { + case (ithOutput, i) => + val expectedOutput = inputList.slice(0, i + 1).mkString + ithOutput should equal(expectedOutput) + } + } + + def zipLaws(scan1: StringScan, scan2: StringScan): Unit = + forAll(Gen.listOf(Gen.alphaLowerChar), Gen.listOf(Gen.alphaLowerChar)) { (inputList1, inputList2) => + val outputList1 = scan1(inputList1) + val outputList2 = scan2(inputList2) + val zippedOutput = outputList1.zip(outputList2) + + val zippedScan = scan1.zip(scan2) + val zippedInput = inputList1.zip(inputList2) + val zippedScanOutput = zippedScan(zippedInput) + + (zippedOutput should contain).theSameElementsInOrderAs(zippedScanOutput) + } + + def joinWithIndexLaws(freeScan: StringScan): Unit = + forAll(Gen.listOf(Gen.alphaLowerChar)) { inputList => + val unIndexedOutput = freeScan(inputList) + + val joinedWithIndexOutput = freeScan.joinWithIndex(inputList) + (unIndexedOutput.zipWithIndex should contain).theSameElementsInOrderAs(joinedWithIndexOutput) + } + + "an illustrative example without scalacheck" should { + "work as you'd expect" in { + val output = directFreeScan(List('a', 'b', 'c')) + (output should contain).theSameElementsInOrderAs(List("a", "ab", "abc")) + } + } + + "freeAggreator laws" should { + "be obeyed by a direct implementation of the almost-free Scan" in { + freeScanLaws(directFreeScan) + } + + "be obeyed by a mutable implementation of the almost-free Scan" in { + freeScanLaws(mutableFreeScan) + } + + "be obeyed by an implementation of the almost-free Scan using fromAggregator, composePrepare, and andThenPresent" in { + freeScanLaws(aggregatorFreeScan) + } + + "be obeyed by an implementation of the almost-free Scan using a direct implementation, andThenPresent, and joinWithPosteriorState" in { + freeScanLaws(joinWithPosteriorStateFreeScan) + } + + "be obeyed by an implementation of the almost-free Scan using a direct implmeentation, andThenPresent, joinWithPriorState, and joinWithInput" in { + freeScanLaws(joinWithPriorStateFreeScan1) + } + + "be obeyed by an implementation of the almost-free Scan using a direct implmeentation, andThenPresent, joinWithPriorState, and join with scan.Identity" in { + freeScanLaws(joinWithPriorStateFreeScan2) + } + + "be obeyed by composing the identity scan on either side of a direct-implementation of the almost-free Scan" in { + freeScanLaws(Scan.identity.compose(directFreeScan)) + freeScanLaws(directFreeScan.compose(Scan.identity)) + } + + } + + "zipping aggregators" should { + "obey its laws" in { + zipLaws(directFreeScan, directFreeScan) + } + } + + "joinWithIndex" should { + "obey its laws" in { + joinWithIndexLaws(directFreeScan) + } + + "replaceState" should { + "behave as you'd expect" in { + forAll(Gen.listOf(Gen.alphaLowerChar), Gen.listOf(Gen.alphaLowerChar)) { (inputList1, inputList2) => + // first we'll run the scan on inputList1 ++ inputList2, which will result in output1 ++ output2. + // We should be able to replace the initial state of the scan such that just scanning only inputList2 + // will return output2. + val (_, output2) = directFreeScan(inputList1 ++ inputList2).splitAt(inputList1.length) + val stateOfScanAfterProcessingList1 = inputList1.reverse + val scanAfterReplacingState = directFreeScan.replaceState(stateOfScanAfterProcessingList1) + scanAfterReplacingState(inputList2) should equal(output2) + } + } + + "Scan.const" should { + "behave as you'd expect" in { + forAll(Gen.alphaLowerChar, Gen.listOf(Gen.alphaLowerChar)) { (const, inputList) => + (Scan.const(const)(inputList) should contain) + .theSameElementsInOrderAs(List.fill(inputList.length)(const)) + } + } + } + } + } +}