Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update pca algorithm #870

Merged
merged 1 commit into from
Dec 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions frontend/packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"mime-types": "2.1.27",
"moment": "2.29.1",
"nprogress": "0.2.0",
"numeric": "1.2.6",
"polished": "4.0.5",
"query-string": "6.13.7",
"react": "17.0.1",
Expand Down Expand Up @@ -94,6 +95,7 @@
"@types/lodash": "4.14.165",
"@types/mime-types": "2.1.0",
"@types/nprogress": "0.2.0",
"@types/numeric": "1.2.1",
"@types/react": "17.0.0",
"@types/react-dom": "17.0.0",
"@types/react-helmet": "6.1.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

import type {PCAResult, Reduction, TSNEResult, UMAPResult} from '~/resource/high-dimensional';
import type {CalculateParams, CalculateResult, Reduction} from '~/resource/high-dimensional';
import React, {useCallback, useEffect, useImperativeHandle, useLayoutEffect, useMemo, useRef, useState} from 'react';
import ScatterChart, {ScatterChartRef} from '~/components/ScatterChart';

Expand All @@ -24,76 +24,8 @@ import type {WithStyled} from '~/utils/style';
import {rem} from '~/utils/style';
import styled from 'styled-components';
import {useTranslation} from 'react-i18next';
import useWebAssembly from '~/hooks/useWebAssembly';
import useWorker from '~/hooks/useWorker';

function useComputeHighDimensional(
reduction: Reduction,
vectors: Float32Array,
dim: number,
is3D: boolean,
perplexity: number,
learningRate: number,
neighbors: number
) {
const pcaParams = useMemo(() => {
if (reduction === 'pca') {
return [Array.from(vectors), dim, 3] as const;
}
return [[], 0, 3];
}, [reduction, vectors, dim]);

const tsneInitParams = useRef({perplexity, epsilon: learningRate});
const tsneParams = useMemo(() => {
if (reduction === 'tsne') {
return {
input: vectors,
dim,
n: is3D ? 3 : 2,
...tsneInitParams.current
};
}
return {
input: new Float32Array(),
dim: 0,
n: is3D ? 3 : 2,
perplexity: 5
};
}, [reduction, vectors, dim, is3D]);

const umapParams = useMemo(() => {
if (reduction === 'umap') {
return {
input: vectors,
dim,
n: is3D ? 3 : 2,
neighbors
};
}
return {
input: new Float32Array(),
dim: 0,
n: is3D ? 3 : 2,
neighbors: 15
};
}, [reduction, vectors, dim, is3D, neighbors]);

const pcaResult = useWebAssembly<PCAResult>('high_dimensional_pca', pcaParams);
const tsneResult = useWorker<TSNEResult>('high-dimensional/tsne', tsneParams);
const umapResult = useWorker<UMAPResult>('high-dimensional/umap', umapParams);

if (reduction === 'pca') {
return pcaResult;
}
if (reduction === 'tsne') {
return tsneResult;
}
if (reduction === 'umap') {
return umapResult;
}
return null as never;
}

const Wrapper = styled.div`
height: 100%;
display: flex;
Expand Down Expand Up @@ -143,7 +75,7 @@ type HighDimensionalChartProps = {
neighbors: number;
highlightIndices?: number[];
onCalculate?: () => unknown;
onCalculated?: (data: PCAResult | TSNEResult | UMAPResult) => unknown;
onCalculated?: (data: CalculateResult) => unknown;
onError?: (e: Error) => unknown;
};

Expand Down Expand Up @@ -196,15 +128,42 @@ const HighDimensionalChart = React.forwardRef<HighDimensionalChartRef, HighDimen
}
}, []);

const {data, error, worker} = useComputeHighDimensional(
reduction,
vectors,
dim,
is3D,
perplexity,
learningRate,
neighbors
);
const params = useMemo<CalculateParams>(() => {
const result = {
input: vectors,
dim,
n: is3D ? 3 : 2
};
switch (reduction) {
case 'pca':
return {
reduction,
params: {
...result
}
};
case 'tsne':
return {
reduction,
params: {
perplexity,
epsilon: learningRate,
...result
}
};
case 'umap':
return {
reduction,
params: {
neighbors,
...result
}
};
default:
return null as never;
}
}, [dim, is3D, learningRate, neighbors, perplexity, reduction, vectors]);
const {data, error, worker} = useWorker<CalculateResult>('high-dimensional/calculate', params);

const iterationId = useRef<number | null>(null);
const iteration = useCallback(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ const Wrapper = styled(Field)`
export type PCADetailProps = {
dimension: Dimension;
variance: number[];
totalVariance: number;
};

