diff --git a/src/vanilla/utils/atomFamily.ts b/src/vanilla/utils/atomFamily.ts
index 81b0098613..71832e8882 100644
--- a/src/vanilla/utils/atomFamily.ts
+++ b/src/vanilla/utils/atomFamily.ts
@@ -1,11 +1,19 @@
import type { Atom } from '../../vanilla.ts'
type ShouldRemove = (createdAt: number, param: Param) => boolean
+type Cleanup = () => void
+type Callback = (event: {
+ type: 'CREATE' | 'REMOVE'
+ param: Param
+ atom: AtomType
+}) => void
export interface AtomFamily {
(param: Param): AtomType
+ getParams(): Iterable
remove(param: Param): void
setShouldRemove(shouldRemove: ShouldRemove | null): void
+ unstable_listen(callback: Callback): Cleanup
}
export function atomFamily>(
@@ -20,6 +28,7 @@ export function atomFamily>(
type CreatedAt = number // in milliseconds
let shouldRemove: ShouldRemove | null = null
const atoms: Map = new Map()
+ const listeners = new Set>()
const createAtom = (param: Param) => {
let item: [AtomType, CreatedAt] | undefined
if (areEqual === undefined) {
@@ -43,16 +52,40 @@ export function atomFamily>(
}
const newAtom = initializeAtom(param)
+ notifyListeners('CREATE', param, newAtom)
atoms.set(param, [newAtom, Date.now()])
return newAtom
}
+ function notifyListeners(
+ type: 'CREATE' | 'REMOVE',
+ param: Param,
+ atom: AtomType,
+ ) {
+ for (const listener of listeners) {
+ listener({ type, param, atom })
+ }
+ }
+
+ createAtom.unstable_listen = (callback: Callback) => {
+ listeners.add(callback)
+ return () => {
+ listeners.delete(callback)
+ }
+ }
+
+ createAtom.getParams = () => atoms.keys()
+
createAtom.remove = (param: Param) => {
if (areEqual === undefined) {
+ if (!atoms.has(param)) return
+ const [atom] = atoms.get(param)!
+ notifyListeners('REMOVE', param, atom)
atoms.delete(param)
} else {
- for (const [key] of atoms) {
+ for (const [key, [atom]] of atoms) {
if (areEqual(key, param)) {
+ notifyListeners('REMOVE', key, atom)
atoms.delete(key)
break
}
@@ -63,8 +96,9 @@ export function atomFamily>(
createAtom.setShouldRemove = (fn: ShouldRemove | null) => {
shouldRemove = fn
if (!shouldRemove) return
- for (const [key, value] of atoms) {
- if (shouldRemove(value[1], key)) {
+ for (const [key, [atom, createdAt]] of atoms) {
+ if (shouldRemove(createdAt, key)) {
+ notifyListeners('REMOVE', key, atom)
atoms.delete(key)
}
}
diff --git a/tests/vanilla/utils/atomFamily.test.ts b/tests/vanilla/utils/atomFamily.test.ts
new file mode 100644
index 0000000000..01d9caf322
--- /dev/null
+++ b/tests/vanilla/utils/atomFamily.test.ts
@@ -0,0 +1,94 @@
+import { expect, it, vi } from 'vitest'
+import { type Atom, atom, createStore } from 'jotai/vanilla'
+import { atomFamily } from 'jotai/vanilla/utils'
+
+it('should create atoms with different params', () => {
+ const store = createStore()
+ const aFamily = atomFamily((param: number) => atom(param))
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+})
+
+it('should remove atoms', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily(initializeAtom)
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ aFamily.remove(2)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should remove atoms with custom comparator', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily(initializeAtom, (a, b) => a === b)
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ aFamily.remove(2)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should remove atoms with custom shouldRemove', () => {
+ const store = createStore()
+ const initializeAtom = vi.fn((param: number) => atom(param))
+ const aFamily = atomFamily>(initializeAtom)
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ aFamily.setShouldRemove((_createdAt, param) => param % 2 === 0)
+ initializeAtom.mockClear()
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(initializeAtom).toHaveBeenCalledTimes(0)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+ expect(store.get(aFamily(3))).toEqual(3)
+ expect(initializeAtom).toHaveBeenCalledTimes(1)
+})
+
+it('should notify listeners', () => {
+ const aFamily = atomFamily((param: number) => atom(param))
+ const listener = vi.fn(() => {})
+ type Event = { type: 'CREATE' | 'REMOVE'; param: number; atom: Atom }
+ const unsubscribe = aFamily.unstable_listen(listener)
+ const atom1 = aFamily(1)
+ expect(listener).toHaveBeenCalledTimes(1)
+ const eventCreate = listener.mock.calls[0]?.at(0) as unknown as Event
+ if (!eventCreate) throw new Error('eventCreate is undefined')
+ expect(eventCreate.type).toEqual('CREATE')
+ expect(eventCreate.param).toEqual(1)
+ expect(eventCreate.atom).toEqual(atom1)
+ listener.mockClear()
+ aFamily.remove(1)
+ expect(listener).toHaveBeenCalledTimes(1)
+ const eventRemove = listener.mock.calls[0]?.at(0) as unknown as Event
+ expect(eventRemove.type).toEqual('REMOVE')
+ expect(eventRemove.param).toEqual(1)
+ expect(eventRemove.atom).toEqual(atom1)
+ unsubscribe()
+ listener.mockClear()
+ aFamily(2)
+ expect(listener).toHaveBeenCalledTimes(0)
+})
+
+it('should return all params', () => {
+ const store = createStore()
+ const aFamily = atomFamily((param: number) => atom(param))
+
+ expect(store.get(aFamily(1))).toEqual(1)
+ expect(store.get(aFamily(2))).toEqual(2)
+ expect(store.get(aFamily(3))).toEqual(3)
+ expect(Array.from(aFamily.getParams())).toEqual([1, 2, 3])
+})