diff --git a/src/vanilla/utils/atomFamily.ts b/src/vanilla/utils/atomFamily.ts index 81b0098613..71832e8882 100644 --- a/src/vanilla/utils/atomFamily.ts +++ b/src/vanilla/utils/atomFamily.ts @@ -1,11 +1,19 @@ import type { Atom } from '../../vanilla.ts' type ShouldRemove = (createdAt: number, param: Param) => boolean +type Cleanup = () => void +type Callback = (event: { + type: 'CREATE' | 'REMOVE' + param: Param + atom: AtomType +}) => void export interface AtomFamily { (param: Param): AtomType + getParams(): Iterable remove(param: Param): void setShouldRemove(shouldRemove: ShouldRemove | null): void + unstable_listen(callback: Callback): Cleanup } export function atomFamily>( @@ -20,6 +28,7 @@ export function atomFamily>( type CreatedAt = number // in milliseconds let shouldRemove: ShouldRemove | null = null const atoms: Map = new Map() + const listeners = new Set>() const createAtom = (param: Param) => { let item: [AtomType, CreatedAt] | undefined if (areEqual === undefined) { @@ -43,16 +52,40 @@ export function atomFamily>( } const newAtom = initializeAtom(param) + notifyListeners('CREATE', param, newAtom) atoms.set(param, [newAtom, Date.now()]) return newAtom } + function notifyListeners( + type: 'CREATE' | 'REMOVE', + param: Param, + atom: AtomType, + ) { + for (const listener of listeners) { + listener({ type, param, atom }) + } + } + + createAtom.unstable_listen = (callback: Callback) => { + listeners.add(callback) + return () => { + listeners.delete(callback) + } + } + + createAtom.getParams = () => atoms.keys() + createAtom.remove = (param: Param) => { if (areEqual === undefined) { + if (!atoms.has(param)) return + const [atom] = atoms.get(param)! + notifyListeners('REMOVE', param, atom) atoms.delete(param) } else { - for (const [key] of atoms) { + for (const [key, [atom]] of atoms) { if (areEqual(key, param)) { + notifyListeners('REMOVE', key, atom) atoms.delete(key) break } @@ -63,8 +96,9 @@ export function atomFamily>( createAtom.setShouldRemove = (fn: ShouldRemove | null) => { shouldRemove = fn if (!shouldRemove) return - for (const [key, value] of atoms) { - if (shouldRemove(value[1], key)) { + for (const [key, [atom, createdAt]] of atoms) { + if (shouldRemove(createdAt, key)) { + notifyListeners('REMOVE', key, atom) atoms.delete(key) } } diff --git a/tests/vanilla/utils/atomFamily.test.ts b/tests/vanilla/utils/atomFamily.test.ts new file mode 100644 index 0000000000..01d9caf322 --- /dev/null +++ b/tests/vanilla/utils/atomFamily.test.ts @@ -0,0 +1,94 @@ +import { expect, it, vi } from 'vitest' +import { type Atom, atom, createStore } from 'jotai/vanilla' +import { atomFamily } from 'jotai/vanilla/utils' + +it('should create atoms with different params', () => { + const store = createStore() + const aFamily = atomFamily((param: number) => atom(param)) + + expect(store.get(aFamily(1))).toEqual(1) + expect(store.get(aFamily(2))).toEqual(2) +}) + +it('should remove atoms', () => { + const store = createStore() + const initializeAtom = vi.fn((param: number) => atom(param)) + const aFamily = atomFamily(initializeAtom) + + expect(store.get(aFamily(1))).toEqual(1) + expect(store.get(aFamily(2))).toEqual(2) + aFamily.remove(2) + initializeAtom.mockClear() + expect(store.get(aFamily(1))).toEqual(1) + expect(initializeAtom).toHaveBeenCalledTimes(0) + expect(store.get(aFamily(2))).toEqual(2) + expect(initializeAtom).toHaveBeenCalledTimes(1) +}) + +it('should remove atoms with custom comparator', () => { + const store = createStore() + const initializeAtom = vi.fn((param: number) => atom(param)) + const aFamily = atomFamily(initializeAtom, (a, b) => a === b) + + expect(store.get(aFamily(1))).toEqual(1) + expect(store.get(aFamily(2))).toEqual(2) + expect(store.get(aFamily(3))).toEqual(3) + aFamily.remove(2) + initializeAtom.mockClear() + expect(store.get(aFamily(1))).toEqual(1) + expect(initializeAtom).toHaveBeenCalledTimes(0) + expect(store.get(aFamily(2))).toEqual(2) + expect(initializeAtom).toHaveBeenCalledTimes(1) +}) + +it('should remove atoms with custom shouldRemove', () => { + const store = createStore() + const initializeAtom = vi.fn((param: number) => atom(param)) + const aFamily = atomFamily>(initializeAtom) + expect(store.get(aFamily(1))).toEqual(1) + expect(store.get(aFamily(2))).toEqual(2) + expect(store.get(aFamily(3))).toEqual(3) + aFamily.setShouldRemove((_createdAt, param) => param % 2 === 0) + initializeAtom.mockClear() + expect(store.get(aFamily(1))).toEqual(1) + expect(initializeAtom).toHaveBeenCalledTimes(0) + expect(store.get(aFamily(2))).toEqual(2) + expect(initializeAtom).toHaveBeenCalledTimes(1) + expect(store.get(aFamily(3))).toEqual(3) + expect(initializeAtom).toHaveBeenCalledTimes(1) +}) + +it('should notify listeners', () => { + const aFamily = atomFamily((param: number) => atom(param)) + const listener = vi.fn(() => {}) + type Event = { type: 'CREATE' | 'REMOVE'; param: number; atom: Atom } + const unsubscribe = aFamily.unstable_listen(listener) + const atom1 = aFamily(1) + expect(listener).toHaveBeenCalledTimes(1) + const eventCreate = listener.mock.calls[0]?.at(0) as unknown as Event + if (!eventCreate) throw new Error('eventCreate is undefined') + expect(eventCreate.type).toEqual('CREATE') + expect(eventCreate.param).toEqual(1) + expect(eventCreate.atom).toEqual(atom1) + listener.mockClear() + aFamily.remove(1) + expect(listener).toHaveBeenCalledTimes(1) + const eventRemove = listener.mock.calls[0]?.at(0) as unknown as Event + expect(eventRemove.type).toEqual('REMOVE') + expect(eventRemove.param).toEqual(1) + expect(eventRemove.atom).toEqual(atom1) + unsubscribe() + listener.mockClear() + aFamily(2) + expect(listener).toHaveBeenCalledTimes(0) +}) + +it('should return all params', () => { + const store = createStore() + const aFamily = atomFamily((param: number) => atom(param)) + + expect(store.get(aFamily(1))).toEqual(1) + expect(store.get(aFamily(2))).toEqual(2) + expect(store.get(aFamily(3))).toEqual(3) + expect(Array.from(aFamily.getParams())).toEqual([1, 2, 3]) +})