Skip to content

Commit

Permalink
feat(orama): adds vector search
Browse files Browse the repository at this point in the history
  • Loading branch information
micheleriva committed Aug 3, 2023
1 parent 3d9d227 commit b33aaac
Show file tree
Hide file tree
Showing 14 changed files with 421 additions and 76 deletions.
13 changes: 13 additions & 0 deletions packages/orama/src/cjs/index.cts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import type { count as esmCount, getByID as esmGetByID } from '../methods/docs.j
import type { insert as esmInsert, insertMultiple as esminsertMultiple } from '../methods/insert.js'
import type { remove as esmRemove, removeMultiple as esmRemoveMultiple } from '../methods/remove.js'
import type { search as esmSearch } from '../methods/search.js'
import type { searchVector as esmSearchVector } from '../methods/search-vector.js'
import type { load as esmLoad, save as esmSave } from '../methods/serialization.js'
import type { update as esmUpdate, updateMultiple as esmUpdateMultiple } from '../methods/update.js'

Expand All @@ -18,6 +19,7 @@ let _esmSave: typeof esmSave
let _esmSearch: typeof esmSearch
let _esmUpdate: typeof esmUpdate
let _esmUpdateMultiple: typeof esmUpdateMultiple
let _esmSearchVector: typeof esmSearchVector

