From 981f99c08e13cc3f5524b29bdfd71fa049aac8c3 Mon Sep 17 00:00:00 2001 From: Grant Steffen <40581813+ghsteff@users.noreply.github.com> Date: Thu, 11 Jul 2024 07:11:34 -0400 Subject: [PATCH] Add fit to node canvas util (#258) * Add fit to node canvas util * Update docs * Pull fit to node logic into functions * Add support for fitting the view to multiple nodes * Add string type input to fitNodes * Add util comment --- docs/Advanced/Refs.mdx | 21 ++++- src/layout/useLayout.ts | 178 +++++++++++++++++------------------ src/layout/utils.test.ts | 82 +++++++++++++++- src/layout/utils.ts | 127 ++++++++++++++++++++++++- stories/Controls.stories.tsx | 8 +- 5 files changed, 319 insertions(+), 97 deletions(-) diff --git a/docs/Advanced/Refs.mdx b/docs/Advanced/Refs.mdx index d532101f..bab81881 100644 --- a/docs/Advanced/Refs.mdx +++ b/docs/Advanced/Refs.mdx @@ -69,14 +69,29 @@ export interface CanvasRef { containerHeight?: number; /** - * Center the canvas to the viewport. + * Positions the canvas to the viewport. */ - centerCanvas?: () => void; + positionCanvas?: (position: CanvasPosition, animated?: boolean) => void; /** * Fit the canvas to the viewport. */ - fitCanvas?: () => void; + fitCanvas?: (animated?: boolean) => void; + + /** + * Fit a group of nodes to the viewport. + */ + fitNodes?: (nodeIds: string | string[], animated?: boolean) => void; + + /** + * Scroll to X/Y + */ + setScrollXY?: (xy: [number, number], animated?: boolean) => void; + + /** + * Factor of zoom. + */ + zoom: number; /** * Set a zoom factor of the canvas. diff --git a/src/layout/useLayout.ts b/src/layout/useLayout.ts index a84c0a25..a6be95ac 100644 --- a/src/layout/useLayout.ts +++ b/src/layout/useLayout.ts @@ -1,19 +1,9 @@ -import { - RefObject, - useCallback, - useEffect, - useLayoutEffect, - useRef, - useState -} from 'react'; -import { - elkLayout, - CanvasDirection, - ElkCanvasLayoutOptions -} from './elkLayout'; +import { RefObject, useCallback, useEffect, useLayoutEffect, useRef, useState } from 'react'; import useDimensions from 'react-cool-dimensions'; import isEqual from 'react-fast-compare'; import { CanvasPosition, EdgeData, NodeData } from '../types'; +import { CanvasDirection, ElkCanvasLayoutOptions, elkLayout } from './elkLayout'; +import { calculateScrollPosition, calculateZoom, findNode } from './utils'; export interface ElkRoot { x?: number; @@ -84,35 +74,27 @@ export interface LayoutResult { /** * Positions the canvas to the viewport. */ - positionCanvas?: (position: CanvasPosition) => void; + positionCanvas?: (position: CanvasPosition, animated?: boolean) => void; /** * Fit the canvas to the viewport. */ - fitCanvas?: () => void; + fitCanvas?: (animated?: boolean) => void; + + /** + * Fit a group of nodes to the viewport. + */ + fitNodes?: (nodeIds: string | string[], animated?: boolean) => void; /** * Scroll to X/Y */ - setScrollXY?: (xy: [number, number]) => void; + setScrollXY?: (xy: [number, number], animated?: boolean) => void; observe: (el: HTMLDivElement) => void; } -export const useLayout = ({ - maxWidth, - maxHeight, - nodes = [], - edges = [], - fit, - pannable, - defaultPosition, - direction, - layoutOptions = {}, - zoom, - setZoom, - onLayoutChange -}: LayoutProps) => { +export const useLayout = ({ maxWidth, maxHeight, nodes = [], edges = [], fit, pannable, defaultPosition, direction, layoutOptions = {}, zoom, setZoom, onLayoutChange }: LayoutProps) => { const scrolled = useRef(false); const ref = useRef(); const { observe, width, height } = useDimensions(); @@ -122,6 +104,11 @@ export const useLayout = ({ const canvasHeight = pannable ? maxHeight : height; const canvasWidth = pannable ? maxWidth : width; + const scrollToXY = (xy: [number, number], animated = false) => { + ref.current.scrollTo({ left: xy[0], top: xy[1], behavior: animated ? 'smooth' : 'auto' }); + setScrollXY(xy); + }; + useEffect(() => { const promise = elkLayout(nodes, edges, { 'elk.direction': direction, @@ -151,21 +138,21 @@ export const useLayout = ({ const centerX = (canvasWidth - layout.width * zoom) / 2; const centerY = (canvasHeight - layout.height * zoom) / 2; switch (position) { - case CanvasPosition.CENTER: - setXY([centerX, centerY]); - break; - case CanvasPosition.TOP: - setXY([centerX, 0]); - break; - case CanvasPosition.LEFT: - setXY([0, centerY]); - break; - case CanvasPosition.RIGHT: - setXY([canvasWidth - layout.width * zoom, centerY]); - break; - case CanvasPosition.BOTTOM: - setXY([centerX, canvasHeight - layout.height * zoom]); - break; + case CanvasPosition.CENTER: + setXY([centerX, centerY]); + break; + case CanvasPosition.TOP: + setXY([centerX, 0]); + break; + case CanvasPosition.LEFT: + setXY([0, centerY]); + break; + case CanvasPosition.RIGHT: + setXY([canvasWidth - layout.width * zoom, centerY]); + break; + case CanvasPosition.BOTTOM: + setXY([centerX, canvasHeight - layout.height * zoom]); + break; } } }, @@ -173,26 +160,26 @@ export const useLayout = ({ ); const positionScroll = useCallback( - (position: CanvasPosition) => { + (position: CanvasPosition, animated = false) => { const scrollCenterX = (canvasWidth - width) / 2; const scrollCenterY = (canvasHeight - height) / 2; if (pannable) { switch (position) { - case CanvasPosition.CENTER: - setScrollXY([scrollCenterX, scrollCenterY]); - break; - case CanvasPosition.TOP: - setScrollXY([scrollCenterX, 0]); - break; - case CanvasPosition.LEFT: - setScrollXY([0, scrollCenterY]); - break; - case CanvasPosition.RIGHT: - setScrollXY([canvasWidth - width, scrollCenterY]); - break; - case CanvasPosition.BOTTOM: - setScrollXY([scrollCenterX, canvasHeight - height]); - break; + case CanvasPosition.CENTER: + scrollToXY([scrollCenterX, scrollCenterY], animated); + break; + case CanvasPosition.TOP: + scrollToXY([scrollCenterX, 0], animated); + break; + case CanvasPosition.LEFT: + scrollToXY([0, scrollCenterY], animated); + break; + case CanvasPosition.RIGHT: + scrollToXY([canvasWidth - width, scrollCenterY], animated); + break; + case CanvasPosition.BOTTOM: + scrollToXY([scrollCenterX, canvasHeight - height], animated); + break; } } }, @@ -200,32 +187,54 @@ export const useLayout = ({ ); const positionCanvas = useCallback( - (position: CanvasPosition) => { + (position: CanvasPosition, animated = false) => { positionVector(position); - positionScroll(position); + positionScroll(position, animated); }, [positionScroll, positionVector] ); - useEffect(() => { - ref?.current?.scrollTo(scrollXY[0], scrollXY[1]); - }, [scrollXY, ref]); - useEffect(() => { if (scrolled.current && defaultPosition) { positionVector(defaultPosition); } }, [positionVector, zoom, defaultPosition]); - const fitCanvas = useCallback(() => { - if (layout) { - const heightZoom = height / layout.height; - const widthZoom = width / layout.width; - const scale = Math.min(heightZoom, widthZoom, 1); - setZoom(scale - 1); - positionCanvas(CanvasPosition.CENTER); - } - }, [height, layout, width, setZoom, positionCanvas]); + const fitCanvas = useCallback( + (animated = false) => { + if (layout) { + const heightZoom = height / layout.height; + const widthZoom = width / layout.width; + const scale = Math.min(heightZoom, widthZoom, 1); + setZoom(scale - 1); + positionCanvas(CanvasPosition.CENTER, animated); + } + }, + [height, layout, width, setZoom, positionCanvas] + ); + + /** + * This centers the chart on the canvas, zooms in to fit the specified nodes, and scrolls to center the nodes in the viewport + */ + const fitNodes = useCallback( + (nodeIds: string | string[], animated = true) => { + if (layout && layout.children) { + const nodes = Array.isArray(nodeIds) ? nodeIds.map((nodeId) => findNode(layout.children, nodeId)) : [findNode(layout.children, nodeIds)]; + + if (nodes) { + // center the chart + positionVector(CanvasPosition.CENTER); + + const updatedZoom = calculateZoom({ nodes, viewportWidth: width, viewportHeight: height, maxViewportCoverage: 0.9, minViewportCoverage: 0.2 }); + const scrollPosition = calculateScrollPosition({ nodes, viewportWidth: width, viewportHeight: height, canvasWidth, canvasHeight, chartWidth: layout.width, chartHeight: layout.height, zoom: updatedZoom }); + + setZoom(updatedZoom - 1); + scrollToXY(scrollPosition, animated); + } + } + }, + [canvasHeight, canvasWidth, height, layout, positionVector, setZoom, width] + ); useLayoutEffect(() => { const scroller = ref.current; @@ -238,19 +247,7 @@ export const useLayout = ({ scrolled.current = true; } - }, [ - canvasWidth, - pannable, - canvasHeight, - layout, - height, - fit, - width, - defaultPosition, - positionCanvas, - fitCanvas, - ref - ]); + }, [canvasWidth, pannable, canvasHeight, layout, height, fit, width, defaultPosition, positionCanvas, fitCanvas, ref]); useLayoutEffect(() => { function onResize() { @@ -278,6 +275,7 @@ export const useLayout = ({ scrollXY, positionCanvas, fitCanvas, - setScrollXY + fitNodes, + setScrollXY: scrollToXY } as LayoutResult; }; diff --git a/src/layout/utils.test.ts b/src/layout/utils.test.ts index 269b4eff..1e3246f5 100644 --- a/src/layout/utils.test.ts +++ b/src/layout/utils.test.ts @@ -1,4 +1,4 @@ -import { parsePadding } from './utils'; +import { parsePadding, findNode, getChildCount, calculateZoom, calculateScrollPosition } from './utils'; test('should set all sides to input number, when a number is provided', () => { const expectedPadding = { @@ -29,3 +29,83 @@ test('should set each padding value individually, when an array with four number }; expect(parsePadding([20, 50, 100, 150])).toEqual(expectedPadding); }); + +test('should find a node by id', () => { + const layout = [ + { + x: 0, + y: 0, + id: '1', + children: [{ x: 0, y: 0, id: '1', children: [] }] + }, + { + x: 0, + y: 0, + id: '3', + children: [{ x: 0, y: 0, id: '4', children: [] }] + } + ]; + const node = findNode(layout, '4'); + + expect(node).toEqual({ x: 0, y: 0, id: '4', children: [] }); +}); + +test('should get the number of children a node has', () => { + const node = { + x: 0, + y: 0, + id: '1', + children: [ + { x: 0, y: 0, id: '1', children: [] }, + { x: 0, y: 0, id: '2', children: [{ x: 0, y: 0, id: '3', children: [] }] } + ] + }; + const count = getChildCount(node); + + expect(count).toEqual(3); +}); + +describe('calculateZoom', () => { + test('should calculate the zoom for a node', () => { + const node = { width: 100, height: 100, x: 0, y: 0, id: '1' }; + const zoom = calculateZoom({ nodes: [node], viewportWidth: 1000, viewportHeight: 1000, minViewportCoverage: 0.2, maxViewportCoverage: 0.9 }); + + expect(zoom).toEqual(2); + }); + + test('should calculate the zoom for a node with many children', () => { + const node = { width: 100, height: 100, x: 0, y: 0, id: '0', children: [{ x: 0, y: 0, id: '1', children: [{ x: 0, y: 0, id: '2', children: [{ x: 0, y: 0, id: '3', children: [] }] }] }] }; + const zoom = calculateZoom({ nodes: [node], viewportWidth: 1000, viewportHeight: 1000, minViewportCoverage: 0.2, maxViewportCoverage: 0.9 }); + + expect(zoom).toEqual(5); + }); + + test('should calculate the zoom for a group of nodes', () => { + const nodes = [ + { width: 100, height: 100, x: 0, y: 0, id: '0' }, + { width: 100, height: 100, x: 50, y: 50, id: '1' } + ]; + const zoom = calculateZoom({ nodes, viewportWidth: 1000, viewportHeight: 1000, minViewportCoverage: 0.2, maxViewportCoverage: 0.9 }); + + expect(zoom).toEqual(2); + }); +}); + +describe('calculateScrollPosition', () => { + test('should calculate the scroll position for a node', () => { + const node = { width: 100, height: 100, x: 0, y: 0, id: '1' }; + const scrollPosition = calculateScrollPosition({ nodes: [node], viewportWidth: 1000, viewportHeight: 1000, canvasWidth: 2000, canvasHeight: 2000, chartWidth: 500, chartHeight: 500, zoom: 1 }); + + expect(scrollPosition).toEqual([300, 300]); + }); + + test('should calculate the scroll position for a group of nodes', () => { + const nodes = [ + { width: 100, height: 100, x: 0, y: 0, id: '0' }, + { width: 100, height: 100, x: 50, y: 50, id: '1' } + ]; + const scrollPosition = calculateScrollPosition({ nodes, viewportWidth: 1000, viewportHeight: 1000, canvasWidth: 2000, canvasHeight: 2000, chartWidth: 500, chartHeight: 500, zoom: 1 }); + + expect(scrollPosition).toEqual([325, 325]); + }); +}); diff --git a/src/layout/utils.ts b/src/layout/utils.ts index 2231dc73..acafc835 100644 --- a/src/layout/utils.ts +++ b/src/layout/utils.ts @@ -1,5 +1,5 @@ import calculateSize from 'calculate-size'; -import { NodeData } from '../types'; +import { LayoutNodeData, NodeData } from '../types'; import ellipsize from 'ellipsize'; const MAX_CHAR_COUNT = 35; @@ -101,3 +101,128 @@ export function formatText(node: NodeData) { labelWidth: labelDim.width }; } + +/** + * Finds a node in a tree of nodes + * @param nodes - The nodes to search through + * @param nodeId - The id of the node to find + * @returns The node if found, undefined otherwise + */ +export const findNode = (nodes: LayoutNodeData[], nodeId: string): any | undefined => { + for (const node of nodes) { + if (node.id === nodeId) { + return node; + } + if (node.children) { + const foundNode = findNode(node.children, nodeId); + if (foundNode) { + return foundNode; + } + } + } + return undefined; +}; + +/** + * Finds the number of nested children a node has + * @param node - The node to search through + * @returns The number of children + */ +export const getChildCount = (node: LayoutNodeData): number => { + return ( + node.children?.reduce((acc, child) => { + if (child.children) { + return acc + 1 + getChildCount(child); + } + return acc + 1; + }, 0) ?? 0 + ); +}; + +/** + * Calculates the zoom for a group of nodes when fitting to the viewport + * @param nodes - The nodes to calculate the zoom for + * @param viewportWidth - The width of the viewport + * @param viewportHeight - The height of the viewport + * @param maxViewportCoverage - The maximum percentage of the viewport that the node group will take up + * @param minViewportCoverage - The minimum percentage of the viewport that the node group will take up + * @returns The zoom + */ +export const calculateZoom = ({ nodes, viewportWidth, viewportHeight, maxViewportCoverage = 0.9, minViewportCoverage = 0.2 }: { nodes: LayoutNodeData[]; viewportWidth: number; viewportHeight: number; maxViewportCoverage?: number; minViewportCoverage?: number }) => { + const maxChildren = Math.max( + 0, + nodes.map(getChildCount).reduce((acc, curr) => acc + curr, 0) + ); + const boundingBox = getNodesBoundingBox(nodes); + const boundingBoxWidth = boundingBox.x1 - boundingBox.x0; + const boundingBoxHeight = boundingBox.y1 - boundingBox.y0; + + // calculate the maximum zoom to ensure no single node takes up more than 20% of the viewport + const maxNodeWidth = Math.max(...nodes.map((node) => node.width)); + const maxNodeHeight = Math.max(...nodes.map((node) => node.height)); + // if a node has children, let it take up an extra 10% per child + const maxNodeZoomX = ((0.2 + maxChildren * 0.1) * viewportWidth) / maxNodeWidth; + const maxNodeZoomY = ((0.2 + maxChildren * 0.1) * viewportHeight) / maxNodeHeight; + const maxNodeZoom = Math.min(maxNodeZoomX, maxNodeZoomY); + + const viewportCoverage = Math.max(Math.min(maxViewportCoverage, maxNodeZoom), minViewportCoverage); + + const updatedHorizontalZoom = (viewportCoverage * viewportWidth) / boundingBoxWidth; + const updatedVerticalZoom = (viewportCoverage * viewportHeight) / boundingBoxHeight; + const updatedZoom = Math.min(updatedHorizontalZoom, updatedVerticalZoom, maxNodeZoom); + + return updatedZoom; +}; + +/** + * Calculates the scroll position for the canvas when fitting nodes to the viewport - assumes the chart is centered + * @param nodes - The nodes to calculate the zoom and position for + * @param viewportWidth - The width of the viewport + * @param viewportHeight - The height of the viewport + * @param canvasWidth - The width of the canvas + * @param canvasHeight - The height of the canvas + * @param chartWidth - The width of the chart + * @param chartHeight - The height of the chart + * @param zoom - The zoom level of the canvas + * @returns The scroll position + */ +export const calculateScrollPosition = ({ nodes, viewportWidth, viewportHeight, canvasWidth, canvasHeight, chartWidth, chartHeight, zoom }: { nodes: LayoutNodeData[]; viewportWidth: number; viewportHeight: number; canvasWidth: number; canvasHeight: number; chartWidth: number; chartHeight: number; zoom: number }): [number, number] => { + const { x0, y0, x1, y1 } = getNodesBoundingBox(nodes); + const boundingBoxWidth = (x1 - x0) * zoom; + const boundingBoxHeight = (y1 - y0) * zoom; + + // the chart is centered so we can assume the x and y positions + const chartPosition = { + x: (canvasWidth - chartWidth * zoom) / 2, + y: (canvasHeight - chartHeight * zoom) / 2 + }; + + const boxXPosition = chartPosition.x + x0 * zoom; + const boxYPosition = chartPosition.y + y0 * zoom; + + const boxCenterXPosition = boxXPosition + boundingBoxWidth / 2; + const boxCenterYPosition = boxYPosition + boundingBoxHeight / 2; + + // scroll to the spot that centers the node in the viewport + const scrollX = boxCenterXPosition - viewportWidth / 2; + const scrollY = boxCenterYPosition - viewportHeight / 2; + + return [scrollX, scrollY]; +}; + +/** + * Calculates the bounding box of a group of nodes + * @param nodes - The nodes to calculate the bounding box for + * @returns The bounding box + */ +export const getNodesBoundingBox = (nodes: LayoutNodeData[]) => { + return nodes.reduce( + (acc, node) => ({ + x0: Math.min(acc.x0, node.x), + y0: Math.min(acc.y0, node.y), + x1: Math.max(acc.x1, node.x + node.width), + y1: Math.max(acc.y1, node.y + node.height) + }), + { x0: nodes[0].x, y0: nodes[0].y, x1: nodes[0].x + nodes[0].width, y1: nodes[0].y + nodes[0].height } + ); +}; diff --git a/stories/Controls.stories.tsx b/stories/Controls.stories.tsx index 975e8a82..f8fd3a82 100644 --- a/stories/Controls.stories.tsx +++ b/stories/Controls.stories.tsx @@ -222,11 +222,15 @@ export const Zoom = () => { Zoom: {zoom}
- + + +