diff --git a/packages/rath-client/package.json b/packages/rath-client/package.json index b45885b2..7a92a21a 100644 --- a/packages/rath-client/package.json +++ b/packages/rath-client/package.json @@ -24,8 +24,6 @@ "airtable": "^0.11.4", "ali-react-table": "^2.6.1", "codemirror": "^6.0.1", - "d3-dag": "^0.11.5", - "d3-shape": "^3.1.0", "dayjs": "^1.11.6", "immer": "^9.0.6", "localforage": "^1.10.0", @@ -55,7 +53,6 @@ "@testing-library/react": "^11.2.3", "@testing-library/user-event": "^12.6.0", "@types/crypto-js": "^4.1.0", - "@types/d3-shape": "^3.1.0", "@types/jest": "^26.0.20", "@types/node": "^12.19.12", "@types/react": "^17.0.2", diff --git a/packages/rath-client/src/components/fieldPlaceholder.tsx b/packages/rath-client/src/components/fieldPlaceholder.tsx index d5af9b64..643a5678 100644 --- a/packages/rath-client/src/components/fieldPlaceholder.tsx +++ b/packages/rath-client/src/components/fieldPlaceholder.tsx @@ -34,7 +34,7 @@ export const PillPlaceholder = styled.div` ` interface FieldPlaceholderProps { - fields: IFieldMeta[]; + fields: readonly IFieldMeta[]; onAdd: (fid: string) => void; } const FieldPlaceholder: React.FC = props => { diff --git a/packages/rath-client/src/components/filterCreationPill.tsx b/packages/rath-client/src/components/filterCreationPill.tsx index 8dd7dc96..d66a89c3 100644 --- a/packages/rath-client/src/components/filterCreationPill.tsx +++ b/packages/rath-client/src/components/filterCreationPill.tsx @@ -29,7 +29,7 @@ const Cont = styled.div` min-width: 16em; `; interface FilterCreationPillProps { - fields: IFieldMeta[]; + fields: readonly IFieldMeta[]; onFilterSubmit: (field: IFieldMeta, filter: IFilter) => void; onRenderPill?: (text: string, handleClick: () => void) => void; } diff --git a/packages/rath-client/src/components/react-vega.tsx b/packages/rath-client/src/components/react-vega.tsx index 71f5befa..3fd04667 100644 --- a/packages/rath-client/src/components/react-vega.tsx +++ b/packages/rath-client/src/components/react-vega.tsx @@ -6,7 +6,7 @@ import { EDITOR_URL } from '../constants'; import { getVegaTimeFormatRules } from '../utils'; interface ReactVegaProps { - dataSource: any[]; + dataSource: readonly any[]; spec: any; actions?: boolean; signalHandler?: { diff --git a/packages/rath-client/src/constants.ts b/packages/rath-client/src/constants.ts index 0528b2e0..15830c81 100644 --- a/packages/rath-client/src/constants.ts +++ b/packages/rath-client/src/constants.ts @@ -47,7 +47,7 @@ export const STORAGES = { DATASOURCE: 'datasource', WORKSPACE: 'workspace', META: 'meta', - MODEL: 'model', + CAUSAL_MODEL: 'causal', STATE: 'state', ITERATOR: 'iterator', CONFIG: 'config', diff --git a/packages/rath-client/src/hooks/use-bounding-client-rect.ts b/packages/rath-client/src/hooks/use-bounding-client-rect.ts new file mode 100644 index 00000000..186f80b3 --- /dev/null +++ b/packages/rath-client/src/hooks/use-bounding-client-rect.ts @@ -0,0 +1,69 @@ +import { RefObject, useEffect, useRef, useState } from "react"; + + +export type BoundingClientRectAttributes = { + /** @default true */ + -readonly [key in keyof Omit]?: boolean; +}; + +/** + * Updates on certain keys of DOMRect changes, detected using `ResizeObserver`. + * DISCUSS: use `IntersectionObserver` with ref elements to implements position changes. + */ +const useBoundingClientRect = < + T extends BoundingClientRectAttributes = { -readonly [key in keyof Omit]: true }, + M extends { -readonly [key in keyof Omit]: T[key] extends true ? key : never } = { -readonly [key in keyof Omit]: T[key] extends true ? key : never }, + E extends Exclude & keyof DOMRect = Exclude & keyof DOMRect, + R extends { readonly [key in E]?: DOMRect[key] } = { readonly [key in E]?: DOMRect[key] }, +>( + ref: RefObject, + /** @default {height:true,width:true,x:true,y:true,bottom:true,left:true,right:true,top:true} */ + attributes: T = { + height: true, + width: true, + x: true, + y: true, + bottom: true, + left: true, + right: true, + top: true, + } as T, +): R => { + const compareKeysRef = useRef<(keyof BoundingClientRectAttributes)[]>([]); + compareKeysRef.current = (["height", "width", "x", "y", "bottom", "left", "right", "top"] as const).filter(key => { + return attributes[key] === true; + }); + + const [box, setBox] = useState({} as R); + + const prevRectRef = useRef(); + const shouldReportRef = useRef<(next: DOMRect) => boolean>(() => true); + shouldReportRef.current = (next: DOMRect): boolean => { + return !prevRectRef.current || compareKeysRef.current.some(k => next[k] !== prevRectRef.current![k]); + }; + + useEffect(() => { + const { current: element } = ref; + + if (element) { + const cb = () => { + const rect = element.getBoundingClientRect(); + if (shouldReportRef.current(rect)) { + setBox(Object.fromEntries(compareKeysRef.current.map(key => [key, rect[key]])) as R); + } + prevRectRef.current = rect; + }; + const ro = new ResizeObserver(cb); + ro.observe(element); + cb(); + return () => { + ro.disconnect(); + }; + } + }, [ref]); + + return box; +}; + + +export default useBoundingClientRect; diff --git a/packages/rath-client/src/pages/causal/config.ts b/packages/rath-client/src/pages/causal/config.ts index ccafa6a3..63fd7c8f 100644 --- a/packages/rath-client/src/pages/causal/config.ts +++ b/packages/rath-client/src/pages/causal/config.ts @@ -28,6 +28,7 @@ export type IAlgoSchema = { }; /** + * @deprecated * a number match { -1 | [0, 1] } * * -1 for not connected: src ---x--> tar @@ -36,6 +37,9 @@ export type IAlgoSchema = { */ export type BgConfidenceLevel = number; +/** + * @deprecated + */ export type BgKnowledge = { src: string; tar: string; @@ -56,8 +60,10 @@ export interface PagLink { tar_type: PAG_NODE; } +/** @deprecated */ export type BgKnowledgePagLink = PagLink; +/** @deprecated */ export type ModifiableBgKnowledge = { src: BgKnowledge['src']; tar: BgKnowledge['tar']; diff --git a/packages/rath-client/src/pages/causal/datasetPanel.tsx b/packages/rath-client/src/pages/causal/datasetPanel.tsx index 4fb33d2d..a97dc70e 100644 --- a/packages/rath-client/src/pages/causal/datasetPanel.tsx +++ b/packages/rath-client/src/pages/causal/datasetPanel.tsx @@ -5,12 +5,11 @@ import { Label, SelectionMode, Slider, - Spinner, Stack, } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; import styled from 'styled-components'; -import React, { useCallback, useMemo, useRef } from 'react'; +import { FC, useCallback, useMemo, useRef } from 'react'; import produce from 'immer'; import intl from 'react-intl-universal' import { useGlobalStore } from '../../store'; @@ -18,7 +17,6 @@ import FilterCreationPill from '../../components/filterCreationPill'; import LaTiaoConsole from '../../components/latiaoConsole/index'; import type { IFieldMeta } from '../../interfaces'; import { FilterCell } from './filters'; -import type { useDataViews } from './hooks/dataViews'; const TableContainer = styled.div` @@ -62,29 +60,32 @@ const Row = styled.div<{ selected: boolean }>` const SelectedKey = 'selected'; -export interface DatasetPanelProps { - context: ReturnType; -} - -const DatasetPanel: React.FC = ({ context }) => { +const DatasetPanel: FC = () => { const { dataSourceStore, causalStore } = useGlobalStore(); - const { fieldMetas, cleanedData } = dataSourceStore; - const { focusFieldIds } = causalStore; - const totalFieldsRef = useRef(fieldMetas); - totalFieldsRef.current = fieldMetas; + const { cleanedData } = dataSourceStore; + const { + fields, allFields, filteredDataSize, sampleRate, sampleSize, filters + } = causalStore.dataset; - const { dataSubset, sampleRate, setSampleRate, appliedSampleRate, filters, setFilters, sampleSize } = context; + const totalFieldsRef = useRef(allFields); + totalFieldsRef.current = allFields; - const focusFieldIdsRef = useRef(focusFieldIds); - focusFieldIdsRef.current = focusFieldIds; + const fieldsRef = useRef(fields); + fieldsRef.current = fields; const toggleFocus = useCallback((fid: string) => { - causalStore.setFocusFieldIds(produce(focusFieldIdsRef.current, draft => { - const idx = draft.findIndex(key => fid === key); + const prevIndices = fieldsRef.current.map( + f => totalFieldsRef.current.findIndex(which => f.fid === which.fid) + ).filter(idx => idx !== -1); + causalStore.dataset.selectFields(produce(prevIndices, draft => { + const idx = totalFieldsRef.current.findIndex(f => f.fid === fid); if (idx !== -1) { - draft.splice(idx, 1); - } else { - draft.push(fid); + const i = draft.findIndex(which => which === idx); + if (i !== -1) { + draft.splice(i, 1); + } else { + draft.push(idx); + } } })); }, [causalStore]); @@ -97,15 +98,15 @@ const DatasetPanel: React.FC = ({ context }) => { onRenderHeader: () => { const handleClick = (_: unknown, checked?: boolean | undefined) => { if (checked) { - causalStore.setFocusFieldIds(totalFieldsRef.current.map(f => f.fid)); + causalStore.selectFields(totalFieldsRef.current.map((_, i) => i)); } else { - causalStore.setFocusFieldIds([]); + causalStore.selectFields([]); } }; return ( 0 && focusFieldIds.length < totalFieldsRef.current.length} + checked={fields.length === totalFieldsRef.current.length} + indeterminate={fields.length > 0 && fields.length < totalFieldsRef.current.length} onChange={handleClick} styles={{ root: { @@ -117,7 +118,7 @@ const DatasetPanel: React.FC = ({ context }) => { }, onRender: (item) => { const field = item as IFieldMeta; - const checked = focusFieldIds.includes(field.fid); + const checked = fields.some(f => f.fid === field.fid); return ( ); @@ -128,7 +129,7 @@ const DatasetPanel: React.FC = ({ context }) => { }, { key: 'name', - name: `因素 (${focusFieldIds.length} / ${totalFieldsRef.current.length})`, + name: `因素 (${fields.length} / ${totalFieldsRef.current.length})`, onRender: (item) => { const field = item as IFieldMeta; return ( @@ -231,7 +232,7 @@ const DatasetPanel: React.FC = ({ context }) => { maxWidth: 100, }, ]; - }, [focusFieldIds, causalStore]); + }, [fields, causalStore]); return ( <> @@ -239,46 +240,7 @@ const DatasetPanel: React.FC = ({ context }) => { - - setSampleRate(val)} - valueFormat={(val) => `${(val * 100).toFixed(0)}%`} - styles={{ - root: { - flexGrow: 0, - flexShrink: 0, - display: 'flex', - flexDirection: 'row', - flexWrap: 'wrap', - alignItems: 'center', - }, - container: { - minWidth: '160px', - maxWidth: '300px', - flexGrow: 1, - flexShrink: 0, - marginInline: '1vmax', - }, - }} - /> - - {`原始大小: ${cleanedData.length} 行,样本量: `} - {sampleRate !== appliedSampleRate ? ( - - ) : ( - `${sampleSize} 行` - )} - - - + @@ -303,38 +265,64 @@ const DatasetPanel: React.FC = ({ context }) => { }} > {filters.map((filter, i) => { - const field = fieldMetas.find((f) => f.fid === filter.fid); + const field = allFields.find((f) => f.fid === filter.fid); return field ? ( - setFilters((list) => { - return produce(list, (draft) => { - draft.splice(i, 1); - }); - }) - } + remove={() => causalStore.dataset.removeFilter(i)} /> ) : null; })} )} - {`${filters.length ? `筛选后子集大小: ${dataSubset.length} 行` : '(无筛选项)'}`} + {`原始大小: ${cleanedData.length} 行,${filters.length ? `筛选后子集大小: ${filteredDataSize} 行` : '(无筛选项)'}`} + + + + causalStore.dataset.sampleRate = val} + valueFormat={(val) => `${(val * 100).toFixed(0)}%`} + styles={{ + root: { + flexGrow: 0, + flexShrink: 0, + display: 'flex', + flexDirection: 'row', + flexWrap: 'wrap', + alignItems: 'center', + }, + container: { + minWidth: '160px', + maxWidth: '300px', + flexGrow: 1, + flexShrink: 0, + marginInline: '1vmax', + }, + }} + /> + + {`样本量: ${sampleSize} 行`} { const field = props?.item as IFieldMeta; - const checked = focusFieldIds.includes(field.fid); + const checked = fields.some(f => f.fid === field.fid); return ( toggleFocus(field.fid)}> {defaultRender?.(props)} diff --git a/packages/rath-client/src/pages/causal/exploration/autoVis/index.tsx b/packages/rath-client/src/pages/causal/exploration/autoVis/index.tsx new file mode 100644 index 00000000..35d03fe4 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/autoVis/index.tsx @@ -0,0 +1,113 @@ +import { FC, useCallback, useMemo } from "react"; +import { observer } from "mobx-react-lite"; +import { Stack } from "@fluentui/react"; +import styled from "styled-components"; +import { NodeSelectionMode, useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { useGlobalStore } from "../../../../store"; +import { IFieldMeta } from "../../../../interfaces"; +import ViewField from "../../../megaAutomation/vizOperation/viewField"; +import FieldPlaceholder from "../../../../components/fieldPlaceholder"; +import MetaList from "./metaList"; +import Vis from "./vis"; +import NeighborList from "./neighborList"; + + +const Container = styled.div` + display: flex; + flex-direction: column; + > * { + flex-grow: 0; + flex-shrink: 0; + & header { + font-size: 1rem; + font-weight: 500; + padding: 0.5em 0 0; + } + } + & .ms-DetailsList { + text-align: center; + & * { + line-height: 1.6em; + min-height: unset; + } + & .ms-DetailsList-headerWrapper { + & * { + height: 2.2em; + } + } + & [role=gridcell] { + display: inline-block; + padding: 0.2em 8px; + height: max-content; + } + & .vega-embed { + margin: 0 0 -10px; + } + } +`; + +const PillContainer = styled.div` + display: flex; + flex-direction: row; + flex-wrap: wrap; + > * { + flex-grow: 0; + flex-shrink: 0; + } +`; + +export interface IAutoVisProps {} + +const AutoVis: FC = () => { + const { causalStore } = useGlobalStore(); + const { fields } = causalStore.dataset; + const viewContext = useCausalViewContext(); + + const { + graphNodeSelectionMode = NodeSelectionMode.NONE, selectedField = null, selectedFieldGroup = [] + } = viewContext ?? {}; + + const selectedFields = useMemo(() => { + if (graphNodeSelectionMode === NodeSelectionMode.NONE) { + return []; + } else if (graphNodeSelectionMode === NodeSelectionMode.SINGLE) { + return selectedField ? [selectedField] : []; + } else { + return selectedFieldGroup; + } + }, [graphNodeSelectionMode, selectedField, selectedFieldGroup]); + + const appendFieldHandler = useCallback((fid: string) => { + viewContext?.selectNode(fid); + }, [viewContext]); + + return viewContext && ( + + + + {selectedFields.map((f: IFieldMeta) => ( + { + viewContext.toggleNodeSelected(f.fid); + }} + /> + ))} + {graphNodeSelectionMode === NodeSelectionMode.MULTIPLE && ( + + )} + + + + + + + + + ); +}; + + +export default observer(AutoVis); diff --git a/packages/rath-client/src/pages/causal/exploration/autoVis/metaList.tsx b/packages/rath-client/src/pages/causal/exploration/autoVis/metaList.tsx new file mode 100644 index 00000000..f312974e --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/autoVis/metaList.tsx @@ -0,0 +1,91 @@ +import { DetailsList, IColumn, SelectionMode } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { FC, useMemo } from "react"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import DistributionChart from "../../../dataSource/metaView/distChart"; + + +const metaKeys = ['dist', 'unique', 'mean', 'min', 'qt_25', 'qt_50', 'qt_75', 'max', 'stdev'] as const; + +const COL_WIDTH = 128; +const DIST_CHART_HEIGHT = 20; + +const MetaList: FC = () => { + const viewContext = useCausalViewContext(); + const { selectedFieldGroup } = viewContext ?? {}; + + const columns = useMemo(() => { + return new Array({ + key: 'KEY', + name: '', + minWidth: 100, + maxWidth: 100, + isResizable: false, + onRender(key: typeof metaKeys[number]) { + return { + dist: '分布', + unique: '唯一值数量', + mean: '均值', + min: '最小值', + qt_25: '25% 分位数', + qt_50: '50% 分位数', + qt_75: '75% 分位数', + max: '最大值', + stdev: '标准差', + }[key]; + }, + }).concat(selectedFieldGroup?.map(f => ({ + key: f.fid, + name: f.name || f.fid, + minWidth: COL_WIDTH, + maxWidth: COL_WIDTH, + isResizable: false, + onRender(key: typeof metaKeys[number]) { + if (key === 'dist') { + return ( + + ); + } + const value = f.features[key]; + if (typeof value === 'number') { + if (key === 'unique') { + return value.toFixed(0); + } + if (Number.isFinite(value)) { + if (Math.abs(value - Math.floor(value)) < Number.MIN_VALUE) { + return value.toFixed(0); + } + return value > 0 && value < 1e-2 ? value.toExponential(2) : value.toPrecision(4); + } + return '-'; + } + return value ?? '-'; + }, + })) ?? []); + }, [selectedFieldGroup]); + + return selectedFieldGroup?.length ? ( +
+
+ 统计信息 +
+ +
+ ) : null; +}; + + +export default observer(MetaList); diff --git a/packages/rath-client/src/pages/causal/exploration/autoVis/neighborList.tsx b/packages/rath-client/src/pages/causal/exploration/autoVis/neighborList.tsx new file mode 100644 index 00000000..f8e93308 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/autoVis/neighborList.tsx @@ -0,0 +1,102 @@ +import { DetailsList, IColumn, SelectionMode } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { FC, useMemo } from "react"; +import { useGlobalStore } from "../../../../store"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { PAG_NODE } from "../../config"; + + +const NeighborList: FC = () => { + const { causalStore } = useGlobalStore(); + const { fields } = causalStore; + const { mutualMatrix, causality } = causalStore.model; + const viewContext = useCausalViewContext(); + const { selectedFieldGroup = [] } = viewContext ?? {}; + + const neighbors = useMemo(() => { + if (!mutualMatrix || !causality) { + return []; + } + return causality.reduce<{ + cause: string; + effect: string; + corr: number; + }[]>((list, link) => { + const isIncluded = [link.src, link.tar].some(fid => selectedFieldGroup.some(f => f.fid === fid)); + const srcIdx = fields.findIndex(f => f.fid === link.src); + const tarIdx = fields.findIndex(f => f.fid === link.tar); + const src = fields[srcIdx]; + const tar = fields[tarIdx]; + if (isIncluded && src && tar) { + if (link.src_type !== PAG_NODE.ARROW) { + list.push({ cause: src.name || src.fid, effect: tar.name || tar.fid, corr: mutualMatrix[srcIdx][tarIdx] }); + } + if (link.tar_type !== PAG_NODE.ARROW) { + list.push({ cause: tar.name || tar.fid, effect: src.name || src.fid, corr: mutualMatrix[tarIdx][srcIdx] }); + } + } + return list; + }, []); + }, [mutualMatrix, causality, selectedFieldGroup, fields]); + + const columns = useMemo(() => { + return [ + { + key: 'cause', + name: '因', + minWidth: 100, + maxWidth: 100, + isResizable: false, + onRender(item: typeof neighbors[number]) { + return item.cause; + }, + }, + { + key: 'corr', + name: '相关系数', + minWidth: 120, + maxWidth: 120, + isResizable: false, + onRender(item: typeof neighbors[number]) { + const value = item.corr; + if (typeof value === 'number') { + if (Number.isFinite(value)) { + if (Math.abs(value - Math.floor(value)) < Number.MIN_VALUE) { + return value.toFixed(0); + } + return value > 0 && value < 1e-2 ? value.toExponential(2) : value.toPrecision(4); + } + return '-'; + } + return value ?? '-'; + }, + }, + { + key: 'effect', + name: '果', + minWidth: 100, + maxWidth: 100, + isResizable: false, + onRender(item: typeof neighbors[number]) { + return item.effect; + }, + }, + ]; + }, []); + + return selectedFieldGroup?.length ? ( +
+
+ 关联因素 +
+ +
+ ) : null; +}; + + +export default observer(NeighborList); diff --git a/packages/rath-client/src/pages/causal/exploration/autoVis/vis.tsx b/packages/rath-client/src/pages/causal/exploration/autoVis/vis.tsx new file mode 100644 index 00000000..1cdd97d0 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/autoVis/vis.tsx @@ -0,0 +1,65 @@ +import { FC, useMemo } from "react"; +import { IPattern } from "@kanaries/loa"; +import { observer } from "mobx-react-lite"; +import { toJS } from "mobx"; +import { NodeSelectionMode, useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { distVis } from "../../../../queries/distVis"; +import ErrorBoundary from "../../../../components/visErrorBoundary"; +import ReactVega from "../../../../components/react-vega"; +import { useGlobalStore } from "../../../../store"; + + +const Vis: FC = () => { + const { causalStore } = useGlobalStore(); + const { visSample } = causalStore.dataset; + const viewContext = useCausalViewContext(); + + const { + graphNodeSelectionMode = NodeSelectionMode.NONE, selectedField = null, selectedFieldGroup = [] + } = viewContext ?? {}; + + const selectedFields = useMemo(() => { + if (graphNodeSelectionMode === NodeSelectionMode.NONE) { + return []; + } else if (graphNodeSelectionMode === NodeSelectionMode.SINGLE) { + return selectedField ? [selectedField] : []; + } else { + return selectedFieldGroup; + } + }, [graphNodeSelectionMode, selectedField, selectedFieldGroup]); + + const viewPattern = useMemo(() => { + if (selectedFields.length === 0) { + return null; + } + return { + fields: selectedFields, + imp: selectedFields[0].features.entropy, + }; + }, [selectedFields]); + + const viewSpec = useMemo(() => { + if (viewPattern === null) { + return null; + } + return distVis({ + pattern: toJS(viewPattern), + interactive: true, + specifiedEncodes: viewPattern.encodes, + }); + }, [viewPattern]); + + return viewContext && viewSpec && ( +
+
+ 可视化分析 +
+ + + +
+ ); +}; + + +export default observer(Vis); diff --git a/packages/rath-client/src/pages/causal/exploration/causalBlame/index.tsx b/packages/rath-client/src/pages/causal/exploration/causalBlame/index.tsx new file mode 100644 index 00000000..dd796f93 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/causalBlame/index.tsx @@ -0,0 +1,170 @@ +import { DetailsList, IColumn, SelectionMode, Stack } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { FC, useCallback, useMemo, useRef } from "react"; +import styled from "styled-components"; +import useBoundingClientRect from "../../../../hooks/use-bounding-client-rect"; +import { useGlobalStore } from "../../../../store"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import FullDistViz from "../../../dataSource/profilingView/fullDistViz"; +import { PAG_NODE } from "../../config"; + + +const Section = styled.div` + display: flex; + flex-direction: column; + > header { + font-size: 0.9rem; + font-weight: 500; + padding: 1em 0; + } +`; + +const META_VIEW_PADDING = 60; +const META_VIEW_HEIGHT = 200; + +const CausalBlame: FC = () => { + const { causalStore } = useGlobalStore(); + const { fields } = causalStore; + const { mutualMatrix, causality } = causalStore.model; + const viewContext = useCausalViewContext(); + const { selectedField } = viewContext ?? {}; + const metaViewContainerRef = useRef(null); + const { width = META_VIEW_PADDING } = useBoundingClientRect(metaViewContainerRef, { width: true }); + + const handleSelect = useCallback((ticks: readonly number[]) => { + // hello world + }, []); + + const neighbors = useMemo<{ + cause: string | undefined; + causeWeight: number | undefined; + causeCorr: number; + effect: string | undefined; + effectWeight: number | undefined; + effectCorr: number; + }[]>(() => { + return selectedField ? fields.filter(f => f.fid !== selectedField.fid).map(f => { + const cause = causality?.find(link => { + if (![link.src, link.tar].every(node => [selectedField.fid, f.fid].includes(node))) { + return false; + } + const currType = link.tar === f.fid ? link.tar_type : link.src_type; + return currType !== PAG_NODE.ARROW; + }); + const effect = causality?.find(link => { + if (![link.src, link.tar].every(node => [selectedField.fid, f.fid].includes(node))) { + return false; + } + const targetType = link.src === selectedField.fid ? link.src_type : link.tar_type; + return targetType !== PAG_NODE.ARROW; + }); + const selectedIdx = fields.findIndex(which => which.fid === selectedField.fid); + const currIdx = fields.findIndex(which => which.fid === f.fid); + return { + cause: cause ? (f.name || f.fid) : undefined, + causeWeight: cause ? -1 : undefined, + causeCorr: mutualMatrix?.[currIdx]?.[selectedIdx] ?? -1, + effect: effect ? (f.name || f.fid) : undefined, + effectWeight: effect ? -1 : undefined, + effectCorr: mutualMatrix?.[selectedIdx]?.[currIdx] ?? -1, + }; + }) : []; + }, [fields, selectedField, mutualMatrix, causality]); + + const columns = useMemo(() => { + return [ + { + key: 'cause', + name: 'Cause', + iconName: 'AlignHorizontalLeft', + isResizable: false, + minWidth: 80, + maxWidth: 80, + onRender(item) { + return item['cause']; + }, + }, + { + key: 'causeWeight', + name: 'Responsibility', + isResizable: false, + minWidth: 120, + maxWidth: 120, + onRender(item) { + return item['causeWeight']; + }, + }, + { + key: 'causeCorr', + name: 'Correlation', + isResizable: false, + minWidth: 120, + maxWidth: 120, + onRender(item) { + return item['causeCorr']; + }, + }, + { + key: 'effectCorr', + name: 'Correlation', + isResizable: false, + minWidth: 120, + maxWidth: 120, + onRender(item) { + return item['effectCorr']; + }, + }, + { + key: 'effectWeight', + name: 'Responsibility', + isResizable: false, + minWidth: 120, + maxWidth: 120, + onRender(item) { + return item['effectWeight']; + }, + }, + { + key: 'effect', + name: 'Effect', + iconName: 'AlignHorizontalRight', + isResizable: false, + minWidth: 80, + maxWidth: 80, + onRender(item) { + return item['effect']; + }, + }, + ]; + }, []); + + return selectedField ? ( + +
+
单变量分析
+ void} + /> +
+
+
关联因素
+ +
+
+ ) : null; +}; + + +export default observer(CausalBlame); diff --git a/packages/rath-client/src/pages/causal/crossFilter/colDist.tsx b/packages/rath-client/src/pages/causal/exploration/crossFilter/colDist.tsx similarity index 97% rename from packages/rath-client/src/pages/causal/crossFilter/colDist.tsx rename to packages/rath-client/src/pages/causal/exploration/crossFilter/colDist.tsx index c9521b7e..a4df14e7 100644 --- a/packages/rath-client/src/pages/causal/crossFilter/colDist.tsx +++ b/packages/rath-client/src/pages/causal/exploration/crossFilter/colDist.tsx @@ -2,8 +2,8 @@ import { ISemanticType } from '@kanaries/loa'; import React, { useEffect, useRef } from 'react'; import { View } from 'vega'; import embed from 'vega-embed'; -import { IRow } from '../../../interfaces'; -import { throttle } from '../../../utils'; +import { IRow } from '../../../../interfaces'; +import { throttle } from '../../../../utils'; export const SELECT_SIGNAL_NAME = '__select__'; export interface IBrushSignalStore { @@ -12,7 +12,7 @@ export interface IBrushSignalStore { values: any[]; } interface ColDistProps { - data: IRow[]; + data: readonly IRow[]; fid: string; name?: string; semanticType: ISemanticType; diff --git a/packages/rath-client/src/pages/causal/crossFilter/index.tsx b/packages/rath-client/src/pages/causal/exploration/crossFilter/index.tsx similarity index 96% rename from packages/rath-client/src/pages/causal/crossFilter/index.tsx rename to packages/rath-client/src/pages/causal/exploration/crossFilter/index.tsx index c529101e..82c541dd 100644 --- a/packages/rath-client/src/pages/causal/crossFilter/index.tsx +++ b/packages/rath-client/src/pages/causal/exploration/crossFilter/index.tsx @@ -1,7 +1,7 @@ import { IconButton } from '@fluentui/react'; import React, { useCallback, useEffect, useState } from 'react'; import styled from 'styled-components'; -import { IFieldMeta, IRow } from '../../../interfaces'; +import { IFieldMeta, IRow } from '../../../../interfaces'; import ColDist, { IBrushSignalStore } from './colDist'; const VizContainer = styled.div` @@ -19,8 +19,8 @@ const VizCard = styled.div` `; interface CrossFilterProps { - fields: IFieldMeta[]; - dataSource: IRow[]; + fields: readonly IFieldMeta[]; + dataSource: readonly IRow[]; onVizEdit?: (fid: string) => void; onVizClue?: (fid: string) => void; onVizDelete?: (fid: string) => void; diff --git a/packages/rath-client/src/pages/causal/explainer/RExplainer.tsx b/packages/rath-client/src/pages/causal/exploration/explainer/RExplainer.tsx similarity index 80% rename from packages/rath-client/src/pages/causal/explainer/RExplainer.tsx rename to packages/rath-client/src/pages/causal/exploration/explainer/RExplainer.tsx index 466fabea..93c692aa 100644 --- a/packages/rath-client/src/pages/causal/explainer/RExplainer.tsx +++ b/packages/rath-client/src/pages/causal/exploration/explainer/RExplainer.tsx @@ -1,42 +1,33 @@ import { observer } from 'mobx-react-lite'; import styled from 'styled-components'; import { DefaultButton, Dropdown, Stack, Toggle } from '@fluentui/react'; -import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { FC, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { applyFilters } from '@kanaries/loa'; -import { useGlobalStore } from '../../../store'; -import type { useInteractFieldGroups } from '../hooks/interactFieldGroup'; -import type { useDataViews } from '../hooks/dataViews'; -import { IFieldMeta, IFilter, IRow } from '../../../interfaces'; -import type { IRInsightExplainResult, IRInsightExplainSubspace } from '../../../workers/insight/r-insight.worker'; -import { RInsightService } from '../../../services/r-insight'; -import type { IFunctionalDep, PagLink } from '../config'; +import { useGlobalStore } from '../../../../store'; +import { useCausalViewContext } from '../../../../store/causalStore/viewStore'; +import { IFieldMeta, IFilter, IRow } from '../../../../interfaces'; +import type { IRInsightExplainResult, IRInsightExplainSubspace } from '../../../../workers/insight/r-insight.worker'; +import { RInsightService } from '../../../../services/r-insight'; import ChartItem from './explainChart'; import RInsightView from './RInsightView'; const Container = styled.div``; -export interface RExplainerProps { - context: ReturnType; - interactFieldGroups: ReturnType; - functionalDependencies: IFunctionalDep[]; - edges: PagLink[]; -} - export const SelectedFlag = '__RExplainer_selected__'; -const RExplainer: React.FC = ({ context, interactFieldGroups, functionalDependencies, edges }) => { +const RExplainer: FC = () => { const { dataSourceStore, causalStore } = useGlobalStore(); const { fieldMetas } = dataSourceStore; - const { fieldGroup } = interactFieldGroups; - const { selectedFields } = causalStore; - - const { sample, vizSampleData } = context; + const viewContext = useCausalViewContext(); + const { selectedFieldGroup = [] } = viewContext ?? {}; + const { fields, sample, visSample } = causalStore.dataset; + const { mergedPag, functionalDependencies } = causalStore.model; - const mainField = fieldGroup.at(-1) ?? null; + const mainField = selectedFieldGroup.at(-1) ?? null; const [indexKey, setIndexKey] = useState(null); - const [aggr, setAggr] = useState<"sum" | "mean" | "count" | null>('count'); - const [diffMode, setDiffMode] = useState<"full" | "other" | "two-group">("full"); + const [aggr, setAggr] = useState<"sum" | "mean" | "count" | null>('sum'); + const [diffMode, setDiffMode] = useState<"full" | "other" | "two-group">("other"); useEffect(() => { setIndexKey(ik => ik ? fieldMetas.find(f => f.fid === ik.fid) ?? null : null); @@ -49,11 +40,12 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f }, [mainField, aggr]); const [irResult, setIrResult] = useState({ causalEffects: [] }); - const [serviceMode, setServiceMode] = useState<'worker' | 'server'>('worker'); + const [serviceMode, setServiceMode] = useState<'worker' | 'server'>('server'); const pendingRef = useRef>(); const calculate = useCallback(() => { + viewContext?.clearLocalWeights(); if (!subspaces || !mainField) { setIrResult({ causalEffects: [] }); return; @@ -65,25 +57,27 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f } const p = new Promise(resolve => { const fieldsInSight = new Set(current.predicates.map(pdc => pdc.fid).concat([mainField.fid])); - RInsightService({ - data: sample, - fields: selectedFields, - causalModel: { - funcDeps: functionalDependencies, - edges, - }, - groups: { - current, - other, - }, - view: { - dimensions: [...fieldsInSight].filter(fid => fid !== mainField.fid), - measures: [mainField].map(ms => ({ - fid: ms.fid, - op: aggr, - })), - }, - }, serviceMode).then(resolve); + sample.getAll().then(data => { + RInsightService({ + data, + fields, + causalModel: { + funcDeps: functionalDependencies, + edges: mergedPag, + }, + groups: { + current, + other, + }, + view: { + dimensions: [...fieldsInSight].filter(fid => fid !== mainField.fid), + measures: [mainField].map(ms => ({ + fid: ms.fid, + op: aggr, + })), + }, + }, serviceMode).then(resolve); + }); }); pendingRef.current = p; p.then(res => { @@ -93,20 +87,21 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f item => Number.isFinite(item.responsibility)// && item.responsibility !== 0 ).sort((a, b) => b.responsibility - a.responsibility) }); + viewContext?.setLocalWeights(res); } }).finally(() => { pendingRef.current = undefined; }); - }, [aggr, mainField, sample, selectedFields, subspaces, edges, serviceMode, functionalDependencies]); + }, [aggr, mainField, sample, fields, subspaces, mergedPag, serviceMode, functionalDependencies, viewContext]); - const [selectedSet, setSelectedSet] = useState([]); + const [selectedSet, setSelectedSet] = useState([]); const [indicesA, indicesB] = useMemo<[number[], number[]]>(() => { if (!subspaces) { return [[], []]; } const indexName = '__this_is_the_index_of_the_row__'; - const data = sample.map((row, i) => ({ ...row, [indexName]: i })); + const data = visSample.map((row, i) => ({ ...row, [indexName]: i })); const indicesA = applyFilters(data, subspaces[0].predicates).map(row => row[indexName]) as number[]; // console.log('indices'); // console.log(indicesA.join(',')); @@ -116,28 +111,28 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f index => !indicesA.includes(index) ); return [indicesA, indicesB]; - }, [subspaces, sample, diffMode]); + }, [subspaces, visSample, diffMode]); useEffect(() => { setIrResult({ causalEffects: [] }); - }, [indexKey, mainField, sample, subspaces, edges]); + }, [indexKey, mainField, visSample, subspaces, mergedPag]); const applySelection = useCallback(() => { if (!subspaces) { - return setSelectedSet(sample); + return setSelectedSet(visSample); } setSelectedSet( - sample.map((row, i) => ({ ...row, [SelectedFlag]: indicesA.includes(i) ? 1 : indicesB.includes(i) ? 2 : 0 })) + visSample.map((row, i) => ({ ...row, [SelectedFlag]: indicesA.includes(i) ? 1 : indicesB.includes(i) ? 2 : 0 })) ); calculate(); - }, [subspaces, sample, indicesA, indicesB, calculate]); + }, [subspaces, visSample, indicesA, indicesB, calculate]); useEffect(() => { if (!subspaces) { - setSelectedSet(sample); + setSelectedSet(visSample); return; } - }, [subspaces, sample]); + }, [subspaces, visSample]); const [editingGroupIdx, setEditingGroupIdx] = useState<1 | 2>(1); @@ -209,8 +204,8 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f label="对照选择"//"Diff Mode" selectedKey={diffMode} options={[ - { key: 'full', text: '数据全集' || 'Full' }, { key: 'other', text: '数据补集' || 'Other' }, + { key: 'full', text: '数据全集' || 'Full' }, { key: 'two-group', text: '自选两个集合' || 'Two Groups' }, ]} onChange={(_, option) => { @@ -259,7 +254,7 @@ const RExplainer: React.FC = ({ context, interactFieldGroups, f )}
= ({ context, interactFieldGroups, f <> = ({ context, interactFieldGroups, f /> = ({ context, interactFieldGroups, f data={selectedSet} result={irResult} mainField={mainField} - mainFieldAggregation={aggr} entryDimension={indexKey} mode={diffMode} subspaces={subspaces} indices={[indicesA, indicesB]} - functionalDependencies={functionalDependencies} aggr={aggr} serviceMode={serviceMode} - context={context} - edges={edges} /> )} diff --git a/packages/rath-client/src/pages/causal/explainer/RInsightView.tsx b/packages/rath-client/src/pages/causal/exploration/explainer/RInsightView.tsx similarity index 89% rename from packages/rath-client/src/pages/causal/explainer/RInsightView.tsx rename to packages/rath-client/src/pages/causal/exploration/explainer/RInsightView.tsx index f5331a39..ac4204f1 100644 --- a/packages/rath-client/src/pages/causal/explainer/RInsightView.tsx +++ b/packages/rath-client/src/pages/causal/exploration/explainer/RInsightView.tsx @@ -4,29 +4,23 @@ import { observer } from "mobx-react-lite"; import { FC, Fragment, useCallback, useEffect, useRef, useState } from "react"; import styled from "styled-components"; import { useId } from '@fluentui/react-hooks'; -import type { IFieldMeta, IRow } from "../../../interfaces"; -import { useGlobalStore } from "../../../store"; -import type { IRInsightExplainResult, IRInsightExplainSubspace } from "../../../workers/insight/r-insight.worker"; -import { RInsightService } from '../../../services/r-insight'; -import type { IFunctionalDep, PagLink } from '../config'; -import type { useDataViews } from '../hooks/dataViews'; +import type { IFieldMeta, IRow } from "../../../../interfaces"; +import { useGlobalStore } from "../../../../store"; +import type { IRInsightExplainResult, IRInsightExplainSubspace } from "../../../../workers/insight/r-insight.worker"; +import { RInsightService } from '../../../../services/r-insight'; import DiffChart from "./diffChart"; import ExplainChart from "./explainChart"; import VisText, { IVisTextProps } from './visText'; export interface IRInsightViewProps { - data: IRow[]; + data: readonly IRow[]; result: IRInsightExplainResult; mainField: IFieldMeta; - mainFieldAggregation: "sum" | "mean" | "count" | null; entryDimension: IFieldMeta | null; mode: "full" | "other" | "two-group"; indices: [number[], number[]]; subspaces: [IRInsightExplainSubspace, IRInsightExplainSubspace]; - context: ReturnType; - functionalDependencies: IFunctionalDep[]; - edges: PagLink[]; aggr: "sum" | "mean" | "count" | null; serviceMode: "worker" | "server"; } @@ -137,16 +131,15 @@ const ExploreQueue = styled.div` `; const RInsightView: FC = ({ - data, result, mainField, mainFieldAggregation, entryDimension, - mode, indices, subspaces, context, functionalDependencies, edges, - aggr, serviceMode, + data, result, mainField, entryDimension, + mode, indices, subspaces, serviceMode, }) => { const { dataSourceStore, causalStore } = useGlobalStore(); const { fieldMetas } = dataSourceStore; - const { selectedFields } = causalStore; + const { fields, sample } = causalStore.dataset; + const { mergedPag, functionalDependencies } = causalStore.model; const [normalize, setNormalize] = useState(true); const [cursor, setCursor] = useState(0); - const { sample } = context; const [localIrResult, setLocalIrResult] = useState<{ addedMeasure: string; @@ -178,25 +171,27 @@ const RInsightView: FC = ({ [mainField.fid, measure] ) ); - RInsightService({ - data: sample, - fields: selectedFields, - causalModel: { - funcDeps: functionalDependencies, - edges, - }, - groups: { - current, - other, - }, - view: { - dimensions: [...fieldsInSight], - measures: [measure].map(fid => ({ - fid: fid, - op: aggr, - })), - }, - }, serviceMode).then(resolve); + sample.getAll().then(data => { + RInsightService({ + data, + fields, + causalModel: { + funcDeps: functionalDependencies, + edges: mergedPag, + }, + groups: { + current, + other, + }, + view: { + dimensions: [...fieldsInSight], + measures: [measure].map(fid => ({ + fid: fid, + op: null, + })), + }, + }, serviceMode).then(resolve); + }); }); pendingRef.current = p; p.then(res => { @@ -334,8 +329,8 @@ const RInsightView: FC = ({ = ({ title="对比分布" data={data} subspaces={indices} - mainField={mainField} - mainFieldAggregation={mainFieldAggregation} + mainField={tar} + mainFieldAggregation={null} dimension={dim} mode={mode} /> diff --git a/packages/rath-client/src/pages/causal/explainer/diffChart.tsx b/packages/rath-client/src/pages/causal/exploration/explainer/diffChart.tsx similarity index 97% rename from packages/rath-client/src/pages/causal/explainer/diffChart.tsx rename to packages/rath-client/src/pages/causal/exploration/explainer/diffChart.tsx index 96fb16dd..c8ca631c 100644 --- a/packages/rath-client/src/pages/causal/explainer/diffChart.tsx +++ b/packages/rath-client/src/pages/causal/exploration/explainer/diffChart.tsx @@ -3,14 +3,14 @@ import type { View } from 'vega'; import intl from 'react-intl-universal'; import { observer } from 'mobx-react-lite'; import embed from 'vega-embed'; -import { EDITOR_URL } from '../../../constants'; -import type { IFieldMeta, IRow } from '../../../interfaces'; -import { getVegaTimeFormatRules } from '../../../utils'; +import { EDITOR_URL } from '../../../../constants'; +import type { IFieldMeta, IRow } from '../../../../interfaces'; +import { getVegaTimeFormatRules } from '../../../../utils'; interface DiffChartProps { title?: string; - data: IRow[]; + data: readonly IRow[]; subspaces: [number[], number[]]; mainField: IFieldMeta; mainFieldAggregation: null | 'sum' | 'mean' | 'count'; diff --git a/packages/rath-client/src/pages/causal/explainer/explainChart.tsx b/packages/rath-client/src/pages/causal/exploration/explainer/explainChart.tsx similarity index 97% rename from packages/rath-client/src/pages/causal/explainer/explainChart.tsx rename to packages/rath-client/src/pages/causal/exploration/explainer/explainChart.tsx index 242f33ba..92bfd158 100644 --- a/packages/rath-client/src/pages/causal/explainer/explainChart.tsx +++ b/packages/rath-client/src/pages/causal/exploration/explainer/explainChart.tsx @@ -4,14 +4,14 @@ import intl from 'react-intl-universal'; import { observer } from 'mobx-react-lite'; import embed from 'vega-embed'; import { Subject, throttleTime } from 'rxjs'; -import { EDITOR_URL } from '../../../constants'; -import type { IFieldMeta, IRow, IFilter } from '../../../interfaces'; -import { getRange, getVegaTimeFormatRules } from '../../../utils'; +import { EDITOR_URL } from '../../../../constants'; +import type { IFieldMeta, IRow, IFilter } from '../../../../interfaces'; +import { getRange, getVegaTimeFormatRules } from '../../../../utils'; import { SelectedFlag } from './RExplainer'; interface ExplainChartProps { title?: string; - data: IRow[]; + data: readonly IRow[]; mainField: IFieldMeta; mainFieldAggregation: null | 'sum' | 'mean' | 'count'; indexKey: IFieldMeta | null; diff --git a/packages/rath-client/src/pages/causal/explainer/visText.tsx b/packages/rath-client/src/pages/causal/exploration/explainer/visText.tsx similarity index 98% rename from packages/rath-client/src/pages/causal/explainer/visText.tsx rename to packages/rath-client/src/pages/causal/exploration/explainer/visText.tsx index cfc61c17..02eeedbe 100644 --- a/packages/rath-client/src/pages/causal/explainer/visText.tsx +++ b/packages/rath-client/src/pages/causal/exploration/explainer/visText.tsx @@ -2,7 +2,7 @@ import { Icon } from "@fluentui/react"; import { observer } from "mobx-react-lite"; import { createElement, forwardRef, ForwardRefExoticComponent, Fragment, PropsWithoutRef, RefAttributes, useMemo } from "react"; import styled, { StyledComponentProps } from "styled-components"; -import type { IFieldMeta } from "../../../interfaces"; +import type { IFieldMeta } from "../../../../interfaces"; type AllowedDOMType = 'div' | 'p' | 'pre' | 'span' | 'output'; diff --git a/packages/rath-client/src/pages/causal/manualAnalyzer.tsx b/packages/rath-client/src/pages/causal/exploration/index.tsx similarity index 58% rename from packages/rath-client/src/pages/causal/manualAnalyzer.tsx rename to packages/rath-client/src/pages/causal/exploration/index.tsx index 450f1d40..697c9343 100644 --- a/packages/rath-client/src/pages/causal/manualAnalyzer.tsx +++ b/packages/rath-client/src/pages/causal/exploration/index.tsx @@ -1,20 +1,20 @@ import { ActionButton, Pivot, PivotItem, Stack } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; -import { forwardRef, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'; +import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from 'react'; import { GraphicWalker } from '@kanaries/graphic-walker'; import type { IPattern } from '@kanaries/loa'; import styled from 'styled-components'; import type { Specification } from 'visual-insights'; -import type { IFieldMeta } from '../../interfaces'; -import { useGlobalStore } from '../../store'; -import SemiEmbed from '../semiAutomation/semiEmbed'; +import type { IFieldMeta } from '../../../interfaces'; +import { useGlobalStore } from '../../../store'; +import SemiEmbed from '../../semiAutomation/semiEmbed'; +import { PAG_NODE } from '../config'; +import { ExplorationKey, ExplorationOptions, useCausalViewContext } from '../../../store/causalStore/viewStore'; import CrossFilter from './crossFilter'; -import type { useInteractFieldGroups } from './hooks/interactFieldGroup'; -import type { useDataViews } from './hooks/dataViews'; -import RExplainer from './explainer/RExplainer'; -import type { IFunctionalDep, PagLink } from './config'; import PredictPanel from './predictPanel'; -import type { ExplorerProps } from './explorer'; +import RExplainer from './explainer/RExplainer'; +import AutoVis from './autoVis'; +// import CausalBlame from './causalBlame'; const Container = styled.div` @@ -33,50 +33,41 @@ const Container = styled.div` } `; -export interface ManualAnalyzerProps { - context: ReturnType; - interactFieldGroups: ReturnType; - functionalDependencies: IFunctionalDep[]; - edges: PagLink[]; - +export interface Subtree { + node: IFieldMeta; + neighbors: { + field: IFieldMeta; + rootType: PAG_NODE; + neighborType: PAG_NODE; + }[]; } -const CustomAnalysisModes = [ - { key: 'crossFilter', text: '因果验证' }, - { key: 'explainer', text: '可解释探索' }, - { key: 'graphicWalker', text: '可视化自助分析' }, - { key: 'predict', text: '模型预测' }, -] as const; - -type CustomAnalysisMode = typeof CustomAnalysisModes[number]['key']; - -const ManualAnalyzer = forwardRef<{ onSubtreeSelected?: ExplorerProps['onNodeSelected'] }, ManualAnalyzerProps>(function ManualAnalyzer ( - { context, interactFieldGroups, functionalDependencies, edges }, ref -) { - const { dataSourceStore, causalStore, langStore } = useGlobalStore(); +const Exploration = forwardRef<{ + onSubtreeSelected?: (subtree: Subtree | null) => void; +}, {}>(function ManualAnalyzer (_, ref) { + const { dataSourceStore, langStore, causalStore } = useGlobalStore(); const { fieldMetas } = dataSourceStore; - const { fieldGroup, setFieldGroup, clearFieldGroup } = interactFieldGroups; const [showSemiClue, setShowSemiClue] = useState(false); const [clueView, setClueView] = useState(null); - const [customAnalysisMode, setCustomAnalysisMode] = useState('crossFilter'); - const { selectedFields } = causalStore; + const { fields, visSample, filters } = causalStore.dataset; - const { vizSampleData, filters } = context; + const viewContext = useCausalViewContext(); + const { selectedFieldGroup = [] } = viewContext ?? {}; useEffect(() => { - if (fieldGroup.length > 0) { + if (selectedFieldGroup.length > 0) { setClueView({ - fields: [...fieldGroup], - filters: [...filters], + fields: selectedFieldGroup.slice(0), + filters: filters.slice(0), imp: 0, }); } else { setClueView(null); } - }, [fieldGroup, filters]); + }, [selectedFieldGroup, filters]); const initialSpec = useMemo(() => { - const [discreteChannel, concreteChannel] = fieldGroup.reduce<[IFieldMeta[], IFieldMeta[]]>( + const [discreteChannel, concreteChannel] = selectedFieldGroup.reduce<[IFieldMeta[], IFieldMeta[]]>( ([discrete, concrete], f, i) => { if (i === 0 || f.semanticType === 'quantitative' || f.semanticType === 'temporal') { concrete.push(f); @@ -87,7 +78,7 @@ const ManualAnalyzer = forwardRef<{ onSubtreeSelected?: ExplorerProps['onNodeSel }, [[], []] ); - return fieldGroup.length + return selectedFieldGroup.length ? { position: concreteChannel.map((f) => f.fid), color: discreteChannel[0] ? [discreteChannel[0].fid] : [], @@ -113,39 +104,51 @@ const ManualAnalyzer = forwardRef<{ onSubtreeSelected?: ExplorerProps['onNodeSel // position: ['gw_count_fid'], // facets: fieldGroup.map(f => f.fid), // }; - }, [fieldGroup]); + }, [selectedFieldGroup]); const predictPanelRef = useRef<{ updateInput?: (input: { features: Readonly[]; targets: Readonly[] }) => void }>({}); useImperativeHandle(ref, () => ({ - onSubtreeSelected: (node, simpleCause) => { - if (customAnalysisMode === 'predict' && node && simpleCause.length > 0) { - const features = simpleCause.map(cause => cause.field); + onSubtreeSelected: (subtree) => { + if (viewContext?.explorationKey === 'predict' && subtree && subtree.neighbors.length > 0) { + const features = subtree.neighbors.filter(neighbor => { + return !( + [PAG_NODE.BLANK, PAG_NODE.CIRCLE].includes(neighbor.rootType) && neighbor.neighborType === PAG_NODE.ARROW + ); + }).map(cause => cause.field); predictPanelRef.current.updateInput?.({ features, - targets: [node], + targets: [subtree.node], }); } }, })); - return ( + const clearFieldGroup = useCallback(() => { + viewContext?.clearSelected(); + }, [viewContext]); + + const removeSelectedField = useCallback((fid: string) => { + viewContext?.toggleNodeSelected(fid); + }, [viewContext]); + + return viewContext && ( { - item && setCustomAnalysisMode(item.props.itemKey as CustomAnalysisMode); + item && viewContext.setExplorationKey(item.props.itemKey as ExplorationKey); }} > - {CustomAnalysisModes.map(mode => ( + {ExplorationOptions.map(mode => ( ))} - {new Array('crossFilter', 'graphicWalker').includes(customAnalysisMode) && ( + {[ExplorationKey.CROSS_FILTER, ExplorationKey.GRAPHIC_WALKER].includes(viewContext.explorationKey) && ( f.fid) : []} /> )} - {new Array('crossFilter', 'explainer').includes(customAnalysisMode) && ( + {[ExplorationKey.AUTO_VIS, ExplorationKey.CROSS_FILTER, ExplorationKey.CAUSAL_INSIGHT].includes(viewContext.explorationKey) && ( )}
{{ - predict: ( - - ), - explainer: vizSampleData.length > 0 && fieldGroup.length > 0 && ( - + // [ExplorationKey.CAUSAL_BLAME]: ( + // + // ), + [ExplorationKey.AUTO_VIS]: ( + ), - crossFilter: vizSampleData.length > 0 && fieldGroup.length > 0 && ( + [ExplorationKey.CROSS_FILTER]: visSample.length > 0 && selectedFieldGroup.length > 0 && ( { - const field = selectedFields.find((f) => f.fid === fid); + const field = fields.find((f) => f.fid === fid); if (field) { setClueView({ fields: [field], @@ -192,15 +190,16 @@ const ManualAnalyzer = forwardRef<{ onSubtreeSelected?: ExplorerProps['onNodeSel setShowSemiClue(true); } }} - onVizDelete={(fid) => { - setFieldGroup((list) => list.filter((f) => f.fid !== fid)); - }} + onVizDelete={removeSelectedField} /> ), - graphicWalker: ( + [ExplorationKey.CAUSAL_INSIGHT]: visSample.length > 0 && ( + + ), + [ExplorationKey.GRAPHIC_WALKER]: ( /* 小心这里的内存占用 */ ), - }[customAnalysisMode]} + [ExplorationKey.PREDICT]: ( + + ), + }[viewContext.explorationKey]}
); }); -export default observer(ManualAnalyzer); +export default observer(Exploration); diff --git a/packages/rath-client/src/pages/causal/exploration/predictPanel/configPanel.tsx b/packages/rath-client/src/pages/causal/exploration/predictPanel/configPanel.tsx new file mode 100644 index 00000000..110ca4a9 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/predictPanel/configPanel.tsx @@ -0,0 +1,159 @@ +import { Checkbox, DetailsList, Dropdown, IColumn, Label, SelectionMode } from "@fluentui/react"; +import produce from "immer"; +import { observer } from "mobx-react-lite"; +import { FC, useMemo } from "react"; +import styled from "styled-components"; +import type { IFieldMeta } from "../../../../interfaces"; +import { useGlobalStore } from "../../../../store"; +import { PredictAlgorithm, PredictAlgorithms } from "../../predict"; + + +const TableContainer = styled.div` + flex-grow: 0; + flex-shrink: 0; + overflow: auto; +`; + +const Row = styled.div<{ selected: 'attribution' | 'target' | false }>` + > div { + background-color: ${({ selected }) => ( + selected === 'attribution' ? 'rgba(194,132,2,0.2)' : selected === 'target' ? 'rgba(66,121,242,0.2)' : undefined + )}; + filter: ${({ selected }) => selected ? 'unset' : 'opacity(0.8)'}; + cursor: pointer; + :hover { + filter: unset; + } + } +`; + +const ConfigPanel: FC<{ + algo: PredictAlgorithm; + setAlgo: (algo: PredictAlgorithm) => void; + running: boolean; + predictInput: { + features: IFieldMeta[]; + targets: IFieldMeta[]; + }; + setPredictInput: (predictInput: { + features: IFieldMeta[]; + targets: IFieldMeta[]; + }) => void; +}> = ({ algo, setAlgo, running, predictInput, setPredictInput }) => { + const { causalStore } = useGlobalStore(); + const { fields } = causalStore; + + const fieldsTableCols = useMemo(() => { + return [ + { + key: 'selectedAsFeature', + name: `特征 (${predictInput.features.length} / ${fields.length})`, + onRender: (item) => { + const field = item as IFieldMeta; + const checked = predictInput.features.some(f => f.fid === field.fid); + return ( + { + if (running) { + return; + } + setPredictInput(produce(predictInput, draft => { + draft.features = draft.features.filter(f => f.fid !== field.fid); + draft.targets = draft.targets.filter(f => f.fid !== field.fid); + if (ok) { + draft.features.push(field); + } + })); + }} + /> + ); + }, + isResizable: false, + minWidth: 90, + maxWidth: 90, + }, + { + key: 'selectedAsTarget', + name: `目标 (${predictInput.targets.length} / ${fields.length})`, + onRender: (item) => { + const field = item as IFieldMeta; + const checked = predictInput.targets.some(f => f.fid === field.fid); + return ( + { + if (running) { + return; + } + setPredictInput(produce(predictInput, draft => { + draft.features = draft.features.filter(f => f.fid !== field.fid); + draft.targets = draft.targets.filter(f => f.fid !== field.fid); + if (ok) { + draft.targets.push(field); + } + })); + }} + /> + ); + }, + isResizable: false, + minWidth: 90, + maxWidth: 90, + }, + { + key: 'name', + name: '因素', + onRender: (item) => { + const field = item as IFieldMeta; + return ( + + {field.name || field.fid} + + ); + }, + minWidth: 120, + }, + ]; + }, [fields, predictInput, running, setPredictInput]); + + return ( + <> + ({ key: algo.key, text: algo.text }))} + selectedKey={algo} + onChange={(_, option) => { + const item = PredictAlgorithms.find(which => which.key === option?.key); + if (item) { + setAlgo(item.key); + } + }} + style={{ width: 'max-content' }} + /> + + + { + const field = props?.item as IFieldMeta; + const checkedAsAttr = predictInput.features.some(f => f.fid === field.fid); + const checkedAsTar = predictInput.targets.some(f => f.fid === field.fid); + return ( + + {defaultRender?.(props)} + + ); + }} + /> + + + ); +}; + + +export default observer(ConfigPanel); diff --git a/packages/rath-client/src/pages/causal/exploration/predictPanel/index.tsx b/packages/rath-client/src/pages/causal/exploration/predictPanel/index.tsx new file mode 100644 index 00000000..b0e5e160 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/predictPanel/index.tsx @@ -0,0 +1,169 @@ +import { DefaultButton, Icon, Spinner } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { nanoid } from "nanoid"; +import { forwardRef, useCallback, useEffect, useImperativeHandle, useMemo, useRef, useState } from "react"; +import styled from "styled-components"; +import type { IFieldMeta } from "../../../../interfaces"; +import { useGlobalStore } from "../../../../store"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { execPredict, IPredictProps, PredictAlgorithm, TrainTestSplitFlag } from "../../predict"; +import TabList from "./tablist"; + + +const Container = styled.div` + flex-grow: 1; + flex-shrink: 1; + display: flex; + flex-direction: column; + overflow: hidden; + > .content { + flex-grow: 1; + flex-shrink: 1; + display: flex; + flex-direction: column; + padding: 0.5em; + overflow: auto; + > * { + flex-grow: 0; + flex-shrink: 0; + } + } +`; + +const ModeOptions = [ + { key: 'classification', text: '分类' }, + { key: 'regression', text: '回归' }, +] as const; + +const TRAIN_RATE = 0.2; + +const PredictPanel = forwardRef<{ + updateInput?: (input: { features: IFieldMeta[]; targets: IFieldMeta[] }) => void; +}, {}>(function PredictPanel (_, ref) { + const { causalStore, dataSourceStore } = useGlobalStore(); + const { fields } = causalStore; + const { cleanedData, fieldMetas } = dataSourceStore; + const viewContext = useCausalViewContext(); + + const [predictInput, setPredictInput] = useState<{ features: IFieldMeta[]; targets: IFieldMeta[] }>({ + features: [], + targets: [], + }); + const [algo, setAlgo] = useState('decisionTree'); + const [mode, setMode] = useState('classification'); + + useImperativeHandle(ref, () => ({ + updateInput: input => setPredictInput(input), + })); + + useEffect(() => { + setPredictInput(before => { + if (before.features.length || before.targets.length) { + return { + features: fields.filter(f => before.features.some(feat => feat.fid === f.fid)), + targets: fields.filter(f => before.targets.some(tar => tar.fid === f.fid)), + }; + } + return { + features: fields.slice(1).map(f => f), + targets: fields.slice(0, 1), + }; + }); + }, [fields]); + + const [running, setRunning] = useState(false); + + const canExecute = predictInput.features.length > 0 && predictInput.targets.length > 0; + const pendingRef = useRef>(); + + useEffect(() => { + pendingRef.current = undefined; + setRunning(false); + }, [predictInput]); + + const dataSourceRef = useRef(cleanedData); + dataSourceRef.current = cleanedData; + const allFieldsRef = useRef(fieldMetas); + allFieldsRef.current = fieldMetas; + + const [tab, setTab] = useState<'config' | 'result'>('config'); + + const trainTestSplitIndices = useMemo(() => { + const indices = cleanedData.map((_, i) => i); + const trainSetIndices = new Map(); + const trainSetTargetSize = Math.floor(cleanedData.length * TRAIN_RATE); + while (trainSetIndices.size < trainSetTargetSize && indices.length) { + const [index] = indices.splice(Math.floor(indices.length * Math.random()), 1); + trainSetIndices.set(index, 1); + } + return cleanedData.map((_, i) => trainSetIndices.has(i) ? TrainTestSplitFlag.train : TrainTestSplitFlag.test); + }, [cleanedData]); + + const trainTestSplitIndicesRef = useRef(trainTestSplitIndices); + trainTestSplitIndicesRef.current = trainTestSplitIndices; + + const handleClickExec = useCallback(() => { + const startTime = Date.now(); + setRunning(true); + const task = execPredict({ + dataSource: dataSourceRef.current, + fields: allFieldsRef.current, + model: { + algorithm: algo, + features: predictInput.features.map(f => f.fid), + targets: predictInput.targets.map(f => f.fid), + }, + trainTestSplitIndices: trainTestSplitIndicesRef.current, + mode, + }); + pendingRef.current = task; + task.then(res => { + if (task === pendingRef.current && res) { + const completeTime = Date.now(); + viewContext?.pushPredictResult({ + id: nanoid(8), + algo, + startTime, + completeTime, + data: res, + }); + setTab('result'); + } + }).finally(() => { + pendingRef.current = undefined; + setRunning(false); + }); + }, [predictInput, algo, mode, viewContext]); + + useEffect(() => { + viewContext?.clearPredictResults(); + }, [mode, viewContext]); + + return ( + + running ? : } + style={{ width: 'max-content', flexGrow: 0, flexShrink: 0, marginLeft: '0.6em' }} + split + menuProps={{ + items: ModeOptions.map(opt => opt), + onItemClick: (_e, item) => { + if (item) { + setMode(item.key as typeof mode); + } + }, + }} + > + {`${ModeOptions.find(m => m.key === mode)?.text}预测`} + + + + ); +}); + + +export default observer(PredictPanel); diff --git a/packages/rath-client/src/pages/causal/exploration/predictPanel/resultPanel.tsx b/packages/rath-client/src/pages/causal/exploration/predictPanel/resultPanel.tsx new file mode 100644 index 00000000..129cb239 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/predictPanel/resultPanel.tsx @@ -0,0 +1,200 @@ +import { Checkbox, DefaultButton, DetailsList, IColumn, Icon, SelectionMode } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { FC, useEffect, useMemo, useRef, useState } from "react"; +import styled from "styled-components"; +import { useGlobalStore } from "../../../../store"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { PredictAlgorithms } from "../../predict"; + + +const TableContainer = styled.div` + flex-grow: 0; + flex-shrink: 0; + overflow: auto; +`; + +const ResultPanel: FC = () => { + const { dataSourceStore } = useGlobalStore(); + const { cleanedData, fieldMetas } = dataSourceStore; + const viewContext = useCausalViewContext(); + const { predictCache = [] } = viewContext ?? {}; + + const dataSourceRef = useRef(cleanedData); + dataSourceRef.current = cleanedData; + const allFieldsRef = useRef(fieldMetas); + allFieldsRef.current = fieldMetas; + + const sortedResults = useMemo(() => { + return predictCache.slice(0).sort((a, b) => b.completeTime - a.completeTime); + }, [predictCache]); + + const [comparison, setComparison] = useState(null); + + useEffect(() => { + setComparison(group => { + if (!group) { + return null; + } + const next = group.filter(id => predictCache.some(rec => rec.id === id)); + if (next.length === 0) { + return null; + } + return next as [string] | [string, string]; + }); + }, [predictCache]); + + const resultTableCols = useMemo(() => { + return [ + { + key: 'selected', + name: '对比', + onRender: (item) => { + const record = item as typeof sortedResults[number]; + const selected = (comparison ?? [] as string[]).includes(record.id); + return ( + { + if (checked) { + setComparison(group => { + if (group === null) { + return [record.id]; + } + return [group[0], record.id]; + }); + } else if (selected) { + setComparison(group => { + if (group?.some(id => id === record.id)) { + return group.length === 1 ? null : group.filter(id => id !== record.id) as [string]; + } + return null; + }); + } + }} + /> + ); + }, + isResizable: false, + minWidth: 30, + maxWidth: 30, + }, + { + key: 'index', + name: '运行次数', + minWidth: 70, + maxWidth: 70, + isResizable: false, + onRender(_, index) { + return <>{index !== undefined ? (sortedResults.length - index) : ''}; + }, + }, + { + key: 'algo', + name: '预测模型', + minWidth: 70, + onRender(item) { + const record = item as typeof sortedResults[number]; + return <>{PredictAlgorithms.find(which => which.key === record.algo)?.text} + }, + }, + { + key: 'accuracy', + name: '准确率', + minWidth: 150, + onRender(item, index) { + if (!item || index === undefined) { + return <>; + } + const record = item as typeof sortedResults[number]; + const previous = sortedResults[index + 1]; + const comparison: 'better' | 'worse' | 'same' | null = previous ? ( + previous.data.accuracy === record.data.accuracy ? 'same' + : record.data.accuracy > previous.data.accuracy ? 'better' : 'worse' + ) : null; + return ( + + {comparison && ( + + )} + {record.data.accuracy} + + ); + }, + }, + ]; + }, [sortedResults, comparison]); + + const diff = useMemo(() => { + if (comparison?.length === 2) { + const before = sortedResults.find(res => res.id === comparison[0]); + const after = sortedResults.find(res => res.id === comparison[1]); + if (before && after) { + const temp: unknown[] = []; + for (let i = 0; i < before.data.result.length; i += 1) { + const row = dataSourceRef.current[before.data.result[i][0]]; + const prev = before.data.result[i][1]; + const next = after.data.result[i][1]; + if (next === 1 && prev === 0) { + temp.push(Object.fromEntries(Object.entries(row).map(([k, v]) => [ + allFieldsRef.current.find(f => f.fid === k)?.name ?? k, + v, + ]))); + } + } + return temp; + } + } + }, [sortedResults, comparison]); + + useEffect(() => { + if (diff) { + // TODO: 在界面上实现一个 diff view,代替这个 console + // eslint-disable-next-line no-console + console.table(diff); + } + }, [diff]); + + return ( + <> + viewContext?.clearPredictResults()} + style={{ width: 'max-content' }} + > + 清空记录 + + + + + + ); +}; + + +export default observer(ResultPanel); diff --git a/packages/rath-client/src/pages/causal/exploration/predictPanel/tablist.tsx b/packages/rath-client/src/pages/causal/exploration/predictPanel/tablist.tsx new file mode 100644 index 00000000..3d73ad86 --- /dev/null +++ b/packages/rath-client/src/pages/causal/exploration/predictPanel/tablist.tsx @@ -0,0 +1,109 @@ +import { Pivot, PivotItem } from "@fluentui/react"; +import { observer } from "mobx-react-lite"; +import { FC, useEffect, useMemo, useRef, useState } from "react"; +import type { IFieldMeta } from "../../../../interfaces"; +import { useGlobalStore } from "../../../../store"; +import { useCausalViewContext } from "../../../../store/causalStore/viewStore"; +import { PredictAlgorithm } from "../../predict"; +import ConfigPanel from "./configPanel"; +import ResultPanel from "./resultPanel"; + + +const TabList: FC<{ + algo: PredictAlgorithm; + setAlgo: (algo: PredictAlgorithm) => void; + tab: 'config' | 'result'; + setTab: (tab: 'config' | 'result') => void; + running: boolean; + predictInput: { + features: IFieldMeta[]; + targets: IFieldMeta[]; + }; + setPredictInput: (predictInput: { + features: IFieldMeta[]; + targets: IFieldMeta[]; + }) => void; +}> = ({ algo, setAlgo, tab, setTab, running, predictInput, setPredictInput }) => { + const { dataSourceStore } = useGlobalStore(); + const { cleanedData, fieldMetas } = dataSourceStore; + const viewContext = useCausalViewContext(); + const { predictCache = [] } = viewContext ?? {}; + + const dataSourceRef = useRef(cleanedData); + dataSourceRef.current = cleanedData; + const allFieldsRef = useRef(fieldMetas); + allFieldsRef.current = fieldMetas; + + const sortedResults = useMemo(() => { + return predictCache.slice(0).sort((a, b) => b.completeTime - a.completeTime); + }, [predictCache]); + + const [comparison, setComparison] = useState(null); + + useEffect(() => { + setComparison(group => { + if (!group) { + return null; + } + const next = group.filter(id => predictCache.some(rec => rec.id === id)); + if (next.length === 0) { + return null; + } + return next as [string] | [string, string]; + }); + }, [predictCache]); + + const diff = useMemo(() => { + if (comparison?.length === 2) { + const before = sortedResults.find(res => res.id === comparison[0]); + const after = sortedResults.find(res => res.id === comparison[1]); + if (before && after) { + const temp: unknown[] = []; + for (let i = 0; i < before.data.result.length; i += 1) { + const row = dataSourceRef.current[before.data.result[i][0]]; + const prev = before.data.result[i][1]; + const next = after.data.result[i][1]; + if (next === 1 && prev === 0) { + temp.push(Object.fromEntries(Object.entries(row).map(([k, v]) => [ + allFieldsRef.current.find(f => f.fid === k)?.name ?? k, + v, + ]))); + } + } + return temp; + } + } + }, [sortedResults, comparison]); + + useEffect(() => { + if (diff) { + // TODO: 在界面上实现一个 diff view,代替这个 console + // eslint-disable-next-line no-console + console.table(diff); + } + }, [diff]); + + return ( + <> + { + item && setTab(item.props.itemKey as typeof tab); + }} + style={{ marginTop: '0.5em' }} + > + + + +
+ {{ + config: , + result: + }[tab]} +
+ + ); +}; + + +export default observer(TabList); diff --git a/packages/rath-client/src/pages/causal/explorer/DAGView.tsx b/packages/rath-client/src/pages/causal/explorer/DAGView.tsx deleted file mode 100644 index db606bdd..00000000 --- a/packages/rath-client/src/pages/causal/explorer/DAGView.tsx +++ /dev/null @@ -1,262 +0,0 @@ -import { forwardRef, MouseEvent, useCallback, useEffect, useMemo, useState } from "react"; -import { line as d3Line, curveCatmullRom } from 'd3-shape'; -import { - dagStratify, - sugiyama, - decrossOpt, - coordGreedy, - coordQuad, - decrossTwoLayer, - layeringLongestPath, - layeringSimplex, - twolayerAgg, - twolayerGreedy, -} from 'd3-dag'; -import styled, { StyledComponentProps } from "styled-components"; -import type { IFieldMeta } from "../../../interfaces"; -import { Flow, mergeFlows } from "./flowAnalyzer"; -import type { DiagramGraphData } from "."; - - -const line = d3Line<{ x: number; y: number }>().curve(curveCatmullRom).x(d => d.x).y(d => d.y); - -const Container = styled.div` - position: relative; - > svg { - position: absolute; - left: 0; - top: 0; - width: 100%; - height: 100%; - & *:not(circle) { - pointer-events: none; - } - & circle { - cursor: pointer; - pointer-events: all; - } - & text { - user-select: none; - } - & line { - opacity: 0.67; - } - } -`; - -export type DAGViewProps = Omit[]; - value: Readonly; - cutThreshold: number; - mode: 'explore' | 'edit'; - focus: number | null; - onClickNode?: (node: DiagramGraphData['nodes'][number]) => void; -}, never>, 'onChange' | 'ref'>; - -const MIN_RADIUS = 0.2; -const MAX_RADIUS = 0.38; -const MIN_STROKE_WIDTH = 0.04; -const MAX_STROKE_WIDTH = 0.09; - -const DAGView = forwardRef(( - { fields, value, onClickNode, focus, cutThreshold, mode, ...props }, - ref -) => { - const [data] = useMemo(() => { - let totalScore = 0; - const nodeCauseWeights = value.nodes.map(() => 0); - const nodeEffectWeights = value.nodes.map(() => 0); - value.links.forEach(link => { - nodeCauseWeights[link.effectId] += link.score; - nodeEffectWeights[link.causeId] += link.score; - totalScore += link.score * 2; - }); - return [{ - nodes: value.nodes.map((node, i) => ({ - id: node.nodeId, - index: i, - causeSum: nodeCauseWeights[i], - effectSum: nodeEffectWeights[i], - score: (nodeCauseWeights[i] + nodeEffectWeights[i]) / totalScore, - diff: (nodeCauseWeights[i] - nodeEffectWeights[i]) / totalScore, - })), - links: value.links.map(link => ({ - source: link.causeId, - target: link.effectId, - value: link.score / nodeCauseWeights[link.effectId], - })), - }, totalScore]; - }, [value]); - - const normalizedLinks = useMemo(() => { - const max = value.links.reduce((m, d) => m > d.score ? m : d.score, 0); - return data.links.map((link, i) => ({ - ...link, - score: value.links[i].score / (max || 1), - })); - }, [value.links, data.links]); - - const normalizedNodes = useMemo(() => { - const max = data.nodes.reduce((m, d) => m > d.score ? m : d.score, 0); - return data.nodes.map(node => ({ - ...node, - score: node.score / (max || 1), - })); - }, [data.nodes]); - - const flows = useMemo(() => { - const flows: Flow[] = []; - for (const node of data.nodes) { - flows.push({ - id: `${node.id}`, - parentIds: [], - }); - } - for (const link of normalizedLinks) { - if (link.score > 0.001 && link.score >= cutThreshold) { - mergeFlows(flows, { - id: `${link.target}`, - parentIds: [`${link.source}`], - }); - } - } - return flows; - }, [data.nodes, normalizedLinks, cutThreshold]); - - const tooManyLinks = data.links.length >= 16; - - const layout = useMemo(() => { - return tooManyLinks - ? sugiyama().layering( - layeringSimplex() - ).decross( - decrossTwoLayer().order(twolayerGreedy().base(twolayerAgg())) - ).coord( - coordGreedy() - ) - : sugiyama().layering( - layeringLongestPath() - ).decross( - decrossOpt() - ).coord( - coordQuad() - ); - }, [tooManyLinks]); - - const dag = useMemo(() => { - const dag = dagStratify()(flows); - return { - // @ts-ignore - size: layout(dag), - steps: dag.size(), - nodes: dag.descendants(), - links: dag.links(), - }; - }, [flows, layout]); - - const nodes = useMemo(() => { - return dag.nodes.map(node => { - const me = normalizedNodes[parseInt(node.data.id)]; - if (me) { - return me; - } - return null; - }); - }, [dag.nodes, normalizedNodes]); - - const links = useMemo(() => { - return dag.links.map(link => { - const source = dag.nodes.find(node => node === link.source)?.data.id; - const target = dag.nodes.find(node => node === link.target)?.data.id; - if (source && target) { - const me = data.links.find(which => `${which.source}` === source && `${which.target}` === target); - return me ?? null; - } - return null; - }); - }, [dag, data]); - - const draggingSource = useMemo(() => mode === 'edit' && typeof focus === 'number' ? focus : null, [mode, focus]); - const [cursorPos, setCursorPos] = useState<[number, number]>([NaN, NaN]); - - useEffect(() => setCursorPos([NaN, NaN]), [draggingSource]); - - const handleMouseMove = useCallback((e: MouseEvent) => { - if (draggingSource === null) { - return; - } - const target = e.target as HTMLDivElement; - const { left, top } = target.getBoundingClientRect(); - const x = e.clientX - left; - const y = e.clientY - top; - setCursorPos([x, y]); - }, [draggingSource]); - - // const focusedNode = dag.nodes.find(node => {parseInt(node.data.id, 10)}); - - return ( - - - {dag.links.map((link, i) => ( - - ))} - {/* {Number.isFinite(cursorPos[0]) && Number.isFinite(cursorPos[1]) && focus !== null && ( - - )} */} - {dag.nodes.map((node, i) => { - const idx = parseInt(node.data.id, 10); - const f = fields[idx]; - return ( - - { - e.stopPropagation(); - onClickNode?.(value.nodes[idx]); - }} - /> - - {f.name ?? f.fid} - - - ); - })} - - - ); -}); - - -export default DAGView; diff --git a/packages/rath-client/src/pages/causal/explorer/explorerMainView.tsx b/packages/rath-client/src/pages/causal/explorer/explorerMainView.tsx index 18819042..b3d56f42 100644 --- a/packages/rath-client/src/pages/causal/explorer/explorerMainView.tsx +++ b/packages/rath-client/src/pages/causal/explorer/explorerMainView.tsx @@ -1,13 +1,9 @@ import { forwardRef } from "react"; import styled, { StyledComponentProps } from "styled-components"; import type { IFieldMeta } from "../../../interfaces"; -import useErrorBoundary from "../../../hooks/use-error-boundary"; -import type { ModifiableBgKnowledge } from "../config"; -// import DAGView from "./DAGView"; -// import ForceView from "./forceView"; +import type { EdgeAssert } from "../../../store/causalStore/modelStore"; +import type { Subtree } from "../exploration"; import GraphView from "./graphView"; -import type { GraphNodeAttributes } from "./graph-utils"; -import type { DiagramGraphData } from "."; const Container = styled.div` @@ -22,30 +18,21 @@ const Container = styled.div` `; export type ExplorerMainViewProps = Omit; /** @default 0 */ cutThreshold?: number; limit: number; mode: 'explore' | 'edit'; onClickNode?: (fid: string | null) => void; - toggleFlowAnalyzer?: () => void; - focus: number | null; - onLinkTogether: (srcFid: string, tarFid: string, type: ModifiableBgKnowledge['type']) => void; + onLinkTogether: (srcFid: string, tarFid: string, type: EdgeAssert) => void; onRevertLink: (srcFid: string, tarFid: string) => void; onRemoveLink: (srcFid: string, tarFid: string) => void; - preconditions: ModifiableBgKnowledge[]; forceRelayoutRef: React.MutableRefObject<() => void>; - autoLayout: boolean; - renderNode?: (node: Readonly) => GraphNodeAttributes | undefined, allowZoom: boolean; handleLasso?: (fields: IFieldMeta[]) => void; + handleSubTreeSelected?: (subtree: Subtree | null) => void; }, never>, 'onChange' | 'ref'>; const ExplorerMainView = forwardRef(({ - selectedSubtree, - value, - focus, cutThreshold = 0, mode, limit, @@ -53,79 +40,32 @@ const ExplorerMainView = forwardRef(({ onLinkTogether, onRevertLink, onRemoveLink, - preconditions, forceRelayoutRef, - autoLayout, - renderNode, - toggleFlowAnalyzer, allowZoom, handleLasso, + handleSubTreeSelected, ...props }, ref) => { - const ErrorBoundary = useErrorBoundary((err, info) => { - // console.error(err ?? info); - return
; - // return

{info}

; - }, [value, cutThreshold, preconditions]); - return ( - {/* */} - - {})} - onLinkTogether={onLinkTogether} - onRevertLink={onRevertLink} - onRemoveLink={onRemoveLink} - focus={focus} - autoLayout={autoLayout} - renderNode={renderNode} - allowZoom={allowZoom} - handleLasso={handleLasso} - style={{ - flexGrow: 1, - flexShrink: 1, - width: '100%', - }} - /> - {/* */} - + /> ); }); diff --git a/packages/rath-client/src/pages/causal/explorer/flowAnalyzer.tsx b/packages/rath-client/src/pages/causal/explorer/flowAnalyzer.tsx deleted file mode 100644 index 38c02fd7..00000000 --- a/packages/rath-client/src/pages/causal/explorer/flowAnalyzer.tsx +++ /dev/null @@ -1,685 +0,0 @@ -/* eslint no-fallthrough: ["error", { "allowEmptyCase": true }] */ -import { FC, useCallback, useEffect, useMemo, useRef, useState } from "react"; -import { - dagStratify, - sugiyama, - decrossOpt, - layeringLongestPath, - layeringSimplex, - decrossTwoLayer, - twolayerGreedy, - twolayerAgg, - coordGreedy, - coordQuad, -} from 'd3-dag'; -import { line as d3Line/*, curveMonotoneY*/, curveCatmullRom } from 'd3-shape'; -import { Dropdown } from "@fluentui/react"; -import { observer } from "mobx-react-lite"; -import styled from "styled-components"; -import type { IFieldMeta, IRow } from "../../../interfaces"; -import { deepcopy } from "../../../utils"; -import ColDist, { IBrushSignalStore } from "../crossFilter/colDist"; -import { useGlobalStore } from "../../../store"; -import type { DiagramGraphData } from "."; - - -export type NodeWithScore = { - field: Readonly; - score: number; -}; - -export interface FlowAnalyzerProps { - display: boolean; - dataSource: IRow[]; - data: DiagramGraphData; - index: number; - cutThreshold: number; - onUpdate: ( - node: Readonly | null, - simpleCause: readonly Readonly[], - simpleEffect: readonly Readonly[], - composedCause: readonly Readonly[], - composedEffect: readonly Readonly[], - ) => void; - onClickNode?: (fid: string | null) => void; - limit: number; -} - -export type Flow = { - id: string; - parentIds: string[]; -}; - -export const mergeFlows = (flows: Flow[], entering: Flow): void => { - const item = flows.find(f => f.id === entering.id); - if (item) { - item.parentIds.push(...entering.parentIds); - } else { - flows.push(entering); - } -}; - -const FLOW_HEIGHT = 500; - -const SVGGroup = styled.div` - flex-grow: 0; - flex-shrink: 0; - width: 100%; - min-height: 50px; - border: 1px solid #e3e2e2; - border-top: 0; - display: flex; - flex-direction: column; - align-items: center; - > svg { - width: 100%; - height: 50vh; - overflow: hidden; - & text { - user-select: none; - } - & *:not(circle) { - pointer-events: none; - } - & circle { - pointer-events: all; - cursor: pointer; - } - } - > div:not(.tools):not(.msg) { - flex-grow: 0; - flex-shrink: 0; - display: flex; - position: relative; - width: 100%; - height: ${FLOW_HEIGHT}px; - > * { - position: absolute; - left: 0; - top: 0; - width: 100%; - height: 100%; - } - > div { - > div { - position: absolute; - box-sizing: content-box; - transform: translate(-50%, -50%); - background-color: #463782; - border: 2px solid #463782; - } - } - } - > div.msg { - padding: 0.8em 2em 1.6em; - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - position: static; - color: #a87c40; - } -`; - -const line = d3Line<{ x: number; y: number }>().curve(curveCatmullRom).x(d => d.x).y(d => d.y); - -const FlowAnalyzer: FC = ({ display, dataSource, data, index, cutThreshold, onUpdate, onClickNode, limit }) => { - const { causalStore } = useGlobalStore(); - const { selectedFields: fields } = causalStore; - const field = useMemo(() => fields[index], [fields, index]); - - const normalizedLinks = useMemo(() => { - const nodeCauseWeights = data.nodes.map(() => 0); - const nodeEffectWeights = data.nodes.map(() => 0); - data.links.forEach(link => { - nodeCauseWeights[link.effectId] += link.score; - nodeEffectWeights[link.causeId] += link.score; - }); - return data.links.map(link => ({ - causeId: link.causeId, - effectId: link.effectId, - score: link.score / nodeCauseWeights[link.effectId], - type: link.type, - })); - }, [data]); - - const linksInView = useMemo(() => { - return normalizedLinks.filter(link => link.score >= cutThreshold).sort( - (a, b) => b.score - a.score - ).slice(0, limit); - }, [normalizedLinks, cutThreshold, limit]); - - const getPathScore = useCallback((effectIdx: number) => { - const scores = new Map(); - const walk = (rootIdx: number, weight: number, flags = new Map()) => { - if (flags.has(rootIdx)) { - return; - } - flags.set(rootIdx, 1); - const paths = data.links.filter(link => link.effectId === rootIdx); - for (const path of paths) { - const nodeIdx = path.causeId; - const value = path.score * weight; - scores.set(nodeIdx, (scores.get(nodeIdx) ?? 0) + value); - walk(nodeIdx, value, flags); - } - }; - walk(effectIdx, 1); - return (causeIdx: number) => scores.get(causeIdx); - }, [data.links]); - - const flowsAsOrigin = useMemo(() => { - if (field) { - let links = linksInView.map(link => link); - const ready = [index]; - const flows: Flow[] = [{ - id: `${index}`, - parentIds: [], - }]; - while (ready.length) { - const source = ready.shift()!; - const nextLinks: typeof links = []; - for (const link of links) { - switch (link.type) { - case 'directed': - case 'weak directed': { - if (link.causeId === source) { - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [`${link.causeId}`], - }); - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [], - }); - ready.push(link.effectId); - } else { - nextLinks.push(link); - } - break; - } - case 'bidirected': - case 'undirected': { - if (link.causeId === source) { - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [`${link.causeId}`], - }); - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [], - }); - ready.push(link.effectId); - } else if (link.effectId === source) { - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [`${link.effectId}`], - }); - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [], - }); - ready.push(link.causeId); - } else { - nextLinks.push(link); - } - break; - } - default: { - break; - } - } - } - links = nextLinks; - } - return flows; - } - return []; - }, [linksInView, field, index]); - - const flowsAsDestination = useMemo(() => { - if (field) { - let links = linksInView.map(link => link); - const ready = [index]; - const flows: Flow[] = [{ - id: `${index}`, - parentIds: [], - }]; - while (ready.length) { - const source = ready.shift()!; - const nextLinks: typeof links = []; - for (const link of links) { - switch (link.type) { - case 'directed': - case 'weak directed': { - if (link.effectId === source) { - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [`${link.causeId}`], - }); - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [], - }); - ready.push(link.causeId); - } else { - nextLinks.push(link); - } - break; - } - case 'bidirected': - case 'undirected': { - if (link.effectId === source) { - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [`${link.causeId}`], - }); - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [], - }); - ready.push(link.causeId); - } else if (link.causeId === source) { - mergeFlows(flows, { - id: `${link.causeId}`, - parentIds: [`${link.effectId}`], - }); - mergeFlows(flows, { - id: `${link.effectId}`, - parentIds: [], - }); - ready.push(link.effectId); - } else { - nextLinks.push(link); - } - break; - } - default: { - break; - } - } - } - links = nextLinks; - } - return flows; - } - return []; - }, [linksInView, field, index]); - - useEffect(() => { - if (field) { - const getCauseScore = getPathScore(index); - const [simpleCause, composedCause] = flowsAsDestination.reduce<[NodeWithScore[], NodeWithScore[]]>(([simple, composed], flow) => { - const effectId = parseInt(flow.id, 10); - const target = fields[effectId]; - for (const causeId of flow.parentIds.map(id => parseInt(id, 10))) { - const source = fields[causeId]; - const score = getCauseScore(causeId); - if (score) { - if (target.fid === field.fid) { - simple.push({ - field: source, - score, - }); - } else if (!composed.some(f => f.field.fid === source.fid)) { - composed.push({ - field: source, - score, - }); - } - } - } - return [simple, composed]; - }, [[], []]); - const [simpleEffect, composedEffect] = flowsAsOrigin.reduce<[NodeWithScore[], NodeWithScore[]]>(([simple, composed], flow) => { - const effectId = parseInt(flow.id, 10); - const target = fields[effectId]; - for (const causeId of flow.parentIds.map(id => parseInt(id, 10))) { - const source = fields[causeId]; - const score = getPathScore(effectId)(index); - if (score) { - if (source.fid === field.fid) { - simple.push({ - field: target, - score, - }); - } else if (!composed.some(f => f.field.fid === target.fid)) { - composed.push({ - field: target, - score, - }); - } - } - } - return [simple, composed]; - }, [[], []]); - onUpdate(field, simpleCause, simpleEffect, composedCause, composedEffect); - } else { - onUpdate(null, [], [], [], []); - } - }, [onUpdate, fields, field, flowsAsDestination, flowsAsOrigin, getPathScore, index]); - - const combinedFlows = useMemo(() => { - const flows = deepcopy(flowsAsDestination) as typeof flowsAsDestination; - for (const flow of flowsAsOrigin) { - mergeFlows(flows, flow); - } - return flows; - }, [flowsAsDestination, flowsAsOrigin]); - - const tooManyLinks = data.links.length >= 16; - - const layout = useMemo(() => { - return tooManyLinks - ? sugiyama().layering( - layeringSimplex() - ).decross( - decrossTwoLayer().order(twolayerGreedy().base(twolayerAgg())) - ).coord( - coordGreedy() - ) - : sugiyama().layering( - layeringLongestPath() - ).decross( - decrossOpt() - ).coord( - coordQuad() - ); - }, [tooManyLinks]); - - const [destinationTree, destTreeMsg] = useMemo(() => { - if (flowsAsDestination.length === 0) { - return [null, null]; - } - try { - const dag = dagStratify()(flowsAsDestination); - return [{ - // @ts-ignore - size: layout(dag), - steps: dag.size(), - nodes: dag.descendants(), - links: dag.links(), - }, null]; - } catch (error) { - return [null, `${error}`]; - } - }, [flowsAsDestination, layout]); - - const [originTree, oriTreeMsg] = useMemo(() => { - if (flowsAsOrigin.length === 0) { - return [null, null]; - } - try { - const dag = dagStratify()(flowsAsOrigin); - return [{ - // @ts-ignore - size: layout(dag), - steps: dag.size(), - nodes: dag.descendants(), - links: dag.links(), - }, null]; - } catch (error) { - return [null, `${error}`]; - } - }, [flowsAsOrigin, layout]); - - const [combinedTree, cbnTreeMsg] = useMemo(() => { - if (combinedFlows.length === 0) { - return [null, null]; - } - try { - const dag = dagStratify()(combinedFlows); - return [{ - // @ts-ignore - size: layout(dag), - steps: dag.size(), - nodes: dag.descendants(), - links: dag.links(), - }, null]; - } catch (error) { - if (display) { - console.warn(error); - } - return [null, null]; - } - }, [combinedFlows, layout, display]); - - const [mode, setMode] = useState<'cause' | 'effect'>('effect'); - - const subtree = useMemo(() => mode === 'cause' ? destinationTree : originTree, [mode, destinationTree, originTree]); - const subtreeMsg = useMemo(() => mode === 'cause' ? destTreeMsg : oriTreeMsg, [mode, destTreeMsg, oriTreeMsg]); - - const [brush, setBrush] = useState([]); - const [brushIdx, setBrushIdx] = useState(-1); - - const ref = useRef(null); - - // 没写反,就是横过来 - const w = (subtree?.size.height ?? 0) + 1; - const h = (subtree?.size.width ?? 0) + 1; - - const [width, setWidth] = useState(0); - - const [fx, fy, fSize] = useMemo<[(x: number) => number, (y: number) => number, (size: number) => number]>(() => { - if (w / width >= h / FLOW_HEIGHT) { - const scale = width / w; - const yl = h * scale; - const yPad = (FLOW_HEIGHT - yl) / 2; - return [ - x => (x + 0.5) * scale, - y => yPad + (y + 0.5) * scale, - size => size * scale, - ]; - } else { - const scale = FLOW_HEIGHT / h; - const xl = w * scale; - const xPad = (width - xl) / 2; - return [ - x => xPad + (x + 0.5) * scale, - y => (y + 0.5) * scale, - size => size * scale, - ]; - } - }, [w, h, width]); - - useEffect(() => { - const { current: container } = ref; - if (subtree && container) { - const cb = () => { - const { width: w } = container.getBoundingClientRect(); - setWidth(w); - }; - const ro = new ResizeObserver(cb); - ro.observe(container); - return () => ro.disconnect(); - } - }, [subtree]); - - return display ? ( - e.stopPropagation()}> - {field ? [combinedTree/*, destinationTree, originTree*/].map((tree, i) => tree ? ( - - - - - - - {tree.links.map((link, i, { length }) => ( - ({ x: p.y + 0.5, y: p.x + 0.5 }))) ?? ''} - fill="none" - stroke="#441ce3" - strokeWidth={0.03} - markerEnd="url(#flow-arrow)" - opacity={0.25} - style={{ - filter: `hue-rotate(${180 * i / length}deg)`, - }} - /> - ))} - {tree.nodes.map((node, i) => { - const idx = parseInt(node.data.id, 10); - const f = fields[idx]; - return ( - - { - e.stopPropagation(); - if (index !== idx) { - onClickNode?.(fields[idx].fid); - } - }} - /> - - {f.name ?? f.fid} - - - ); - })} - - ) : ( -
-

{'选中结点的关联路径不是一张有向无环图。'}

- {/*

{'Cannot display corresponding subset because it is not a directed acyclic graph.'}

*/} -

{'尝试查看其他的结点、调大权重筛选或调小显示上限。'}

- {/*

{'Try to click on a different node, turn up the link filter above or turn down the display limit.'}

*/} - {cbnTreeMsg} -
- )) : null} - {field && ( -
- { - const key = option?.key as undefined | typeof mode; - if (key) { - setMode(key); - } - }} - options={[ - { key: 'cause', text: `${field.name ?? field.fid} 如何被其他因素影响` }, - { key: 'effect', text: `${field.name ?? field.fid} 如何影响其他因素` }, - ]} - // options={[ - // { key: 'cause', text: `How ${field.name ?? field.fid} is effected by other fields` }, - // { key: 'effect', text: `How ${field.name ?? field.fid} effects other fields` }, - // ]} - styles={{ - root: { - width: '26em', - } - }} - /> - {combinedTree && !subtree ? ( - //

Click on a node to explore.

-

点击一个结点以在有向图结构上探索。

- ) : null} -
- )} - {field ? ( - subtree ? ( -
- - - - - - - {subtree.links.map((link, i, { length }) => ( - ({ x: p.y + 0.5, y: p.x + 0.5 }))) ?? ''} - fill="none" - stroke="#441ce3" - strokeWidth={0.03} - markerEnd="url(#flow-arrow)" - opacity={0.25} - style={{ - filter: `hue-rotate(${180 * i / length}deg)`, - }} - /> - ))} - -
- {subtree.nodes.map((node, i) => { - const idx = parseInt(node.data.id, 10); - const f = fields[idx]; - return ( -
- { - if (!brush) { - return; - } - setBrush(brush); - setBrushIdx(i); - }} - width={fSize(0.8)} - height={fSize(0.9)} - brush={brushIdx === i ? null : brush} - /> - -
- ); - })} -
-
- ) : ( -
-

{'选中的组可能包含环结构。'}

- {/*

{'Cannot display the group because it is not a directed acyclic graph.'}

*/} -

{'尝试查看其他的结点、调整权重筛选、显示上限,或切换探索模式。'}

- {/*

{'Try to click on a different node, adjust the link filter or display limit, or change the exploration mode.'}

*/} - {subtreeMsg} -
- ) - ) : null} -
- ) : null; -}; - - -export default observer(FlowAnalyzer); diff --git a/packages/rath-client/src/pages/causal/explorer/graph-helper.ts b/packages/rath-client/src/pages/causal/explorer/graph-helper.ts index 8bcfde09..2673054d 100644 --- a/packages/rath-client/src/pages/causal/explorer/graph-helper.ts +++ b/packages/rath-client/src/pages/causal/explorer/graph-helper.ts @@ -1,35 +1,65 @@ -import { RefObject, useEffect, useRef, MutableRefObject } from "react"; -import G6, { Graph } from "@antv/g6"; +import { RefObject, useEffect, useRef, MutableRefObject, useMemo } from "react"; +import G6, { Graph, INode } from "@antv/g6"; +import { NodeSelectionMode, useCausalViewContext } from "../../../store/causalStore/viewStore"; +import type { Subtree } from "../exploration"; +import { PAG_NODE } from "../config"; import type { IFieldMeta } from "../../../interfaces"; import { GRAPH_HEIGHT, useGraphOptions, useRenderData } from "./graph-utils"; -export const useReactiveGraph = ( - containerRef: RefObject, - width: number, - graphRef: MutableRefObject, - options: ReturnType, - data: ReturnType, - mode: "explore" | "edit", - handleNodeClick: ((fid: string | null) => void) | undefined, - handleEdgeClick: ((edge: { srcFid: string, tarFid: string } | null) => void) | undefined, - fields: readonly IFieldMeta[], - updateSelectedRef: MutableRefObject<(idx: number) => void> | undefined, - forceRelayoutFlag: 0 | 1, - focus: number | null, - selectedSubtree: readonly string[], - allowZoom: boolean, -) => { +export interface IReactiveGraphProps { + containerRef: RefObject; + width: number; + graphRef: MutableRefObject; + options: ReturnType; + data: ReturnType; + mode: "explore" | "edit"; + handleNodeClick?: ((fid: string | null) => void) | undefined; + handleNodeDblClick?: ((fid: string | null) => void) | undefined; + handleEdgeClick?: ((edge: { srcFid: string, tarFid: string } | null) => void) | undefined; + fields: readonly IFieldMeta[]; + allowZoom: boolean; + handleSubtreeSelected?: (subtree: Subtree | null) => void | undefined; +} + +export interface IReactiveGraphHandler { + readonly refresh: () => void; +} + +export const useReactiveGraph = ({ + containerRef, + width, + graphRef, + options, + data, + mode, + handleNodeClick, + handleNodeDblClick, + handleEdgeClick, + fields, + allowZoom, + handleSubtreeSelected, +}: IReactiveGraphProps): IReactiveGraphHandler => { const cfgRef = useRef(options); cfgRef.current = options; const dataRef = useRef(data); dataRef.current = data; const handleNodeClickRef = useRef(handleNodeClick); handleNodeClickRef.current = handleNodeClick; + const handleNodeDblClickRef = useRef(handleNodeDblClick); + handleNodeDblClickRef.current = handleNodeDblClick; const fieldsRef = useRef(fields); fieldsRef.current = fields; const handleEdgeClickRef = useRef(handleEdgeClick); handleEdgeClickRef.current = handleEdgeClick; + const handleSubtreeSelectedRef = useRef(handleSubtreeSelected); + handleSubtreeSelectedRef.current = handleSubtreeSelected; + + const viewContext = useCausalViewContext(); + const { selectedFieldGroup = [], graphNodeSelectionMode = NodeSelectionMode.NONE } = viewContext ?? {}; + + const graphNodeSelectionModeRef = useRef(graphNodeSelectionMode); + graphNodeSelectionModeRef.current = graphNodeSelectionMode; useEffect(() => { const { current: container } = containerRef; @@ -46,38 +76,36 @@ export const useReactiveGraph = ( graph.render(); graph.on('node:click', (e: any) => { - const nodeId = e.item._cfg.id; - if (typeof nodeId === 'string') { - const idx = parseInt(nodeId, 10); - handleNodeClickRef.current?.(fieldsRef.current[idx].fid); + const fid = e.item._cfg.id; + if (typeof fid === 'string') { + handleNodeClickRef.current?.(fid); } else { handleNodeClickRef.current?.(null); } }); + graph.on('node:dblclick', (e: any) => { + const fid = e.item._cfg.id; + if (typeof fid === 'string') { + handleNodeDblClickRef.current?.(fid); + } else { + handleNodeDblClickRef.current?.(null); + } + }); + graph.on('edge:click', (e: any) => { const edge = e.item; if (edge) { - const src = (edge._cfg?.source as any)?._cfg.id; - const tar = (edge._cfg?.target as any)?._cfg.id; - if (src && tar) { - const srcF = fieldsRef.current[parseInt(src, 10)]; - const tarF = fieldsRef.current[parseInt(tar, 10)]; - handleEdgeClickRef.current?.({ srcFid: srcF.fid, tarFid: tarF.fid }); + const srcFid = (edge._cfg?.source as any)?._cfg.id as string | undefined; + const tarFid = (edge._cfg?.target as any)?._cfg.id as string | undefined; + if (srcFid && tarFid) { + handleEdgeClickRef.current?.({ srcFid, tarFid }); } else { handleEdgeClickRef.current?.(null); } } }); - if (updateSelectedRef) { - updateSelectedRef.current = idx => { - if (idx === -1) { - handleNodeClickRef.current?.(null); - } - }; - } - graphRef.current = graph; return () => { @@ -85,7 +113,7 @@ export const useReactiveGraph = ( container.innerHTML = ''; }; } - }, [containerRef, graphRef, updateSelectedRef]); + }, [containerRef, graphRef]); useEffect(() => { if (graphRef.current) { @@ -98,7 +126,7 @@ export const useReactiveGraph = ( // for rendering after each iteration tick: () => { graphRef.current?.refreshPositions(); - } + }, }); graphRef.current.render(); } @@ -107,27 +135,41 @@ export const useReactiveGraph = ( useEffect(() => { const { current: graph } = graphRef; if (graph) { - graph.data(dataRef.current); - graph.render(); + if (mode === 'explore') { + // It is found that under explore mode, + // it works strange that the edges are not correctly synchronized with changeData() method, + // while it's checked that the input data is always right. + // This unexpected behavior never occurs under edit mode. + // Fortunately we have data less frequently updated under explore mode, + // unlike what goes under edit mode, which behaviors well. + // Thus, this is a reasonable solution to completely reset the layout + // using read() method (is a combination of data() and render()). + // If a better solution which always perfectly prevents the unexpected behavior mentioned before, + // just remove this clause. + // @author kyusho antoineyang99@gmail.com + graph.read(data); + } else { + graph.changeData(data); + graph.refresh(); + } } - }, [forceRelayoutFlag, graphRef]); + }, [graphRef, data, mode]); useEffect(() => { const { current: graph } = graphRef; if (graph) { - graph.updateLayout(options); - graph.refresh(); + graph.data(dataRef.current); + graph.render(); } - }, [options, graphRef]); + }, [graphRef]); useEffect(() => { - const { current: container } = containerRef; const { current: graph } = graphRef; - if (container && graph) { - graph.changeData(data); - graph.refresh(); + if (graph) { + graph.updateLayout(options); + graph.render(); } - }, [data, graphRef, containerRef]); + }, [options, graphRef]); useEffect(() => { const { current: graph } = graphRef; @@ -139,65 +181,84 @@ export const useReactiveGraph = ( useEffect(() => { const { current: graph } = graphRef; if (graph) { - const focusedNode = graph.getNodes().find(node => { - const id = (() => { - try { - return parseInt(node._cfg?.id ?? '-1', 10); - } catch { - return -1; - } - })(); - return id === focus; + const focusedNodes = graph.getNodes().filter(node => { + const fid = node._cfg?.id as string | undefined; + return fid !== undefined && selectedFieldGroup.some(field => field.fid === fid); }); - const subtree = focusedNode ? graph.getNeighbors(focusedNode).map(node => { - const idx = (() => { - try { - return parseInt(node._cfg?.id ?? '-1', 10); - } catch { - return -1; + const subtreeNodes = focusedNodes.reduce((list, focusedNode) => { + for (const node of graph.getNeighbors(focusedNode)) { + if (focusedNodes.some(item => item === node) || list.some(item => item === node)) { + continue; } - })(); - return fieldsRef.current[idx]?.fid; - }) : []; + list.push(node); + } + return list; + }, []); + const subtreeFidArr = subtreeNodes.map(node => { + return node._cfg?.id as string | undefined; + }).filter(Boolean) as string[]; + const subtreeFields = subtreeFidArr.reduce((list, fid) => { + const f = fieldsRef.current.find(which => which.fid === fid); + if (f) { + return list.concat([f]); + } + return list; + }, []); + const subtreeRoot = ( + graphNodeSelectionModeRef.current === NodeSelectionMode.SINGLE && selectedFieldGroup.length === 1 + ) ? selectedFieldGroup[0] : null; + handleSubtreeSelectedRef.current?.(subtreeRoot ? { + node: subtreeRoot, + neighbors: subtreeFields.map(node => ({ + field: fieldsRef.current.find(f => f.fid === node.fid)!, + // FIXME: 查询这条边上的节点状态 + rootType: PAG_NODE.EMPTY, + neighborType: PAG_NODE.EMPTY, + })), + } : null); graph.getNodes().forEach(node => { - const isFocused = node === focusedNode; + const isFocused = focusedNodes.some(item => item === node); graph.setItemState(node, 'focused', isFocused); - const isInSubtree = focusedNode ? graph.getNeighbors(focusedNode).some(neighbor => neighbor === node) : false; + const isInSubtree = isFocused ? false : subtreeNodes.some(neighbor => neighbor === node); graph.setItemState(node, 'highlighted', isInSubtree); - graph.setItemState(node, 'faded', focus !== null && !isFocused && !isInSubtree); + graph.setItemState(node, 'faded', selectedFieldGroup.length !== 0 && !isFocused && !isInSubtree); + graph.updateItem(node, { + labelCfg: { + style: { + opacity: focusedNodes.length === 0 ? 1 : isFocused ? 1 : isInSubtree ? 0.5 : 0.2, + fontWeight: isFocused ? 600 : 400, + }, + }, + }); }); graph.getEdges().forEach(edge => { - const sourceIdx = (() => { - try { - return parseInt((edge._cfg?.source as any)?._cfg?.id ?? '-1', 10); - } catch { - return -1; - } - })(); - const targetIdx = (() => { - try { - return parseInt((edge._cfg?.target as any)?._cfg?.id ?? '-1', 10); - } catch { - return -1; - } - })(); - const isInSubtree = focus !== null && [ - fieldsRef.current[sourceIdx]?.fid, fieldsRef.current[targetIdx]?.fid - ].includes(fieldsRef.current[focus]?.fid) && [ - fieldsRef.current[sourceIdx]?.fid, fieldsRef.current[targetIdx]?.fid - ].every(fid => { - return [fieldsRef.current[focus]?.fid].concat(subtree).includes(fid); - }); + const sourceFid = (edge._cfg?.source as any)?._cfg?.id as string | undefined; + const targetFid = (edge._cfg?.target as any)?._cfg?.id as string | undefined; + const nodesSelected = [ + sourceFid, targetFid + ].filter(fid => typeof fid === 'string' && selectedFieldGroup.some(f => f.fid === fid)); + const nodesInSubtree = [ + sourceFid, targetFid + ].filter(fid => typeof fid === 'string' && subtreeFidArr.some(f => f === fid)); + const isInSubtree = nodesSelected.length === 2; + const isHalfInSubtree = nodesSelected.length === 1 && nodesInSubtree.length === 1; graph.updateItem(edge, { labelCfg: { style: { - opacity: isInSubtree ? 1 : 0, + opacity: isInSubtree ? 1 : isHalfInSubtree ? 0.6 : 0, }, }, }); graph.setItemState(edge, 'highlighted', isInSubtree); - graph.setItemState(edge, 'faded', focus !== null && !isInSubtree); + graph.setItemState(edge, 'semiHighlighted', isHalfInSubtree); + graph.setItemState(edge, 'faded', selectedFieldGroup.length !== 0 && !isInSubtree && !isHalfInSubtree); }); } - }, [graphRef, focus, selectedSubtree]); + }, [graphRef, selectedFieldGroup, data]); + + return useMemo(() => ({ + refresh() { + graphRef.current?.read(dataRef.current); + }, + }), [graphRef]); }; diff --git a/packages/rath-client/src/pages/causal/explorer/graph-utils.ts b/packages/rath-client/src/pages/causal/explorer/graph-utils.ts index 088ca351..6a431a04 100644 --- a/packages/rath-client/src/pages/causal/explorer/graph-utils.ts +++ b/packages/rath-client/src/pages/causal/explorer/graph-utils.ts @@ -1,8 +1,7 @@ import { useMemo, useRef, CSSProperties } from "react"; import G6, { Graph, GraphData, GraphOptions } from "@antv/g6"; -import type { ModifiableBgKnowledge } from "../config"; +import { PagLink, PAG_NODE } from "../config"; import type { IFieldMeta } from "../../../interfaces"; -import type { CausalLink } from "."; export const GRAPH_HEIGHT = 500; @@ -30,49 +29,10 @@ export type GraphNodeAttributes< }>; const arrows = { - undirected: { - start: '', - end: '', - }, - directed: { - start: '', - end: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - }, - bidirected: { - start: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - end: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - }, - 'weak directed': { - start: 'M 8.4,0 a 5.6,5.6 0 1,0 11.2,0 a 5.6,5.6 0 1,0 -11.2,0 Z', - end: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - }, - 'weak undirected': { - start: 'M 8.4,0 a 5.6,5.6 0 1,0 11.2,0 a 5.6,5.6 0 1,0 -11.2,0 Z', - end: 'M 8.4,0 a 5.6,5.6 0 1,0 11.2,0 a 5.6,5.6 0 1,0 -11.2,0 Z', - }, -} as const; - -const bkArrows = { - "must-link": { - fill: '#0027b4', - start: '', - end: '', - }, - "must-not-link": { - fill: '#c50f1f', - start: '', - end: '', - }, - "directed-must-link": { - fill: '#0027b4', - start: '', - end: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - }, - "directed-must-not-link": { - fill: '#c50f1f', - start: '', - end: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', - }, + [PAG_NODE.EMPTY]: '', + [PAG_NODE.BLANK]: '', + [PAG_NODE.ARROW]: 'M 8.4,0 L 19.6,5.6 L 19.6,-5.6 Z', + [PAG_NODE.CIRCLE]: 'M 8.4,0 a 5.6,5.6 0 1,0 11.2,0 a 5.6,5.6 0 1,0 -11.2,0 Z', } as const; export const ForbiddenEdgeType = 'forbidden-edge'; @@ -106,80 +66,108 @@ G6.registerEdge( 'line', ); -export const useRenderData = ( - data: { nodes: { id: number }[]; links: { source: number; target: number; type: CausalLink['type']; score?: number }[] }, - mode: "explore" | "edit", - preconditions: readonly ModifiableBgKnowledge[], - fields: readonly Readonly[], +export interface IRenderDataProps { + mode: "explore" | "edit"; + fields: readonly Readonly[]; + PAG: readonly PagLink[]; + /** @default undefined */ + weights?: Map> | undefined; + /** @default 0 */ + cutThreshold?: number; + /** @default Infinity */ + limit?: number; renderNode?: (node: Readonly) => GraphNodeAttributes | undefined, -) => { +} + +export const useRenderData = ({ + mode, + fields, + PAG, + weights = undefined, + cutThreshold = 0, + limit = Infinity, + renderNode, +}: IRenderDataProps) => { return useMemo(() => ({ - nodes: data.nodes.map((node, i) => { + nodes: fields.map((f) => { return { - id: `${node.id}`, - description: fields[i].name ?? fields[i].fid, - ...renderNode?.(fields[i]), + id: `${f.fid}`, + description: f.name ?? f.fid, + ...renderNode?.(f), }; }), - edges: mode === 'explore' ? data.links.map((link, i) => { + edges: mode === 'explore' ? PAG.filter(link => { + const w = weights?.get(link.src)?.get(link.tar); + return w === undefined || w >= cutThreshold; + }).slice(0, limit).map((link, i) => { + const w = weights?.get(link.src)?.get(link.tar); + return { id: `link_${i}`, - source: `${link.source}`, - target: `${link.target}`, + source: link.src, + target: link.tar, style: { startArrow: { fill: '#F6BD16', - path: arrows[link.type].start, + path: arrows[link.src_type], }, endArrow: { fill: '#F6BD16', - path: arrows[link.type].end, + path: arrows[link.tar_type], }, - lineWidth: typeof link.score === 'number' ? 1 + link.score * 2 : undefined, + lineWidth: typeof w === 'number' ? 1 + w * 2 : undefined, }, - label: typeof link.score === 'number' ? `${link.score.toPrecision(2)}` : undefined, + label: typeof w === 'number' ? `${(w * 100).toFixed(2).replace(/\.?0+$/, '')}%` : undefined, labelCfg: { style: { opacity: 0, }, }, }; - }) : preconditions.map((bk, i) => ({ - id: `bk_${i}`, - source: `${fields.findIndex(f => f.fid === bk.src)}`, - target: `${fields.findIndex(f => f.fid === bk.tar)}`, - style: { - lineWidth: 2, - lineAppendWidth: 5, - stroke: bkArrows[bk.type].fill, - startArrow: { - fill: bkArrows[bk.type].fill, - stroke: bkArrows[bk.type].fill, - path: bkArrows[bk.type].start, - }, - endArrow: { - fill: bkArrows[bk.type].fill, - stroke: bkArrows[bk.type].fill, - path: bkArrows[bk.type].end, - }, - }, - edgeStateStyles: { - active: { + }) : PAG.map((assr, i) => { + const isForbiddenType = [assr.src_type, assr.tar_type].includes(PAG_NODE.EMPTY); + const color = isForbiddenType ? '#c50f1f' : '#0027b4'; + + return { + id: `bk_${i}`, + source: assr.src, + target: assr.tar, + style: { lineWidth: 2, + lineAppendWidth: 5, + stroke: color, + startArrow: { + fill: color, + stroke: color, + path: arrows[assr.src_type], + }, + endArrow: { + fill: color, + stroke: color, + path: arrows[assr.tar_type], + }, }, - }, - type: bk.type === 'must-not-link' || bk.type === 'directed-must-not-link' ? ForbiddenEdgeType : undefined, - })), - }), [data, mode, preconditions, fields, renderNode]); + type: isForbiddenType ? ForbiddenEdgeType : undefined, + }; + }), + }), [fields, mode, PAG, limit, renderNode, weights, cutThreshold]); }; -export const useGraphOptions = ( - width: number, - fields: readonly Readonly[], - handleLasso: ((fields: IFieldMeta[]) => void) | undefined, - handleLink: (srcFid: string, tarFid: string) => void, - graphRef: { current: Graph | undefined }, -) => { +export interface IGraphOptions { + width: number; + fields: readonly Readonly[]; + handleLasso?: ((fields: IFieldMeta[]) => void) | undefined; + handleLink?: (srcFid: string, tarFid: string) => void | undefined; + graphRef: { current: Graph | undefined }; +} + +export const useGraphOptions = ({ + width, + fields, + handleLasso, + handleLink, + graphRef, +}: IGraphOptions) => { const widthRef = useRef(width); widthRef.current = width; const fieldsRef = useRef(fields); @@ -190,16 +178,16 @@ export const useGraphOptions = ( handleLinkRef.current = handleLink; return useMemo>(() => { - let createEdgeFrom = -1; + let createEdgeFrom: string | null = null; const exploreMode = ['drag-canvas', 'drag-node', { type: 'lasso-select', trigger: 'shift', onSelect(nodes: any, edges: any) { const selected: IFieldMeta[] = []; for (const node of nodes) { - const idx = node._cfg?.id; - if (idx) { - const f = fieldsRef.current[parseInt(idx, 10)]; + const fid = node._cfg?.id as string | undefined; + if (fid) { + const f = fieldsRef.current.find(which => which.fid === fid); if (f) { selected.push(f); } @@ -216,25 +204,23 @@ export const useGraphOptions = ( type: 'create-edge', trigger: 'drag', shouldBegin(e: any) { - const source = e.item?._cfg?.id; - if (source) { - createEdgeFrom = parseInt(source, 10); + const sourceFid = e.item?._cfg?.id as string | undefined; + if (sourceFid) { + createEdgeFrom = sourceFid; } return true; }, shouldEnd(e: any) { - if (createEdgeFrom === -1) { + if (createEdgeFrom === null) { return false; } - const target = e.item?._cfg?.id; - if (target) { - const origin = fieldsRef.current[createEdgeFrom]; - const destination = fieldsRef.current[parseInt(target, 10)]; - if (origin.fid !== destination.fid) { - handleLinkRef.current(origin.fid, destination.fid); + const targetFid = e.item?._cfg?.id as string | undefined; + if (targetFid) { + if (createEdgeFrom !== targetFid) { + handleLinkRef.current?.(createEdgeFrom, targetFid); } } - createEdgeFrom = -1; + createEdgeFrom = null; return false; }, }]; @@ -270,13 +256,14 @@ export const useGraphOptions = ( focused: { lineWidth: 1.5, opacity: 1, + shadowColor: '#F6BD16', + shadowBlur: 8, }, highlighted: { - lineWidth: 1.25, - opacity: 1, + opacity: 0.4, }, faded: { - opacity: 0.4, + opacity: 0.2, }, }, defaultEdge: { @@ -288,8 +275,11 @@ export const useGraphOptions = ( highlighted: { opacity: 1, }, + semiHighlighted: { + opacity: 0.8, + }, faded: { - opacity: 0.2, + opacity: 0.12, }, }, }; diff --git a/packages/rath-client/src/pages/causal/explorer/graphView.tsx b/packages/rath-client/src/pages/causal/explorer/graphView.tsx index 874db83b..2c166ddf 100644 --- a/packages/rath-client/src/pages/causal/explorer/graphView.tsx +++ b/packages/rath-client/src/pages/causal/explorer/graphView.tsx @@ -4,13 +4,18 @@ import { Graph } from "@antv/g6"; import { observer } from "mobx-react-lite"; import { ActionButton, Dropdown } from "@fluentui/react"; import type { IFieldMeta } from "../../../interfaces"; -import type { ModifiableBgKnowledge } from "../config"; +import type { Subtree } from "../exploration"; +import { EdgeAssert, NodeAssert } from "../../../store/causalStore/modelStore"; +import { useCausalViewContext } from "../../../store/causalStore/viewStore"; import { useGlobalStore } from "../../../store"; -import { GraphNodeAttributes, useGraphOptions, useRenderData } from "./graph-utils"; +import { useGraphOptions, useRenderData } from "./graph-utils"; import { useReactiveGraph } from "./graph-helper"; -import type { DiagramGraphData } from "."; +const sNormalize = (matrix: readonly (readonly number[])[]): number[][] => { + return matrix.map(vec => vec.map(n => 2 / (1 + Math.exp(-n)) - 1)); +}; + const Container = styled.div` overflow: hidden; position: relative; @@ -23,142 +28,100 @@ const Container = styled.div` left: 1em; top: 1em; padding: 0.8em; + & * { + user-select: none; + } } `; export type GraphViewProps = Omit; cutThreshold: number; limit: number; mode: 'explore' | 'edit'; - focus: number | null; onClickNode?: (fid: string | null) => void; - toggleFlowAnalyzer: () => void; - onLinkTogether: (srcFid: string, tarFid: string, type: ModifiableBgKnowledge['type']) => void; + onLinkTogether: (srcFid: string, tarFid: string, type: EdgeAssert) => void; onRevertLink: (srcFid: string, tarFid: string) => void; onRemoveLink: (srcFid: string, tarFid: string) => void; - preconditions: ModifiableBgKnowledge[]; forceRelayoutRef: React.MutableRefObject<() => void>; - autoLayout: boolean; - renderNode?: (node: Readonly) => GraphNodeAttributes | undefined; handleLasso?: (fields: IFieldMeta[]) => void; + handleSubtreeSelected?: (subtree: Subtree | null) => void; allowZoom: boolean; }, never>, 'onChange' | 'ref'>; -/** 调试用的,不需要的时候干掉 */ -type ExportableGraphData = { - nodes: { id: string }[]; - edges: { source: string; target: string }[]; -}; -/** 调试用的,不需要的时候干掉 */ -const ExportGraphButton: React.FC<{ data: DiagramGraphData; fields: readonly Readonly[] }> = ({ data, fields }) => { - const value = useMemo(() => { - const graph: ExportableGraphData = { - nodes: fields.map(f => ({ id: f.fid })), - edges: [], - }; - for (const link of data.links) { - const source = fields[link.causeId].fid; - const target = fields[link.effectId].fid; - graph.edges.push({ source, target }); - if (link.type === 'bidirected' || link.type === 'undirected') { - graph.edges.push({ source: target, target: source }); - } - } - return new File([JSON.stringify(graph, undefined, 2)], `test - ${new Date().toLocaleString()}.json`); - }, [data, fields]); - const dataUrlRef = useRef(''); - useEffect(() => { - dataUrlRef.current = URL.createObjectURL(value); - return () => { - URL.revokeObjectURL(dataUrlRef.current); - }; - }, [value]); - const handleExport = useCallback(() => { - const a = document.createElement('a'); - a.href = dataUrlRef.current; - a.download = value.name; - a.click(); - a.remove(); - }, [value.name]); - return ( - - 导出为图 - - ); -}; - const GraphView = forwardRef(({ - selectedSubtree, - value, onClickNode, - focus, cutThreshold, limit, mode, onLinkTogether, onRevertLink, onRemoveLink, - preconditions, forceRelayoutRef, - autoLayout, - renderNode, - toggleFlowAnalyzer, allowZoom, handleLasso, + handleSubtreeSelected, ...props }, ref) => { const { causalStore } = useGlobalStore(); - const { selectedFields: fields } = causalStore; - - const [data] = useMemo(() => { - let totalScore = 0; - const nodeCauseWeights = value.nodes.map(() => 0); - const nodeEffectWeights = value.nodes.map(() => 0); - value.links.forEach(link => { - nodeCauseWeights[link.effectId] += link.score; - nodeEffectWeights[link.causeId] += link.score; - totalScore += link.score * 2; - }); - return [{ - nodes: value.nodes.map((node, i) => ({ - id: node.nodeId, - index: i, - causeSum: nodeCauseWeights[i], - effectSum: nodeEffectWeights[i], - score: (nodeCauseWeights[i] + nodeEffectWeights[i]) / totalScore, - diff: (nodeCauseWeights[i] - nodeEffectWeights[i]) / totalScore, - })), - links: value.links.map(link => ({ - source: link.causeId, - target: link.effectId, - score: link.score / nodeCauseWeights[link.effectId], - type: link.type, - })).filter(link => link.score >= cutThreshold).sort((a, b) => b.score - a.score).slice(0, limit), - }, totalScore]; - }, [value, cutThreshold, limit]); + const { fields } = causalStore; + const { causality, assertionsAsPag, mutualMatrix } = causalStore.model; + const { onRenderNode, localWeights } = useCausalViewContext() ?? {}; const containerRef = useRef(null); const [width, setWidth] = useState(0); - const updateSelectedRef = useRef<(idx: number) => void>(() => {}); - - const [createEdgeMode, setCreateEdgeMode] = useState('directed-must-link'); + const [createEdgeMode, setCreateEdgeMode] = useState(EdgeAssert.TO_EFFECT); const handleLinkTogether = useCallback((srcFid: string, tarFid: string) => { onLinkTogether(srcFid, tarFid, createEdgeMode); }, [createEdgeMode, onLinkTogether]); + const W = useMemo> | undefined>(() => { + if (!causality || !mutualMatrix || mutualMatrix.length !== fields.length) { + return undefined; + } + + const scoreMatrix = sNormalize(mutualMatrix); + + const map = new Map>(); + + for (const link of causality) { + const srcIdx = fields.findIndex(f => f.fid === link.src); + const tarIdx = fields.findIndex(f => f.fid === link.tar); + if (srcIdx !== -1 && tarIdx !== -1) { + const w = Math.abs(scoreMatrix[srcIdx][tarIdx]); + if (!map.has(link.src)) { + map.set(link.src, new Map()); + } + map.get(link.src)!.set(link.tar, w); + } + } + + return map; + }, [causality, fields, mutualMatrix]); + const graphRef = useRef(); - const renderData = useRenderData(data, mode, preconditions, fields, renderNode); - const cfg = useGraphOptions(width, fields, handleLasso, handleLinkTogether, graphRef); + const renderData = useRenderData({ + mode, + fields, + PAG: mode === 'edit' ? assertionsAsPag : causality ?? [], + weights: mode === 'edit' ? undefined : localWeights ?? W, + cutThreshold, + limit, + renderNode: onRenderNode, + }); + const cfg = useGraphOptions({ + width, + fields, + handleLasso, + handleLink: handleLinkTogether, + graphRef, + }); const cfgRef = useRef(cfg); cfgRef.current = cfg; - const [forceRelayoutFlag, setForceRelayoutFlag] = useState<0 | 1>(0); - const [clickEdgeMode, setClickEdgeMode] = useState<'delete' | 'forbid'>('forbid'); + const [dblClickNodeMode, setDblClickNodeMode] = useState(NodeAssert.FORBID_AS_CAUSE); const handleEdgeClick = useCallback((edge: { srcFid: string; tarFid: string; } | null) => { if (edge) { @@ -176,46 +139,41 @@ const GraphView = forwardRef(({ } }, [onRevertLink, onRemoveLink, clickEdgeMode]); - useReactiveGraph( + const handleNodeDblClick = useCallback((fid: string | null) => { + if (mode === 'edit' && fid) { + const overload = causalStore.model.assertions.find(decl => 'fid' in decl && decl.fid === fid); + if (overload?.assertion === dblClickNodeMode) { + // remove it + causalStore.model.removeNodeAssertion(fid); + } else { + causalStore.model.addNodeAssertion(fid, dblClickNodeMode); + } + } + }, [mode, dblClickNodeMode, causalStore]); + + const graph = useReactiveGraph({ containerRef, width, graphRef, - cfg, - renderData, + options: cfg, + data: renderData, mode, - onClickNode, + handleNodeClick: onClickNode, handleEdgeClick, + handleNodeDblClick, fields, - updateSelectedRef, - forceRelayoutFlag, - focus, - selectedSubtree, allowZoom, - ); - - useEffect(() => { - const { current: graph } = graphRef; - if (graph) { - graph.stopAnimate(); - graph.destroyLayout(); - if (autoLayout) { - graph.updateLayout(cfgRef.current.layout); - } - } - }, [autoLayout]); + handleSubtreeSelected, + }); useEffect(() => { - forceRelayoutRef.current = () => setForceRelayoutFlag(flag => flag === 0 ? 1 : 0); + forceRelayoutRef.current = () => { + graph.refresh(); + }; return () => { forceRelayoutRef.current = () => {}; }; - }, [forceRelayoutRef]); - - useEffect(() => { - if (focus !== null) { - updateSelectedRef.current(focus); - } - }, [focus]); + }, [forceRelayoutRef, graph]); useEffect(() => { const { current: container } = containerRef; @@ -236,24 +194,21 @@ const GraphView = forwardRef(({ { - if (e.shiftKey) { - toggleFlowAnalyzer(); - } - e.stopPropagation(); - }} >
{mode === 'edit' && (
+ causalStore.model.clearAssertions()}> + 清空所有 + { if (!option) { @@ -318,9 +273,43 @@ const GraphView = forwardRef(({ }, }} /> + { + if (!option) { + return; + } + const assrType = option.key as typeof dblClickNodeMode; + setDblClickNodeMode(assrType); + }} + styles={{ + title: { + fontSize: '0.8rem', + lineHeight: '1.8em', + height: '1.8em', + padding: '0 2.8em 0 0.8em', + border: 'none', + borderBottom: '1px solid #8888', + }, + caretDownWrapper: { + fontSize: '0.8rem', + lineHeight: '1.8em', + height: '1.8em', + }, + caretDown: { + fontSize: '0.8rem', + lineHeight: '1.8em', + height: '1.8em', + }, + }} + />
)} - ); }); diff --git a/packages/rath-client/src/pages/causal/explorer/index.tsx b/packages/rath-client/src/pages/causal/explorer/index.tsx index a06c66d6..d13dcc28 100644 --- a/packages/rath-client/src/pages/causal/explorer/index.tsx +++ b/packages/rath-client/src/pages/causal/explorer/index.tsx @@ -1,20 +1,19 @@ -import { DefaultButton, Icon, Slider, Toggle } from "@fluentui/react"; +import { DefaultButton, Icon, Slider, Stack, Toggle } from "@fluentui/react"; import { observer } from "mobx-react-lite"; -import { FC, useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { FC, useCallback, useEffect, useRef, useState } from "react"; import styled from "styled-components"; -import useErrorBoundary from "../../../hooks/use-error-boundary"; -import type { IFieldMeta, IRow } from "../../../interfaces"; +import type { IFieldMeta } from "../../../interfaces"; import { useGlobalStore } from "../../../store"; -import { CausalLinkDirection } from "../../../utils/resolve-causal"; -import type { ModifiableBgKnowledge } from "../config"; +import type { EdgeAssert } from "../../../store/causalStore/modelStore"; +import { useCausalViewContext } from "../../../store/causalStore/viewStore"; +import type { Subtree } from "../exploration"; import Floating from "../floating"; import ExplorerMainView from "./explorerMainView"; -import FlowAnalyzer, { NodeWithScore } from "./flowAnalyzer"; -import type { GraphNodeAttributes } from "./graph-utils"; export type CausalNode = { nodeId: number; + fid: string; } export type CausalLink = { @@ -24,39 +23,19 @@ export type CausalLink = { type: 'directed' | 'bidirected' | 'undirected' | 'weak directed' | 'weak undirected'; } -export interface DiagramGraphData { - readonly nodes: readonly Readonly[]; - readonly links: readonly Readonly[]; -} - export interface ExplorerProps { allowEdit: boolean; - dataSource: IRow[]; - scoreMatrix: readonly (readonly number[])[]; - preconditions: ModifiableBgKnowledge[]; - onNodeSelected: ( - node: Readonly | null, - simpleCause: readonly Readonly[], - simpleEffect: readonly Readonly[], - composedCause: readonly Readonly[], - composedEffect: readonly Readonly[], - ) => void; - onLinkTogether: (srcIdx: number, tarIdx: number, type: ModifiableBgKnowledge['type']) => void; + onLinkTogether: (srcFid: string, tarFid: string, type: EdgeAssert) => void; onRevertLink: (srcFid: string, tarFid: string) => void; onRemoveLink: (srcFid: string, tarFid: string) => void; - renderNode?: (node: Readonly) => GraphNodeAttributes | undefined; - synchronizePredictionsUsingCausalResult: () => void; handleLasso?: (fields: IFieldMeta[]) => void; + handleSubTreeSelected?: (subtree: Subtree | null) => void; } -const sNormalize = (matrix: readonly (readonly number[])[]): number[][] => { - return matrix.map(vec => vec.map(n => 2 / (1 + Math.exp(-n)) - 1)); -}; - const Container = styled.div` width: 100%; display: flex; - flex-direction: row; + flex-direction: column; align-items: stretch; position: relative; `; @@ -100,234 +79,76 @@ const MainView = styled.div` const Explorer: FC = ({ allowEdit, - dataSource, - scoreMatrix, - onNodeSelected, onLinkTogether, onRevertLink, onRemoveLink, - preconditions, - renderNode, - synchronizePredictionsUsingCausalResult, handleLasso, + handleSubTreeSelected, }) => { const { causalStore } = useGlobalStore(); - const { causalStrength, selectedFields } = causalStore; + const { causality } = causalStore.model; const [cutThreshold, setCutThreshold] = useState(0); const [mode, setMode] = useState<'explore' | 'edit'>('explore'); const [allowZoom, setAllowZoom] = useState(false); - const data = useMemo(() => sNormalize(scoreMatrix), [scoreMatrix]); - - const nodes = useMemo(() => { - return selectedFields.map((_, i) => ({ nodeId: i })); - }, [selectedFields]); - - const links = useMemo(() => { - if (causalStrength.length === 0) { - return []; - } - if (causalStrength.length !== data.length) { - console.warn(`lengths of matrixes do not match`); - return []; - } - - const links: CausalLink[] = []; - - for (let i = 0; i < data.length - 1; i += 1) { - for (let j = i + 1; j < data.length; j += 1) { - const weight = Math.abs(data[i][j]); - const direction = causalStrength[i][j]; - switch (direction) { - case CausalLinkDirection.none: { - break; - } - case CausalLinkDirection.directed: { - links.push({ - causeId: i, - effectId: j, - score: weight, - type: 'directed', - }); - break; - } - case CausalLinkDirection.reversed: { - links.push({ - causeId: j, - effectId: i, - score: weight, - type: 'directed', - }); - break; - } - case CausalLinkDirection.weakDirected: { - links.push({ - causeId: i, - effectId: j, - score: weight, - type: 'weak directed', - }); - break; - } - case CausalLinkDirection.weakReversed: { - links.push({ - causeId: j, - effectId: i, - score: weight, - type: 'weak directed', - }); - break; - } - case CausalLinkDirection.undirected: { - links.push({ - causeId: i, - effectId: j, - score: weight, - type: 'undirected', - }); - break; - } - case CausalLinkDirection.weakUndirected: { - links.push({ - causeId: i, - effectId: j, - score: weight, - type: 'weak undirected', - }); - break; - } - case CausalLinkDirection.bidirected: { - links.push({ - causeId: i, - effectId: j, - score: weight, - type: 'bidirected', - }); - break; - } - default: { - break; - } - } - } - } - - return links.sort((a, b) => Math.abs(b.score) - Math.abs(a.score)); - }, [data, causalStrength]); - - const value = useMemo(() => ({ nodes, links }), [nodes, links]); - - const [focus, setFocus] = useState(-1); - const [showFlowAnalyzer, setShowFlowAnalyzer] = useState(false); + const viewContext = useCausalViewContext(); const handleClickCircle = useCallback((fid: string | null) => { if (fid === null) { - return setFocus(-1); + return viewContext?.clearSelected(); } - const idx = selectedFields.findIndex(f => f.fid === fid); if (mode === 'explore') { - setFocus(idx === focus ? -1 : idx); + viewContext?.toggleNodeSelected(fid); } - }, [mode, focus, selectedFields]); - - const toggleFlowAnalyzer = useCallback(() => { - setShowFlowAnalyzer(display => !display); - }, []); - - const ErrorBoundary = useErrorBoundary((err, info) => { - // console.error(err ?? info); - return ( -
-

- {"Failed to visualize flows as DAG. Click a different node or turn up the link filter."} -

- {err?.message ?? info} -
- ); - }, [selectedFields, value, mode === 'explore' ? focus : -1, cutThreshold]); - - const handleLink = useCallback((srcFid: string, tarFid: string, type: ModifiableBgKnowledge['type']) => { - if (srcFid === tarFid) { - return; - } - onLinkTogether(selectedFields.findIndex(f => f.fid === srcFid), selectedFields.findIndex(f => f.fid === tarFid), type); - }, [selectedFields, onLinkTogether]); - - const [selectedSubtree, setSelectedSubtree] = useState([]); - - const onNodeSelectedRef = useRef(onNodeSelected); - onNodeSelectedRef.current = onNodeSelected; - - const handleNodeSelect = useCallback((node, simpleCause, simpleEffect, composedCause, composedEffect) => { - onNodeSelectedRef.current(node, simpleCause, simpleEffect, composedCause, composedEffect); - const shallowSubtree = simpleEffect.reduce[]>( - (list, f) => { - if (!list.some((which) => which.field.fid === f.field.fid)) { - list.push(f); - } - return list; - }, - [...simpleCause] - ); - setSelectedSubtree(shallowSubtree.map(node => node.field.fid)); - }, []); + }, [mode, viewContext]); const forceRelayoutRef = useRef<() => void>(() => {}); useEffect(() => { - setFocus(-1); - onNodeSelectedRef.current(null, [], [], [], []); - }, [mode]); + viewContext?.clearSelected(); + }, [mode, viewContext]); const [limit, setLimit] = useState(20); - const [autoLayout, setAutoLayout] = useState(true); const forceLayout = useCallback(() => { - setAutoLayout(true); forceRelayoutRef.current(); }, []); - useEffect(() => { - if (mode === 'edit') { - synchronizePredictionsUsingCausalResult(); - } - }, [mode, synchronizePredictionsUsingCausalResult]); - useEffect(() => { setMode('explore'); }, [allowEdit]); - return (<> - focus !== -1 && setFocus(-1)}> + return ( + + + + 重新布局 + + = ({ ()}> - - 刷新布局 - = ({ inlineLabel /> )} - setAutoLayout(Boolean(checked))} - onText="On" - offText="Off" - inlineLabel - /> setLimit(value)} /> @@ -398,19 +199,7 @@ const Explorer: FC = ({ - - - - ); + ); }; diff --git a/packages/rath-client/src/pages/causal/functionalDependencies/FDBatch.tsx b/packages/rath-client/src/pages/causal/functionalDependencies/FDBatch.tsx index cab53fd9..57bbf720 100644 --- a/packages/rath-client/src/pages/causal/functionalDependencies/FDBatch.tsx +++ b/packages/rath-client/src/pages/causal/functionalDependencies/FDBatch.tsx @@ -1,12 +1,11 @@ import { ActionButton, DefaultButton, Spinner, Stack } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; -import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import { FC, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import styled from 'styled-components'; import produce from 'immer'; import { useGlobalStore } from '../../../store'; import type { IFunctionalDep } from '../config'; -import type { FDPanelProps } from './FDPanel'; -import { getGeneratedFDFromAutoDetection, getGeneratedFDFromExtInfo } from './utils'; +import { getGeneratedFDFromAutoDetection } from './utils'; import FDEditor from './FDEditor'; @@ -55,28 +54,26 @@ const dropdownOptions: { key: BatchUpdateMode; text: string }[] = [ }, ]; -const FDBatch: React.FC = ({ - context, functionalDependencies, setFunctionalDependencies, renderNode, -}) => { +const FDBatch: FC = () => { const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; + const { sample } = causalStore.dataset; + const { functionalDependencies } = causalStore.model; const [displayPreview, setDisplayPreview] = useState(false); - const [preview, setPreview] = useState(null); + const [preview, setPreview] = useState(null); const isPending = displayPreview && preview === null; const [mode, setMode] = useState(BatchUpdateMode.OVERWRITE_ONLY); - const { dataSubset } = context; - const updatePreview = useMemo(() => { + const updatePreview = useMemo<(fdArr: IFunctionalDep[] | ((prev: readonly IFunctionalDep[] | null) => readonly IFunctionalDep[])) => void>(() => { if (displayPreview) { - return setPreview as typeof setFunctionalDependencies; + return setPreview; } return () => {}; }, [displayPreview]); const generateFDFromExtInfo = useCallback(() => { - setPreview(getGeneratedFDFromExtInfo(selectedFields)); + setPreview(causalStore.model.generatedFDFromExtInfo); setDisplayPreview(true); - }, [selectedFields]); + }, [causalStore]); const pendingRef = useRef>(); useEffect(() => { @@ -85,7 +82,7 @@ const FDBatch: React.FC = ({ } }, [displayPreview]); const generateFDFromAutoDetection = useCallback(() => { - const p = getGeneratedFDFromAutoDetection(dataSubset, selectedFields.map(f => f.fid)); + const p = sample.getAll().then(data => getGeneratedFDFromAutoDetection(data)); pendingRef.current = p; p.then(res => { if (p === pendingRef.current) { @@ -100,11 +97,11 @@ const FDBatch: React.FC = ({ pendingRef.current = undefined; }); setDisplayPreview(true); - }, [selectedFields, dataSubset]); + }, [sample]); const handleClear = useCallback(() => { - setFunctionalDependencies([]); - }, [setFunctionalDependencies]); + causalStore.model.updateFunctionalDependencies([]); + }, [causalStore]); const submittable = useMemo(() => { if (preview) { @@ -118,7 +115,7 @@ const FDBatch: React.FC = ({ }); } return deps.concat([dep]); - }, functionalDependencies); + }, functionalDependencies.slice(0)); } case BatchUpdateMode.FILL_ONLY: { return preview.reduce((deps, dep) => { @@ -134,25 +131,25 @@ const FDBatch: React.FC = ({ }); } return deps; - }, functionalDependencies); + }, functionalDependencies.slice(0)); } case BatchUpdateMode.FULLY_REPLACE: { - return preview; + return preview.slice(0); } default: { - return functionalDependencies; + return functionalDependencies.slice(0); } } } else { - return functionalDependencies; + return functionalDependencies.slice(0); } }, [preview, functionalDependencies, mode]); const handleSubmit = useCallback(() => { - setFunctionalDependencies(submittable); + causalStore.model.updateFunctionalDependencies(submittable); setDisplayPreview(false); setPreview(null); - }, [setFunctionalDependencies, submittable]); + }, [causalStore, submittable]); const handleCancel = useCallback(() => { setPreview(null); @@ -169,13 +166,10 @@ const FDBatch: React.FC = ({ 使用扩展字段计算图 - + {/* 导入影响关系 - - - 导入因果模型 - - + */} + 自动识别 @@ -188,10 +182,8 @@ const FDBatch: React.FC = ({ ) : ( )}
diff --git a/packages/rath-client/src/pages/causal/functionalDependencies/FDEditor.tsx b/packages/rath-client/src/pages/causal/functionalDependencies/FDEditor.tsx index 3034cded..b5f37d0b 100644 --- a/packages/rath-client/src/pages/causal/functionalDependencies/FDEditor.tsx +++ b/packages/rath-client/src/pages/causal/functionalDependencies/FDEditor.tsx @@ -1,20 +1,20 @@ import { observer } from 'mobx-react-lite'; -import React from 'react'; +import type { FC } from 'react'; +import type { IFunctionalDep } from '../config'; import FDGraph from './FDGraph'; -import type { FDPanelProps } from './FDPanel'; -const FDEditor: React.FC = ({ - context, functionalDependencies, setFunctionalDependencies, renderNode, title = '编辑视图', -}) => { +const FDEditor: FC<{ + title?: string; + functionalDependencies: readonly IFunctionalDep[]; + setFunctionalDependencies: (fdArr: IFunctionalDep[] | ((prev: readonly IFunctionalDep[] | null) => readonly IFunctionalDep[])) => void; +}> = ({ functionalDependencies, setFunctionalDependencies, title = '编辑视图' }) => { return ( <>

{title}

); diff --git a/packages/rath-client/src/pages/causal/functionalDependencies/FDGraph.tsx b/packages/rath-client/src/pages/causal/functionalDependencies/FDGraph.tsx index 08008abd..ad384024 100644 --- a/packages/rath-client/src/pages/causal/functionalDependencies/FDGraph.tsx +++ b/packages/rath-client/src/pages/causal/functionalDependencies/FDGraph.tsx @@ -1,15 +1,15 @@ import { observer } from 'mobx-react-lite'; -import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; +import React, { useCallback, useEffect, useRef, useState } from 'react'; import type { Graph } from '@antv/g6'; -import produce from 'immer'; import { DefaultButton } from '@fluentui/react'; import styled from 'styled-components'; +import produce from 'immer'; import { useGlobalStore } from '../../../store'; -import type { CausalLink } from '../explorer'; import { useRenderData, useGraphOptions } from '../explorer/graph-utils'; import { useReactiveGraph } from '../explorer/graph-helper'; -import type { ModifiableBgKnowledge } from '../config'; -import type { FDPanelProps } from './FDPanel'; +import { transformFuncDepsToPag } from '../../../store/causalStore/pag'; +import type { IFunctionalDep } from '../config'; +import { useCausalViewContext } from '../../../store/causalStore/viewStore'; const Container = styled.div` @@ -35,26 +35,23 @@ const Container = styled.div` } `; -const FDGraph: React.FC = ({ - functionalDependencies, setFunctionalDependencies, renderNode, +const FDGraph: React.FC<{ + functionalDependencies: readonly IFunctionalDep[]; + setFunctionalDependencies: (fdArr: IFunctionalDep[] | ((prev: readonly IFunctionalDep[] | null) => readonly IFunctionalDep[])) => void; +}> = ({ + functionalDependencies, + setFunctionalDependencies, }) => { const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; + const { fields } = causalStore; + const functionalDependenciesAsPag = transformFuncDepsToPag(functionalDependencies); + const { onRenderNode } = useCausalViewContext() ?? {}; const containerRef = useRef(null); const [width, setWidth] = useState(0); - const nodes = useMemo(() => selectedFields.map((f, i) => ({ id: i, fid: f.fid })), [selectedFields]); - const data = useMemo<{ - nodes: { id: number }[]; - links: { source: number; target: number; type: CausalLink['type'] }[]; - }>(() => ({ - nodes, - links: [], - }), [nodes]); - const onLinkTogether = useCallback((srcFid: string, tarFid: string) => { - setFunctionalDependencies(list => produce(list, draft => { + setFunctionalDependencies(list => produce(list ?? [], draft => { const linked = draft.find(fd => fd.fid === tarFid); if (linked && !linked.params.some(prm => prm.fid === srcFid)) { linked.params.push({ fid: srcFid }); @@ -77,7 +74,7 @@ const FDGraph: React.FC = ({ const onRemoveLink = useCallback((edge: { srcFid: string; tarFid: string; } | null) => { if (edge) { - setFunctionalDependencies(list => produce(list, draft => { + setFunctionalDependencies(list => produce(list ?? [], draft => { const linkedIdx = draft.findIndex(fd => fd.fid === edge.tarFid && fd.params.some(prm => prm.fid === edge.srcFid)); if (linkedIdx !== -1) { const linked = draft[linkedIdx]; @@ -94,43 +91,33 @@ const FDGraph: React.FC = ({ } }, [setFunctionalDependencies]); - const conditions = useMemo(() => { - return functionalDependencies.reduce((list, fd) => { - for (const from of fd.params) { - list.push({ - src: from.fid, - tar: fd.fid, - type: 'directed-must-link', - }); - } - return list; - }, []); - }, [functionalDependencies]); - const graphRef = useRef(); - const renderData = useRenderData(data, 'edit', conditions, selectedFields, renderNode); - const cfg = useGraphOptions(width, selectedFields, undefined, onLinkTogether, graphRef); + const renderData = useRenderData({ + mode: 'edit', + fields, + PAG: functionalDependenciesAsPag, + renderNode: onRenderNode, + }); + const cfg = useGraphOptions({ + width, + fields, + handleLink: onLinkTogether, + graphRef, + }); const cfgRef = useRef(cfg); cfgRef.current = cfg; - const [forceUpdateFlag, setUpdateFlag] = useState<1 | 0>(1); - - useReactiveGraph( + const graph = useReactiveGraph({ containerRef, width, graphRef, - cfg, - renderData, - 'edit', - undefined, - onRemoveLink, - selectedFields, - undefined, - forceUpdateFlag, - null, - [], - false, - ); + options: cfg, + data: renderData, + mode: 'edit', + handleEdgeClick: onRemoveLink, + fields, + allowZoom: false, + }); useEffect(() => { const { current: container } = containerRef; @@ -147,6 +134,10 @@ const FDGraph: React.FC = ({ } }, []); + const handleForceLayout = useCallback(() => { + graph.refresh(); + }, [graph]); + return (
@@ -157,10 +148,10 @@ const FDGraph: React.FC = ({ padding: '0.4em 0', height: 'unset', }} - onClick={() => setUpdateFlag(flag => flag ? 0 : 1)} - iconProps={{ iconName: 'Repair' }} + onClick={handleForceLayout} + iconProps={{ iconName: 'Play' }} > - 刷新布局 + 重新布局
diff --git a/packages/rath-client/src/pages/causal/functionalDependencies/FDPanel.tsx b/packages/rath-client/src/pages/causal/functionalDependencies/FDPanel.tsx index b9cfc4d7..a7179352 100644 --- a/packages/rath-client/src/pages/causal/functionalDependencies/FDPanel.tsx +++ b/packages/rath-client/src/pages/causal/functionalDependencies/FDPanel.tsx @@ -1,10 +1,8 @@ import { observer } from 'mobx-react-lite'; -import React from 'react'; +import { FC, useCallback } from 'react'; import styled from 'styled-components'; -import type { IFunctionalDep } from '../config'; -import type { IFieldMeta } from '../../../interfaces'; -import type { GraphNodeAttributes } from '../explorer/graph-utils'; -import type { useDataViews } from '../hooks/dataViews'; +import { useGlobalStore } from '../../../store'; +import { IFunctionalDep } from '../config'; import FDBatch from './FDBatch'; import FDEditor from './FDEditor'; @@ -22,29 +20,22 @@ const Container = styled.div` } `; -export interface FDPanelProps { - context: ReturnType; - functionalDependencies: IFunctionalDep[]; - setFunctionalDependencies: (fdArr: IFunctionalDep[] | ((prev: IFunctionalDep[]) => IFunctionalDep[])) => void; - renderNode?: (node: Readonly) => GraphNodeAttributes | undefined; -} +const FDPanel: FC = () => { + const { causalStore } = useGlobalStore(); + const { functionalDependencies } = causalStore.model; + + const setFunctionalDependencies = useCallback(( + fdArr: IFunctionalDep[] | ((prev: readonly IFunctionalDep[] | null) => readonly IFunctionalDep[]) + ) => { + causalStore.model.updateFunctionalDependencies(Array.isArray(fdArr) ? fdArr : fdArr(functionalDependencies)); + }, [causalStore, functionalDependencies]); -const FDPanel: React.FC = ({ - context, functionalDependencies, setFunctionalDependencies, renderNode, -}) => { return ( - + ); diff --git a/packages/rath-client/src/pages/causal/functionalDependencies/utils.ts b/packages/rath-client/src/pages/causal/functionalDependencies/utils.ts index fbdf3586..fe6c70f8 100644 --- a/packages/rath-client/src/pages/causal/functionalDependencies/utils.ts +++ b/packages/rath-client/src/pages/causal/functionalDependencies/utils.ts @@ -1,89 +1,60 @@ import { notify } from "../../../components/error"; -import type { IFieldMeta, IRow } from "../../../interfaces"; +import type { IRow } from "../../../interfaces"; import { getGlobalStore } from "../../../store"; -import type { IFunctionalDep, ModifiableBgKnowledge } from "../config"; +import { IFunctionalDep, IFunctionalDepParam, PAG_NODE } from "../config"; -export const getGeneratedPreconditionsFromExtInfo = (fields: IFieldMeta[]): ModifiableBgKnowledge[] => { - return fields.reduce((list, f) => { - if (f.extInfo) { - for (const from of f.extInfo.extFrom) { - list.push({ - src: from, - tar: f.fid, - type: 'directed-must-link', - }); - } - } - return list; - }, []); -}; - -export const getGeneratedFDFromExtInfo = (fields: IFieldMeta[]): IFunctionalDep[] => { - return fields.reduce((list, f) => { - if (f.extInfo) { - list.push({ - fid: f.fid, - params: f.extInfo.extFrom.map(from => ({ - fid: from, - })), - func: f.extInfo.extOpt, - extInfo: f.extInfo, - }); - } - return list; - }, []); -}; +const AutoDetectionApiPath = 'causal/FuncDepTest'; -// FIXME: path -const AutoDetectionApiPath = 'autoDetect'; - -export const getGeneratedFDFromAutoDetection = async ( - dataSource: IRow[], - fields: string[], -): Promise => { +export const getGeneratedFDFromAutoDetection = async (dataSource: readonly IRow[]): Promise => { try { - const { causalStore, dataSourceStore } = getGlobalStore(); - const { apiPrefix } = causalStore; - const { fieldMetas } = dataSourceStore; - const res = await fetch(`${apiPrefix}/${AutoDetectionApiPath}`, { + const { causalStore } = getGlobalStore(); + const { causalServer } = causalStore.operator; + const { allFields, fields } = causalStore.dataset; + const res = await fetch(`${causalServer}/${AutoDetectionApiPath}`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, - // FIXME: I have no idea what is the payload body: JSON.stringify({ dataSource, - fields: fieldMetas, - focusedFields: fields, + fields: allFields, + focusedFields: fields.map(f => f.fid), + bgKnowledgesPag: [], + funcDeps: [], + params: { + alpha: -3.3010299956639813, + catEncodeType: "topk-with-noise", + indep_test: "chisq", + o_alpha: 3, + orient: "ANM", + quantEncodeType: "bin" + }, }), }); const result = await res.json(); if (result.success) { - return result.data; + const matrix = result.data.matrix as PAG_NODE[][]; + const deps: IFunctionalDep[] = []; + for (let j = 0; j < matrix.length; j += 1) { + const params: IFunctionalDepParam[] = []; + for (let i = 0; i < matrix.length; i += 1) { + if (i === j || matrix[i][j] !== PAG_NODE.ARROW || matrix[j][i] !== PAG_NODE.BLANK) { + continue; + } + params.push({ fid: fields[i].fid, type: 'FuncDepTest' }); + } + if (params.length > 0) { + deps.push({ + fid: fields[j].fid, + params, + }); + } + } + return deps; } else { throw new Error(result.message); } - // // FIXME: mock data - // await new Promise(resolve => setTimeout(resolve, 2_000)); - // const selectedFields = fieldMetas.filter(f => fields.includes(f.fid)); - // const fidArr = selectedFields.map(f => f.fid); - // const list: ModifiableBgKnowledge[] = []; - // while (list.length < 6 && fidArr.length >= 2) { - // const srcIdx = Math.floor(Math.random() * fidArr.length); - // const tarIdx = (srcIdx + Math.floor(Math.random() * (fidArr.length - 1))) % fidArr.length; - // if (srcIdx !== tarIdx) { - // list.push({ - // src: fidArr[srcIdx], - // tar: fidArr[tarIdx], - // type: (['must-link', 'must-not-link', 'directed-must-link', 'directed-must-not-link'] as const)[ - // Math.floor(Math.random() * 4) - // ], - // }); - // } - // fidArr.splice(srcIdx, 1); - // } - // return list; } catch (error) { notify({ title: 'Causal Preconditions Auto Detection Error', diff --git a/packages/rath-client/src/pages/causal/hooks/dataViews.ts b/packages/rath-client/src/pages/causal/hooks/dataViews.ts deleted file mode 100644 index d0c97714..00000000 --- a/packages/rath-client/src/pages/causal/hooks/dataViews.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { useState, useMemo, useEffect } from "react"; -import { applyFilters, IFilter } from '@kanaries/loa' -import { IRow } from "../../../interfaces"; -import { focusedSample } from "../../../utils/sample"; -import { useGlobalStore } from "../../../store"; -import { baseDemoSample } from "../../../utils/view-sample"; - -const VIZ_SUBSET_LIMIT = 2_000; -const SAMPLE_UPDATE_DELAY = 500; - -/** 这是一个局部状态,不要在 causal page 以外的任何组件使用它 */ -export function useDataViews (originData: IRow[]) { - const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; - const [sampleRate, setSampleRate] = useState(1); - const [appliedSampleRate, setAppliedSampleRate] = useState(sampleRate); - const [filters, setFilters] = useState([]); - const sampleSize = Math.round(originData.length * appliedSampleRate); - const filteredData = useMemo(() => { - return applyFilters(originData, filters); - }, [originData, filters]); - const sample = useMemo(() => { - return focusedSample(filteredData, selectedFields, sampleSize).map(i => filteredData[i]); - }, [filteredData, selectedFields, sampleSize]); - const vizSampleData = useMemo(() => { - if (sample.length < VIZ_SUBSET_LIMIT) { - return sample; - } - return baseDemoSample(sample, VIZ_SUBSET_LIMIT); - }, [sample]); - - useEffect(() => { - if (sampleRate !== appliedSampleRate) { - const delayedTask = setTimeout(() => { - setAppliedSampleRate(sampleRate); - }, SAMPLE_UPDATE_DELAY); - - return () => { - clearTimeout(delayedTask); - }; - } - }, [sampleRate, appliedSampleRate]); - return { - vizSampleData, - dataSubset: sample, - sample, - filteredData, - sampleRate, - setSampleRate, - appliedSampleRate, - setAppliedSampleRate, - filters, - setFilters, - sampleSize - } -} \ No newline at end of file diff --git a/packages/rath-client/src/pages/causal/hooks/interactFieldGroup.ts b/packages/rath-client/src/pages/causal/hooks/interactFieldGroup.ts deleted file mode 100644 index b11a1c56..00000000 --- a/packages/rath-client/src/pages/causal/hooks/interactFieldGroup.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { useCallback, useState } from "react"; -import { IFieldMeta } from "../../../interfaces"; - -/** 这是一个局部状态,不要在 causal page 以外的任何组件使用它 */ -export function useInteractFieldGroups (fieldMetas: IFieldMeta[]) { - const [fieldGroup, setFieldGroup] = useState([]); - const appendFields2Group = useCallback( - (fids: string[]) => { - // causalStore.setFocusNodeIndex(fieldMetas.findIndex((f) => f.fid === xFid)); - setFieldGroup((group) => { - const nextGroup = [...group]; - for (let fid of fids) { - const fm = fieldMetas.find((f) => f.fid === fid); - if (fm && !nextGroup.find((f) => f.fid === fid)) { - nextGroup.push(fm); - } - } - return nextGroup; - }); - }, - [setFieldGroup, fieldMetas] - ); - - const clearFieldGroup = useCallback(() => { - setFieldGroup([]); - }, [setFieldGroup]); - - return { - fieldGroup, - setFieldGroup, - appendFields2Group, - clearFieldGroup - } -} \ No newline at end of file diff --git a/packages/rath-client/src/pages/causal/index.tsx b/packages/rath-client/src/pages/causal/index.tsx index ac91b67f..04b83ea4 100644 --- a/packages/rath-client/src/pages/causal/index.tsx +++ b/packages/rath-client/src/pages/causal/index.tsx @@ -1,14 +1,10 @@ import { observer } from 'mobx-react-lite'; -import { FC, useCallback, useEffect, useRef, useState } from 'react'; +import { FC, useCallback, useRef, useState } from 'react'; import styled from 'styled-components'; -import type { IFieldMeta } from '../../interfaces'; import { useGlobalStore } from '../../store'; -import type { IFunctionalDep, ModifiableBgKnowledge } from './config'; -import { useInteractFieldGroups } from './hooks/interactFieldGroup'; -import { useDataViews } from './hooks/dataViews'; -import type { GraphNodeAttributes } from './explorer/graph-utils'; +import { useCausalViewProvider } from '../../store/causalStore/viewStore'; +import type { IFunctionalDep } from './config'; import { CausalStepPager } from './step'; -import { getGeneratedFDFromExtInfo } from './functionalDependencies/utils'; const Main = styled.div` @@ -23,43 +19,9 @@ const Main = styled.div` `; const CausalPage: FC = () => { - const { dataSourceStore, causalStore } = useGlobalStore(); - const { fieldMetas, cleanedData } = dataSourceStore; - const { selectedFields } = causalStore; - const interactFieldGroups = useInteractFieldGroups(fieldMetas); + const { causalStore } = useGlobalStore(); - useEffect(() => { - causalStore.setFocusFieldIds( - fieldMetas - .filter((f) => f.disable !== true) - .slice(0, 10) - .map((f) => f.fid) - ); // 默认只使用前 10 个) - }, [fieldMetas, causalStore]); - - const [modifiablePrecondition, __unsafeSetModifiablePrecondition] = useState([]); - - const setModifiablePrecondition = useCallback((next: ModifiableBgKnowledge[] | ((prev: ModifiableBgKnowledge[]) => ModifiableBgKnowledge[])) => { - __unsafeSetModifiablePrecondition(prev => { - const list = typeof next === 'function' ? next(prev) : next; - return list.reduce((links, link) => { - if (link.src === link.tar) { - // 禁止自环边 - return links; - } - const overloadIdx = links.findIndex( - which => [which.src, which.tar].every(node => [link.src, link.tar].includes(node)) - ); - if (overloadIdx !== -1) { - const temp = links.map(l => l); - temp.splice(overloadIdx, 1, link); - return temp; - } else { - return links.concat([link]); - } - }, []); - }); - }, []); + const ViewContextProvider = useCausalViewProvider(causalStore); const [functionalDependencies, __unsafeSetFunctionalDependencies] = useState([]); @@ -72,50 +34,20 @@ const CausalPage: FC = () => { }); }, []); - const dataContext = useDataViews(cleanedData); - - useEffect(() => { - causalStore.updateCausalAlgorithmList(fieldMetas); - }, [causalStore, fieldMetas]); - - // 结点可以 project 一些字段信息 - const renderNode = useCallback((node: Readonly): GraphNodeAttributes | undefined => { - const value = 2 / (1 + Math.exp(-1 * node.features.entropy / 2)) - 1; - return { - style: { - stroke: `rgb(${Math.floor(95 * (1 - value))},${Math.floor(149 * (1 - value))},255)`, - }, - }; - }, []); - const submitRef = useRef(setFunctionalDependencies); submitRef.current = setFunctionalDependencies; const fdRef = useRef(functionalDependencies); fdRef.current = functionalDependencies; - useEffect(() => { - setTimeout(() => { - if (fdRef.current.length === 0) { - const fds = getGeneratedFDFromExtInfo(selectedFields); - submitRef.current(fds); - } - }, 400); - }, [selectedFields]); return (
-
-

因果分析

-
- -
+ +
+

因果分析

+
+ +
+
); }; diff --git a/packages/rath-client/src/pages/causal/matrixPanel/directionMatrix.tsx b/packages/rath-client/src/pages/causal/matrixPanel/directionMatrix.tsx index 1781f86b..20532f17 100644 --- a/packages/rath-client/src/pages/causal/matrixPanel/directionMatrix.tsx +++ b/packages/rath-client/src/pages/causal/matrixPanel/directionMatrix.tsx @@ -9,7 +9,7 @@ import { CausalLinkDirection, describeDirection, stringifyDirection } from '../. interface Props { mark: 'circle' | 'square'; - data: DeepReadonly; + data: DeepReadonly; fields: DeepReadonly; onSelect?: (xFieldId: string, yFieldId: string) => void; } diff --git a/packages/rath-client/src/pages/causal/matrixPanel/index.tsx b/packages/rath-client/src/pages/causal/matrixPanel/index.tsx index ff50ae86..d5192e9f 100644 --- a/packages/rath-client/src/pages/causal/matrixPanel/index.tsx +++ b/packages/rath-client/src/pages/causal/matrixPanel/index.tsx @@ -1,8 +1,8 @@ import { Dropdown, Pivot, PivotItem, PrimaryButton, Spinner, Stack } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; -import React, { useEffect, useState } from 'react'; +import { FC, useState } from 'react'; import styled from 'styled-components'; -import { IFieldMeta, IRow } from '../../../interfaces'; +import type { IFieldMeta } from '../../../interfaces'; import { useGlobalStore } from '../../../store'; import DirectionMatrix from './directionMatrix'; import RelationMatrixHeatMap from './relationMatrixHeatMap'; @@ -54,28 +54,24 @@ const MARK_LABELS = [ { key: 'square', text: '矩形' }, ]; -function showMatrix(causalFields: IFieldMeta[], mat: number[][], computing?: boolean): boolean { +function showMatrix(causalFields: readonly IFieldMeta[], mat: readonly (readonly number[])[], computing: boolean): boolean { return causalFields.length > 0 && mat.length > 0 && causalFields.length === mat.length && !computing; } interface MatrixPanelProps { onMatrixPointClick?: (xFid: string, yFid: string) => void; - fields: IFieldMeta[]; - dataSource: IRow[]; onCompute: (type: MATRIX_TYPE) => void; diagram?: JSX.Element; } -const MatrixPanel: React.FC = (props) => { - const { onMatrixPointClick, fields, onCompute, dataSource, diagram } = props; +const MatrixPanel: FC = (props) => { + const { onMatrixPointClick, onCompute, diagram } = props; const [viewType, setViewType] = useState(VIEW_TYPE.diagram); const [selectedKey, setSelectedKey] = useState(MATRIX_TYPE.causal); const [markType, setMarkType] = useState<'circle' | 'square'>('circle'); const { causalStore } = useGlobalStore(); - const { computing, igCondMatrix, igMatrix, causalStrength } = causalStore; - - useEffect(() => { - causalStore.computeIGMatrix(dataSource, fields); - }, [dataSource, fields, causalStore]); + const { fields } = causalStore; + const { mutualMatrix, condMutualMatrix, causalityRaw } = causalStore.model; + const { busy } = causalStore.operator; return ( @@ -99,19 +95,19 @@ const MatrixPanel: React.FC = (props) => { onRenderText={(props, defaultRenderer) => { return (
- {computing && } + {busy && } {defaultRenderer?.(props)}
); }} - disabled={computing} + disabled={busy} onClick={() => { - if (computing) { + if (busy) { return; } onCompute(selectedKey); }} - iconProps={computing ? undefined : { iconName: 'Rerun' }} + iconProps={busy ? undefined : { iconName: 'Rerun' }} style={{ width: 'max-content', transition: 'width 400ms' }} /> {selectedKey === MATRIX_TYPE.causal && ( @@ -158,37 +154,37 @@ const MatrixPanel: React.FC = (props) => {
- {selectedKey === MATRIX_TYPE.mutualInfo && showMatrix(fields, igMatrix, computing) && ( + {selectedKey === MATRIX_TYPE.mutualInfo && mutualMatrix && showMatrix(fields, mutualMatrix, busy) && ( )} - {selectedKey === MATRIX_TYPE.conditionalMutualInfo && showMatrix(fields, igCondMatrix, computing) && ( + {selectedKey === MATRIX_TYPE.conditionalMutualInfo && condMutualMatrix && showMatrix(fields, condMutualMatrix, busy) && ( )} {selectedKey === MATRIX_TYPE.causal && ( viewType === VIEW_TYPE.diagram ? ( - computing || diagram - ) : showMatrix(fields, causalStrength, computing) && ( + busy || diagram + ) : causalityRaw && showMatrix(fields, causalityRaw, busy) && ( ) )} - {computing && } + {busy && }
); diff --git a/packages/rath-client/src/pages/causal/modelStorage/index.tsx b/packages/rath-client/src/pages/causal/modelStorage/index.tsx index 54080037..4f0f36ad 100644 --- a/packages/rath-client/src/pages/causal/modelStorage/index.tsx +++ b/packages/rath-client/src/pages/causal/modelStorage/index.tsx @@ -1,6 +1,6 @@ import { ChoiceGroup, DefaultButton, Label, Modal, PrimaryButton, Stack } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; -import React, { Fragment, useState } from 'react'; +import { FC, Fragment, useState } from 'react'; import styled from 'styled-components'; import { notify } from '../../../components/error'; import { useGlobalStore } from '../../../store'; @@ -9,10 +9,9 @@ const ModalInnerContainer = styled.div` padding: 1em; `; -interface ModelStorageProps {} -const ModelStorage: React.FC = (props) => { +const ModelStorage: FC = () => { const { causalStore } = useGlobalStore(); - const { userModelKeys } = causalStore; + const { saveKeys } = causalStore; const [selectedModelKey, setSelectedModelKey] = useState(undefined); const [showModels, setShowModels] = useState(false); return ( @@ -21,22 +20,21 @@ const ModelStorage: React.FC = (props) => { text="保存因果模型" iconProps={{ iconName: 'Save' }} onClick={() => { - causalStore - .saveCausalModel() - .then(() => { + causalStore.save().then(ok => { + if (ok) { notify({ title: 'Causal Model Saved', content: 'Causal model saved successfully.', type: 'success', }); - }) - .catch((err) => { + } else { notify({ title: 'Causal Model Save Failed', - content: `${err}`, + content: 'DatasetId is null.', type: 'error', }); - }); + } + }); }} /> = (props) => { iconProps={{ iconName: 'CloudDownload' }} onClick={() => { setShowModels(true); - causalStore.getCausalModelList(); + causalStore.updateSaveKeys(); }} /> = (props) => { { + options={saveKeys.map((key) => { return { key, text: key, @@ -74,7 +72,7 @@ const ModelStorage: React.FC = (props) => { text="使用" onClick={() => { if (selectedModelKey) { - causalStore.fetchCausalModel(selectedModelKey); + causalStore.checkout(selectedModelKey); } setShowModels(false); }} diff --git a/packages/rath-client/src/pages/causal/params.tsx b/packages/rath-client/src/pages/causal/params.tsx index 69830f26..2ccdd419 100644 --- a/packages/rath-client/src/pages/causal/params.tsx +++ b/packages/rath-client/src/pages/causal/params.tsx @@ -7,37 +7,35 @@ import { PrimaryButton, } from '@fluentui/react'; import produce from 'immer'; -import { toJS } from 'mobx'; +import { runInAction, toJS } from 'mobx'; import { observer } from 'mobx-react-lite'; -import React, { useEffect, useMemo, useState } from 'react'; +import { FC, useEffect, useMemo, useState } from 'react'; import { makeRenderLabelHandler } from '../../components/labelTooltip'; -import { IRow } from '../../interfaces'; import { useGlobalStore } from '../../store'; -import type { BgKnowledge, BgKnowledgePagLink, IFunctionalDep } from './config'; +import { useCausalViewContext } from '../../store/causalStore/viewStore'; +import { IAlgoSchema } from './config'; import DynamicForm from './dynamicForm'; -const Params: React.FC<{ - dataSource: IRow[]; - focusFields: string[]; - bgKnowledge: BgKnowledgePagLink[]; - /** @deprecated */precondition: BgKnowledge[]; - funcDeps: IFunctionalDep[]; -}> = ({ precondition, bgKnowledge, dataSource, funcDeps }) => { +const Params: FC = () => { const { causalStore } = useGlobalStore(); - const { causalAlgorithm, causalParams, showSettings, causalAlgorithmForm, causalAlgorithmOptions } = causalStore; + const { algorithm, causalAlgorithmForm, params: causalParams, causalAlgorithmOptions } = causalStore.operator; + const viewContext = useCausalViewContext(); + const { shouldDisplayAlgorithmPanel } = viewContext ?? {}; - const [algoName, setAlgoName] = useState(causalAlgorithm); - const [params, setParams] = useState<{ [algo: string]: { [key: string]: any } }>(causalParams[causalAlgorithm]); + const [algoName, setAlgoName] = useState(algorithm); + const [params, setParams] = useState<{ [key: string]: any }>(algorithm ? causalParams[algorithm] : {}); useEffect(() => { - setAlgoName(causalAlgorithm); - }, [causalAlgorithm, showSettings]); + setAlgoName(algorithm); + }, [algorithm, shouldDisplayAlgorithmPanel]); useEffect(() => { - setParams(causalParams[algoName]); - }, [causalParams, algoName, showSettings]); + setParams(algoName && algoName in causalParams ? causalParams[algoName] : {}); + }, [causalParams, algoName, shouldDisplayAlgorithmPanel]); - const form = useMemo(() => causalAlgorithmForm[algoName], [causalAlgorithmForm, algoName]); + const form = useMemo(() => { + return algoName && algoName in causalAlgorithmForm ? causalAlgorithmForm[algoName] : null; + }, [causalAlgorithmForm, algoName]); const updateParam = (key: string, value: any) => { setParams(p => produce(toJS(p), draft => { @@ -46,9 +44,14 @@ const Params: React.FC<{ }; const saveParamsAndRun = () => { - causalStore.updateCausalAlgoAndParams(algoName, params); - causalStore.reRunCausalDiscovery(dataSource, precondition, bgKnowledge, funcDeps); - causalStore.toggleSettings(false); + if (algoName === null) { + return; + } + runInAction(() => { + causalStore.operator.updateConfig(algoName, params); + causalStore.run(); + viewContext?.closeAlgorithmPanel(); + }); }; return ( @@ -56,14 +59,12 @@ const Params: React.FC<{ causalStore.toggleSettings(true)} + onClick={() => viewContext?.openAlgorithmPanel()} /> { - causalStore.toggleSettings(false); - }} + onDismiss={() => viewContext?.closeAlgorithmPanel()} > -
{ form.description }
- - + {form && ( + <> +
{ form.description }
+ + + + )}
); diff --git a/packages/rath-client/src/pages/causal/precondition/preconditionBatch.tsx b/packages/rath-client/src/pages/causal/precondition/preconditionBatch.tsx deleted file mode 100644 index 0c52e4ef..00000000 --- a/packages/rath-client/src/pages/causal/precondition/preconditionBatch.tsx +++ /dev/null @@ -1,231 +0,0 @@ -import { DefaultButton, Spinner, Stack } from '@fluentui/react'; -import { observer } from 'mobx-react-lite'; -import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import styled from 'styled-components'; -import produce from 'immer'; -import { useGlobalStore } from '../../../store'; -import type { ModifiableBgKnowledge } from '../config'; -import type { PreconditionPanelProps } from './preconditionPanel'; -import { getGeneratedPreconditionsFromAutoDetection, getGeneratedPreconditionsFromExtInfo } from './utils'; -import PreconditionEditor from './preconditionEditor'; - - -const Container = styled.div` - > button { - margin: 0 1em; - :first-child { - margin: 0 2em 0 0; - } - } -`; - -const Mask = styled.div` - position: fixed; - top: 0; - left: 0; - z-index: 9999; - width: 100vw; - height: 100vh; - display: flex; - align-items: center; - justify-content: center; - background-color: #fff8; - > div { - box-shadow: 0 0 12px rgba(0, 0, 0, 0.15), 0 0 8px rgba(0, 0, 0, 0.03); - background-color: #fff; - padding: 2em; - > div.container { - width: 600px; - > * { - width: 100%; - } - } - } -`; - -enum BatchUpdateMode { - OVERWRITE_ONLY = 'overwrite-only', - FILL_ONLY = 'fill-only', - FULLY_REPLACE = 'fully replace', -} - -const dropdownOptions: { key: BatchUpdateMode; text: string }[] = [ - { - key: BatchUpdateMode.OVERWRITE_ONLY, - text: '更新并替换',//BatchUpdateMode.OVERWRITE_ONLY, - }, - { - key: BatchUpdateMode.FILL_ONLY, - text: '补充不替换',//BatchUpdateMode.FILL_ONLY, - }, - { - key: BatchUpdateMode.FULLY_REPLACE, - text: '全部覆盖',//BatchUpdateMode.FULLY_REPLACE, - }, -]; - -const PreconditionBatch: React.FC = ({ - context, modifiablePrecondition, setModifiablePrecondition, renderNode, -}) => { - const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; - const [displayPreview, setDisplayPreview] = useState(false); - const [preview, setPreview] = useState(null); - const isPending = displayPreview && preview === null; - const [mode, setMode] = useState(BatchUpdateMode.OVERWRITE_ONLY); - const { dataSubset } = context; - - const updatePreview = useMemo(() => { - if (displayPreview) { - return setPreview as typeof setModifiablePrecondition; - } - return () => {}; - }, [displayPreview]); - - const generatePreconditionsFromExtInfo = useCallback(() => { - setPreview(getGeneratedPreconditionsFromExtInfo(selectedFields)); - setDisplayPreview(true); - }, [selectedFields]); - - const pendingRef = useRef>(); - useEffect(() => { - if (!displayPreview) { - pendingRef.current = undefined; - } - }, [displayPreview]); - const generatePreconditionsFromAutoDetection = useCallback(() => { - const p = getGeneratedPreconditionsFromAutoDetection(dataSubset, selectedFields.map(f => f.fid)); - pendingRef.current = p; - p.then(res => { - if (p === pendingRef.current) { - setPreview(res); - } - }).catch(err => { - if (p === pendingRef.current) { - setPreview([]); - } - console.warn(err); - }).finally(() => { - pendingRef.current = undefined; - }); - setDisplayPreview(true); - }, [selectedFields, dataSubset]); - - const handleClear = useCallback(() => { - setModifiablePrecondition([]); - }, [setModifiablePrecondition]); - - const submittable = useMemo(() => { - if (preview) { - switch (mode) { - case BatchUpdateMode.OVERWRITE_ONLY: { - return preview.reduce((links, link) => { - const overloadIdx = links.findIndex( - which => [which.src, which.tar].every(node => [link.src, link.tar].includes(node)) - ); - if (overloadIdx !== -1) { - return produce(links, draft => { - draft.splice(overloadIdx, 1, link); - }); - } - return links.concat([link]); - }, modifiablePrecondition); - } - case BatchUpdateMode.FILL_ONLY: { - return preview.reduce((links, link) => { - const alreadyDefined = links.find( - which => [which.src, which.tar].every(node => [link.src, link.tar].includes(node)) - ); - if (!alreadyDefined) { - return links.concat([link]); - } - return links; - }, modifiablePrecondition); - } - case BatchUpdateMode.FULLY_REPLACE: { - return preview; - } - default: { - return modifiablePrecondition; - } - } - } else { - return modifiablePrecondition; - } - }, [preview, modifiablePrecondition, mode]); - - const handleSubmit = useCallback(() => { - setModifiablePrecondition(submittable); - setDisplayPreview(false); - setPreview(null); - }, [setModifiablePrecondition, submittable]); - - const handleCancel = useCallback(() => { - setPreview(null); - setDisplayPreview(false); - }, []); - - return ( - <> -

快捷操作

- - - 全部删除 - - - 使用扩展字段计算图 - - - 导入影响关系 - - - 导入因果模型 - - - 自动识别 - - - {displayPreview && ( - -
-
- {isPending ? ( - - ) : ( - - )} -
- - opt.key === mode)?.text ?? '确定'} - onClick={handleSubmit} - primary - split - menuProps={{ - items: dropdownOptions, - onItemClick: (_e, item) => { - if (item) { - setMode(item.key as BatchUpdateMode); - } - }, - }} - /> - - -
-
- )} - - ); -}; - -export default observer(PreconditionBatch); diff --git a/packages/rath-client/src/pages/causal/precondition/preconditionEditor.tsx b/packages/rath-client/src/pages/causal/precondition/preconditionEditor.tsx deleted file mode 100644 index fec28790..00000000 --- a/packages/rath-client/src/pages/causal/precondition/preconditionEditor.tsx +++ /dev/null @@ -1,69 +0,0 @@ -import { Pivot, PivotItem } from '@fluentui/react'; -import { observer } from 'mobx-react-lite'; -import React, { useState } from 'react'; -import PreconditionTable from './preconditionTable'; -import PreconditionGraph from './preconditionGraph'; -import type { PreconditionPanelProps } from './preconditionPanel'; - - -const EditModes = [{ - itemKey: 'diagram', - text: '图', - iconName: 'BranchPullRequest', -}, { -// itemKey: 'matrix', // TODO: 实现矩阵编辑 -// text: '矩阵', -// iconName: 'GridViewSmall', -// }, { - itemKey: 'table', - text: '表', - iconName: 'BulletedListText', -}] as const; - -type EditMode = (typeof EditModes)[number]['itemKey']; - -const PreconditionEditor: React.FC = ({ - context, modifiablePrecondition, setModifiablePrecondition, renderNode, title = '编辑视图', -}) => { - const [editMode, setEditMode] = useState('diagram'); - - return ( - <> -

{title}

- { - if (item) { - setEditMode(item.props.itemKey as EditMode); - } - }} - > - {EditModes.map((item) => { - return ; - })} - - {{ - diagram: ( - - ), - matrix: null, - table: ( - - ), - }[editMode]} - - ); -}; - -export default observer(PreconditionEditor); diff --git a/packages/rath-client/src/pages/causal/precondition/preconditionGraph.tsx b/packages/rath-client/src/pages/causal/precondition/preconditionGraph.tsx deleted file mode 100644 index 1422667f..00000000 --- a/packages/rath-client/src/pages/causal/precondition/preconditionGraph.tsx +++ /dev/null @@ -1,172 +0,0 @@ -import { observer } from 'mobx-react-lite'; -import React, { useCallback, useEffect, useMemo, useRef, useState } from 'react'; -import type { Graph } from '@antv/g6'; -import produce from 'immer'; -import { DefaultButton, Dropdown } from '@fluentui/react'; -import styled from 'styled-components'; -import { useGlobalStore } from '../../../store'; -import type { CausalLink } from '../explorer'; -import { useRenderData, useGraphOptions } from '../explorer/graph-utils'; -import { useReactiveGraph } from '../explorer/graph-helper'; -import type { ModifiableBgKnowledge } from '../config'; -import type { PreconditionPanelProps } from './preconditionPanel'; - - -const Container = styled.div` - height: 600px; - position: relative; - > div:first-child { - width: 100%; - height: 100%; - } - > .tools { - position: absolute; - left: 0; - top: 0; - padding: 1em; - flex-grow: 0; - flex-shrink: 0; - flex-basis: max-content; - font-size: 0.8rem; - opacity: 0.7; - :hover { - opacity: 0.95; - } - } -`; - -const PreconditionGraph: React.FC = ({ - modifiablePrecondition, setModifiablePrecondition, renderNode, -}) => { - const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; - - const containerRef = useRef(null); - const [width, setWidth] = useState(0); - - const nodes = useMemo(() => selectedFields.map((f, i) => ({ id: i, fid: f.fid })), [selectedFields]); - const data = useMemo<{ - nodes: { id: number }[]; - links: { source: number; target: number; type: CausalLink['type'] }[]; - }>(() => ({ - nodes, - links: [], - }), [nodes]); - - const [createEdgeMode, setCreateEdgeMode] = useState('directed-must-link'); - const onLinkTogether = useCallback((srcFid: string, tarFid: string) => { - setModifiablePrecondition(list => produce(list, draft => { - draft.push({ - src: srcFid, - tar: tarFid, - type: createEdgeMode, - }); - })); - }, [setModifiablePrecondition, createEdgeMode]); - const onRemoveLink = useCallback((edge: { srcFid: string; tarFid: string; } | null) => { - if (edge) { - setModifiablePrecondition( - list => list.filter(link => [link.src, link.tar].some(fid => ![edge.srcFid, edge.tarFid].includes(fid))) - ); - } - }, [setModifiablePrecondition]); - - const graphRef = useRef(); - const renderData = useRenderData(data, 'edit', modifiablePrecondition, selectedFields, renderNode); - const cfg = useGraphOptions(width, selectedFields, undefined, onLinkTogether, graphRef); - const cfgRef = useRef(cfg); - cfgRef.current = cfg; - - const [forceUpdateFlag, setUpdateFlag] = useState<1 | 0>(1); - - useReactiveGraph( - containerRef, - width, - graphRef, - cfg, - renderData, - 'edit', - undefined, - onRemoveLink, - selectedFields, - undefined, - forceUpdateFlag, - null, - [], - false, - ); - - useEffect(() => { - const { current: container } = containerRef; - if (container) { - const cb = () => { - const { width: w } = container.getBoundingClientRect(); - setWidth(w); - }; - const ro = new ResizeObserver(cb); - ro.observe(container); - return () => { - ro.disconnect(); - }; - } - }, []); - - return ( - -
-
- setUpdateFlag(flag => flag ? 0 : 1)} - iconProps={{ iconName: 'Repair' }} - > - 刷新布局 - - { - if (!option) { - return; - } - const linkType = option.key as typeof createEdgeMode; - setCreateEdgeMode(linkType); - }} - styles={{ - title: { - fontSize: '0.8rem', - lineHeight: '1.8em', - minWidth: '18em', - height: '1.8em', - padding: '0 2.8em 0 0.8em', - border: 'none', - borderBottom: '1px solid #8888', - }, - caretDownWrapper: { - fontSize: '0.8rem', - lineHeight: '1.8em', - height: '1.8em', - }, - caretDown: { - fontSize: '0.8rem', - lineHeight: '1.8em', - height: '1.8em', - }, - }} - /> -
- - ); -}; - -export default observer(PreconditionGraph); diff --git a/packages/rath-client/src/pages/causal/precondition/preconditionPanel.tsx b/packages/rath-client/src/pages/causal/precondition/preconditionPanel.tsx deleted file mode 100644 index ae7dd1a7..00000000 --- a/packages/rath-client/src/pages/causal/precondition/preconditionPanel.tsx +++ /dev/null @@ -1,53 +0,0 @@ -import { observer } from 'mobx-react-lite'; -import React from 'react'; -import styled from 'styled-components'; -import type { ModifiableBgKnowledge } from '../config'; -import type { IFieldMeta } from '../../../interfaces'; -import type { GraphNodeAttributes } from '../explorer/graph-utils'; -import type { useDataViews } from '../hooks/dataViews'; -import PreconditionBatch from './preconditionBatch'; -import PreconditionEditor from './preconditionEditor'; - - -const Container = styled.div` - overflow: hidden auto; - padding: 0.4em 1.6em; - & h3 { - font-size: 0.8rem; - font-weight: 500; - padding: 0.4em 0; - :not(:first-child) { - margin-top: 0.4em; - } - } -`; - -export interface PreconditionPanelProps { - context: ReturnType; - modifiablePrecondition: ModifiableBgKnowledge[]; - setModifiablePrecondition: (precondition: ModifiableBgKnowledge[] | ((prev: ModifiableBgKnowledge[]) => ModifiableBgKnowledge[])) => void; - renderNode?: (node: Readonly) => GraphNodeAttributes | undefined; -} - -const PreconditionPanel: React.FC = ({ - context, modifiablePrecondition, setModifiablePrecondition, renderNode, -}) => { - return ( - - - - - ); -}; - -export default observer(PreconditionPanel); diff --git a/packages/rath-client/src/pages/causal/precondition/preconditionTable.tsx b/packages/rath-client/src/pages/causal/precondition/preconditionTable.tsx deleted file mode 100644 index b15ef5ef..00000000 --- a/packages/rath-client/src/pages/causal/precondition/preconditionTable.tsx +++ /dev/null @@ -1,265 +0,0 @@ -import { - ActionButton, - Dropdown, - IColumn, - DetailsList, - SelectionMode, - Label, - Stack, -} from '@fluentui/react'; -import { observer } from 'mobx-react-lite'; -import React, { useEffect, useMemo, useState } from 'react'; -import produce from 'immer'; -import { useGlobalStore } from '../../../store'; -import type { ModifiableBgKnowledge } from '../config'; -import type { PreconditionPanelProps } from './preconditionPanel'; - - -const PreconditionTable: React.FC = ({ modifiablePrecondition, setModifiablePrecondition }) => { - const { causalStore } = useGlobalStore(); - const { selectedFields } = causalStore; - - const [editingPrecondition, setEditingPrecondition] = useState>({ - type: 'must-link', - }); - - useEffect(() => { - setEditingPrecondition({ type: 'must-link' }); - }, [selectedFields]); - - const preconditionTableCols = useMemo(() => { - return [ - { - key: 'delete-btn', - name: '', - onRender: (item, index) => - typeof index === 'number' ? ( - setModifiablePrecondition((list) => { - const next = [...list]; - next.splice(index, 1); - return next; - })} - /> - ) : null, - minWidth: 30, - maxWidth: 30, - onRenderHeader: () => ( - setModifiablePrecondition([])} - /> - ), - }, - { - key: 'src', - name: '因素', //'Source', - onRender: (item) => ( - - {selectedFields.find((f) => f.fid === item.src)?.name ?? item.src} - - ), - minWidth: 160, - maxWidth: 160, - }, - { - key: 'type', - name: '影响约束', //'Constraint', - onRender: (item: ModifiableBgKnowledge, index) => - typeof index === 'number' ? ( - { - if (!option) { - return; - } - const linkType = option.key as typeof item.type; - setModifiablePrecondition((p) => - produce(p, (draft) => { - draft[index].type = linkType; - }) - ); - }} - styles={{ - title: { - fontSize: '0.8rem', - lineHeight: '1.8em', - height: '1.8em', - padding: '0 2.8em 0 0.8em', - border: 'none', - borderBottom: '1px solid #8888', - }, - caretDownWrapper: { - fontSize: '0.8rem', - lineHeight: '1.8em', - height: '1.8em', - }, - caretDown: { - fontSize: '0.8rem', - lineHeight: '1.8em', - height: '1.8em', - }, - }} - /> - ) : null, - minWidth: 200, - maxWidth: 200, - }, - { - key: 'tar', - name: '因素', //'Target', - onRender: item => ( - - {selectedFields.find((f) => f.fid === item.tar)?.name ?? item.tar} - - ), - minWidth: 160, - maxWidth: 160, - }, - { - key: 'empty', - name: '', - onRender: () =>
, - minWidth: 0, - }, - ]; - }, [selectedFields, setModifiablePrecondition]); - - return ( - -
-
- - { - if (!option) { - return; - } - const fid = option.key as string; - setEditingPrecondition((p) => ({ - type: p.type, - src: fid, - tar: p.tar === fid ? undefined : p.tar, - })); - }} - options={selectedFields.map((f) => ({ - key: f.fid, - text: f.name ?? f.fid, - }))} - styles={{ root: { width: '28%', margin: '0 1%' } }} - /> - { - if (!option) { - return; - } - setEditingPrecondition((p) => ({ - ...p, - type: option.key as typeof p['type'], - })); - }} - options={[ - { key: 'directed-must-link', text: '单向一定影响' }, - { key: 'directed-must-not-link', text: '单向一定不影响' }, - { key: 'must-link', text: '至少在一个方向存在影响' }, - { key: 'must-not-link', text: '在任意方向一定不影响' }, - ]} - styles={{ root: { width: '20%' }, title: { textAlign: 'center' } }} - /> - { - if (!option) { - return; - } - const fid = option.key as string; - setEditingPrecondition((p) => ({ - type: p.type, - tar: fid, - src: p.src === fid ? undefined : p.src, - })); - }} - options={selectedFields.map((f) => ({ - key: f.fid, - text: f.name ?? f.fid, - }))} - styles={{ root: { width: '28%', margin: '0 1%' } }} - /> - { - if ( - editingPrecondition.src && - editingPrecondition.tar && - editingPrecondition.type && - editingPrecondition.src !== editingPrecondition.tar - ) { - setEditingPrecondition({ type: editingPrecondition.type }); - setModifiablePrecondition((list) => [ - ...list, - editingPrecondition as ModifiableBgKnowledge, - ]); - } - }} - /> -
-
- -
- ); -}; - -export default observer(PreconditionTable); diff --git a/packages/rath-client/src/pages/causal/precondition/utils.ts b/packages/rath-client/src/pages/causal/precondition/utils.ts deleted file mode 100644 index b55f5dd9..00000000 --- a/packages/rath-client/src/pages/causal/precondition/utils.ts +++ /dev/null @@ -1,79 +0,0 @@ -import { notify } from "../../../components/error"; -import type { IFieldMeta, IRow } from "../../../interfaces"; -import { getGlobalStore } from "../../../store"; -import type { ModifiableBgKnowledge } from "../config"; - - -export const getGeneratedPreconditionsFromExtInfo = (fields: IFieldMeta[]): ModifiableBgKnowledge[] => { - return fields.reduce((list, f) => { - if (f.extInfo) { - for (const from of f.extInfo.extFrom) { - list.push({ - src: from, - tar: f.fid, - type: 'directed-must-link', - }); - } - } - return list; - }, []); -}; - -// FIXME: path -const AutoDetectionApiPath = 'autoDetect'; - -export const getGeneratedPreconditionsFromAutoDetection = async ( - dataSource: IRow[], - fields: string[], -): Promise => { - try { - const { causalStore, dataSourceStore } = getGlobalStore(); - const { apiPrefix } = causalStore; - const { fieldMetas } = dataSourceStore; - const res = await fetch(`${apiPrefix}/${AutoDetectionApiPath}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - // FIXME: I have no idea what is the payload - body: JSON.stringify({ - dataSource, - fields: fieldMetas, - focusedFields: fields, - }), - }); - const result = await res.json(); - if (result.success) { - return result.data; - } else { - throw new Error(result.message); - } - // // FIXME: mock data - // await new Promise(resolve => setTimeout(resolve, 2_000)); - // const selectedFields = fieldMetas.filter(f => fields.includes(f.fid)); - // const fidArr = selectedFields.map(f => f.fid); - // const list: ModifiableBgKnowledge[] = []; - // while (list.length < 6 && fidArr.length >= 2) { - // const srcIdx = Math.floor(Math.random() * fidArr.length); - // const tarIdx = (srcIdx + Math.floor(Math.random() * (fidArr.length - 1))) % fidArr.length; - // if (srcIdx !== tarIdx) { - // list.push({ - // src: fidArr[srcIdx], - // tar: fidArr[tarIdx], - // type: (['must-link', 'must-not-link', 'directed-must-link', 'directed-must-not-link'] as const)[ - // Math.floor(Math.random() * 4) - // ], - // }); - // } - // fidArr.splice(srcIdx, 1); - // } - // return list; - } catch (error) { - notify({ - title: 'Causal Preconditions Auto Detection Error', - type: 'error', - content: `${error}`, - }); - return []; - } -}; diff --git a/packages/rath-client/src/pages/causal/predict.ts b/packages/rath-client/src/pages/causal/predict.ts index c0c9b7b4..61f36624 100644 --- a/packages/rath-client/src/pages/causal/predict.ts +++ b/packages/rath-client/src/pages/causal/predict.ts @@ -1,6 +1,5 @@ import { notify } from "../../components/error"; import type { IRow, IFieldMeta } from "../../interfaces"; -import { getGlobalStore } from "../../store"; export const PredictAlgorithms = [ @@ -48,7 +47,8 @@ export interface IPredictResult { result: PredictResultItem[]; } -const PredictApiPath = 'api/train_test'; +// TODO: 模型预测服务:上生产环境后改称线上服务地址 +const PredictApiPath = 'http://127.0.0.1:5533/api/train_test'; export const execPredict = async (props: IPredictProps): Promise => { try { @@ -57,9 +57,7 @@ export const execPredict = async (props: IPredictProps): Promise .content { - flex-grow: 1; - flex-shrink: 1; - display: flex; - flex-direction: column; - padding: 0.5em; - overflow: auto; - > * { - flex-grow: 0; - flex-shrink: 0; - } - } -`; - -const TableContainer = styled.div` - flex-grow: 0; - flex-shrink: 0; - overflow: auto; -`; - -const Row = styled.div<{ selected: 'attribution' | 'target' | false }>` - > div { - background-color: ${({ selected }) => ( - selected === 'attribution' ? 'rgba(194,132,2,0.2)' : selected === 'target' ? 'rgba(66,121,242,0.2)' : undefined - )}; - filter: ${({ selected }) => selected ? 'unset' : 'opacity(0.8)'}; - cursor: pointer; - :hover { - filter: unset; - } - } -`; - -const ModeOptions = [ - { key: 'classification', text: '分类' }, - { key: 'regression', text: '回归' }, -] as const; - -// FIXME: 防止切到别的流程时预测结果被清空,先在全局存一下,决定好要不要保留 && 状态应该存哪里以后及时迁走 -const predictCache: { - id: string; algo: PredictAlgorithm; startTime: number; completeTime: number; data: IPredictResult; -}[] = []; - -const PredictPanel = forwardRef<{ - updateInput?: (input: { features: IFieldMeta[]; targets: IFieldMeta[] }) => void; -}, {}>(function PredictPanel (_, ref) { - const { causalStore, dataSourceStore } = useGlobalStore(); - const { selectedFields } = causalStore; - const { cleanedData, fieldMetas } = dataSourceStore; - - const [predictInput, setPredictInput] = useState<{ features: IFieldMeta[]; targets: IFieldMeta[] }>({ - features: [], - targets: [], - }); - const [algo, setAlgo] = useState('decisionTree'); - const [mode, setMode] = useState('classification'); - - useImperativeHandle(ref, () => ({ - updateInput: input => setPredictInput(input), - })); - - useEffect(() => { - setPredictInput(before => { - if (before.features.length || before.targets.length) { - return { - features: selectedFields.filter(f => before.features.some(feat => feat.fid === f.fid)), - targets: selectedFields.filter(f => before.targets.some(tar => tar.fid === f.fid)), - }; - } - return { - features: selectedFields.slice(1).map(f => f), - targets: selectedFields.slice(0, 1), - }; - }); - }, [selectedFields]); - - const [running, setRunning] = useState(false); - - const fieldsTableCols = useMemo(() => { - return [ - { - key: 'selectedAsFeature', - name: `特征 (${predictInput.features.length} / ${selectedFields.length})`, - onRender: (item) => { - const field = item as IFieldMeta; - const checked = predictInput.features.some(f => f.fid === field.fid); - return ( - { - if (running) { - return; - } - setPredictInput(produce(predictInput, draft => { - draft.features = draft.features.filter(f => f.fid !== field.fid); - draft.targets = draft.targets.filter(f => f.fid !== field.fid); - if (ok) { - draft.features.push(field); - } - })); - }} - /> - ); - }, - isResizable: false, - minWidth: 90, - maxWidth: 90, - }, - { - key: 'selectedAsTarget', - name: `目标 (${predictInput.targets.length} / ${selectedFields.length})`, - onRender: (item) => { - const field = item as IFieldMeta; - const checked = predictInput.targets.some(f => f.fid === field.fid); - return ( - { - if (running) { - return; - } - setPredictInput(produce(predictInput, draft => { - draft.features = draft.features.filter(f => f.fid !== field.fid); - draft.targets = draft.targets.filter(f => f.fid !== field.fid); - if (ok) { - draft.targets.push(field); - } - })); - }} - /> - ); - }, - isResizable: false, - minWidth: 90, - maxWidth: 90, - }, - { - key: 'name', - name: '因素', - onRender: (item) => { - const field = item as IFieldMeta; - return ( - - {field.name || field.fid} - - ); - }, - minWidth: 120, - }, - ]; - }, [selectedFields, predictInput, running]); - - const canExecute = predictInput.features.length > 0 && predictInput.targets.length > 0; - const pendingRef = useRef>(); - - useEffect(() => { - pendingRef.current = undefined; - setRunning(false); - }, [predictInput]); - - const dataSourceRef = useRef(cleanedData); - dataSourceRef.current = cleanedData; - const allFieldsRef = useRef(fieldMetas); - allFieldsRef.current = fieldMetas; - - const [results, setResults] = useState<{ - id: string; algo: PredictAlgorithm; startTime: number; completeTime: number; data: IPredictResult; - }[]>([]); - - // FIXME: 防止切到别的流程时预测结果被清空,先在全局存一下,决定好要不要保留 && 状态应该存哪里以后及时迁走 - useEffect(() => { - setResults(predictCache); - return () => { - setResults(res => { - predictCache.splice(0, Infinity, ...res); - return []; - }); - }; - }, [cleanedData, fieldMetas]); - - const [tab, setTab] = useState<'config' | 'result'>('config'); - - const trainTestSplitIndices = useMemo(() => { - const TRAIN_RATE = 0.2; - const indices = cleanedData.map((_, i) => i); - const trainSetIndices = new Map(); - const trainSetTargetSize = Math.floor(cleanedData.length * TRAIN_RATE); - while (trainSetIndices.size < trainSetTargetSize && indices.length) { - const [index] = indices.splice(Math.floor(indices.length * Math.random()), 1); - trainSetIndices.set(index, 1); - } - return cleanedData.map((_, i) => trainSetIndices.has(i) ? TrainTestSplitFlag.train : TrainTestSplitFlag.test); - }, [cleanedData]); - - const trainTestSplitIndicesRef = useRef(trainTestSplitIndices); - trainTestSplitIndicesRef.current = trainTestSplitIndices; - - const handleClickExec = useCallback(() => { - const startTime = Date.now(); - setRunning(true); - const task = execPredict({ - dataSource: dataSourceRef.current, - fields: allFieldsRef.current, - model: { - algorithm: algo, - features: predictInput.features.map(f => f.fid), - targets: predictInput.targets.map(f => f.fid), - }, - trainTestSplitIndices: trainTestSplitIndicesRef.current, - mode, - }); - pendingRef.current = task; - task.then(res => { - if (task === pendingRef.current && res) { - const completeTime = Date.now(); - setResults(list => { - const record = { - id: nanoid(8), - algo, - startTime, - completeTime, - data: res, - }; - if (list.length > 0 && list[0].algo !== algo) { - return [record]; - } - return list.concat([record]); - }); - setTab('result'); - } - }).finally(() => { - pendingRef.current = undefined; - setRunning(false); - }); - }, [predictInput, algo, mode]); - - const sortedResults = useMemo(() => { - return results.slice(0).sort((a, b) => b.completeTime - a.completeTime); - }, [results]); - - const [comparison, setComparison] = useState(null); - - useEffect(() => { - setComparison(group => { - if (!group) { - return null; - } - const next = group.filter(id => results.some(rec => rec.id === id)); - if (next.length === 0) { - return null; - } - return next as [string] | [string, string]; - }); - }, [results]); - - const resultTableCols = useMemo(() => { - return [ - { - key: 'selected', - name: '对比', - onRender: (item) => { - const record = item as typeof sortedResults[number]; - const selected = (comparison ?? [] as string[]).includes(record.id); - return ( - { - if (checked) { - setComparison(group => { - if (group === null) { - return [record.id]; - } - return [group[0], record.id]; - }); - } else if (selected) { - setComparison(group => { - if (group?.some(id => id === record.id)) { - return group.length === 1 ? null : group.filter(id => id !== record.id) as [string]; - } - return null; - }); - } - }} - /> - ); - }, - isResizable: false, - minWidth: 30, - maxWidth: 30, - }, - { - key: 'index', - name: '运行次数', - minWidth: 70, - maxWidth: 70, - isResizable: false, - onRender(_, index) { - return <>{index !== undefined ? (sortedResults.length - index) : ''}; - }, - }, - { - key: 'algo', - name: '预测模型', - minWidth: 70, - onRender(item) { - const record = item as typeof sortedResults[number]; - return <>{PredictAlgorithms.find(which => which.key === record.algo)?.text} - }, - }, - { - key: 'accuracy', - name: '准确率', - minWidth: 150, - onRender(item, index) { - if (!item || index === undefined) { - return <>; - } - const record = item as typeof sortedResults[number]; - const previous = sortedResults[index + 1]; - const comparison: 'better' | 'worse' | 'same' | null = previous ? ( - previous.data.accuracy === record.data.accuracy ? 'same' - : record.data.accuracy > previous.data.accuracy ? 'better' : 'worse' - ) : null; - return ( - - {comparison && ( - - )} - {record.data.accuracy} - - ); - }, - }, - ]; - }, [sortedResults, comparison]); - - const diff = useMemo(() => { - if (comparison?.length === 2) { - const before = sortedResults.find(res => res.id === comparison[0]); - const after = sortedResults.find(res => res.id === comparison[1]); - if (before && after) { - const temp: unknown[] = []; - for (let i = 0; i < before.data.result.length; i += 1) { - const row = dataSourceRef.current[before.data.result[i][0]]; - const prev = before.data.result[i][1]; - const next = after.data.result[i][1]; - if (next === 1 && prev === 0) { - temp.push(Object.fromEntries(Object.entries(row).map(([k, v]) => [ - allFieldsRef.current.find(f => f.fid === k)?.name ?? k, - v, - ]))); - } - } - return temp; - } - } - }, [sortedResults, comparison]); - - useEffect(() => { - if (diff) { - // TODO: 在界面上实现一个 diff view,代替这个 console - // eslint-disable-next-line no-console - console.table(diff); - } - }, [diff]); - - return ( - - running ? : } - style={{ width: 'max-content', flexGrow: 0, flexShrink: 0, marginLeft: '0.6em' }} - split - menuProps={{ - items: ModeOptions.map(opt => opt), - onItemClick: (_e, item) => { - if (item) { - setMode(item.key as typeof mode); - } - }, - }} - > - {`${ModeOptions.find(m => m.key === mode)?.text}预测`} - - { - item && setTab(item.props.itemKey as typeof tab); - }} - style={{ marginTop: '0.5em' }} - > - - - -
- {{ - config: ( - <> - ({ key: algo.key, text: algo.text }))} - selectedKey={algo} - onChange={(_, option) => { - const item = PredictAlgorithms.find(which => which.key === option?.key); - if (item) { - setAlgo(item.key); - } - }} - style={{ width: 'max-content' }} - /> - - - { - const field = props?.item as IFieldMeta; - const checkedAsAttr = predictInput.features.some(f => f.fid === field.fid); - const checkedAsTar = predictInput.targets.some(f => f.fid === field.fid); - return ( - - {defaultRender?.(props)} - - ); - }} - /> - - - ), - result: ( - <> - setResults([])} - style={{ width: 'max-content' }} - > - 清空记录 - - - - - - ), - }[tab]} -
-
- ); -}); - - -export default observer(PredictPanel); diff --git a/packages/rath-client/src/pages/causal/service.ts b/packages/rath-client/src/pages/causal/service.ts index effedaa7..30e71e53 100644 --- a/packages/rath-client/src/pages/causal/service.ts +++ b/packages/rath-client/src/pages/causal/service.ts @@ -8,13 +8,13 @@ import CausalComputationWorker from './computation.worker.js?worker'; type ICausalProps = { task: 'ig'; - dataSource: IRow[]; - fields: IFieldMeta[]; + dataSource: readonly IRow[]; + fields: readonly IFieldMeta[]; } | { task: 'ig_cond'; - dataSource: IRow[]; - fields: IFieldMeta[]; - matrix: number[][]; + dataSource: readonly IRow[]; + fields: readonly IFieldMeta[]; + matrix: readonly (readonly number[])[]; } export async function causalService(props: ICausalProps): Promise { diff --git a/packages/rath-client/src/pages/causal/step/FDConfig.tsx b/packages/rath-client/src/pages/causal/step/FDConfig.tsx index 062f2f4b..863d4aca 100644 --- a/packages/rath-client/src/pages/causal/step/FDConfig.tsx +++ b/packages/rath-client/src/pages/causal/step/FDConfig.tsx @@ -1,33 +1,12 @@ import { observer } from 'mobx-react-lite'; -import React from 'react'; -import type { IFieldMeta } from '../../../interfaces'; -import type { IFunctionalDep } from '../config'; -import type { useDataViews } from '../hooks/dataViews'; -import type { GraphNodeAttributes } from '../explorer/graph-utils'; +import type { FC } from 'react'; import FDPanel from '../functionalDependencies/FDPanel'; -export interface CausalFDConfigProps { - dataContext: ReturnType; - functionalDependencies: IFunctionalDep[]; - setFunctionalDependencies: (fdArr: IFunctionalDep[] | ((prev: IFunctionalDep[]) => IFunctionalDep[])) => void; - renderNode: (node: Readonly) => GraphNodeAttributes | undefined; -} - -const CausalFDConfig: React.FC = ({ - dataContext, - functionalDependencies, - setFunctionalDependencies, - renderNode, -}) => { +const CausalFDConfig: FC = () => { return ( <> - + ); }; diff --git a/packages/rath-client/src/pages/causal/step/causalModel.tsx b/packages/rath-client/src/pages/causal/step/causalModel.tsx index 186171e0..7d49b028 100644 --- a/packages/rath-client/src/pages/causal/step/causalModel.tsx +++ b/packages/rath-client/src/pages/causal/step/causalModel.tsx @@ -1,20 +1,16 @@ import { Stack } from '@fluentui/react'; import { observer } from 'mobx-react-lite'; -import React, { RefObject, useCallback, useEffect, useMemo, useRef } from 'react'; -import produce from 'immer'; +import { FC, RefObject, useCallback, useRef } from 'react'; import styled from 'styled-components'; import { IFieldMeta } from '../../../interfaces'; import { useGlobalStore } from '../../../store'; -import { mergeCausalPag, resolvePreconditionsFromCausal, transformPreconditions } from '../../../utils/resolve-causal'; -import Explorer, { ExplorerProps } from '../explorer'; +import { useCausalViewContext } from '../../../store/causalStore/viewStore'; +import type { EdgeAssert } from '../../../store/causalStore/modelStore'; +import Explorer from '../explorer'; import Params from '../params'; -import type { BgKnowledge, BgKnowledgePagLink, IFunctionalDep, ModifiableBgKnowledge } from '../config'; import ModelStorage from '../modelStorage'; -import ManualAnalyzer from '../manualAnalyzer'; +import Exploration, { Subtree } from '../exploration'; import MatrixPanel, { MATRIX_TYPE } from '../matrixPanel'; -import type { useInteractFieldGroups } from '../hooks/interactFieldGroup'; -import type { useDataViews } from '../hooks/dataViews'; -import type { GraphNodeAttributes } from '../explorer/graph-utils'; const Container = styled.div` @@ -36,229 +32,104 @@ const Container = styled.div` } `; -export interface CausalModalProps { - dataContext: ReturnType; - modifiablePrecondition: ModifiableBgKnowledge[]; - setModifiablePrecondition: (precondition: ModifiableBgKnowledge[] | ((prev: ModifiableBgKnowledge[]) => ModifiableBgKnowledge[])) => void; - functionalDependencies: IFunctionalDep[]; - renderNode: (node: Readonly) => GraphNodeAttributes | undefined; - interactFieldGroups: ReturnType; -} - -export const CausalExplorer = observer< - Omit & { - allowEdit: boolean; listenerRef?: RefObject<{ onSubtreeSelected?: ExplorerProps['onNodeSelected'] }>; - } ->(function CausalExplorer ({ +export const CausalExplorer = observer<{ + allowEdit: boolean; + listenerRef?: RefObject<{ onSubtreeSelected?: (subtree: Subtree | null) => void }>; +}>(function CausalExplorer ({ allowEdit, - dataContext, - modifiablePrecondition, - setModifiablePrecondition, - renderNode, - interactFieldGroups, listenerRef, }) { const { causalStore } = useGlobalStore(); - const { igMatrix, selectedFields, causalStrength } = causalStore; - const { dataSubset } = dataContext; - const { appendFields2Group, setFieldGroup } = interactFieldGroups; - const handleLasso = useCallback((fields: IFieldMeta[]) => { - setFieldGroup(fields); - }, [setFieldGroup]); + const viewContext = useCausalViewContext(); - const handleSubTreeSelected = useCallback(( - node, simpleCause, simpleEffect, composedCause, composedEffect, - ) => { - if (node) { - appendFields2Group([node.fid]); - } - listenerRef?.current?.onSubtreeSelected?.(node, simpleCause, simpleEffect, composedCause, composedEffect); - }, - [appendFields2Group, listenerRef] - ); + const handleLasso = useCallback((fields: IFieldMeta[]) => { + for (const f of fields) { + viewContext?.toggleNodeSelected(f.fid); + } + }, [viewContext]); - const handleLinkTogether = useCallback((srcIdx: number, tarIdx: number, type: ModifiableBgKnowledge['type']) => { - setModifiablePrecondition((list) => { - return list.concat([{ - src: selectedFields[srcIdx].fid, - tar: selectedFields[tarIdx].fid, - type, - }]); - }); - }, [selectedFields, setModifiablePrecondition]); + const handleSubTreeSelected = useCallback((subtree: Subtree | null) => { + listenerRef?.current?.onSubtreeSelected?.(subtree); + }, [listenerRef]); - const handleRevertLink = useCallback((srcFid: string, tarFid: string) => setModifiablePrecondition((list) => { - return list.map((link) => { - if (link.src === srcFid && link.tar === tarFid) { - return produce(link, draft => { - draft.type = ({ - "must-link": 'must-not-link', - "must-not-link": 'must-link', - "directed-must-link": 'directed-must-not-link', - "directed-must-not-link": 'directed-must-link', - } as const)[draft.type]; - }); - } - return link; - }); - }), [setModifiablePrecondition]); + const handleLinkTogether = useCallback((srcFid: string, tarFid: string, assert: EdgeAssert) => { + causalStore.model.addEdgeAssertion(srcFid, tarFid, assert); + }, [causalStore]); - const handleRemoveLink = useCallback((srcFid: string, tarFid: string) => setModifiablePrecondition((list) => { - return list.filter((link) => { - return !(link.src === srcFid && link.tar === tarFid); - }); - }), [setModifiablePrecondition]); + const handleRevertLink = useCallback((srcFid: string, tarFid: string) => { + causalStore.model.revertEdgeAssertion([srcFid, tarFid]); + }, [causalStore]); - const synchronizePredictionsUsingCausalResult = useCallback(() => { - setModifiablePrecondition(resolvePreconditionsFromCausal(causalStrength, selectedFields)); - }, [setModifiablePrecondition, causalStrength, selectedFields]); + const handleRemoveLink = useCallback((srcFid: string, tarFid: string) => { + causalStore.model.removeEdgeAssertion([srcFid, tarFid]); + }, [causalStore]); return ( ); }); -const CausalModal: React.FC = ({ - dataContext, - modifiablePrecondition, - setModifiablePrecondition, - renderNode, - functionalDependencies, - interactFieldGroups, -}) => { - const { dataSourceStore, causalStore } = useGlobalStore(); - const { fieldMetas } = dataSourceStore; - const { focusFieldIds, computing, igMatrix, selectedFields, causalStrength } = causalStore; - const { dataSubset } = dataContext; - - /** @deprecated FCI 已经迁移到 preconditionPag 参数,等到所有算法更新完可以删掉对应逻辑 */ - const precondition = useMemo(() => { - if (computing || igMatrix.length !== selectedFields.length) { - return []; - } - return modifiablePrecondition.reduce((list, decl) => { - const srcIdx = selectedFields.findIndex((f) => f.fid === decl.src); - const tarIdx = selectedFields.findIndex((f) => f.fid === decl.tar); - - if (srcIdx !== -1 && tarIdx !== -1) { - if (decl.type === 'directed-must-link' || decl.type === 'directed-must-not-link') { - list.push({ - src: decl.src, - tar: decl.tar, - type: decl.type === 'directed-must-link' ? 1 : -1, - }); - } else { - list.push({ - src: decl.src, - tar: decl.tar, - type: decl.type === 'must-link' ? 1 : -1, - }, { - src: decl.tar, - tar: decl.src, - type: decl.type === 'must-link' ? 1 : -1, - }); - } - } - - return list; - }, []); - }, [igMatrix, modifiablePrecondition, selectedFields, computing]); - - const preconditionPag = useMemo(() => { - if (computing || igMatrix.length !== selectedFields.length) { - return []; - } - return transformPreconditions(modifiablePrecondition, selectedFields); - }, [igMatrix, modifiablePrecondition, selectedFields, computing]); +const CausalModal: FC = () => { + const { causalStore } = useGlobalStore(); - const { appendFields2Group } = interactFieldGroups; - - const onFieldGroupSelect = useCallback( - (xFid: string, yFid: string) => { - causalStore.setFocusNodeIndex(fieldMetas.findIndex((f) => f.fid === xFid)); - appendFields2Group([xFid, yFid]); - }, - [appendFields2Group, causalStore, fieldMetas] - ); + const viewContext = useCausalViewContext(); - const resetExploringFieldsRef = useRef(() => interactFieldGroups.clearFieldGroup()); - resetExploringFieldsRef.current = () => interactFieldGroups.clearFieldGroup(); - - useEffect(() => { - resetExploringFieldsRef.current(); - }, [causalStrength]); + const appendFields2Group = useCallback((fidArr: string[]) => { + for (const fid of fidArr) { + viewContext?.selectNode(fid); + } + }, [viewContext]); - const edges = useMemo(() => { - return mergeCausalPag(causalStrength, modifiablePrecondition, fieldMetas); - }, [causalStrength, fieldMetas, modifiablePrecondition]); + const onFieldGroupSelect = useCallback((xFid: string, yFid: string) => { + appendFields2Group([xFid, yFid]); + }, [appendFields2Group]); - const listenerRef = useRef<{ onSubtreeSelected?: ExplorerProps['onNodeSelected'] }>({}); + const listenerRef = useRef<{ onSubtreeSelected?: (subtree: Subtree | null) => void }>({}); return (
- + { + if (causalStore.operator.busy) { + return; + } switch (matKey) { case MATRIX_TYPE.conditionalMutualInfo: - causalStore.computeIGCondMatrix(dataSubset, selectedFields); + causalStore.computeCondMutualMatrix(); break; case MATRIX_TYPE.causal: - causalStore.causalDiscovery(dataSubset, precondition, preconditionPag, functionalDependencies); + causalStore.run(); break; case MATRIX_TYPE.mutualInfo: default: - causalStore.computeIGMatrix(dataSubset, selectedFields); + causalStore.computeMutualMatrix(); break; } }} diagram={( )} />
- +
); diff --git a/packages/rath-client/src/pages/causal/step/datasetConfig.tsx b/packages/rath-client/src/pages/causal/step/datasetConfig.tsx index b55a2875..463eb9cf 100644 --- a/packages/rath-client/src/pages/causal/step/datasetConfig.tsx +++ b/packages/rath-client/src/pages/causal/step/datasetConfig.tsx @@ -1,17 +1,12 @@ import { observer } from 'mobx-react-lite'; -import React from 'react'; -import type { useDataViews } from '../hooks/dataViews'; +import type { FC } from 'react'; import DatasetPanel from '../datasetPanel'; -export interface CausalDatasetConfigProps { - dataContext: ReturnType; -} - -const CausalDatasetConfig: React.FC = ({ dataContext }) => { +const CausalDatasetConfig: FC = () => { return ( <> - + ); }; diff --git a/packages/rath-client/src/pages/causal/step/index.tsx b/packages/rath-client/src/pages/causal/step/index.tsx index 6e5213be..de00ffa2 100644 --- a/packages/rath-client/src/pages/causal/step/index.tsx +++ b/packages/rath-client/src/pages/causal/step/index.tsx @@ -2,11 +2,6 @@ import { DefaultButton, Icon, IconButton } from "@fluentui/react"; import { observer } from "mobx-react-lite"; import { Fragment, useEffect, useMemo, useState } from "react"; import styled from "styled-components"; -import type { useDataViews } from "../hooks/dataViews"; -import type { IFunctionalDep, ModifiableBgKnowledge } from "../config"; -import type { GraphNodeAttributes } from "../explorer/graph-utils"; -import type { IFieldMeta } from "../../../interfaces"; -import type { useInteractFieldGroups } from "../hooks/interactFieldGroup"; import CausalDatasetConfig from './datasetConfig'; import CausalFDConfig from './FDConfig'; import CausalModel from "./causalModel"; @@ -109,25 +104,7 @@ export const CausalSteps: readonly CausalStepOption[] = [ }, ]; -interface CausalStepPagerProps { - dataContext: ReturnType; - modifiablePrecondition: ModifiableBgKnowledge[]; - setModifiablePrecondition: (precondition: ModifiableBgKnowledge[] | ((prev: ModifiableBgKnowledge[]) => ModifiableBgKnowledge[])) => void; - functionalDependencies: IFunctionalDep[]; - setFunctionalDependencies: (fdArr: IFunctionalDep[] | ((prev: IFunctionalDep[]) => IFunctionalDep[])) => void; - renderNode: (node: Readonly) => GraphNodeAttributes | undefined; - interactFieldGroups: ReturnType; -} - -export const CausalStepPager = observer(function CausalStepPager ({ - dataContext, - modifiablePrecondition, - setModifiablePrecondition, - functionalDependencies, - setFunctionalDependencies, - renderNode, - interactFieldGroups, -}) { +export const CausalStepPager = observer(function CausalStepPager () { const [stepKey, setStepKey] = useState(CausalStep.DATASET_CONFIG); const [showHelp, setShowHelp] = useState(stepKey); @@ -227,24 +204,12 @@ export const CausalStepPager = observer(function CausalSte
{{ - [CausalStep.DATASET_CONFIG]: , + [CausalStep.DATASET_CONFIG]: , [CausalStep.FD_CONFIG]: ( - + ), [CausalStep.CAUSAL_MODEL]: ( - + ), }[curStep.key]} diff --git a/packages/rath-client/src/pages/dataSource/metaView/distChart.tsx b/packages/rath-client/src/pages/dataSource/metaView/distChart.tsx index d149486a..88cd677d 100644 --- a/packages/rath-client/src/pages/dataSource/metaView/distChart.tsx +++ b/packages/rath-client/src/pages/dataSource/metaView/distChart.tsx @@ -33,11 +33,13 @@ export interface DistributionChartProps { height?: number; maxItemInView?: number; dataSource: IRow[] + /** @default true */ + label?: boolean; } const DistributionChart: React.FC = (props) => { const chart = useRef(null); - const { x, y, dataSource, semanticType, width = 180, height = 80, maxItemInView = 10 } = props; + const { x, y, dataSource, semanticType, width = 180, height = 80, maxItemInView = 10, label = true } = props; const [view, setView] = useState(); // 是否有分箱的ordinal列 const hasBinIndex = useMemo(() => { @@ -103,12 +105,12 @@ const DistributionChart: React.FC = (props) => { x: { field: x, title: null, - axis: { + axis: label ? { // "labelAngle": 0, labelLimit: 52, "labelOverlap": "parity", ticks: false - }, + } : null, // axis: null, type: semanticType === 'quantitative' ? 'ordinal' : semanticType, sort: sortBy }, @@ -129,7 +131,7 @@ const DistributionChart: React.FC = (props) => { }).catch(console.error) } } - }, [x, y, sortBy, semanticType, width, height, maxItemInView]) + }, [x, y, sortBy, semanticType, width, height, maxItemInView, label]) useEffect(() => { if (view) { try { diff --git a/packages/rath-client/src/pages/dataSource/profilingView/metaDetail.tsx b/packages/rath-client/src/pages/dataSource/profilingView/metaDetail.tsx index 39f3a0c1..e538bb69 100644 --- a/packages/rath-client/src/pages/dataSource/profilingView/metaDetail.tsx +++ b/packages/rath-client/src/pages/dataSource/profilingView/metaDetail.tsx @@ -54,9 +54,13 @@ const DetailContainer = styled.div` interface MetaDetailProps { field?: IFieldMeta; + /** @default 200 */ + height?: number; + /** @default 620 */ + width?: number; } const MetaDetail: React.FC = (props) => { - const { field } = props; + const { field, width = 620, height = 200 } = props; const [selection, setSelection] = React.useState([]); const { dataSourceStore, commonStore, semiAutoStore } = useGlobalStore(); const { cleanedData } = dataSourceStore; @@ -134,8 +138,8 @@ const MetaDetail: React.FC = (props) => { dataSource={field.distribution} x="memberName" y="count" - height={200} - width={620} + height={height} + width={width} maxItemInView={1000} analyticType={field.analyticType} semanticType={field.semanticType} diff --git a/packages/rath-client/src/queries/distVis.ts b/packages/rath-client/src/queries/distVis.ts index ff6806c5..ec5a5227 100644 --- a/packages/rath-client/src/queries/distVis.ts +++ b/packages/rath-client/src/queries/distVis.ts @@ -18,6 +18,7 @@ export const geomTypeMap: { [key: string]: any } = { interface BaseVisProps { // dataSource: DataSource; pattern: IPattern; + /** @default false */ interactive?: boolean; resizeMode?: IResizeMode; width?: number; diff --git a/packages/rath-client/src/services/r-insight.ts b/packages/rath-client/src/services/r-insight.ts index f782f634..c37c0c3c 100644 --- a/packages/rath-client/src/services/r-insight.ts +++ b/packages/rath-client/src/services/r-insight.ts @@ -10,8 +10,8 @@ export const RInsightService = async (props: IRInsightExplainProps, mode: 'worke const { causalStore } = getGlobalStore(); if (mode === 'server') { - const { apiPrefix } = causalStore; - const res = await fetch(`${apiPrefix}/explain`, { + const { causalServer } = causalStore.operator; + const res = await fetch(`${causalServer}/explain`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/packages/rath-client/src/store/causalStore.ts b/packages/rath-client/src/store/causalStore.ts deleted file mode 100644 index 93f49c30..00000000 --- a/packages/rath-client/src/store/causalStore.ts +++ /dev/null @@ -1,301 +0,0 @@ -import { IDropdownOption } from '@fluentui/react'; -import { makeAutoObservable, observable, runInAction, toJS } from 'mobx'; -import { notify } from '../components/error'; -import type { IFieldMeta, IRow } from '../interfaces'; -import { - CAUSAL_ALGORITHM_FORM, - ICausalAlgorithm, - makeFormInitParams, - PC_PARAMS_FORM, - IAlgoSchema, - CAUSAL_ALGORITHM_OPTIONS, - BgKnowledge, - BgKnowledgePagLink, - IFunctionalDep, -} from '../pages/causal/config'; -import { causalService } from '../pages/causal/service'; -import resolveCausal, { CausalLinkDirection, findUnmatchedCausalResults, stringifyDirection } from '../utils/resolve-causal'; -import { getModelStorage, getModelStorageList, setModelStorage } from '../utils/storage'; -import { DataSourceStore } from './dataSourceStore'; - -enum CausalServerUrl { - local = 'http://localhost:8001', - // test = 'http://gateway.kanaries.cn:2080/causal', - test = 'http://dev02-thinkpad-t14-gen-2a.local:2281', -} -export class CausalStore { - public igMatrix: number[][] = []; - public igCondMatrix: number[][] = []; - public computing: boolean = false; - public showSettings: boolean = false; - public focusNodeIndex: number = 0; - public focusFieldIds: string[] = []; - /** Name of algorithm selected to be used in next call, modified in the settings panel */ - public causalAlgorithm: string = ICausalAlgorithm.PC; - public userModelKeys: string[] = []; - public showSemi: boolean = false; - /** Fields received from algorithm, the starting N items are equals to `inputFields`, and then there may have some extra trailing fields built during the process, the size of it is C (C >= N) */ - public causalFields: IFieldMeta[] = []; - /** An (N x N) matrix of flags representing the links between any two nodes */ - public causalStrength: CausalLinkDirection[][] = []; - /** asserts algorithm in keys of `causalStore.causalAlgorithmForm`. */ - public causalParams: { [algo: string]: { [key: string]: any } } = { - // alpha: 0.05, - // indep_test: IndepenenceTest.fisherZ, - // stable: true, - // uc_rule: UCRule.uc_supset, - // uc_priority: UCPriority.default, - // mvpc: false, - // catEncodeType: ICatEncodeType.none, // encoding for catecorical data - // quantEncodeType: IQuantEncodeType.none, // encoding for quantitative data - // keepOriginCat: true, - // keepOriginQuant: true - }; // save - - /** Keep the options synchorized with `CausalStore.causalAlgorithmForm` */ - private _causalAlgorithmOptions: IDropdownOption[] = CAUSAL_ALGORITHM_OPTIONS; - private _fetchedCausalAlgorithmForm: IAlgoSchema = Object.fromEntries(Object.entries(CAUSAL_ALGORITHM_FORM)); - public get causalAlgorithmOptions(): IDropdownOption[] { - return this._causalAlgorithmOptions; - // console.log(this.causalAlgorithmForm) - // for (let [key, schema] of this.causalAlgorithmForm.entries()) { - // options.push({ key, text: schema.title, ariaLabel: schema.description } as IDropdownOption) - // } return options; - } - public get causalAlgorithmForm(): IAlgoSchema { - return this._fetchedCausalAlgorithmForm; - } - public set causalAlgorithmForm(schema: IAlgoSchema) { - if (Object.keys(schema).length === 0) { - console.error('[causalAlgorithmForm]: schema is empty'); - return; - } - this._fetchedCausalAlgorithmForm = schema; - this._causalAlgorithmOptions = Object.entries(schema).map(([key, form]) => { - return { key: key, text: `${key}: ${form.title}` } as IDropdownOption; - }); - let firstAlgorithm = Object.entries(schema)[0]; - this.causalAlgorithm = firstAlgorithm[0]; - for (let entry of Object.entries(schema)) { - this.causalParams[entry[0]] = makeFormInitParams(entry[1]); - } - } - private causalServer = - decodeURIComponent(new URL(window.location.href).searchParams.get('causalServer') ?? '').replace(/\/$/, '') || - CausalServerUrl.test; // FIXME: - public get apiPrefix() { - return this.causalServer; - } - private dataSourceStore: DataSourceStore; - constructor(dataSourceStore: DataSourceStore) { - this.dataSourceStore = dataSourceStore; - this.causalAlgorithm = ICausalAlgorithm.PC; - this.causalParams[ICausalAlgorithm.PC] = makeFormInitParams(PC_PARAMS_FORM); - this.updateCausalAlgorithmList(dataSourceStore.fieldMetas); - makeAutoObservable(this, { - causalFields: observable.ref, - causalStrength: observable.ref, - igMatrix: observable.ref, - igCondMatrix: observable.ref, - focusFieldIds: observable.ref, - // @ts-ignore - dataSourceStore: false, - }); - } - public switchCausalAlgorithm(algorithm: string) { - if (this.causalAlgorithmForm[algorithm] !== undefined) { - this.causalAlgorithm = algorithm; - // this.causalParams[algorithm] = // makeFormInitParams(this.causalAlgorithmForm[algorithm]); - return true; - } else { - console.error(`[switchCausalAlgorithm error]: algorithm ${algorithm} not known.`); - return false; - } - } - public updateCausalAlgoAndParams(algorithm: string, params: CausalStore['causalParams']) { - if (this.switchCausalAlgorithm(algorithm)) { - this.causalParams[algorithm] = params; - } - } - public updateCausalParamsValue(key: string, value: any) { - this.causalParams[this.causalAlgorithm][key] = value; - } - public toggleSettings(show: boolean) { - this.showSettings = show; - } - public setFocusNodeIndex(index: number) { - this.focusNodeIndex = index; - } - public setFocusFieldIds(fids: string[]) { - this.focusFieldIds = fids; - } - public get selectedFields(): IFieldMeta[] { - return this.focusFieldIds - .map((fid) => this.dataSourceStore.fieldMetas.find((f) => f.fid === fid)) - .filter((f) => Boolean(f)) as IFieldMeta[]; - } - public async saveCausalModel() { - if (this.dataSourceStore.datasetId) { - return setModelStorage(this.dataSourceStore.datasetId, { - metas: this.dataSourceStore.fieldMetas, - causal: { - algorithm: this.causalAlgorithm, - causalMatrix: this.causalStrength, - corMatrix: this.igMatrix, - fieldIds: this.causalFields.map((f) => f.fid), - params: toJS(this.causalParams), - }, - }); - } - throw new Error('datasetId is not set'); - } - public async fetchCausalModel(datasetId: string) { - const model = await getModelStorage(datasetId); - if (model) { - const fieldMetas = this.dataSourceStore.fieldMetas; - this.causalParams = model.causal.params; - this.causalAlgorithm = model.causal.algorithm; - this.igMatrix = model.causal.corMatrix; - this.setCausalResult( - model.causal.fieldIds - .map((f) => fieldMetas.find((m) => m.fid === f)) - .filter((f) => Boolean(f)) as IFieldMeta[], - model.causal.causalMatrix - ); - } - } - public async getCausalModelList() { - const modelKeys = await getModelStorageList(); - this.userModelKeys = modelKeys; - } - public async computeIGMatrix(dataSource: IRow[], fields: IFieldMeta[]) { - this.computing = true; - const res = await causalService({ task: 'ig', dataSource, fields }); - runInAction(() => { - this.igMatrix = res; - this.computing = false; - }); - } - public async computeIGCondMatrix(dataSource: IRow[], fields: IFieldMeta[]) { - this.computing = true; - const res = await causalService({ task: 'ig_cond', dataSource, fields, matrix: this.igMatrix }); - runInAction(() => { - this.igCondMatrix = res; - this.computing = false; - }); - } - public async updateCausalAlgorithmList(fields: IFieldMeta[]) { - try { - const schema: IAlgoSchema = await fetch(`${this.causalServer}/algo/list`, { - method: 'POST', - body: JSON.stringify({ - fieldIds: fields.map((f) => f.fid), - fieldMetas: fields, - }), - headers: { - 'Content-Type': 'application/json', - }, - }).then((resp) => resp.json()); - this.causalAlgorithmForm = schema; - // for (let [algoName, algoSchema] of schema.entries()) { - // } - } catch (error) { - console.error('[CausalAlgorithmList error]:', error); - } - } - public setCausalResult(causalFields: IFieldMeta[], causalMatrix: CausalLinkDirection[][]) { - this.causalFields = causalFields; - this.causalStrength = causalMatrix; - } - public async causalDiscovery( - dataSource: IRow[], - /** @deprecated */ precondition: BgKnowledge[], - preconditionPag: BgKnowledgePagLink[], - funcDeps: IFunctionalDep[], - ) { - const fields = this.dataSourceStore.fieldMetas; - const focusFieldIds = this.focusFieldIds; - const algoName = this.causalAlgorithm; - const inputFields = focusFieldIds.map((fid) => fields.find((f) => f.fid === fid)! ?? fid); - if (inputFields.some((f) => typeof f === 'string')) { - notify({ - title: 'Causal Discovery Error', - type: 'error', - content: `Fields ${inputFields.filter((f) => typeof f === 'string').join(', ')} not found`, - }); - return; - } - try { - this.computing = true; - this.causalFields = []; - this.causalStrength = []; - const originFieldsLength = inputFields.length; - const res = await fetch(`${this.causalServer}/causal/${algoName}`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - dataSource, // encodeDiscrte(dataSource, fields), - fields, - focusedFields: focusFieldIds, - bgKnowledges: precondition, - bgKnowledgesPag: preconditionPag, - funcDeps, - params: this.causalParams[algoName], - }), - }); - const result = await res.json(); - if (result.success) { - const resultMatrix = (result.data.matrix as number[][]) - .slice(0, originFieldsLength) - .map((row) => row.slice(0, originFieldsLength)); - const causalMatrix = resolveCausal(resultMatrix); - const unmatched = findUnmatchedCausalResults(inputFields, preconditionPag, causalMatrix); - if (unmatched.length > 0 && process.env.NODE_ENV !== 'production') { - const getFieldName = (fid: string) => { - const field = inputFields.find(f => f.fid === fid); - return field?.name ?? fid; - }; - for (const info of unmatched) { - notify({ - title: 'Causal Result Not Matching', - type: 'error', - content: `Conflict in edge "${getFieldName(info.srcFid)} -> ${getFieldName(info.tarFid)}":\n` - + ` Expected: ${ - typeof info.expected === 'object' - ? ('not' in info.expected - ? `not ${stringifyDirection(info.expected.not)}` - : `one of ${info.expected.oneOf.map( - direction => stringifyDirection(direction) - ).join(', ')}` - ) - : stringifyDirection(info.expected) - }\n` - + ` Received: ${stringifyDirection(info.received)}`, - }); - } - } - this.setCausalResult(inputFields, causalMatrix); - } else { - throw new Error(result.message); - } - } catch (error) { - notify({ - title: 'Causal Discovery Error', - type: 'error', - content: `${error}`, - }); - } finally { - this.computing = false; - } - } - public async reRunCausalDiscovery( - dataSource: IRow[], - /** @deprecated */ precondition: BgKnowledge[], - preconditionPag: BgKnowledgePagLink[], - funcDeps: IFunctionalDep[], - ) { - this.causalDiscovery(dataSource, precondition, preconditionPag, funcDeps); - } -} diff --git a/packages/rath-client/src/store/causalStore/datasetStore.ts b/packages/rath-client/src/store/causalStore/datasetStore.ts new file mode 100644 index 00000000..b8d7df83 --- /dev/null +++ b/packages/rath-client/src/store/causalStore/datasetStore.ts @@ -0,0 +1,229 @@ +import produce from "immer"; +import { makeAutoObservable, observable, reaction, runInAction, toJS } from "mobx"; +import { combineLatest, from, map, Observable, share, Subject, switchAll, throttleTime } from "rxjs"; +import type { IFieldMeta, IFilter, ICol, IRow } from "../../interfaces"; +import { filterDataService } from "../../services"; +import { IteratorStorage, IteratorStorageMetaInfo } from "../../utils/iteStorage"; +import { focusedSample } from "../../utils/sample"; +import { baseDemoSample } from "../../utils/view-sample"; +import type { DataSourceStore } from "../dataSourceStore"; + + +const VIS_SUBSET_LIMIT = 400; +const SAMPLE_UPDATE_DELAY = 500; + + +export default class CausalDatasetStore { + + public datasetId: string | null; + + public allFields: readonly IFieldMeta[] = []; + + protected fieldIndices$ = new Subject(); + /** All fields to analyze */ + public fields: readonly IFieldMeta[] = []; + + protected filters$ = new Subject(); + public filters: readonly IFilter[] = []; + + public fullDataSize = 0; + public filteredDataSize = 0; + public sampleSize = 0; + + protected _sampleRate: number = 1; + public get sampleRate() { + return this._sampleRate; + } + public set sampleRate(value: number) { + this._sampleRate = Math.max(0, Math.min(1, value)); + } + + protected filteredData: IteratorStorage; + /** + * Rows used to do analysis. + * Never use it to decide a distinguishing because it will never changes. + * In view, use `visSample` instead. + */ + public sample: IteratorStorage; + public readonly sampleMetaInfo$: Observable; + /** Rows used to render sub charts */ + public visSample: readonly IRow[] = []; + + public readonly destroy: () => void; + + constructor(dataSourceStore: DataSourceStore) { + this.filteredData = new IteratorStorage({ itemKey: 'causalStoreFilteredData' }); + this.sample = new IteratorStorage({ itemKey: 'causalStoreSample' }); + + const allFields$ = new Subject(); + const fields$ = new Subject(); + const fullDataChangedSignal$ = new Subject<1>(); + const sampleRate$ = new Subject(); + + makeAutoObservable(this, { + allFields: observable.ref, + fields: observable.ref, + filters: observable.ref, + // @ts-expect-error non-public field + filteredData: false, + sample: false, + sampleMetaInfo$: false, + visSample: observable.ref, + destroy: false, + }); + + const filteredDataMetaInfo$ = combineLatest({ + _: fullDataChangedSignal$, + filters: this.filters$, + }).pipe( + map(({ filters }) => { + return from(filterDataService({ + computationMode: 'inline', + dataSource: dataSourceStore.cleanedData, + extData: new Map>(), + filters: toJS(filters) as IFilter[], + }).then(r => { + return this.filteredData.setAll(r.rows); + }).then(() => { + return this.filteredData.syncMetaInfoFromStorage(); + })) + }), + switchAll(), + share() + ); + + this.sampleMetaInfo$ = combineLatest({ + filteredDataMetaInfo: filteredDataMetaInfo$, + sampleRate: sampleRate$.pipe(throttleTime(SAMPLE_UPDATE_DELAY)), + fields: fields$, + }).pipe( + map(({ sampleRate, fields }) => { + const fullData = this.filteredData.getAll(); + return from( + fullData.then(rows => { + const indices = focusedSample(rows, fields, sampleRate * rows.length); + return indices.map(idx => rows[idx]); + }).then(rows => { + return this.sample.setAll(rows); + }).then(() => { + return this.sample.syncMetaInfoFromStorage(); + }) + ); + }), + switchAll(), + share() + ); + + const visSample$ = this.sampleMetaInfo$.pipe( + map(() => { + const fullData = this.sample.getAll(); + return from( + fullData.then(rows => { + return baseDemoSample(rows, VIS_SUBSET_LIMIT); + }) + ); + }), + switchAll(), + share() + ); + + const mobxReactions = [ + reaction(() => dataSourceStore.datasetId, id => { + runInAction(() => { + this.datasetId = id; + }); + this.filters$.next([]); + }), + reaction(() => dataSourceStore.cleanedData, cleanedData => { + fullDataChangedSignal$.next(1); + runInAction(() => { + this.fullDataSize = cleanedData.length; + }); + }), + reaction(() => dataSourceStore.fieldMetas, fieldMetas => { + allFields$.next(fieldMetas); + }), + reaction(() => this.sampleRate, sr => { + sampleRate$.next(sr); + }), + ]; + + const rxReactions = [ + // reset field selector + allFields$.subscribe(fields => { + runInAction(() => { + this.allFields = fields; + }); + // Choose the first 10 fields by default + this.fieldIndices$.next(fields.slice(0, 10).map((_, i) => i)); + }), + + // compute `fields` + this.fieldIndices$.subscribe((fieldIndices) => { + fields$.next(fieldIndices.map(index => this.allFields[index])); + }), + + // bind `fields` with observer + fields$.subscribe(fields => { + runInAction(() => { + this.fields = fields; + }); + }), + + // assign filters + this.filters$.subscribe(filters => { + runInAction(() => { + this.filters = filters; + }); + }), + + // update filteredData info + filteredDataMetaInfo$.subscribe(meta => { + runInAction(() => { + this.filteredDataSize = meta.length; + }); + }), + + // update sample info + this.sampleMetaInfo$.subscribe(meta => { + runInAction(() => { + this.sampleSize = meta.length; + }); + }), + + // update `visSample` + visSample$.subscribe(data => { + runInAction(() => { + this.visSample = data; + }); + }), + ]; + + // initialize data + this.datasetId = dataSourceStore.datasetId; + allFields$.next(dataSourceStore.fieldMetas); + sampleRate$.next(this.sampleRate); + fullDataChangedSignal$.next(1); + this.filters$.next([]); + + this.destroy = () => { + mobxReactions.forEach(dispose => dispose()); + rxReactions.forEach(subscription => subscription.unsubscribe()); + }; + } + + public selectFields(indices: readonly number[]) { + this.fieldIndices$.next(indices); + } + + public appendFilter(filter: IFilter) { + this.filters$.next(this.filters.concat([filter])); + } + + public removeFilter(index: number) { + this.filters$.next(produce(this.filters, draft => { + draft.splice(index, 1); + })); + } + +} diff --git a/packages/rath-client/src/store/causalStore/mainStore.ts b/packages/rath-client/src/store/causalStore/mainStore.ts new file mode 100644 index 00000000..2db18168 --- /dev/null +++ b/packages/rath-client/src/store/causalStore/mainStore.ts @@ -0,0 +1,179 @@ +import { action, makeAutoObservable, runInAction, toJS } from "mobx"; +import { notify } from "../../components/error"; +import type { PAG_NODE } from "../../pages/causal/config"; +import { getCausalModelStorage, getCausalModelStorageKeys, setCausalModelStorage } from "../../utils/storage"; +import type { DataSourceStore } from "../dataSourceStore"; +import CausalDatasetStore from "./datasetStore"; +import CausalModelStore from "./modelStore"; +import CausalOperatorStore from "./operatorStore"; +import { resolveCausality } from "./pag"; + + +export interface ICausalStoreSave { + readonly datasetId: string; + readonly fields: readonly string[]; + readonly causalModel: { + readonly algorithm: string; + readonly params: { readonly [key: string]: any }; + readonly causalityRaw: readonly (readonly PAG_NODE[])[]; + } | null; +} + +export default class CausalStore { + + public readonly dataset: CausalDatasetStore; + public readonly operator: CausalOperatorStore; + public readonly model: CausalModelStore; + + public get fields() { + return this.dataset.fields; + } + + public destroy() { + this.model.destroy(); + this.operator.destroy(); + this.dataset.destroy(); + } + + public saveKeys: string[] = []; + + public async checkout(saveKey: string) { + const save = await getCausalModelStorage(saveKey); + if (save) { + if (save.datasetId !== this.dataset.datasetId) { + notify({ + type: 'error', + title: 'Load Causal Model Failed', + content: `Dataset ID not match\nrequires: ${save.datasetId}\n: current:${this.dataset.datasetId}.`, + }); + return false; + } + const droppedFields = save.fields.filter(fid => { + return this.dataset.allFields.findIndex(f => f.fid === fid) === -1; + }); + if (droppedFields.length > 0) { + notify({ + type: 'error', + title: 'Load Causal Model Failed', + content: `${droppedFields.length} fields not found: ${droppedFields.join(', ')}.`, + }); + return false; + } + this.dataset.selectFields(save.fields.map( + fid => this.dataset.allFields.findIndex(f => f.fid === fid) + )); + if (save.causalModel) { + this.operator.updateConfig(save.causalModel.algorithm, save.causalModel.params); + runInAction(() => { + this.model.causalityRaw = save.causalModel!.causalityRaw; + this.model.causality = resolveCausality(save.causalModel!.causalityRaw, this.dataset.fields); + }); + } + return true; + } + notify({ + type: 'error', + title: 'Load Causal Model Failed', + content: `Save id ${saveKey} fields not found.`, + }); + return false; + } + + public async save(): Promise { + if (!this.dataset.datasetId) { + return false; + } + const save: ICausalStoreSave = { + datasetId: this.dataset.datasetId, + fields: this.fields.map(f => f.fid), + causalModel: this.operator.algorithm && this.model.causalityRaw ? { + algorithm: this.operator.algorithm, + params: toJS(this.operator.params[this.operator.algorithm]), + causalityRaw: this.model.causalityRaw, + } : null, + }; + await setCausalModelStorage(this.dataset.datasetId, save); + return true; + } + + public async updateSaveKeys() { + const modelKeys = await getCausalModelStorageKeys(); + runInAction(() => { + this.saveKeys = modelKeys; + }); + } + + constructor(dataSourceStore: DataSourceStore) { + this.dataset = new CausalDatasetStore(dataSourceStore); + this.operator = new CausalOperatorStore(dataSourceStore); + this.model = new CausalModelStore(this.dataset, this.operator); + + makeAutoObservable(this, { + dataset: false, + operator: false, + model: false, + checkout: action, + }); + } + + public selectFields(...args: Parameters) { + this.dataset.selectFields(...args); + } + + public appendFilter(...args: Parameters) { + this.dataset.appendFilter(...args); + } + + public removeFilter(...args: Parameters) { + this.dataset.removeFilter(...args); + } + + public async run() { + runInAction(() => { + this.model.causalityRaw = null; + this.model.causality = null; + }); + const result = await this.operator.causalDiscovery( + this.dataset.sample, + this.dataset.fields, + this.model.functionalDependencies, + this.model.assertionsAsPag, + ); + runInAction(() => { + this.model.causalityRaw = result?.raw ?? null; + this.model.causality = result?.pag ?? null; + }); + + return result; + } + + public async computeMutualMatrix() { + runInAction(() => { + this.model.mutualMatrix = null; + }); + const result = await this.operator.computeMutualMatrix(this.dataset.sample, this.dataset.fields); + runInAction(() => { + this.model.mutualMatrix = result; + }); + return result; + } + + public async computeCondMutualMatrix() { + if (!this.model.mutualMatrix) { + await this.computeMutualMatrix(); + } + const { mutualMatrix } = this.model; + if (!mutualMatrix) { + return null; + } + runInAction(() => { + this.model.condMutualMatrix = null; + }); + const result = await this.operator.computeCondMutualMatrix(this.dataset.sample, this.dataset.fields, mutualMatrix); + runInAction(() => { + this.model.condMutualMatrix = result; + }); + return result; + } + +} diff --git a/packages/rath-client/src/store/causalStore/modelStore.ts b/packages/rath-client/src/store/causalStore/modelStore.ts new file mode 100644 index 00000000..72262cf7 --- /dev/null +++ b/packages/rath-client/src/store/causalStore/modelStore.ts @@ -0,0 +1,319 @@ +import produce from "immer"; +import { makeAutoObservable, observable, reaction, runInAction } from "mobx"; +import { combineLatest, distinctUntilChanged, map, Subject, switchAll } from "rxjs"; +import type { IFieldMeta } from "../../interfaces"; +import type { IFunctionalDep, PagLink } from "../../pages/causal/config"; +import type CausalDatasetStore from "./datasetStore"; +import CausalOperatorStore from "./operatorStore"; +import { mergePAGs, transformAssertionsToPag, transformFuncDepsToPag, transformPagToAssertions } from "./pag"; + + +export enum NodeAssert { + FORBID_AS_CAUSE, + FORBID_AS_EFFECT, +} + +export type CausalModelNodeAssertion = { + fid: string; + assertion: NodeAssert; +}; + +export enum EdgeAssert { + TO_BE_RELEVANT, + TO_BE_NOT_RELEVANT, + TO_EFFECT, + TO_NOT_EFFECT, +} + +export type CausalModelEdgeAssertion = { + sourceFid: string; + targetFid: string; + assertion: EdgeAssert; +}; + +export type CausalModelAssertion = CausalModelNodeAssertion | CausalModelEdgeAssertion; + +export default class CausalModelStore { + + public readonly destroy: () => void; + + public generatedFDFromExtInfo: readonly IFunctionalDep[] = []; + public functionalDependencies: readonly IFunctionalDep[] = []; + public functionalDependenciesAsPag: readonly PagLink[] = []; + + protected assertions$ = new Subject(); + /** + * Modifiable assertions based on background knowledge of user, + * reset with the non-weak value of the causal result when the latter changes. + */ + public assertions: readonly CausalModelAssertion[] = []; + public assertionsAsPag: readonly PagLink[] = []; + + public mutualMatrix: readonly (readonly number[])[] | null = null; + public condMutualMatrix: readonly (readonly number[])[] | null = null; + + public causalityRaw: readonly (readonly number[])[] | null = null; + public causality: readonly PagLink[] | null = null; + /** causality + assertionsAsPag */ + public mergedPag: readonly PagLink[] = []; + + constructor(datasetStore: CausalDatasetStore, operatorStore: CausalOperatorStore) { + const fields$ = new Subject(); + const extFields$ = new Subject(); + const causality$ = new Subject(); + const assertionsPag$ = new Subject(); + + makeAutoObservable(this, { + destroy: false, + functionalDependencies: observable.ref, + generatedFDFromExtInfo: observable.ref, + assertions: observable.ref, + assertionsAsPag: observable.ref, + mutualMatrix: observable.ref, + causalityRaw: observable.ref, + causality: observable.ref, + mergedPag: observable.ref, + }); + + const mobxReactions = [ + reaction(() => datasetStore.fields, fields => { + fields$.next(fields); + runInAction(() => { + this.assertions = []; + this.assertionsAsPag = []; + this.mutualMatrix = null; + this.condMutualMatrix = null; + this.causalityRaw = null; + this.causality = null; + }); + }), + reaction(() => this.mutualMatrix, () => { + runInAction(() => { + this.condMutualMatrix = null; + }); + }), + reaction(() => this.functionalDependencies, funcDeps => { + runInAction(() => { + this.functionalDependenciesAsPag = transformFuncDepsToPag(funcDeps); + this.causalityRaw = null; + this.causality = null; + }); + }), + reaction(() => this.causality, () => { + this.synchronizeAssertionsWithResult(); + causality$.next(this.causality ?? []); + }), + ]; + + const rxReactions = [ + // find extInfo in fields + fields$.subscribe(fields => { + extFields$.next(fields.filter(f => Boolean(f.extInfo))); + }), + // auto update FD using extInfo + extFields$.pipe( + distinctUntilChanged((prev, curr) => { + return prev.length === curr.length && curr.every(f => prev.some(which => which.fid === f.fid)); + }), + map(extFields => { + return extFields.reduce((list, f) => { + if (f.extInfo) { + list.push({ + fid: f.fid, + params: f.extInfo.extFrom.map(from => ({ + fid: from, + })), + func: f.extInfo.extOpt, + extInfo: f.extInfo, + }); + } + return list; + }, []); + }), + ).subscribe(deps => { + runInAction(() => { + this.generatedFDFromExtInfo = deps; + }); + }), + // compute mutual matrix + combineLatest({ + dataSignal: datasetStore.sampleMetaInfo$, + fields: fields$, + }).pipe( + map(({ fields }) => operatorStore.computeMutualMatrix(datasetStore.sample, fields)), + switchAll() + ).subscribe(matrix => { + runInAction(() => { + this.mutualMatrix = matrix; + }); + }), + // update assertions + this.assertions$.subscribe(assertions => { + runInAction(() => { + this.assertions = assertions; + this.assertionsAsPag = transformAssertionsToPag(assertions, datasetStore.fields); + assertionsPag$.next(this.assertionsAsPag); + }); + }), + // compute merged pag + combineLatest({ + basis: causality$, + assertions: assertionsPag$, + }).pipe( + map(({ basis, assertions }) => mergePAGs(basis, assertions)) + ).subscribe(pag => { + runInAction(() => { + this.mergedPag = pag; + }); + }), + ]; + + fields$.next(datasetStore.fields); + + this.destroy = () => { + mobxReactions.forEach(dispose => dispose()); + rxReactions.forEach(subscription => subscription.unsubscribe()); + }; + } + + public updateFunctionalDependencies(functionalDependencies: readonly IFunctionalDep[]) { + this.functionalDependencies = functionalDependencies; + } + + public addFunctionalDependency(sourceFid: string, targetFid: string) { + this.functionalDependencies = produce(this.functionalDependencies, draft => { + const linked = draft.find(fd => fd.fid === targetFid); + if (linked && !linked.params.some(prm => prm.fid === sourceFid)) { + linked.params.push({ fid: sourceFid }); + if (!linked.func) { + linked.func = ''; + } else if (linked.func !== '') { + linked.func = ''; + } + } else { + draft.push({ + fid: targetFid, + params: [{ + fid: sourceFid, + }], + func: '', + }); + } + }); + } + + public removeFunctionalDependency(sourceFid: string, targetFid: string) { + this.functionalDependencies = produce(this.functionalDependencies, draft => { + const linkedIdx = draft.findIndex(fd => fd.fid === targetFid && fd.params.some(prm => prm.fid === sourceFid)); + if (linkedIdx !== -1) { + const linked = draft[linkedIdx]; + if (linked.params.length > 1) { + linked.params = linked.params.filter(prm => prm.fid !== sourceFid); + if (linked.func !== '') { + linked.func = ''; + } + } else { + draft.splice(linkedIdx, 1); + } + } + }); + } + + protected synchronizeAssertionsWithResult() { + const nodeAssertions = this.assertions.filter(decl => 'fid' in decl); + this.assertions$.next(this.causality ? nodeAssertions.concat(transformPagToAssertions(this.causality)) : []); + } + + public clearAssertions() { + this.assertions$.next([]); + } + + public addNodeAssertion(fid: string, assertion: NodeAssert): boolean { + const assertionsWithoutThisNode = this.assertions.filter(decl => { + if ('fid' in decl) { + return decl.fid !== fid; + } + return [decl.sourceFid, decl.targetFid].every(node => node !== fid); + }); + this.assertions$.next(assertionsWithoutThisNode.concat([{ + fid, + assertion, + }])); + return true; + } + + public removeNodeAssertion(fid: string): boolean { + const assertionIndex = this.assertions.findIndex(decl => 'fid' in decl && decl.fid === fid); + if (assertionIndex === -1) { + return false; + } + this.assertions$.next(produce(this.assertions, draft => { + draft.splice(assertionIndex, 1); + })); + return true; + } + + public revertNodeAssertion(fid: string) { + const assertionIndex = this.assertions.findIndex(decl => 'fid' in decl && decl.fid === fid); + if (assertionIndex === -1) { + return false; + } + this.assertions$.next(produce(this.assertions, draft => { + const decl = draft[assertionIndex] as CausalModelNodeAssertion; + decl.assertion = ({ + [NodeAssert.FORBID_AS_CAUSE]: NodeAssert.FORBID_AS_EFFECT, + [NodeAssert.FORBID_AS_EFFECT]: NodeAssert.FORBID_AS_CAUSE, + })[decl.assertion]; + })); + return true; + } + + public addEdgeAssertion(sourceFid: string, targetFid: string, assertion: EdgeAssert) { + if (sourceFid === targetFid || this.assertions.some(decl => 'fid' in decl && [sourceFid, targetFid].includes(decl.fid))) { + return false; + } + const assertionsWithoutThisEdge = this.assertions.filter( + decl => 'fid' in decl || !([decl.sourceFid, decl.targetFid].every(fid => [sourceFid, targetFid].includes(fid))) + ); + this.assertions$.next(assertionsWithoutThisEdge.concat([{ + sourceFid, + targetFid, + assertion, + }])); + } + + public removeEdgeAssertion(nodes: [string, string]) { + if (nodes[0] === nodes[1]) { + return false; + } + const assertionIndex = this.assertions.findIndex(decl => 'sourceFid' in decl && nodes.every(fid => [decl.sourceFid, decl.targetFid].includes(fid))); + if (assertionIndex === -1) { + return false; + } + this.assertions$.next(produce(this.assertions, draft => { + draft.splice(assertionIndex, 1); + })); + return true; + } + + public revertEdgeAssertion(nodes: [string, string]) { + if (nodes[0] === nodes[1]) { + return false; + } + const assertionIndex = this.assertions.findIndex(decl => 'sourceFid' in decl && nodes.every(fid => [decl.sourceFid, decl.targetFid].includes(fid))); + if (assertionIndex === -1) { + return false; + } + this.assertions$.next(produce(this.assertions, draft => { + const decl = draft[assertionIndex] as CausalModelEdgeAssertion; + decl.assertion = ({ + [EdgeAssert.TO_BE_RELEVANT]: EdgeAssert.TO_BE_NOT_RELEVANT, + [EdgeAssert.TO_BE_NOT_RELEVANT]: EdgeAssert.TO_BE_RELEVANT, + [EdgeAssert.TO_EFFECT]: EdgeAssert.TO_NOT_EFFECT, + [EdgeAssert.TO_NOT_EFFECT]: EdgeAssert.TO_EFFECT, + })[decl.assertion]; + })); + return true; + } + +} diff --git a/packages/rath-client/src/store/causalStore/operatorStore.ts b/packages/rath-client/src/store/causalStore/operatorStore.ts new file mode 100644 index 00000000..caf4fbab --- /dev/null +++ b/packages/rath-client/src/store/causalStore/operatorStore.ts @@ -0,0 +1,239 @@ +import type { IDropdownOption } from "@fluentui/react"; +import { makeAutoObservable, reaction, runInAction } from "mobx"; +import { distinctUntilChanged, Subject, switchAll } from "rxjs"; +import { getGlobalStore } from ".."; +import { notify } from "../../components/error"; +import type { IFieldMeta } from "../../interfaces"; +import { IAlgoSchema, IFunctionalDep, makeFormInitParams, PagLink, PAG_NODE } from "../../pages/causal/config"; +import { causalService } from "../../pages/causal/service"; +import type { IteratorStorage } from "../../utils/iteStorage"; +import type { DataSourceStore } from "../dataSourceStore"; +import { findUnmatchedCausalResults, resolveCausality } from "./pag"; + + +export default class CausalOperatorStore { + + public causalServer = ( + decodeURIComponent(new URL(window.location.href).searchParams.get('causalServer') ?? '').replace(/\/$/, '') + || 'http://gateway.kanaries.cn:2080/causal' + ); + + public busy = false; + + protected _causalAlgorithmForm: IAlgoSchema = {}; + public get causalAlgorithmForm(): IAlgoSchema { + return this._causalAlgorithmForm; + } + public params: { [algo: string]: { [key: string]: any } } = {}; + protected set causalAlgorithmForm(schema: IAlgoSchema) { + if (Object.keys(schema).length === 0) { + console.error('[causalAlgorithmForm]: schema is empty'); + return; + } + this._causalAlgorithmForm = schema; + } + public get causalAlgorithmOptions() { + return Object.entries(this._causalAlgorithmForm).map(([key, form]) => { + return { key: key, text: `${key}: ${form.title}` } as IDropdownOption; + }); + } + protected _algorithm: string | null = null; + public get algorithm() { + return this._algorithm; + } + public set algorithm(algoName: string | null) { + if (this.busy) { + return; + } else if (algoName === null) { + this._algorithm = null; + } else if (algoName in this._causalAlgorithmForm) { + this._algorithm = algoName; + } + } + + public readonly destroy: () => void; + + constructor(dataSourceStore: DataSourceStore) { + const allFields$ = new Subject(); + const dynamicFormSchema$ = new Subject>(); + + makeAutoObservable(this, { + destroy: false, + }); + + const mobxReactions = [ + reaction(() => dataSourceStore.fieldMetas, fieldMetas => { + allFields$.next(fieldMetas); + }), + // this reaction requires `makeAutoObservable` to be called before + reaction(() => this._causalAlgorithmForm, form => { + runInAction(() => { + this._algorithm = null; + this.params = {}; + for (const algoName of Object.keys(form)) { + this.params[algoName] = makeFormInitParams(form[algoName]); + } + const [firstAlgoName] = Object.keys(form); + if (firstAlgoName) { + this._algorithm = firstAlgoName; + } + }); + }), + ]; + + const rxReactions = [ + // fetch schema + allFields$.pipe( + distinctUntilChanged((prev, next) => { + return prev.length === next.length && next.every(f => prev.some(which => which.fid === f.fid)); + }), + ).subscribe(fields => { + runInAction(() => { + this.causalAlgorithmForm = {}; + }); + dynamicFormSchema$.next(this.fetchCausalAlgorithmList(fields)); + }), + // update form + dynamicFormSchema$.pipe( + switchAll() + ).subscribe(schema => { + runInAction(() => { + this.causalAlgorithmForm = schema ?? {}; + }); + }), + ]; + + this.destroy = () => { + mobxReactions.forEach(dispose => dispose()); + rxReactions.forEach(subscription => subscription.unsubscribe()); + }; + } + + protected async fetchCausalAlgorithmList(fields: readonly IFieldMeta[]): Promise { + try { + const schema: IAlgoSchema = await fetch(`${this.causalServer}/algo/list`, { + method: 'POST', + body: JSON.stringify({ + fieldIds: fields.map((f) => f.fid), + fieldMetas: fields, + }), + headers: { + 'Content-Type': 'application/json', + }, + }).then((resp) => resp.json()); + return schema; + } catch (error) { + console.error('[CausalAlgorithmList error]:', error); + return null; + } + } + + public async computeMutualMatrix(data: IteratorStorage, fields: readonly IFieldMeta[]): Promise { + const dataSource = await data.getAll(); + const res = await causalService({ task: 'ig', dataSource, fields }); + return res; + } + + public async computeCondMutualMatrix( + data: IteratorStorage, fields: readonly IFieldMeta[], mutualMatrix: readonly (readonly number[])[] + ): Promise { + const dataSource = await data.getAll(); + const res = await causalService({ task: 'ig_cond', dataSource, fields, matrix: mutualMatrix }); + return res; + } + + public async causalDiscovery( + data: IteratorStorage, + fields: readonly IFieldMeta[], + functionalDependencies: readonly IFunctionalDep[], + assertions: readonly PagLink[], + ): Promise<{ raw: number[][]; pag: PagLink[] } | null> { + if (this.busy) { + return null; + } + let causality: { raw: number[][]; pag: PagLink[] } | null = null; + const { fieldMetas: allFields } = getGlobalStore().dataSourceStore; + const focusedFields = fields.map(f => { + return allFields.findIndex(which => which.fid === f.fid); + }).filter(idx => idx !== -1); + const algoName = this._algorithm; + const inputFields = focusedFields.map(idx => allFields[idx]); + if (!algoName) { + notify({ + title: 'Causal Discovery Error', + type: 'error', + content: 'Algorithm is not chosen yet.', + }); + return null; + } + try { + runInAction(() => { + this.busy = true; + }); + const originFieldsLength = inputFields.length; + const dataSource = await data.getAll(); + const res = await fetch(`${this.causalServer}/causal/${algoName}`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + dataSource, + fields: allFields, + focusedFields: inputFields.map(f => f.fid), + bgKnowledgesPag: assertions, + funcDeps: functionalDependencies, + params: this.params[algoName], + }), + }); + const result = await res.json(); + if (result.success) { + const rawMatrix = result.data.matrix as PAG_NODE[][]; + const causalMatrix = rawMatrix + .slice(0, originFieldsLength) + .map((row) => row.slice(0, originFieldsLength)); + const causalPag = resolveCausality(causalMatrix, inputFields); + causality = { raw: causalMatrix, pag: causalPag }; + const unmatched = findUnmatchedCausalResults(assertions, causalPag); + if (unmatched.length > 0 && process.env.NODE_ENV !== 'production') { + const getFieldName = (fid: string) => { + const field = inputFields.find(f => f.fid === fid); + return field?.name ?? fid; + }; + for (const info of unmatched) { + notify({ + title: 'Causal Result Not Matching', + type: 'error', + content: `Conflict in edge "${getFieldName(info.srcFid)} -> ${getFieldName(info.tarFid)}":\n` + + ` Expected: ${info.expected.src_type} -> ${info.expected.tar_type}\n` + + ` Received: ${info.received.src_type} -> ${info.received.tar_type}`, + }); + } + } + } else { + throw new Error(result.message); + } + } catch (error) { + notify({ + title: 'Causal Discovery Error', + type: 'error', + content: `${error}`, + }); + } finally { + runInAction(() => { + this.busy = false; + }); + } + return causality; + } + + public updateConfig(algoName: string, params: typeof this.params[string]): boolean { + this.algorithm = algoName; + if (this._algorithm !== null && this._algorithm in this.params) { + this.params[this._algorithm] = params; + return true; + } + return false; + } + +} diff --git a/packages/rath-client/src/store/causalStore/pag.ts b/packages/rath-client/src/store/causalStore/pag.ts new file mode 100644 index 00000000..3e1bc410 --- /dev/null +++ b/packages/rath-client/src/store/causalStore/pag.ts @@ -0,0 +1,320 @@ +import type { IFieldMeta } from "../../interfaces"; +import { IFunctionalDep, PagLink, PAG_NODE } from "../../pages/causal/config"; +import { CausalModelAssertion, NodeAssert, EdgeAssert } from "./modelStore"; + + +export const transformAssertionsToPag = ( + assertions: readonly CausalModelAssertion[], fields: readonly IFieldMeta[] +): PagLink[] => { + return assertions.reduce((list, decl) => { + if ('fid' in decl) { + switch (decl.assertion) { + case NodeAssert.FORBID_AS_CAUSE: { + return list.concat(fields.filter(f => f.fid !== decl.fid).map(f => ({ + src: f.fid, + src_type: PAG_NODE.EMPTY, + tar: decl.fid, + tar_type: PAG_NODE.ARROW, + }))); + } + case NodeAssert.FORBID_AS_EFFECT: { + return list.concat(fields.filter(f => f.fid !== decl.fid).map(f => ({ + src: decl.fid, + src_type: PAG_NODE.EMPTY, + tar: f.fid, + tar_type: PAG_NODE.ARROW, + }))); + } + default: { + return list; + } + } + } + const srcIdx = fields.findIndex((f) => f.fid === decl.sourceFid); + const tarIdx = fields.findIndex((f) => f.fid === decl.targetFid); + + if (srcIdx !== -1 && tarIdx !== -1) { + switch (decl.assertion) { + case EdgeAssert.TO_BE_RELEVANT: { + list.push({ + src: decl.sourceFid, + tar: decl.targetFid, + src_type: PAG_NODE.CIRCLE, + tar_type: PAG_NODE.CIRCLE, + }); + break; + } + case EdgeAssert.TO_BE_NOT_RELEVANT: { + list.push({ + src: decl.sourceFid, + tar: decl.targetFid, + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.EMPTY, + }); + break; + } + case EdgeAssert.TO_EFFECT: { + list.push({ + src: decl.sourceFid, + tar: decl.targetFid, + src_type: PAG_NODE.BLANK, + tar_type: PAG_NODE.ARROW, + }); + break; + } + case EdgeAssert.TO_NOT_EFFECT: { + list.push({ + src: decl.sourceFid, + tar: decl.targetFid, + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.ARROW, + }); + break; + } + default: { + break; + } + } + } + + return list; + }, []); +}; + +export const transformFuncDepsToPag = (funcDeps: readonly IFunctionalDep[]): PagLink[] => { + return funcDeps.reduce((list, funcDep) => { + const { fid: tar } = funcDep; + for (const { fid: src } of funcDep.params) { + list.push({ + src, + tar, + src_type: PAG_NODE.BLANK, + tar_type: PAG_NODE.ARROW, + }); + } + return list; + }, []); +}; + +export const transformPagToAssertions = (pag: readonly PagLink[]): CausalModelAssertion[] => { + return pag.reduce((list, link) => { + if (link.src_type === PAG_NODE.BLANK && link.tar_type === PAG_NODE.ARROW) { + return list.concat([{ + sourceFid: link.src, + targetFid: link.tar, + assertion: EdgeAssert.TO_EFFECT, + }]); + } else if (link.tar_type === PAG_NODE.BLANK && link.src_type === PAG_NODE.ARROW) { + return list.concat([{ + sourceFid: link.tar, + targetFid: link.src, + assertion: EdgeAssert.TO_EFFECT, + }]); + } else if (link.src_type === PAG_NODE.BLANK && link.tar_type === PAG_NODE.BLANK) { + return list.concat([{ + sourceFid: link.src, + targetFid: link.tar, + assertion: EdgeAssert.TO_BE_RELEVANT, + }]); + } else if (link.src_type === PAG_NODE.ARROW && link.tar_type === PAG_NODE.ARROW) { + return list.concat([{ + sourceFid: link.src, + targetFid: link.tar, + assertion: EdgeAssert.TO_BE_RELEVANT, + }]); + } + return list; + }, []); +}; + +export const resolveCausality = (causality: readonly (readonly PAG_NODE[])[], fields: readonly IFieldMeta[]): PagLink[] => { + const links: PagLink[] = []; + + for (let i = 0; i < causality.length - 1; i += 1) { + for (let j = i + 1; j < causality.length; j += 1) { + const src = fields[i].fid; + const tar = fields[j].fid; + const src_type = causality[i][j]; + const tar_type = causality[j][i]; + if (src_type === PAG_NODE.BLANK && tar_type === PAG_NODE.ARROW) { + // i ----> j + links.push({ + src, + tar, + src_type, + tar_type, + }); + } else if (tar_type === PAG_NODE.BLANK && src_type === PAG_NODE.ARROW) { + // j ----> i + links.push({ + src: tar, + tar: src, + src_type: tar_type, + tar_type: src_type, + }); + } else if (src_type === PAG_NODE.BLANK && tar_type === PAG_NODE.BLANK) { + // i ----- j + links.push({ + src, + tar, + src_type, + tar_type, + }); + } else if (src_type === PAG_NODE.ARROW && tar_type === PAG_NODE.ARROW) { + // i <---> j + links.push({ + src, + tar, + src_type, + tar_type, + }); + } else if (src_type === PAG_NODE.CIRCLE && tar_type === PAG_NODE.ARROW) { + // i o---> j + links.push({ + src, + tar, + src_type, + tar_type, + }); + } else if (src_type === PAG_NODE.ARROW && tar_type === PAG_NODE.CIRCLE) { + // j o---> i + links.push({ + src: tar, + tar: src, + src_type: tar_type, + tar_type: src_type, + }); + } else if (tar_type === PAG_NODE.CIRCLE && src_type === PAG_NODE.CIRCLE) { + // i o---o j + links.push({ + src, + tar, + src_type, + tar_type, + }); + } + } + } + + return links; +}; + +export const mergePAGs = (pag1: readonly PagLink[], pag2: readonly PagLink[]): PagLink[] => { + return pag2.reduce((links, link) => { + const overloadIndex = links.findIndex(which => [which.src, which.tar].every(fid => [link.src, link.tar].some(node => node === fid))); + if (overloadIndex === -1) { + return links.concat([link]); + } + links.splice(overloadIndex, 1, link); + return links; + }, pag1.slice(0)).filter(link => ![link.src_type, link.tar_type].some(nodeType => nodeType === PAG_NODE.EMPTY)); +}; + +export interface ICausalDiff { + srcFid: string; + tarFid: string; + expected: Pick; + received: Pick; +} + +export const findUnmatchedCausalResults = ( + assertions: readonly PagLink[], + causality: readonly PagLink[], +): Readonly[] => { + const diffs: ICausalDiff[] = []; + + for (const decl of assertions) { + const link = causality.find(which => ( + (which.src === decl.src && which.tar === decl.tar) || (which.tar === decl.src && which.src === decl.tar) + )); + if ([decl.src_type, decl.tar_type].every(nodeType => nodeType === PAG_NODE.CIRCLE)) { + // EdgeAssert.TO_BE_RELEVANT + if (!link) { + diffs.push({ + srcFid: decl.src, + tarFid: decl.src, + expected: { + src_type: decl.src_type, + tar_type: decl.tar_type, + }, + received: { + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.EMPTY, + }, + }); + } + } else if ([decl.src_type, decl.tar_type].every(nodeType => nodeType === PAG_NODE.EMPTY)) { + // EdgeAssert.TO_BE_NOT_RELEVANT + if (link) { + diffs.push({ + srcFid: link.src, + tarFid: link.src, + expected: { + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.EMPTY, + }, + received: { + src_type: link.src_type, + tar_type: link.tar_type, + }, + }); + } + } else { + const sourceNode = decl.src_type === PAG_NODE.ARROW ? decl.tar : decl.src; + const targetNode = decl.src_type === PAG_NODE.ARROW ? decl.src : decl.tar; + const shouldEffect = (decl.src_type === PAG_NODE.ARROW ? decl.tar_type : decl.src_type) === PAG_NODE.BLANK; + if (shouldEffect) { + if (!link) { + diffs.push({ + srcFid: sourceNode, + tarFid: targetNode, + expected: { + src_type: PAG_NODE.BLANK, + tar_type: PAG_NODE.ARROW, + }, + received: { + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.EMPTY, + }, + }); + } else { + const sourceType = link.src === sourceNode ? link.src_type : link.tar_type; + const targetType = link.tar === targetNode ? link.tar_type : link.src_type; + if (targetType !== PAG_NODE.ARROW) { + diffs.push({ + srcFid: sourceNode, + tarFid: targetNode, + expected: { + src_type: PAG_NODE.BLANK, + tar_type: PAG_NODE.ARROW, + }, + received: { + src_type: sourceType, + tar_type: targetType, + }, + }); + } + } + } else if (link) { + const sourceType = link.src === sourceNode ? link.src_type : link.tar_type; + const targetType = link.tar === targetNode ? link.tar_type : link.src_type; + if (targetType === PAG_NODE.ARROW) { + diffs.push({ + srcFid: sourceNode, + tarFid: targetNode, + expected: { + src_type: PAG_NODE.EMPTY, + tar_type: PAG_NODE.ARROW, + }, + received: { + src_type: sourceType, + tar_type: targetType, + }, + }); + } + } + } + } + + return diffs; +}; diff --git a/packages/rath-client/src/store/causalStore/viewStore.ts b/packages/rath-client/src/store/causalStore/viewStore.ts new file mode 100644 index 00000000..a5776637 --- /dev/null +++ b/packages/rath-client/src/store/causalStore/viewStore.ts @@ -0,0 +1,282 @@ +import produce from "immer"; +import { makeAutoObservable, observable, reaction, runInAction } from "mobx"; +import { createContext, FC, useContext, useMemo, createElement, useEffect, useCallback } from "react"; +import { Subject, withLatestFrom } from "rxjs"; +import type { IFieldMeta } from "../../interfaces"; +import type { GraphNodeAttributes } from "../../pages/causal/explorer/graph-utils"; +import type { IPredictResult, PredictAlgorithm } from "../../pages/causal/predict"; +import type { IRInsightExplainResult } from "../../workers/insight/r-insight.worker"; +import type CausalStore from "./mainStore"; + + +export enum NodeSelectionMode { + NONE, + SINGLE, + MULTIPLE, +} + +export enum ExplorationKey { + // CAUSAL_BLAME = 'CausalBlame', + AUTO_VIS = 'AutoVis', + CROSS_FILTER = 'CrossFilter', + CAUSAL_INSIGHT = 'CausalInsight', + GRAPHIC_WALKER = 'GraphicWalker', + PREDICT = 'predict', +} + +export const ExplorationOptions = [ + // { key: ExplorationKey.CAUSAL_BLAME, text: '归因分析' }, + { key: ExplorationKey.AUTO_VIS, text: '变量概览' }, + { key: ExplorationKey.CROSS_FILTER, text: '因果验证' }, + { key: ExplorationKey.CAUSAL_INSIGHT, text: '可解释探索' }, + { key: ExplorationKey.GRAPHIC_WALKER, text: '可视化自助分析' }, + { key: ExplorationKey.PREDICT, text: '模型预测' }, +] as const; + +class CausalViewStore { + + public explorationKey = ExplorationKey.AUTO_VIS; + public graphNodeSelectionMode = NodeSelectionMode.MULTIPLE; + + protected selectedFidArr$ = new Subject(); + protected _selectedNodes: readonly IFieldMeta[] = []; + public get selectedFieldGroup() { + return this._selectedNodes.slice(0); + } + public get selectedField() { + return this._selectedNodes.at(0) ?? null; + } + + public shouldDisplayAlgorithmPanel = false; + + public onRenderNode: ((node: Readonly) => GraphNodeAttributes | undefined) | undefined; + public localWeights: Map> | undefined; + public predictCache: { + id: string; algo: PredictAlgorithm; startTime: number; completeTime: number; data: IPredictResult; + }[]; + + public readonly destroy: () => void; + + constructor(causalStore: CausalStore) { + this.onRenderNode = node => { + const value = 2 / (1 + Math.exp(-1 * node.features.entropy / 2)) - 1; + return { + style: { + stroke: `rgb(${Math.floor(95 * (1 - value))},${Math.floor(149 * (1 - value))},255)`, + }, + }; + }; + this.localWeights = undefined; + this.predictCache = []; + + const fields$ = new Subject(); + + makeAutoObservable(this, { + onRenderNode: observable.ref, + localWeights: observable.ref, + predictCache: observable.shallow, + // @ts-expect-error non-public field + _selectedNodes: observable.ref, + selectedFidArr$: false, + }); + + const mobxReactions = [ + reaction(() => causalStore.fields, fields => { + fields$.next(fields); + this.selectedFidArr$.next([]); + }), + reaction(() => causalStore.model.mergedPag, () => { + this.selectedFidArr$.next([]); + }), + reaction(() => this.explorationKey, explorationKey => { + runInAction(() => { + switch (explorationKey) { + // case ExplorationKey.CAUSAL_BLAME: + case ExplorationKey.CAUSAL_INSIGHT: + case ExplorationKey.PREDICT: { + this.graphNodeSelectionMode = NodeSelectionMode.SINGLE; + break; + } + case ExplorationKey.AUTO_VIS: + case ExplorationKey.CROSS_FILTER: { + this.graphNodeSelectionMode = NodeSelectionMode.MULTIPLE; + break; + } + default: { + this.graphNodeSelectionMode = NodeSelectionMode.NONE; + } + } + }); + }), + reaction(() => this.graphNodeSelectionMode, graphNodeSelectionMode => { + runInAction(() => { + switch (graphNodeSelectionMode) { + case NodeSelectionMode.SINGLE: { + this._selectedNodes = this._selectedNodes.slice(this._selectedNodes.length - 1); + break; + } + case NodeSelectionMode.MULTIPLE: { + break; + } + default: { + this._selectedNodes = []; + break; + } + } + }); + }), + ]; + + const rxReactions = [ + this.selectedFidArr$.subscribe(() => { + this.localWeights = undefined; + }), + this.selectedFidArr$.pipe( + withLatestFrom(fields$) + ).subscribe(([fidArr, fields]) => { + runInAction(() => { + this._selectedNodes = fidArr.reduce((nodes, fid) => { + const f = fields.find(which => which.fid === fid); + if (f) { + return nodes.concat([f]); + } else { + console.warn(`Select node warning: cannot find field ${fid}.`, fields); + } + return nodes; + }, []); + }); + }), + ]; + + fields$.next(causalStore.fields); + + this.destroy = () => { + mobxReactions.forEach(dispose => dispose()); + rxReactions.forEach(subscription => subscription.unsubscribe()); + }; + } + + public setExplorationKey(explorationKey: ExplorationKey) { + this.explorationKey = explorationKey; + } + + public setNodeSelectionMode(selectionMode: NodeSelectionMode) { + this.graphNodeSelectionMode = selectionMode; + } + + public toggleNodeSelected(fid: string) { + switch (this.graphNodeSelectionMode) { + case NodeSelectionMode.SINGLE: { + if (this.selectedField?.fid === fid) { + this.selectedFidArr$.next([]); + return false; + } else { + this.selectedFidArr$.next([fid]); + return true; + } + } + case NodeSelectionMode.MULTIPLE: { + const selectedFidArr = this.selectedFieldGroup.map(f => f.fid); + this.selectedFidArr$.next(produce(selectedFidArr, draft => { + const matchedIndex = draft.findIndex(f => f === fid); + if (matchedIndex !== -1) { + draft.splice(matchedIndex, 1); + } else { + draft.push(fid); + } + })); + break; + } + default: { + return undefined; + } + } + } + + public selectNode(fid: string) { + switch (this.graphNodeSelectionMode) { + case NodeSelectionMode.SINGLE: { + if (this.selectedField?.fid === fid) { + this.selectedFidArr$.next([]); + return false; + } else { + this.selectedFidArr$.next([fid]); + return true; + } + } + case NodeSelectionMode.MULTIPLE: { + const selectedFidArr = this.selectedFieldGroup.map(f => f.fid); + this.selectedFidArr$.next(produce(selectedFidArr, draft => { + const matchedIndex = draft.findIndex(f => f === fid); + if (matchedIndex === -1) { + draft.push(fid); + } + })); + break; + } + default: { + return undefined; + } + } + } + + public clearSelected() { + this.selectedFidArr$.next([]); + } + + public openAlgorithmPanel() { + this.shouldDisplayAlgorithmPanel = true; + } + + public closeAlgorithmPanel() { + this.shouldDisplayAlgorithmPanel = false; + } + + public setNodeRenderer(handleRender: typeof this.onRenderNode) { + this.onRenderNode = handleRender; + } + + public clearLocalWeights() { + this.localWeights = undefined; + } + + public setLocalWeights(irResult: IRInsightExplainResult) { + const weights = new Map>(); + for (const link of irResult.causalEffects) { + if (!weights.has(link.src)) { + weights.set(link.src, new Map()); + } + weights.get(link.src)!.set(link.tar, link.responsibility); + } + this.localWeights = weights; + } + + public pushPredictResult(result: typeof this.predictCache[number]) { + this.predictCache.push(result); + } + + public clearPredictResults() { + this.predictCache = []; + } + +} + + +const CausalViewContext = createContext(null); + +export const useCausalViewProvider = (causalStore: CausalStore): FC => { + const context = useMemo(() => new CausalViewStore(causalStore), [causalStore]); + + useEffect(() => { + const ref = context; + return () => { + ref.destroy(); + }; + }, [context]); + + return useCallback(function CausalViewProvider ({ children }) { + return createElement(CausalViewContext.Provider, { value: context }, children); + }, [context]); +}; + +export const useCausalViewContext = () => useContext(CausalViewContext); diff --git a/packages/rath-client/src/store/index.tsx b/packages/rath-client/src/store/index.tsx index 8dcb2a1b..9ca5c09e 100644 --- a/packages/rath-client/src/store/index.tsx +++ b/packages/rath-client/src/store/index.tsx @@ -9,7 +9,7 @@ import { SemiAutomationStore } from './semiAutomation/mainStore'; import { PainterStore } from './painterStore' import { CollectionStore } from './collectionStore' import DashboardStore from './dashboardStore'; -import { CausalStore } from './causalStore'; +import CausalStore from './causalStore/mainStore'; export interface StoreCollection { langStore: LangStore; dataSourceStore: DataSourceStore; @@ -47,7 +47,7 @@ const storeCol: StoreCollection = { painterStore, collectionStore, dashboardStore, - causalStore + causalStore, } const StoreContext = React.createContext(null!); diff --git a/packages/rath-client/src/utils/resolve-causal.ts b/packages/rath-client/src/utils/resolve-causal.ts index 64106e90..d10961f7 100644 --- a/packages/rath-client/src/utils/resolve-causal.ts +++ b/packages/rath-client/src/utils/resolve-causal.ts @@ -1,3 +1,5 @@ +/** @deprecated */ + import intl from 'react-intl-universal'; import type { IFieldMeta } from '../interfaces'; import { BgKnowledgePagLink, ModifiableBgKnowledge, PagLink, PAG_NODE } from '../pages/causal/config'; diff --git a/packages/rath-client/src/utils/sample.test.ts b/packages/rath-client/src/utils/sample.test.ts index 90a06c14..738ff838 100644 --- a/packages/rath-client/src/utils/sample.test.ts +++ b/packages/rath-client/src/utils/sample.test.ts @@ -159,6 +159,13 @@ describe('function focusedSample', () => { expect(sample.length).toBe(0); }); + it('Sample size test (float)', () => { + const { data: fullSet, fields } = createRandomData(8, 100); + const sampleSize = 33.3; + const sample = focusedSample(fullSet, fields, sampleSize); + + expect(sample.length).toBe(33); + }); it('Sample size test (small set, more than half)', () => { const { data: fullSet, fields } = createRandomData(6, 64); const sampleRate = 0.8; diff --git a/packages/rath-client/src/utils/sample.ts b/packages/rath-client/src/utils/sample.ts index 0210b738..65da3854 100644 --- a/packages/rath-client/src/utils/sample.ts +++ b/packages/rath-client/src/utils/sample.ts @@ -216,7 +216,7 @@ export const focusedSample = (fullSet: readonly IRow[], focusedFields: readonly } const hashed = hashAll(fullSet, focusedFields); const bins = treeSplit(hashed, fullSet.length / Math.sqrt(sampleSize)); - const indices = sampleBins(bins, fullSet.length, sampleSize); + const indices = sampleBins(bins, fullSet.length, Math.floor(sampleSize)); return indices; }; diff --git a/packages/rath-client/src/utils/storage.ts b/packages/rath-client/src/utils/storage.ts index c7806bdf..fe18d9f9 100644 --- a/packages/rath-client/src/utils/storage.ts +++ b/packages/rath-client/src/utils/storage.ts @@ -1,7 +1,8 @@ import localforage from 'localforage'; import { RESULT_STORAGE_SPLITOR, STORAGES, STORAGE_INSTANCE } from '../constants'; -import { IFieldMeta, IMuteFieldBase, IRow } from '../interfaces'; +import type { IFieldMeta, IMuteFieldBase, IRow } from '../interfaces'; +import type { ICausalStoreSave } from '../store/causalStore/mainStore'; import type { CausalLinkDirection } from './resolve-causal'; export interface IDBMeta { @@ -15,6 +16,7 @@ export interface IDBMeta { fields?: IMuteFieldBase[]; } +/** @deprecated */ export interface IModel { metas: IFieldMeta[]; causal: { @@ -192,35 +194,35 @@ export async function getDataConfig(name: string) { return ds; } -export async function setModelStorage (name: string, model: IModel) { +export async function setCausalModelStorage (saveId: string, model: ICausalStoreSave) { const modelBucket = localforage.createInstance({ name: STORAGE_INSTANCE, - storeName: STORAGES.MODEL - }) - await modelBucket.setItem(name, model); + storeName: STORAGES.CAUSAL_MODEL, + }); + await modelBucket.setItem(saveId, model); } -export async function deleteModelStorage (name: string, model: IModel) { +export async function deleteCausalModelStorage (saveId: string) { const modelBucket = localforage.createInstance({ name: STORAGE_INSTANCE, - storeName: STORAGES.MODEL - }) - await modelBucket.removeItem(name); + storeName: STORAGES.CAUSAL_MODEL, + }); + await modelBucket.removeItem(saveId); } -export async function getModelStorage (name: string): Promise { +export async function getCausalModelStorage (saveId: string): Promise { const modelBucket = localforage.createInstance({ name: STORAGE_INSTANCE, - storeName: STORAGES.MODEL - }) - return await modelBucket.getItem(name) as IModel; + storeName: STORAGES.CAUSAL_MODEL, + }); + return await modelBucket.getItem(saveId); } -export async function getModelStorageList (): Promise { +export async function getCausalModelStorageKeys (): Promise { const modelBucket = localforage.createInstance({ name: STORAGE_INSTANCE, - storeName: STORAGES.MODEL - }) + storeName: STORAGES.CAUSAL_MODEL, + }); return await modelBucket.keys(); } diff --git a/packages/rath-client/src/utils/view-sample.ts b/packages/rath-client/src/utils/view-sample.ts index 98edbcd6..ce160e56 100644 --- a/packages/rath-client/src/utils/view-sample.ts +++ b/packages/rath-client/src/utils/view-sample.ts @@ -73,6 +73,6 @@ export function viewSampling (data: IRow[], fields: IFieldMeta[], sampleSize: nu return samples } -export function baseDemoSample (data: IRow[], sampleSize: number): IRow[] { - return Sampling.reservoirSampling(data, sampleSize) +export function baseDemoSample (data: readonly IRow[], sampleSize: number): IRow[] { + return Sampling.reservoirSampling(data as IRow[], sampleSize) } diff --git a/packages/rath-client/src/workers/insight/r-insight.worker.ts b/packages/rath-client/src/workers/insight/r-insight.worker.ts index 9af2982c..5b0c2387 100644 --- a/packages/rath-client/src/workers/insight/r-insight.worker.ts +++ b/packages/rath-client/src/workers/insight/r-insight.worker.ts @@ -12,13 +12,13 @@ export interface IRInsightExplainSubspace { export interface IRInsightExplainProps { /** 因果图输入数据子集 */ - data: IRow[]; - fields: IFieldMeta[]; + data: readonly IRow[]; + fields: readonly IFieldMeta[]; causalModel: { /** 函数依赖 */ - funcDeps: IFunctionalDep[]; + funcDeps: readonly IFunctionalDep[]; /** 用户编辑后的因果图 */ - edges: PagLink[]; + edges: readonly PagLink[]; }; groups: { current: IRInsightExplainSubspace; diff --git a/packages/rath-client/src/workers/insight/utils.ts b/packages/rath-client/src/workers/insight/utils.ts index 4ff52669..51eeed9a 100644 --- a/packages/rath-client/src/workers/insight/utils.ts +++ b/packages/rath-client/src/workers/insight/utils.ts @@ -224,10 +224,10 @@ export const insightExplain = (props: IRInsightExplainProps): IRInsightExplainRe if (!measure) { continue; } - if (view.dimensions.some(dim => cramersV(data, dim, f.fid) >= RELATION_THRESHOLD)) { + if (view.dimensions.some(dim => cramersV(data.slice(0), dim, f.fid) >= RELATION_THRESHOLD)) { continue; } - const responsibility = diffGroups(data, indices1, indices2, f, { + const responsibility = diffGroups(data.slice(0), indices1, indices2, f, { field: measure, aggregate: target.op, }); diff --git a/services/causal-service/.gitignore b/services/causal-service/.gitignore index 3eb8214d..430786a6 100644 --- a/services/causal-service/.gitignore +++ b/services/causal-service/.gitignore @@ -1,3 +1,7 @@ __pycache__ .ipynb_checkpoints -causal-learn \ No newline at end of file +causal-learn + +# python virtual env +cms +causal_model.png diff --git a/services/causal-service/Dockerfile b/services/causal-service/Dockerfile index bf5c865b..8be7ae61 100644 --- a/services/causal-service/Dockerfile +++ b/services/causal-service/Dockerfile @@ -2,9 +2,9 @@ FROM python:3 WORKDIR /app COPY requirements.txt requirements.txt +RUN pip install -vvv -i https://mirrors.aliyun.com/pypi/simple --no-cache-dir -r /app/requirements.txt COPY causal-learn causal-learn -RUN pip install -i https://mirrors.aliyun.com/pypi/simple --no-cache-dir -r /app/requirements.txt && \ - pip install -i https://mirrors.aliyun.com/pypi/simple --no-cache-dir /app/causal-learn +RUN pip install -vvv -i https://mirrors.aliyun.com/pypi/simple --no-cache-dir /app/causal-learn COPY . . EXPOSE 8000 CMD gunicorn main:app --workers 16 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000 --reload diff --git a/services/causal-service/algorithms/FuncDepTest.py b/services/causal-service/algorithms/FuncDepTest.py index bd4b3565..c7ece899 100644 --- a/services/causal-service/algorithms/FuncDepTest.py +++ b/services/causal-service/algorithms/FuncDepTest.py @@ -18,14 +18,14 @@ class FuncDepTestParams(common.OptionalParams, title="FuncDepTest Algorithm"): # cg.G.graph[i,j] = cg.G.graph[j,i] = 1 indicates i <-> j. # """ indep_test: Optional[str] = Field( - default='chisq', title="独立性检验", #"Independence Test", + default='gsq', title="独立性检验", #"Independence Test", description="The independence test to use for causal discovery", options=common.getOpts(common.IDepTestItems), ) alpha: Optional[float] = Field( - default=math.log10(0.0005), title="log10(显著性阈值)", # "Alpha", + default=-12, title="log10(显著性阈值)", # "Alpha", description="desired significance level (float) in (0, 1). Default: log10(0.005).", - ge=-9, lt=0.0 + ge=-16, lt=0.0 ) orient: Optional[float] = Field( default='ANM', title="方向判断算法", @@ -39,11 +39,11 @@ class FuncDepTestParams(common.OptionalParams, title="FuncDepTest Algorithm"): class FuncDepTest(common.AlgoInterface): ParamType = FuncDepTestParams - def __init__(self, dataSource: List[common.IRow], fields: List[common.IFieldMeta], params: Optional[ParamType] = ParamType()): + def __init__(self, dataSource: List[common.IRow], fields: List[common.IFieldMeta], params: Optional[ParamType] = ParamType(), **kwargs): print("FuncDepTest", fields, params) super(FuncDepTest, self).__init__(dataSource=dataSource, fields=fields, params=params) - def calc(self, params: Optional[ParamType] = ParamType(), focusedFields: List[str] = [], bgKnowledgesPag: Optional[List[common.BgKnowledgePag]] = []): + def calc(self, params: Optional[ParamType] = ParamType(), focusedFields: List[str] = [], bgKnowledgesPag: Optional[List[common.BgKnowledgePag]] = [], **kwargs): array = self.selectArray(focusedFields=focusedFields, params=params) d = len(focusedFields) import itertools, numpy as np @@ -55,28 +55,29 @@ def calc(self, params: Optional[ParamType] = ParamType(), focusedFields: List[st o_test = lambda x, y: anm.cause_or_effect(x, y) # coef = np.corrcoef(array, rowvar=False) from causallearn.utils.cit import CIT - cit = CIT(array, 'fisherz') + # cit = CIT(array, 'fisherz') + cit = CIT(array, params.indep_test) coeff_p = np.zeros((d, d)) for i in range(d): for j in range(d): if i != j: coeff_p[i, j] = coeff_p[j, i] = cit(i, j, []) print(coeff_p) - linear_threshold = 1e-12 + linear_threshold = 1e-18 threshold = 10 ** params.o_alpha + max_samples = 128 + array = array[np.random.choice(range(array.shape[0]), min(array.shape[0], max_samples), replace=False).tolist(),:] for i in range(d): for j in range(i): - if linear_threshold < coeff_p[i, j] < 10 ** params.alpha: + if coeff_p[i, j] < 10 ** params.alpha: a, b = o_test(array[:, i:i+1], array[:, j:j+1]) print(f"indep: {i}, {j}, {coeff_p[i, j]}") print("Orient model p:", a, b) if a * threshold < b: - res[i, j], res[j, i] = 1, -1 - elif a > b * threshold: res[i, j], res[j, i] = -1, 1 - else: - res[i, j], res[j, i] = -1, -1 - elif coeff_p[i, j] <= linear_threshold: # linear - res[i, j], res[j, i] = 1, 1 + elif a > b * threshold: + res[i, j], res[j, i] = 1, -1 + # else: res[i, j], res[j, i] = -1, -1 + # elif coeff_p[i, j] <= linear_threshold: # linear res[i, j], res[j, i] = 1, 1 # for i in range(d): # for j in range(i): diff --git a/services/causal-service/algorithms/causallearn/PC.py b/services/causal-service/algorithms/causallearn/PC.py index 3d1f6b04..c699f3bd 100644 --- a/services/causal-service/algorithms/causallearn/PC.py +++ b/services/causal-service/algorithms/causallearn/PC.py @@ -54,6 +54,7 @@ class PCParams(OptionalParams, title="PC Algorithm"): class PC(AlgoInterface): ParamType = PCParams + dev_only = False def __init__(self, dataSource: List[IRow], fields: List[IFieldMeta], params: Optional[ParamType] = ParamType()): super(PC, self).__init__(dataSource=dataSource, fields=fields, params=params) diff --git a/services/causal-service/algorithms/causallearn/XLearner.py b/services/causal-service/algorithms/causallearn/XLearner.py index 32ed2135..e19d3171 100644 --- a/services/causal-service/algorithms/causallearn/XLearner.py +++ b/services/causal-service/algorithms/causallearn/XLearner.py @@ -322,6 +322,7 @@ class XLearnerParams(OptionalParams, title="XLearn"): class XLearner(AlgoInterface): ParamType = XLearnerParams + dev_only = False def __init__(self, dataSource: List[IRow], fields: List[IFieldMeta], params: Optional[ParamType] = ParamType()): super(XLearner, self).__init__(dataSource, fields, params) diff --git a/services/causal-service/algorithms/common.py b/services/causal-service/algorithms/common.py index 3137bb57..d66308bf 100644 --- a/services/causal-service/algorithms/common.py +++ b/services/causal-service/algorithms/common.py @@ -96,7 +96,7 @@ def checkLinearCorr(array: np.ndarray): print(U, s, VT, sep='\n') # raise Exception("The input array is linear correlated, some fields should be unselected.\n[to be optimized]") # array *= (1 + np.random.randn(*array.shape)*1e-3) - array *= (1 + np.random.randn(*array.shape) * 1e-3) + # array *= (1 + np.random.randn(*array.shape) * 1e-3) print("The input array is linear correlated, some fields should be unselected.\n[to be optimized]", file=sys.stderr) # if np.abs(s[-1] / s[0]) < 1e-4: # print("CheckLinearCorr", U, s, VT) @@ -196,7 +196,7 @@ def encodeCat(origin: pd.Series, fact: pd.Series, encodeType: str) -> pd.DataFra return pd.DataFrame(fact) def encodeQuant(x: pd.Series, encodeType: str) -> pd.DataFrame: - n, eps = 10, 1e-5 + n, eps = 16, 1e-5 if encodeType == 'bin': # encodeType.bin: width = x.max() - x.min() if width == 0: return pd.DataFrame(x) @@ -332,6 +332,7 @@ class CausalRequest(BaseModel, extra=Extra.allow): class AlgoInterface: ParamType = OptionalParams + dev_only = True cache_path = None # '/tmp/causal.json' verbose = False def __init__(self, dataSource: List[IRow], fields: List[IFieldMeta], diff --git a/services/causal-service/algorithms/dowhy/ExplainData.py b/services/causal-service/algorithms/dowhy/ExplainData.py index 8f63f30f..55ec5e02 100644 --- a/services/causal-service/algorithms/dowhy/ExplainData.py +++ b/services/causal-service/algorithms/dowhy/ExplainData.py @@ -287,13 +287,21 @@ def ExplainData(props: IDoWhy.IRInsightExplainProps) -> tp.List[IDoWhy.IRInsight results.append(IDoWhy.LinkInfo( src=props.view.dimensions[0], tar=props.view.measures[0].fid, src_type=-1, tar_type=1, description=IDoWhy.LinkInfoDescription(key='', data=descrip_data), - responsibility=significance_value(session.estimate.value, var=1.) + responsibility=session.estimate.value )) except Exception as e: print(str(e), file=sys.stderr) results.extend(explainData(props)) - # print("results =", results) + + sum2 = 0. + for res in results: + sum2 += res.responsibility * res.responsibility + vars = math.sqrt(sum2 / len(results)) + for res in results: + res.responsibility = significance_value(res.responsibility, vars) + + print("results =", results) return IDoWhy.IRInsightExplainResult( causalEffects=results diff --git a/services/causal-service/main.py b/services/causal-service/main.py index 462feedd..8058e98e 100644 --- a/services/causal-service/main.py +++ b/services/causal-service/main.py @@ -9,13 +9,18 @@ import interfaces as I import algorithms +debug = os.environ.get('mode', 'prod') == 'dev' +print("Development Mode" if debug else 'Production Mode', file=sys.stderr) app = FastAPI() origins = [ "*" ] +cors_regex = \ + "^(https?\://)?(([\w\-_\.]*\.)?kanaries\.\w*|rath[\w\-_]*\-kanaries\.vercel.app)(\:\d{1,})?$" if not debug else \ + "^(https?\://)?(([\w\-_\.]*\.)?kanaries\.\w*|rath[\w\-_]*\-kanaries\.vercel.app|localhost|192\.168\.\d{1,3}\.\d{1,3}|127\.0\.0\.1)(\:\d{1,})?$" app.add_middleware( CORSMiddleware, allow_origins=origins, - # allow_origin_regex="^https?\://([\w\-_\.]*\.kanaries\.\w*)(\:\d{1,})?$", - allow_origin_regex="^https?\://([\w\-_\.]*\.kanaries\.\w*|localhost)(\:\d{1,})?$", # dev only + # allow_origin_regex="^https?\://([\w\-_\.]*\.kanaries\.\w*|rath[\w\-_]*\-kanaries\.vercel.app)(\:\d{1,})?$", + allow_origin_regex=cors_regex, allow_credentials=False, allow_methods=["*"], allow_headers=["*"], @@ -84,7 +89,7 @@ async def algoList(req: AlgoListRequest, response: Response) -> Dict[str, I.Serv # print("/algo/list", req) return { algoName: getAlgoSchema(algoName, req) - for algoName in algorithms.DICT.keys() + for algoName, algo in algorithms.DICT.items() if algo.dev_only == False or debug == True } @app.post('/algo/list/{algoName}', response_model=I.ServiceSchemaResponse) @@ -149,7 +154,6 @@ async def algoSchema(algoName: str, response: Response): import sys import logging -debug = os.environ.get('dev', None) is not None def causal(algoName: str, item: algorithms.CausalRequest, response: Response) -> I.CausalAlgorithmResponse: try: method: I.AlgoInterface = algorithms.DICT.get(algoName)(item.dataSource, item.fields, item.params) diff --git a/services/causal-service/run-docker.sh b/services/causal-service/run-docker.sh index 92c1572e..29c8f528 100755 --- a/services/causal-service/run-docker.sh +++ b/services/causal-service/run-docker.sh @@ -3,7 +3,16 @@ PORT=8001 PORT2=2281 cur_dir=`dirname $0` cd $cur_dir -. fetch.sh && \ -docker build -t causal-server . && \ -docker run -d -p $PORT:8000 -p $PORT2:8000 --name run-causal-server causal-server +. $cur_dir/fetch.sh && \ +docker build -t causal-server . +docker ps -f name=run-causal-server | grep run-causal-server +if [ $? -eq 0 ] ; then +docker stop run-causal-server +docker rm run-causal-server +fi +docker run -d -p $PORT:8000 -p $PORT2:8000 --env mode=$mode --name run-causal-server causal-server + +# docker wait run-causal-server && \ +# docker rm run-causal-server && \ +# docker run -d -p $PORT:8000 -p $PORT2:8000 --name run-causal-server causal-server # docker run -it --rm -v $(pwd):/app -p $PORT:8000 -p $PORT2:8000 --name run-causal-server causal-server diff --git a/yarn.lock b/yarn.lock index f703136d..6e5242b8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3128,18 +3128,6 @@ resolved "https://registry.npmmirror.com/@types/crypto-js/-/crypto-js-4.1.1.tgz#602859584cecc91894eb23a4892f38cfa927890d" integrity sha512-BG7fQKZ689HIoc5h+6D2Dgq1fABRa0RbBWKBd9SP/MVRVXROflpm5fhwyATX5duFmbStzyzyycPB8qUYKDH3NA== -"@types/d3-path@*": - version "3.0.0" - resolved "https://registry.npmmirror.com/@types/d3-path/-/d3-path-3.0.0.tgz#939e3a784ae4f80b1fde8098b91af1776ff1312b" - integrity sha512-0g/A+mZXgFkQxN3HniRDbXMN79K3CdTpLsevj+PXiTcb2hVyvkZUBg37StmgCQkaD84cUJ4uaDAWq7UJOQy2Tg== - -"@types/d3-shape@^3.1.0": - version "3.1.0" - resolved "https://registry.npmmirror.com/@types/d3-shape/-/d3-shape-3.1.0.tgz#1d87a6ddcf28285ef1e5c278ca4bdbc0658f3505" - integrity sha512-jYIYxFFA9vrJ8Hd4Se83YI6XF+gzDL1aC5DCsldai4XYYiVNdhtpGbA/GM6iyQ8ayhSp3a148LY34hy7A4TxZA== - dependencies: - "@types/d3-path" "*" - "@types/d3-timer@^2.0.0": version "2.0.1" resolved "https://registry.npmmirror.com/@types/d3-timer/-/d3-timer-2.0.1.tgz#ffb6620d290624f3726aa362c0c8a4b44c8d7200" @@ -5274,7 +5262,7 @@ customize-cra@^1.0.0: dependencies: lodash.flow "^3.5.0" -"d3-array@1 - 3", "d3-array@2 - 3", "d3-array@2.10.0 - 3", "d3-array@2.5.0 - 3", d3-array@^3.1.1, d3-array@^3.1.6: +"d3-array@1 - 3", "d3-array@2 - 3", "d3-array@2.10.0 - 3", "d3-array@2.5.0 - 3", d3-array@^3.1.1: version "3.2.0" resolved "https://registry.yarnpkg.com/d3-array/-/d3-array-3.2.0.tgz#15bf96cd9b7333e02eb8de8053d78962eafcff14" integrity sha512-3yXFQo0oG3QCxbF06rMPFyGRMGJNS7NvsV1+2joOjbBE+9xvWQ8+GcMJAjRCzw06zQ3/arXeJgbPYcjUCuC+3g== @@ -5298,16 +5286,6 @@ d3-color@1: resolved "https://registry.yarnpkg.com/d3-color/-/d3-color-3.1.0.tgz#395b2833dfac71507f12ac2f7af23bf819de24e2" integrity sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA== -d3-dag@^0.11.5: - version "0.11.5" - resolved "https://registry.npmmirror.com/d3-dag/-/d3-dag-0.11.5.tgz#666675d763770ca013d13b609eba5fc66d8d419e" - integrity sha512-sNHvYqjzDlvV2fyEkoOCSuLs2GeWliIg7pJcAiKXgtUSxl0kIX0C2q1J8JzzA9CQWptKxYtzxFCXiKptTW8qsQ== - dependencies: - d3-array "^3.1.6" - fastpriorityqueue "0.7.2" - javascript-lp-solver "0.4.24" - quadprog "^1.6.1" - d3-delaunay@^6.0.2: version "6.0.2" resolved "https://registry.yarnpkg.com/d3-delaunay/-/d3-delaunay-6.0.2.tgz#7fd3717ad0eade2fc9939f4260acfb503f984e92" @@ -6411,11 +6389,6 @@ fast-levenshtein@^2.0.6, fast-levenshtein@~2.0.6: resolved "https://registry.yarnpkg.com/fast-levenshtein/-/fast-levenshtein-2.0.6.tgz#3d8a5c66883a16a30ca8643e851f19baa7797917" integrity sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw== -fastpriorityqueue@0.7.2: - version "0.7.2" - resolved "https://registry.npmmirror.com/fastpriorityqueue/-/fastpriorityqueue-0.7.2.tgz#64dfee2c2adbc18c076cf7552dc4bfbef7befe3f" - integrity sha512-5DtIKh6vtOmEGkYdEPNNb+mxeYCnBiKbK3s4gq52l6cX8I5QaTDWWw0Wx/iYo80fVOblSycHu1/iJeqeNxG8Jw== - fastq@^1.6.0: version "1.13.0" resolved "https://registry.yarnpkg.com/fastq/-/fastq-1.13.0.tgz#616760f88a7526bdfc596b7cab8c18938c36b98c" @@ -7667,11 +7640,6 @@ jake@^10.8.5: filelist "^1.0.1" minimatch "^3.0.4" -javascript-lp-solver@0.4.24: - version "0.4.24" - resolved "https://registry.npmmirror.com/javascript-lp-solver/-/javascript-lp-solver-0.4.24.tgz#3bb5f8aa051f0bf04747e39130133a0f738f635c" - integrity sha512-5edoDKnMrt/u3M6GnZKDDIPxOyFOg+WrwDv8mjNiMC2DePhy2H9/FFQgf4ggywaXT1utvkxusJcjQUER72cZmA== - jest-changed-files@^27.5.1: version "27.5.1" resolved "https://registry.npmmirror.com/jest-changed-files/-/jest-changed-files-27.5.1.tgz#a348aed00ec9bf671cc58a66fcbe7c3dfd6a68f5" @@ -10612,11 +10580,6 @@ qs@^6.4.0: dependencies: side-channel "^1.0.4" -quadprog@^1.6.1: - version "1.6.1" - resolved "https://registry.npmmirror.com/quadprog/-/quadprog-1.6.1.tgz#1cd3b13700de9553ef939a6fa73d0d55ddb2f082" - integrity sha512-fN5Jkcjlln/b3pJkseDKREf89JkKIyu6cKIVXisgL6ocKPQ0yTp9n6NZUAq3otEPPw78WZMG9K0o9WsfKyMWJw== - querystring@^0.2.0: version "0.2.1" resolved "https://registry.yarnpkg.com/querystring/-/querystring-0.2.1.tgz#40d77615bb09d16902a85c3e38aa8b5ed761c2dd"