Skip to content

Commit

Permalink
repackage
Browse files Browse the repository at this point in the history
  • Loading branch information
Jelmer Kuperus committed Dec 30, 2024
1 parent d2c2764 commit 9d50dbc
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.client

import java.net.{InetSocketAddress, SocketAddress}

import scala.concurrent.Await
import scala.concurrent.duration.Duration

import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse, RegistrationServiceGrpc}
import io.grpc.netty.NettyChannelBuilder

object RegistrationClient {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.server

import java.net.InetSocketAddress
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}

import scala.concurrent.Future

import com.github.jelmerk.server.registration.{RegisterRequest, RegisterResponse}
import com.github.jelmerk.server.registration.RegistrationServiceGrpc.RegistrationService

class DefaultRegistrationService(val registrationLatch: CountDownLatch) extends RegistrationService {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.server.registration
package com.github.jelmerk.registration.server

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.{CountDownLatch, Executors}
Expand All @@ -7,6 +7,7 @@ import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters.mapAsScalaConcurrentMapConverter
import scala.util.Try

import com.github.jelmerk.server.registration.RegistrationServiceGrpc
import io.grpc.netty.NettyServerBuilder

// TODO this is all a bit messy
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.spark.knn
package com.github.jelmerk.serving.client

import java.net.InetSocketAddress
import java.util.concurrent.{Executors, LinkedBlockingQueue}
Expand All @@ -10,8 +10,8 @@ import scala.concurrent.duration.Duration
import scala.language.implicitConversions
import scala.util.Random

import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.server.index._
import com.github.jelmerk.server.registration.PartitionAndReplica
import io.grpc.netty.NettyChannelBuilder
import io.grpc.stub.StreamObserver

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package com.github.jelmerk.spark.knn
package com.github.jelmerk.serving.client

import java.net.InetSocketAddress

import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.server.index.{Result, SearchRequest}
import com.github.jelmerk.server.registration.PartitionAndReplica

class IndexClientFactory[TId, TVector, TDistance](
vectorConverter: TVector => SearchRequest.Vector,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.github.jelmerk.serving.client

/** Neighbor of an item.
*
* @param neighbor
* identifies the neighbor
* @param distance
* distance to the item
*
* @tparam TId
* type of the index item identifier
* @tparam TDistance
* type of distance
*/
case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance) // TODO move this
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package com.github.jelmerk.server.index
package com.github.jelmerk.serving.server

import scala.concurrent.{ExecutionContext, Future}
import scala.language.implicitConversions

import com.github.jelmerk.knn.scalalike.{Index, Item}
import com.github.jelmerk.server.index._
import com.github.jelmerk.server.index.IndexServiceGrpc.IndexService
import io.grpc.stub.StreamObserver
import org.apache.commons.io.output.CountingOutputStream
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.github.jelmerk.server.index
package com.github.jelmerk.serving.server

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.{LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
Expand All @@ -7,6 +7,7 @@ import scala.concurrent.ExecutionContext
import scala.util.Try

import com.github.jelmerk.knn.scalalike.{Index, Item}
import com.github.jelmerk.server.index.{IndexServiceGrpc, Result, SearchRequest}
import io.grpc.netty.NettyServerBuilder
import org.apache.hadoop.conf.Configuration

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import scala.util.control.NonFatal

import com.github.jelmerk.knn.ObjectSerializer
import com.github.jelmerk.knn.scalalike._
import com.github.jelmerk.server.index.IndexServerFactory
import com.github.jelmerk.server.registration.{PartitionAndReplica, RegistrationClient, RegistrationServerFactory}
import com.github.jelmerk.registration.client.RegistrationClient
import com.github.jelmerk.registration.server.{PartitionAndReplica, RegistrationServerFactory}
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.serving.server.IndexServerFactory
import com.github.jelmerk.spark.util.SerializableConfiguration
import org.apache.hadoop.fs.Path
import org.apache.spark.{Partitioner, SparkContext, TaskContext}
Expand Down Expand Up @@ -43,20 +45,6 @@ private[knn] case class ModelMetaData(
paramMap: Map[String, Any]
)

/** Neighbor of an item.
*
* @param neighbor
* identifies the neighbor
* @param distance
* distance to the item
*
* @tparam TId
* type of the index item identifier
* @tparam TDistance
* type of distance
*/
private[knn] case class Neighbor[TId, TDistance](neighbor: TId, distance: TDistance)

private[knn] trait IndexType {

/** Type of index. */
Expand Down Expand Up @@ -501,7 +489,7 @@ private[knn] abstract class KnnModelReader[TModel <: KnnModelBase[TModel]: Class

val indexRdd = sc
.makeRDD(partitionPaths)
.partitionBy(new PartitionIdPassthrough(metadata.numPartitions))
.partitionBy(new PartitionIdPartitioner(metadata.numPartitions))
.withResources(profile)
.mapPartitions { it =>
val (partitionId, indexPath) = it.next()
Expand Down Expand Up @@ -791,7 +779,7 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid
)
.as[(Int, TItem)]
.rdd
.partitionBy(new PartitionIdPassthrough(getNumPartitions))
.partitionBy(new PartitionIdPartitioner(getNumPartitions))
.values
else
dataset
Expand Down Expand Up @@ -886,17 +874,22 @@ private[knn] abstract class KnnAlgorithm[TModel <: KnnModelBase[TModel]](overrid

}

/** Partitioner that uses precomputed partitions
/** Partitioner that uses precomputed partitions. Each partition id is its own partition
*
* @param numPartitions
* number of partitions
*/
private[knn] class PartitionIdPassthrough(override val numPartitions: Int) extends Partitioner {
private[knn] class PartitionIdPartitioner(override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

// TODO rename this
private[knn] class PartitionReplicaPartitioner(partitions: Int, replicas: Int) extends Partitioner {
/**
* Partitioner that uses precomputed partitions. Each unique partition and replica combination is own partition
*
* @param partitions the total number of partitions
* @param replicas the total number of replicas
*/
private[knn] class PartitionReplicaIdPartitioner(partitions: Int, replicas: Int) extends Partitioner {

override def numPartitions: Int = partitions + (replicas * partitions)

Expand Down Expand Up @@ -928,7 +921,9 @@ private[knn] trait ModelLogging extends Logging {
protected def logInfo(partition: Int, message: String): Unit = println(f"partition $partition%04d: $message")

// protected def logInfo(partition: Int, replica: Int, message: String): Unit = logInfo(f"partition $partition%04d replica $partition%04d: $message")
protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(f"partition $partition%04d replica $partition%04d: $message")
protected def logInfo(partition: Int, replica: Int, message: String): Unit = println(
f"partition $partition%04d replica $partition%04d: $message"
)
}

private[knn] trait IndexServing extends ModelLogging with IndexType {
Expand Down Expand Up @@ -956,7 +951,7 @@ private[knn] trait IndexServing extends ModelLogging with IndexType {
}

val replicaRdd =
if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaPartitioner(numPartitions, numReplicas))
if (numReplicas > 0) keyedIndexRdd.partitionBy(new PartitionReplicaIdPartitioner(numPartitions, numReplicas))
else keyedIndexRdd

val server = RegistrationServerFactory.create(numPartitions, numReplicas)
Expand All @@ -974,7 +969,11 @@ private[knn] trait IndexServing extends ModelLogging with IndexType {

val serverAddress = server.address

logInfo(partitionNum, replicaNum, s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}")
logInfo(
partitionNum,
replicaNum,
s"started index server on host ${serverAddress.getHostName}:${serverAddress.getPort}"
)

logInfo(
partitionNum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package com.github.jelmerk.spark.knn

import java.net.InetSocketAddress

import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.{IndexClientFactory, Neighbor}

class QueryIterator[TId, TVector, TDistance, TQueryId](
indices: Map[PartitionAndReplica, InetSocketAddress],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._
import com.github.jelmerk.knn.ObjectSerializer
import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item}
import com.github.jelmerk.knn.scalalike.bruteforce.BruteForceIndex
import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.spark.knn._
import org.apache.spark.SparkContext
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import scala.reflect.runtime.universe._
import com.github.jelmerk.knn
import com.github.jelmerk.knn.scalalike.{DistanceFunction, Item}
import com.github.jelmerk.knn.scalalike.hnsw._
import com.github.jelmerk.server.registration.PartitionAndReplica
import com.github.jelmerk.registration.server.PartitionAndReplica
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.spark.knn._
import org.apache.spark.SparkContext
import org.apache.spark.ml.param._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ import com.github.jelmerk.server.index.{
DenseVector,
DoubleArrayVector,
FloatArrayVector,
IndexServerFactory,
Result,
SearchRequest,
SparseVector
}
import com.github.jelmerk.serving.client.IndexClientFactory
import com.github.jelmerk.serving.server.IndexServerFactory
import com.github.jelmerk.spark.linalg.functions.VectorDistanceFunctions
import org.apache.spark.ml.linalg.{DenseVector => SparkDenseVector, SparseVector => SparkSparseVector, Vector, Vectors}

Expand Down
3 changes: 2 additions & 1 deletion hnswlib-spark/src/test/scala/ServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ import scala.concurrent.ExecutionContext

import com.github.jelmerk.knn.scalalike.{floatCosineDistance, Item}
import com.github.jelmerk.knn.scalalike.hnsw.HnswIndex
import com.github.jelmerk.server.index.{DefaultIndexService, IndexServiceGrpc}
import com.github.jelmerk.server.index.IndexServiceGrpc
import com.github.jelmerk.serving.server.DefaultIndexService
import io.grpc.netty.NettyServerBuilder
import org.apache.hadoop.conf.Configuration

Expand Down

0 comments on commit 9d50dbc

Please sign in to comment.