const PCADetail: FunctionComponent<PCADetailProps> = ({dimension, variance}) => {
const PCADetail: FunctionComponent<PCADetailProps> = ({dimension, variance, totalVariance}) => {
const {t} = useTranslation(['high-dimensional', 'common']);

const dim = useMemo(() => (dimension === '3d' ? 3 : 2), [dimension]);
const totalVariance = useMemo(() => variance.reduce((s, c) => s + c, 0), [variance]);

return (
<Wrapper>
Expand Down
44 changes: 34 additions & 10 deletions frontend/packages/core/src/pages/high-dimensional.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ import useWorker from '~/hooks/useWorker';

const MODE = import.meta.env.MODE;

const MAX_COUNT: Record<Reduction, number | undefined> = {
pca: 50000,
tsne: 10000,
umap: 5000
} as const;

const MAX_DIMENSION: Record<Reduction, number | undefined> = {
pca: 200,
tsne: undefined,
umap: undefined
};

const AsideTitle = styled.div`
font-size: ${rem(16)};
line-height: ${rem(16)};
Expand Down Expand Up @@ -182,6 +194,12 @@ const HighDimensional: FunctionComponent = () => {
);
const labelByLabels = useMemo(() => getLabelByLabels(labelBy), [getLabelByLabels, labelBy]);

// dimension of display
const [dimension, setDimension] = useState<Dimension>('3d');
const [reduction, setReduction] = useState<Reduction>('pca');

const is3D = useMemo(() => dimension === '3d', [dimension]);

const readFile = useCallback(
(phase: string, file: File | null, setter: React.Dispatch<React.SetStateAction<string>>) => {
if (file) {
Expand Down Expand Up @@ -221,12 +239,17 @@ const HighDimensional: FunctionComponent = () => {
}, []);

const params = useMemo<ParseParams>(() => {
const maxValues = {
maxCount: MAX_COUNT[reduction],
maxDimension: MAX_DIMENSION[reduction]
};
if (vectorContent) {
return {
from: 'string',
params: {
vectors: vectorContent,
metadata: metadataContent
metadata: metadataContent,
...maxValues
}
};
}
Expand All @@ -236,12 +259,13 @@ const HighDimensional: FunctionComponent = () => {
params: {
shape: selectedEmbedding.shape,
vectors: tensorData.data,
metadata: metadataData ?? ''
metadata: metadataData ?? '',
...maxValues
}
};
}
return null;
}, [vectorContent, metadataContent, selectedEmbedding, tensorData, metadataData]);
}, [reduction, vectorContent, selectedEmbedding, tensorData, metadataContent, metadataData]);
const result = useWorker<ParseResult, ParseParams>('high-dimensional/parse-data', params);
useEffect(() => {
const {error, data} = result;
Expand All @@ -264,12 +288,6 @@ const HighDimensional: FunctionComponent = () => {
selectedEmbedding
]);

// dimension of display
const [dimension, setDimension] = useState<Dimension>('3d');
const [reduction, setReduction] = useState<Reduction>('pca');

const is3D = useMemo(() => dimension === '3d', [dimension]);

const [perplexity, setPerplexity] = useState(5);
const [learningRate, setLearningRate] = useState(10);

Expand Down Expand Up @@ -324,7 +342,13 @@ const HighDimensional: FunctionComponent = () => {
const detail = useMemo(() => {
switch (reduction) {
case 'pca':
return <PCADetail dimension={dimension} variance={(data as PCAResult)?.variance ?? []} />;
return (
<PCADetail
dimension={dimension}
variance={(data as PCAResult)?.variance ?? []}
totalVariance={(data as PCAResult)?.totalVariance ?? 0}
/>
);
case 'tsne':
return (
<TSNEDetail
Expand Down
13 changes: 8 additions & 5 deletions frontend/packages/core/src/resource/high-dimensional/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
*/

export type {
CalculateParams,
CalculateResult,
Dimension,
Reduction,
PcaParams,
PCAResult,
Vectors,
ParseParams,
ParseResult,
PCAParams,
PCAResult,
Reduction,
TSNEParams,
TSNEResult,
UMAPParams,
UMAPResult
UMAPResult,
Vectors
} from './types';

export {parseFromBlob, parseFromString, ParserError} from './parser';

export {default as PCA} from './pca';
export {default as tSNE} from './tsne';
export {default as UMAP} from './umap';
52 changes: 39 additions & 13 deletions frontend/packages/core/src/resource/high-dimensional/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ export class ParserError extends Error {
}
}

function split<T = string>(str: string, processer?: (item: string) => T): T[][] {
function split<T = string>(str: string, handler?: (item: string) => T): T[][] {
return safeSplit(str, '\n')
.map(r => safeSplit(r, '\t').map(n => (processer ? processer(n) : n) as T))
.map(r => safeSplit(r, '\t').map(n => (handler ? handler(n) : n) as T))
.filter(r => r.length);
}

Expand All @@ -64,13 +64,19 @@ function alignItems<T>(data: T[][], dimension: number, defaultValue: T): T[][] {
});
}

function parseVectors(str: string): VectorResult {
function parseVectors(str: string, maxCount?: number, maxDimension?: number): VectorResult {
if (!str) {
throw new ParserError('Tenser file is empty', ParserError.CODES.TENSER_EMPTY);
}
let vectors = split(str, Number.parseFloat);
// TODO: sampling
const dimension = Math.min(...vectors.map(vector => vector.length));
// TODO: random sampling
if (maxCount) {
vectors = vectors.slice(0, maxCount);
}
let dimension = Math.min(...vectors.map(vector => vector.length));
if (maxDimension) {
dimension = Math.min(dimension, maxDimension);
}
vectors = alignItems(vectors, dimension, 0);
return {
dimension,
Expand Down Expand Up @@ -124,31 +130,51 @@ function genMetadataAndLabels(metadata: string, count: number) {
};
}

export function parseFromString({vectors: v, metadata: m}: ParseFromStringParams): ParseResult {
export function parseFromString({vectors: v, metadata: m, maxCount, maxDimension}: ParseFromStringParams): ParseResult {
const result: ParseResult = {
count: 0,
dimension: 0,
vectors: new Float32Array(),
labels: [],
metadata: []
};
if (v) {
const {dimension, vectors, count} = parseVectors(v);
const {dimension, vectors, count} = parseVectors(v, maxCount, maxDimension);
result.dimension = dimension;
result.vectors = vectors;
Object.assign(result, genMetadataAndLabels(m, count));
const metadataAndLabels = genMetadataAndLabels(m, count);
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count);
Object.assign(result, metadataAndLabels);
}
return result;
}

export async function parseFromBlob({shape, vectors: v, metadata: m}: ParseFromBlobParams): Promise<ParseResult> {
const [count, dimension] = shape;
const vectors = new Float32Array(await v.arrayBuffer());
if (count * dimension !== vectors.length) {
export async function parseFromBlob({
shape,
vectors: v,
metadata: m,
maxCount,
maxDimension
}: ParseFromBlobParams): Promise<ParseResult> {
// TODO: random sampling
const [originalCount, originalDimension] = shape;
const originalVectors = new Float32Array(await v.arrayBuffer());
if (originalCount * originalDimension !== originalVectors.length) {
throw new ParserError('Size of tensor does not match.', ParserError.CODES.SHAPE_MISMATCH);
}
const count = maxCount ? Math.min(originalCount, maxCount) : originalCount;
const dimension = maxDimension ? Math.min(originalDimension, maxDimension) : originalDimension;
const vectors = new Float32Array(count * dimension);
for (let c = 0; c < count; c++) {
const offset = c * originalDimension;
vectors.set(originalVectors.subarray(offset, offset + dimension), c * dimension);
}
const metadataAndLabels = genMetadataAndLabels(m, originalCount);
metadataAndLabels.metadata = metadataAndLabels.metadata.slice(0, count);
return {
count,
dimension,
vectors,
...genMetadataAndLabels(m, count)
...metadataAndLabels
};
}
Loading