Skip to content

Multiple attributes NSW implemented by Golang

Notifications You must be signed in to change notification settings

RyanLiGod/MA-NSW

Repository files navigation

MA-NSW

MA-NSW is a GO implementation of the multiple attribute HNSW approximate nearest-neighbour search algorithm based on HNSW (Implemented in C++ in https://github.com/searchivarius/nmslib and described in https://arxiv.org/abs/1603.09320)

Usage

Simple usage example. See examples folder for more. Note that both index building and searching can be safely done in parallel with multiple goroutines. You can always extend the index, even while searching.

package main

import (
	hnsw ".."
	"fmt"
	"github.com/grd/stat"
	"math/rand"
	"time"
)

// NUM: Size of training data
var NUM = 3000

// DIMENSION: Dimension of data
var DIMENSION = 32

// TESTNUM: Size of query data
var TESTNUM = 10

func main() {

	const (
		M              = 16
		efConstruction = 400
		efSearch       = 400
		K              = 100
		distType       = "l2" // l2 or cosine
	)

	var zero hnsw.Point = make([]float32, DIMENSION)

	h := hnsw.New(M, efConstruction, zero, distType)
	h.Grow(NUM)

	provinces := []string{"blue", "red", "green", "yellow"}
	types := []string{"sky", "land", "sea"}
	titles := []string{"boy", "girl"}

	for i := 1; i <= NUM; i++ {
		randomAttr := []string{provinces[rand.Intn(4)], types[rand.Intn(3)], titles[rand.Intn(2)]}
		h.Add(randomPoint(), uint32(i), randomAttr)
		if (i)%1000 == 0 {
			fmt.Printf("%v points added\n", i)
		}
	}

	fmt.Println("Saving index...")
	err := h.Save("test.ind", true)  //true: product mode
	if err != nil {
		panic("Save error!")
	}
	fmt.Println("Done! Loading index...")
	h, timestamp, _ := hnsw.Load("test.ind")
	fmt.Printf("Index loaded, time saved was %v\n", time.Unix(timestamp, 0))

	fmt.Printf("Now searching with HNSW...\n")
	timeRecord := make([]float64, TESTNUM)
	hits := 0
	queries := make([]hnsw.Point, TESTNUM)
	truth := make([][]uint32, TESTNUM)
	// start := time.Now()
	for i := 0; i < TESTNUM; i++ {
		searchAttr := []string{provinces[rand.Intn(4)], types[rand.Intn(3)], titles[rand.Intn(2)]}
		fmt.Printf("Generating queries and calculating true answers using bruteforce search...\n")

		queries[i] = randomPoint()
		resultTruth := h.SearchBrute(queries[i], K, searchAttr)
		truth[i] = make([]uint32, K)
		for j := K - 1; j >= 0; j-- {
			item := resultTruth.Pop()
			truth[i][j] = item.ID
		}

		startSearch := time.Now()
		result := h.Search(queries[i], efSearch, K, searchAttr)
		//result := h.Search(queries[i], efSearch, K, []string{"nil", "nil", "nil"})
		stopSearch := time.Since(startSearch)
		timeRecord[i] = stopSearch.Seconds() * 1000
		fmt.Print("Searching with attributes:")
		fmt.Println(searchAttr)
		if result.Size != 0 {
			for j := 0; j < K; j++ {
				item := result.Pop()
				fmt.Printf("%v  ", item)
				if item != nil {
					fmt.Println(h.GetNodeAttr(item.ID))
					for k := 0; k < K; k++ {
						if item.ID == truth[i][k] {
							hits++
						}
					}
				}
			}
		} else {
			fmt.Println("Can't return any node")
		}

		fmt.Println()
	}

	// stop := time.Since(start)

	data := stat.Float64Slice(timeRecord)
	mean := stat.Mean(data)
	variance := stat.Variance(data)

	fmt.Printf("Mean of queries time(MS): %v\n", mean)
	fmt.Printf("Variance of queries time: %v\n", variance)
	fmt.Printf("%v queries / second (single thread)\n", 1000.0/mean)
	fmt.Printf("Average 10-NN precision: %v\n", float64(hits)/(float64(TESTNUM)*float64(K)))
	fmt.Printf("\n")
	fmt.Printf(h.Stats())
}

func randomPoint() hnsw.Point {
	var v hnsw.Point = make([]float32, DIMENSION)
	for i := range v {
		v[i] = rand.Float32()
	}
	return v
}

About

Multiple attributes NSW implemented by Golang

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published