Skip to content

Commit

Permalink
[core] Use useRtl instead of useTheme to access direction (#14359)
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasTy authored Aug 29, 2024
1 parent a8794b6 commit bba14d0
Show file tree
Hide file tree
Showing 19 changed files with 101 additions and 122 deletions.
6 changes: 3 additions & 3 deletions packages/x-charts/src/ChartsLegend/ContinuousColorLegend.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import { ScaleSequential } from '@mui/x-charts-vendor/d3-scale';
import { useTheme } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import ChartsContinuousGradient from '../internals/components/ChartsAxesGradients/ChartsContinuousGradient';
import { AxisDefaultized, ContinuousScaleName } from '../models/axis';
import { useChartId, useDrawingArea } from '../hooks';
Expand Down Expand Up @@ -205,6 +206,7 @@ const defaultLabelFormatter: LabelFormatter = ({ formattedValue }) => formattedV

function ContinuousColorLegend(props: ContinuousColorLegendProps) {
const theme = useTheme();
const isRtl = useRtl();
const {
id: idProp,
minLabel = defaultLabelFormatter,
Expand All @@ -224,8 +226,6 @@ function ContinuousColorLegend(props: ContinuousColorLegendProps) {
const chartId = useChartId();
const id = idProp ?? `gradient-legend-${chartId}`;

const isRTL = theme.direction === 'rtl';

const axisItem = useAxis({ axisDirection, axisId });
const { width, height, left, right, top, bottom } = useDrawingArea();

Expand Down Expand Up @@ -277,7 +277,7 @@ function ContinuousColorLegend(props: ContinuousColorLegendProps) {
// Place bar and texts

const barBox =
direction === 'column' || (isRTL && direction === 'row')
direction === 'column' || (isRtl && direction === 'row')
? { width: thickness, height: size }
: { width: size, height: thickness };

Expand Down
9 changes: 5 additions & 4 deletions packages/x-charts/src/ChartsLegend/LegendPerItem.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as React from 'react';
import NoSsr from '@mui/material/NoSsr';
import { useTheme, styled } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import { DrawingArea } from '../context/DrawingProvider';
import { DefaultizedProps } from '../models/helpers';
import { ChartsText, ChartsTextStyle } from '../ChartsText';
Expand Down Expand Up @@ -111,7 +112,7 @@ export function LegendPerItem(props: LegendPerItemProps) {
labelStyle: inLabelStyle,
} = props;
const theme = useTheme();
const isRTL = theme.direction === 'rtl';
const isRtl = useRtl();
const drawingArea = useDrawingArea();

const labelStyle = React.useMemo(
Expand Down Expand Up @@ -200,11 +201,11 @@ export function LegendPerItem(props: LegendPerItemProps) {
<g
key={id}
className={classes?.series}
transform={`translate(${gapX + (isRTL ? legendWidth - positionX : positionX)} ${gapY + positionY})`}
transform={`translate(${gapX + (isRtl ? legendWidth - positionX : positionX)} ${gapY + positionY})`}
>
<rect
className={classes?.mark}
x={isRTL ? -itemMarkWidth : 0}
x={isRtl ? -itemMarkWidth : 0}
y={-itemMarkHeight / 2}
width={itemMarkWidth}
height={itemMarkHeight}
Expand All @@ -213,7 +214,7 @@ export function LegendPerItem(props: LegendPerItemProps) {
<ChartsText
style={labelStyle}
text={label}
x={(isRTL ? -1 : 1) * (itemMarkWidth + markGap)}
x={(isRtl ? -1 : 1) * (itemMarkWidth + markGap)}
y={0}
/>
</g>
Expand Down
5 changes: 3 additions & 2 deletions packages/x-charts/src/ChartsYAxis/ChartsYAxis.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import PropTypes from 'prop-types';
import useSlotProps from '@mui/utils/useSlotProps';
import composeClasses from '@mui/utils/composeClasses';
import { useThemeProps, useTheme, Theme } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import { useCartesianContext } from '../context/CartesianProvider';
import { useTicks } from '../hooks/useTicks';
import { useDrawingArea } from '../hooks/useDrawingArea';
Expand Down Expand Up @@ -77,7 +78,7 @@ function ChartsYAxis(inProps: ChartsYAxisProps) {
} = defaultizedProps;

const theme = useTheme();
const isRTL = theme.direction === 'rtl';
const isRtl = useRtl();

const classes = useUtilityClasses({ ...defaultizedProps, theme });

Expand Down Expand Up @@ -106,7 +107,7 @@ function ChartsYAxis(inProps: ChartsYAxisProps) {
const TickLabel = slots?.axisTickLabel ?? ChartsText;
const Label = slots?.axisLabel ?? ChartsText;

const revertAnchor = (!isRTL && position === 'right') || (isRTL && position !== 'right');
const revertAnchor = (!isRtl && position === 'right') || (isRtl && position !== 'right');
const axisTickLabelProps = useSlotProps({
elementType: TickLabel,
externalSlotProps: slotProps?.axisTickLabel,
Expand Down
8 changes: 4 additions & 4 deletions packages/x-charts/src/PieChart/PieChart.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as React from 'react';
import PropTypes from 'prop-types';
import { useRtl } from '@mui/system/RtlProvider';
import { useThemeProps } from '@mui/material/styles';
import {
ResponsiveChartContainer,
Expand Down Expand Up @@ -30,7 +31,6 @@ import {
ChartsXAxisProps,
ChartsYAxisProps,
} from '../models/axis';
import { useIsRTL } from '../internals/useIsRTL';
import {
ChartsOverlay,
ChartsOverlayProps,
Expand Down Expand Up @@ -153,12 +153,12 @@ const PieChart = React.forwardRef(function PieChart(inProps: PieChartProps, ref)
className,
...other
} = props;
const isRTL = useIsRTL();
const isRtl = useRtl();

const margin = { ...(isRTL ? defaultRTLMargin : defaultMargin), ...marginProps };
const margin = { ...(isRtl ? defaultRTLMargin : defaultMargin), ...marginProps };
const legend: ChartsLegendProps = {
direction: 'column',
position: { vertical: 'middle', horizontal: isRTL ? 'left' : 'right' },
position: { vertical: 'middle', horizontal: isRtl ? 'left' : 'right' },
...legendProps,
};

Expand Down
6 changes: 0 additions & 6 deletions packages/x-charts/src/internals/useIsRTL.ts

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as React from 'react';
import { useTheme } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import PropTypes from 'prop-types';
import MenuItem from '@mui/material/MenuItem';
import ListItemIcon from '@mui/material/ListItemIcon';
Expand All @@ -12,7 +12,7 @@ function GridColumnMenuPinningItem(props: GridColumnMenuItemProps) {
const { colDef, onClick } = props;
const apiRef = useGridApiContext();
const rootProps = useGridRootProps();
const theme = useTheme();
const isRtl = useRtl();

const pinColumn = React.useCallback(
(side: GridPinnedColumnPosition) => (event: React.MouseEvent<HTMLElement>) => {
Expand Down Expand Up @@ -76,7 +76,7 @@ function GridColumnMenuPinningItem(props: GridColumnMenuItemProps) {
);
}

if (theme.direction === 'rtl') {
if (isRtl) {
return (
<React.Fragment>
{pinToRightMenuItem}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as React from 'react';
import composeClasses from '@mui/utils/composeClasses';
import { useTheme } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import {
CursorCoordinates,
useGridApiEventHandler,
Expand Down Expand Up @@ -75,7 +75,7 @@ export const useGridColumnReorder = (
const removeDnDStylesTimeout = React.useRef<ReturnType<typeof setTimeout>>();
const ownerState = { classes: props.classes };
const classes = useUtilityClasses(ownerState);
const theme = useTheme();
const isRtl = useRtl();

React.useEffect(() => {
return () => {
Expand Down Expand Up @@ -219,14 +219,10 @@ export const useGridColumnReorder = (
const cursorMoveDirectionX = getCursorMoveDirectionX(cursorPosition.current, coordinates);
const hasMovedLeft =
cursorMoveDirectionX === CURSOR_MOVE_DIRECTION_LEFT &&
(theme.direction === 'rtl'
? dragColIndex < targetColIndex
: targetColIndex < dragColIndex);
(isRtl ? dragColIndex < targetColIndex : targetColIndex < dragColIndex);
const hasMovedRight =
cursorMoveDirectionX === CURSOR_MOVE_DIRECTION_RIGHT &&
(theme.direction === 'rtl'
? targetColIndex < dragColIndex
: dragColIndex < targetColIndex);
(isRtl ? targetColIndex < dragColIndex : dragColIndex < targetColIndex);

if (hasMovedLeft || hasMovedRight) {
let canBeReordered: boolean;
Expand Down Expand Up @@ -298,7 +294,7 @@ export const useGridColumnReorder = (
cursorPosition.current = coordinates;
}
},
[apiRef, logger, theme.direction],
[apiRef, logger, isRtl],
);

const handleDragEnd = React.useCallback<GridEventListener<'columnHeaderDragEnd'>>(
Expand Down
6 changes: 3 additions & 3 deletions packages/x-data-grid/src/components/cell/GridActionsCell.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';
import PropTypes from 'prop-types';
import MenuList from '@mui/material/MenuList';
import { useTheme } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import { unstable_useId as useId } from '@mui/utils';
import { GridRenderCellParams } from '../../models/params/gridCellParams';
import { gridClasses } from '../../constants/gridClasses';
Expand Down Expand Up @@ -47,7 +47,7 @@ function GridActionsCell(props: GridActionsCellProps) {
const buttonRef = React.useRef<HTMLButtonElement>(null);
const ignoreCallToFocus = React.useRef(false);
const touchRippleRefs = React.useRef<Record<string, TouchRippleActions | null>>({});
const theme = useTheme();
const isRtl = useRtl();
const menuId = useId();
const buttonId = useId();
const rootProps = useGridRootProps();
Expand Down Expand Up @@ -149,7 +149,7 @@ function GridActionsCell(props: GridActionsCellProps) {
}

// for rtl mode we need to reverse the direction
const rtlMod = theme.direction === 'rtl' ? -1 : 1;
const rtlMod = isRtl ? -1 : 1;
const indexMod = (direction === 'left' ? -1 : 1) * rtlMod;

// if the button that should receive focus is disabled go one more step
Expand Down
2 changes: 1 addition & 1 deletion packages/x-data-grid/src/hooks/core/gridCoreSelector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ import { GridStateCommunity } from '../../models/gridStateCommunity';
* Get the theme state
* @category Core
*/
export const gridThemeSelector = (state: GridStateCommunity) => state.theme;
export const gridIsRtlSelector = (state: GridStateCommunity) => state.isRtl;
4 changes: 2 additions & 2 deletions packages/x-data-grid/src/hooks/core/useGridInitialization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as React from 'react';
import type { GridApiCommon, GridPrivateApiCommon } from '../../models/api/gridApiCommon';
import { DataGridProcessedProps } from '../../models/props/DataGridProps';
import { useGridRefs } from './useGridRefs';
import { useGridTheme } from './useGridTheme';
import { useGridIsRtl } from './useGridIsRtl';
import { useGridLoggerFactory } from './useGridLoggerFactory';
import { useGridApiInitialization } from './useGridApiInitialization';
import { useGridLocaleText } from './useGridLocaleText';
Expand All @@ -23,7 +23,7 @@ export const useGridInitialization = <
const privateApiRef = useGridApiInitialization<PrivateApi, Api>(inputApiRef, props);

useGridRefs(privateApiRef);
useGridTheme(privateApiRef);
useGridIsRtl(privateApiRef);
useGridLoggerFactory(privateApiRef, props);
useGridStateInitialization(privateApiRef);
useGridPipeProcessing(privateApiRef);
Expand Down
20 changes: 20 additions & 0 deletions packages/x-data-grid/src/hooks/core/useGridIsRtl.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import * as React from 'react';
import { useRtl } from '@mui/system/RtlProvider';
import { GridPrivateApiCommon } from '../../models/api/gridApiCommon';

export const useGridIsRtl = (apiRef: React.MutableRefObject<GridPrivateApiCommon>): void => {
const isRtl = useRtl();

if (apiRef.current.state.isRtl === undefined) {
apiRef.current.state.isRtl = isRtl;
}

const isFirstEffect = React.useRef(true);
React.useEffect(() => {
if (isFirstEffect.current) {
isFirstEffect.current = false;
} else {
apiRef.current.setState((state) => ({ ...state, isRtl }));
}
}, [apiRef, isRtl]);
};
20 changes: 0 additions & 20 deletions packages/x-data-grid/src/hooks/core/useGridTheme.tsx

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as React from 'react';
import clsx from 'clsx';
import { styled, useTheme } from '@mui/material/styles';
import { styled } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import { DataGridProcessedProps } from '../../../models/props/DataGridProps';
import { useGridSelector } from '../../utils';
import { useGridRootProps } from '../../utils/useGridRootProps';
Expand Down Expand Up @@ -98,7 +99,7 @@ export const useGridColumnHeaders = (props: UseGridColumnHeadersProps) => {
const [resizeCol, setResizeCol] = React.useState('');

const apiRef = useGridPrivateApiContext();
const theme = useTheme();
const isRtl = useRtl();
const rootProps = useGridRootProps();
const dimensions = useGridSelector(apiRef, gridDimensionsSelector);
const hasVirtualization = useGridSelector(apiRef, gridVirtualizationColumnEnabledSelector);
Expand All @@ -110,7 +111,7 @@ export const useGridColumnHeaders = (props: UseGridColumnHeadersProps) => {
const offsetLeft = computeOffsetLeft(
columnPositions,
renderContext,
theme.direction,
isRtl,
pinnedColumns.left.length,
);
const gridHasFiller = dimensions.columnsTotalWidth < dimensions.viewportOuterSize.width;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
unstable_useEventCallback as useEventCallback,
} from '@mui/utils';
import useLazyRef from '@mui/utils/useLazyRef';
import { useTheme, Direction } from '@mui/material/styles';
import { useRtl } from '@mui/system/RtlProvider';
import {
findGridCellElementsFromCol,
findGridElement,
Expand Down Expand Up @@ -112,11 +112,11 @@ function flipResizeDirection(side: ResizeDirection) {
return 'Right';
}

function getResizeDirection(separator: HTMLElement, direction: Direction) {
function getResizeDirection(separator: HTMLElement, isRtl: boolean) {
const side = separator.classList.contains(gridClasses['columnSeparator--sideRight'])
? 'Right'
: 'Left';
if (direction === 'rtl') {
if (isRtl) {
// Resizing logic should be mirrored in the RTL case
return flipResizeDirection(side);
}
Expand Down Expand Up @@ -280,7 +280,7 @@ export const useGridColumnResize = (
| 'onColumnWidthChange'
>,
) => {
const theme = useTheme();
const isRtl = useRtl();
const logger = useGridLogger(apiRef, 'useGridColumnResize');

const refs = useLazyRef(createResizeRefs).current;
Expand Down Expand Up @@ -491,7 +491,7 @@ export const useGridColumnResize = (
? []
: findRightPinnedHeadersBeforeCol(apiRef.current, refs.columnHeaderElement);

resizeDirection.current = getResizeDirection(separator, theme.direction);
resizeDirection.current = getResizeDirection(separator, isRtl);

initialOffsetToSeparator.current = computeOffsetToSeparator(
xStart,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
GridPinnedColumnFields,
EMPTY_PINNED_COLUMN_FIELDS,
} from './gridColumnsInterfaces';
import { gridThemeSelector } from '../../core/gridCoreSelector';
import { gridIsRtlSelector } from '../../core/gridCoreSelector';

/**
* Get the columns state
Expand Down Expand Up @@ -85,13 +85,9 @@ export const gridVisiblePinnedColumnDefinitionsSelector = createSelectorMemoized
gridColumnsStateSelector,
gridPinnedColumnsSelector,
gridVisibleColumnFieldsSelector,
gridThemeSelector,
(columnsState, model, visibleColumnFields, theme) => {
const visiblePinnedFields = filterVisibleColumns(
model,
visibleColumnFields,
theme.direction === 'rtl',
);
gridIsRtlSelector,
(columnsState, model, visibleColumnFields, isRtl) => {
const visiblePinnedFields = filterVisibleColumns(model, visibleColumnFields, isRtl);
const visiblePinnedColumns = {
left: visiblePinnedFields.left.map((field) => columnsState.lookup[field]),
right: visiblePinnedFields.right.map((field) => columnsState.lookup[field]),
Expand Down
Loading

0 comments on commit bba14d0

Please sign in to comment.