export async function count(...args: Parameters<typeof esmCount>): ReturnType<typeof esmCount> {
if (!_esmCount) {
Expand Down Expand Up @@ -133,5 +135,16 @@ export async function updateMultiple(
return _esmUpdateMultiple(...args)
}

export async function searchVector(
...args: Parameters<typeof esmSearchVector>
): ReturnType<typeof esmSearchVector> {
if (!_esmSearchVector) {
const imported = await import('../methods/search-vector.js')
_esmSearchVector = imported.searchVector
}

return _esmSearchVector(...args)
}

export * as components from './components/defaults.cjs'
export * as internals from './internals.cjs'
43 changes: 43 additions & 0 deletions packages/orama/src/components/cosine-similarity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import type { Magnitude, VectorType } from '../types.js'

export type SimilarVector = {
id: string
score: number
}

export function getMagnitude(vector: Float32Array, vectorLength: number): number {
let magnitude = 0
for (let i = 0; i < vectorLength; i++) {
magnitude += vector[i] * vector[i]
}
return Math.sqrt(magnitude)
}

// @todo: Write plugins for Node and Browsers to use parallel computation for this function
export function findSimilarVectors(
targetVector: Float32Array,
vectors: Record<string, [Magnitude, VectorType]>,
length: number,
threshold = 0.8
) {
const targetMagnitude = getMagnitude(targetVector, length);

const similarVectors: SimilarVector[] = []

for (const [vectorId, [magnitude, vector]] of Object.entries(vectors)) {
let dotProduct = 0

for (let i = 0; i < length; i++) {
dotProduct += targetVector[i] * vector[i]
}

const similarity = dotProduct / (targetMagnitude * magnitude)

if (similarity >= threshold) {
similarVectors.push({ id: vectorId, score: similarity })
}
}

return similarVectors.sort((a, b) => b.score - a.score)
}

31 changes: 30 additions & 1 deletion packages/orama/src/components/defaults.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ export async function validateSchema<S extends Schema = Schema>(doc: Document, s

const typeOfType = typeof type

if (isVectorType(type as string)) {
// TODO: validate vector size
if (!Array.isArray(value)) {
// TODO: run actual validation
return undefined
}
continue
}

if (typeOfType === 'string' && isArrayType(type as SearchableType)) {
if (!Array.isArray(value)) {
return prop
Expand Down Expand Up @@ -78,14 +87,34 @@ const IS_ARRAY_TYPE: Record<SearchableType, boolean> = {
'number[]': true,
'boolean[]': true,
}

const INNER_TYPE: Record<ArraySearchableType, ScalarSearchableType> = {
'string[]': 'string',
'number[]': 'number',
'boolean[]': 'boolean',
}
export function isArrayType(type: SearchableType) {

export function isVectorType(type: string): boolean {
return /^vector\[\d+\]$/.test(type)
}

export function isArrayType(type: SearchableType): boolean {
return IS_ARRAY_TYPE[type]
}

export function getInnerType(type: ArraySearchableType): ScalarSearchableType {
return INNER_TYPE[type]
}

export function getVectorSize(type: string): number {
const size = Number(type.slice(7, -1))

switch (true) {
case isNaN(size):
throw createError('INVALID_VECTOR_VALUE', type)
case size <= 0:
throw createError('INVALID_VECTOR_SIZE', type)
default:
return size
}
}
124 changes: 85 additions & 39 deletions packages/orama/src/components/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
import type {
ArraySearchableType,
BM25Params,
ComparisonOperator,
IIndex,
Magnitude,
OpaqueDocumentStore,
OpaqueIndex,
Orama,
ScalarSearchableType,
Schema,
SearchableType,
SearchableValue,
SearchContext,
Tokenizer,
TokenScore,
VectorType,
} from '../types.js'
import { createError } from '../errors.js'
import {
create as avlCreate,
Expand All @@ -16,31 +34,16 @@ import {
Node as RadixNode,
removeDocumentByWord as radixRemoveDocument,
} from '../trees/radix.js'
import {
ArraySearchableType,
BM25Params,
ComparisonOperator,
IIndex,
OpaqueDocumentStore,
OpaqueIndex,
Orama,
ScalarSearchableType,
Schema,
SearchableType,
SearchableValue,
SearchContext,
Tokenizer,
TokenScore,
} from '../types.js'
import { intersect } from '../utils.js'
import { BM25 } from './algorithms.js'
import { getInnerType, isArrayType } from './defaults.js'
import { getInnerType, getVectorSize, isArrayType, isVectorType } from './defaults.js'
import {
DocumentID,
getInternalDocumentId,
InternalDocumentID,
InternalDocumentIDStore,
} from './internal-document-id-store.js'
import { getMagnitude } from './cosine-similarity.js'

export type FrequencyMap = {
[property: string]: {
Expand All @@ -57,9 +60,17 @@ export type BooleanIndex = {
false: InternalDocumentID[]
}

export type VectorIndex = {
size: number
vectors: {
[docID: string]: [Magnitude, VectorType]
}
}

export interface Index extends OpaqueIndex {
sharedInternalDocumentStore: InternalDocumentIDStore
indexes: Record<string, RadixNode | AVLNode<number, InternalDocumentID[]> | BooleanIndex>
vectorIndexes: Record<string, VectorIndex>
searchableProperties: string[]
searchablePropertiesWithTypes: Record<string, SearchableType>
frequencies: FrequencyMap
Expand Down Expand Up @@ -181,6 +192,7 @@ export async function create(
index = {
sharedInternalDocumentStore,
indexes: {},
vectorIndexes: {},
searchableProperties: [],
searchablePropertiesWithTypes: {},
frequencies: {},
Expand All @@ -200,29 +212,38 @@ export async function create(
continue
}

switch (type) {
case 'boolean':
case 'boolean[]':
index.indexes[path] = { true: [], false: [] }
break
case 'number':
case 'number[]':
index.indexes[path] = avlCreate<number, InternalDocumentID[]>(0, [])
break
case 'string':
case 'string[]':
index.indexes[path] = radixCreate()
index.avgFieldLength[path] = 0
index.frequencies[path] = {}
index.tokenOccurrences[path] = {}
index.fieldLengths[path] = {}
break
default:
throw createError('INVALID_SCHEMA_TYPE', Array.isArray(type) ? 'array' : (type as unknown as string), path)
}
if (isVectorType(type as string)) {
index.searchableProperties.push(path)
index.searchablePropertiesWithTypes[path] = (type as SearchableType)
index.vectorIndexes[path] = {
size: getVectorSize(type as string),
vectors: {},
}
} else {
switch (type) {
case 'boolean':
case 'boolean[]':
index.indexes[path] = { true: [], false: [] }
break
case 'number':
case 'number[]':
index.indexes[path] = avlCreate<number, InternalDocumentID[]>(0, [])
break
case 'string':
case 'string[]':
index.indexes[path] = radixCreate()
index.avgFieldLength[path] = 0
index.frequencies[path] = {}
index.tokenOccurrences[path] = {}
index.fieldLengths[path] = {}
break
default:
throw createError('INVALID_SCHEMA_TYPE', Array.isArray(type) ? 'array' : (type as unknown as string), path)
}

index.searchableProperties.push(path)
index.searchablePropertiesWithTypes[path] = type
index.searchableProperties.push(path)
index.searchablePropertiesWithTypes[path] = type
}
}

return index
Expand Down Expand Up @@ -276,6 +297,11 @@ export async function insert(
tokenizer: Tokenizer,
docsCount: number,
): Promise<void> {

if (isVectorType(schemaType)) {
return insertVector(index, prop, value as number[] | Float32Array, id)
}

if (!isArrayType(schemaType)) {
return insertScalar(
implementation,
Expand All @@ -299,6 +325,17 @@ export async function insert(
}
}

function insertVector(index: Index, prop: string, value: number[] | VectorType, id: DocumentID): void {
if (!(value instanceof Float32Array)) {
value = new Float32Array(value)
}

const size = index.vectorIndexes[prop].size
const magnitude = getMagnitude(value, size)

index.vectorIndexes[prop].vectors[id] = [magnitude, value]
}

async function removeScalar(
implementation: IIndex<Index>,
index: Index,
Expand Down Expand Up @@ -525,6 +562,7 @@ function loadNode(node: RadixNode): RadixNode {
export async function load<R = unknown>(sharedInternalDocumentStore: InternalDocumentIDStore, raw: R): Promise<Index> {
const {
indexes: rawIndexes,
vectorIndexes: rawVectorIndexes,
searchableProperties,
searchablePropertiesWithTypes,
frequencies,
Expand All @@ -534,6 +572,7 @@ export async function load<R = unknown>(sharedInternalDocumentStore: InternalDoc
} = raw as Index

const indexes: Index['indexes'] = {}
const vectorIndexes: Index['vectorIndexes'] = {}

for (const prop of Object.keys(rawIndexes)) {
const value = rawIndexes[prop]
Expand All @@ -547,9 +586,14 @@ export async function load<R = unknown>(sharedInternalDocumentStore: InternalDoc
indexes[prop] = loadNode(value)
}

for (const prop of Object.keys(rawVectorIndexes)) {
// TODO: load vector indexes, convert arrays into Float32Arrays
}

return {
sharedInternalDocumentStore,
indexes,
vectorIndexes,
searchableProperties,
searchablePropertiesWithTypes,
frequencies,
Expand All @@ -562,6 +606,7 @@ export async function load<R = unknown>(sharedInternalDocumentStore: InternalDoc
export async function save<R = unknown>(index: Index): Promise<R> {
const {
indexes,
vectorIndexes,
searchableProperties,
searchablePropertiesWithTypes,
frequencies,
Expand All @@ -572,6 +617,7 @@ export async function save<R = unknown>(index: Index): Promise<R> {

return {
indexes,
vectorIndexes,
searchableProperties,
searchablePropertiesWithTypes,
frequencies,
Expand Down
Loading

0 comments on commit b33aaac

Please sign in to comment.