diff --git a/.size-snapshot.json b/.size-snapshot.json index 17491b4ee6..41cf8b21ac 100644 --- a/.size-snapshot.json +++ b/.size-snapshot.json @@ -77,9 +77,9 @@ } }, "middleware.js": { - "bundled": 5547, - "minified": 2945, - "gzipped": 1293, + "bundled": 6548, + "minified": 3299, + "gzipped": 1388, "treeshaked": { "rollup": { "code": 0, diff --git a/src/middleware.ts b/src/middleware.ts index 7530a48f1b..3097a52782 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -204,6 +204,43 @@ type PersistOptions = { migrate?: (persistedState: any, version: number) => S | Promise } +interface Thenable { + then( + onFulfilled: (value: Value) => V | Promise | Thenable + ): Thenable + catch( + onRejected: (reason: Error) => V | Promise | Thenable + ): Thenable +} + +const toThenable = ( + fn: (input: Input) => Result | Promise | Thenable +) => (input: Input): Thenable => { + try { + const result = fn(input) + if (result instanceof Promise) { + return result as Thenable + } + return { + then(onFulfilled) { + return toThenable(onFulfilled)(result as Result) + }, + catch(_onRejected) { + return this as Thenable + }, + } + } catch (e) { + return { + then(_onFulfilled) { + return this as Thenable + }, + catch(onRejected) { + return toThenable(onRejected)(e) + }, + } + } +} + export const persist = ( config: StateCreator, options: PersistOptions @@ -211,8 +248,8 @@ export const persist = ( const { name, getStorage = () => localStorage, - serialize = JSON.stringify, - deserialize = JSON.parse, + serialize = JSON.stringify as (state: StorageValue) => string, + deserialize = JSON.parse as (str: string) => StorageValue, blacklist, whitelist, onRehydrateStorage, @@ -241,7 +278,9 @@ export const persist = ( ) } - const setItem = async () => { + const thenableSerialize = toThenable(serialize) + + const setItem = (): Thenable => { const state = { ...get() } if (whitelist) { @@ -253,7 +292,18 @@ export const persist = ( blacklist.forEach((key) => delete state[key]) } - return storage?.setItem(name, await serialize({ state, version })) + let errorInSync: Error | undefined + const thenable = thenableSerialize({ state, version }) + .then((serializedValue) => + (storage as StateStorage).setItem(name, serializedValue) + ) + .catch((e) => { + errorInSync = e + }) + if (errorInSync) { + throw errorInSync + } + return thenable } const savedSetState = api.setState @@ -264,38 +314,52 @@ export const persist = ( } // rehydrate initial state with existing stored state - ;(async () => { - const postRehydrationCallback = onRehydrateStorage?.(get()) || undefined - - try { - const storageValue = await storage.getItem(name) + // a workaround to solve the issue of not storing rehydrated state in sync storage + // the set(state) value would be later overridden with initial state by create() + // to avoid this, we merge the state from localStorage into the initial state. + let stateFromStorageInSync: S | undefined + const postRehydrationCallback = onRehydrateStorage?.(get()) || undefined + // bind is used to avoid `TypeError: Illegal invocation` error + toThenable(storage.getItem.bind(storage))(name) + .then((storageValue) => { if (storageValue) { - const deserializedStorageValue = await deserialize(storageValue) - - // if versions mismatch, run migration + return deserialize(storageValue) + } + }) + .then((deserializedStorageValue) => { + if (deserializedStorageValue) { if (deserializedStorageValue.version !== version) { - const migratedState = await migrate?.( - deserializedStorageValue.state, - deserializedStorageValue.version - ) - if (migratedState) { - set(migratedState) - await setItem() + if (migrate) { + return migrate( + deserializedStorageValue.state, + deserializedStorageValue.version + ) } + console.error( + `State loaded from storage couldn't be migrated since no migrate function was provided` + ) } else { + stateFromStorageInSync = deserializedStorageValue.state set(deserializedStorageValue.state) } } - } catch (e) { + }) + .then((migratedState) => { + if (migratedState) { + stateFromStorageInSync = migratedState as S + set(migratedState as PartialState) + return setItem() + } + }) + .then(() => { + postRehydrationCallback?.(get(), undefined) + }) + .catch((e: Error) => { postRehydrationCallback?.(undefined, e) - return - } - - postRehydrationCallback?.(get(), undefined) - })() + }) - return config( + const configResult = config( (...args) => { set(...args) void setItem() @@ -303,4 +367,8 @@ export const persist = ( get, api ) + + return stateFromStorageInSync + ? { ...configResult, ...stateFromStorageInSync } + : configResult } diff --git a/tests/persist.test.tsx b/tests/persist.test.tsx index 96b5002039..a5a105ab64 100644 --- a/tests/persist.test.tsx +++ b/tests/persist.test.tsx @@ -151,7 +151,7 @@ it('can migrate persisted state', async () => { name: 'test-storage', version: 13, getStorage: () => storage, - migrate: (state, version) => { + migrate: async (state, version) => { migrateCallCount++ expect(state.count).toBe(42) expect(version).toBe(12) diff --git a/tests/persistSync.test.tsx b/tests/persistSync.test.tsx new file mode 100644 index 0000000000..e8539d963f --- /dev/null +++ b/tests/persistSync.test.tsx @@ -0,0 +1,201 @@ +import create from '../src/index' +import { persist } from '../src/middleware' + +const consoleError = console.error +afterEach(() => { + console.error = consoleError +}) + +describe('persist middleware with sync configuration', () => { + it('can rehydrate state', () => { + let postRehydrationCallbackCallCount = 0 + + const storage = { + getItem: (name: string) => + JSON.stringify({ + state: { count: 42, name }, + version: 0, + }), + setItem: () => {}, + } + + const useStore = create( + persist( + () => ({ + count: 0, + name: 'empty', + }), + { + name: 'test-storage', + getStorage: () => storage, + onRehydrateStorage: () => (state, error) => { + postRehydrationCallbackCallCount++ + expect(error).toBeUndefined() + expect(state?.count).toBe(42) + expect(state?.name).toBe('test-storage') + }, + } + ) + ) + + expect(useStore.getState()).toEqual({ count: 42, name: 'test-storage' }) + expect(postRehydrationCallbackCallCount).toBe(1) + }) + + it('can throw rehydrate error', () => { + let postRehydrationCallbackCallCount = 0 + + const storage = { + getItem: () => { + throw new Error('getItem error') + }, + setItem: () => {}, + } + + create( + persist(() => ({ count: 0 }), { + name: 'test-storage', + getStorage: () => storage, + onRehydrateStorage: () => (_, e) => { + postRehydrationCallbackCallCount++ + expect(e?.message).toBe('getItem error') + }, + }) + ) + + expect(postRehydrationCallbackCallCount).toBe(1) + }) + + it('can persist state', () => { + let setItemCallCount = 0 + + const storage = { + getItem: () => null, + setItem: (name: string, value: string) => { + setItemCallCount++ + expect(name).toBe('test-storage') + expect(value).toBe( + JSON.stringify({ + state: { count: 42 }, + version: 0, + }) + ) + }, + } + + const useStore = create( + persist(() => ({ count: 0 }), { + name: 'test-storage', + getStorage: () => storage, + onRehydrateStorage: () => (_, error) => { + expect(error).toBeUndefined() + }, + }) + ) + + expect(useStore.getState()).toEqual({ count: 0 }) + useStore.setState({ count: 42 }) + expect(useStore.getState()).toEqual({ count: 42 }) + expect(setItemCallCount).toBe(1) + }) + + it('can migrate persisted state', () => { + let migrateCallCount = 0 + let setItemCallCount = 0 + + const storage = { + getItem: () => + JSON.stringify({ + state: { count: 42 }, + version: 12, + }), + setItem: (_: string, value: string) => { + setItemCallCount++ + expect(value).toBe( + JSON.stringify({ + state: { count: 99 }, + version: 13, + }) + ) + }, + } + + const useStore = create( + persist(() => ({ count: 0 }), { + name: 'test-storage', + version: 13, + getStorage: () => storage, + onRehydrateStorage: () => (_, error) => { + expect(error).toBeUndefined() + }, + migrate: (state, version) => { + migrateCallCount++ + expect(state.count).toBe(42) + expect(version).toBe(12) + return { count: 99 } + }, + }) + ) + + expect(useStore.getState()).toEqual({ count: 99 }) + expect(migrateCallCount).toBe(1) + expect(setItemCallCount).toBe(1) + }) + + it.only('can correclty handle a missing migrate function', () => { + console.error = jest.fn() + const storage = { + getItem: () => + JSON.stringify({ + state: { count: 42 }, + version: 12, + }), + setItem: (_: string, value: string) => {}, + } + + const useStore = create( + persist(() => ({ count: 0 }), { + name: 'test-storage', + version: 13, + getStorage: () => storage, + onRehydrateStorage: () => (_, error) => { + expect(error).toBeUndefined() + }, + }) + ) + + expect(useStore.getState()).toEqual({ count: 0 }) + expect(console.error).toHaveBeenCalled() + }) + + it('can throw migrate error', () => { + let postRehydrationCallbackCallCount = 0 + + const storage = { + getItem: () => + JSON.stringify({ + state: {}, + version: 12, + }), + setItem: () => {}, + } + + const useStore = create( + persist(() => ({ count: 0 }), { + name: 'test-storage', + version: 13, + getStorage: () => storage, + migrate: () => { + throw new Error('migrate error') + }, + onRehydrateStorage: () => (_, e) => { + postRehydrationCallbackCallCount++ + expect(e?.message).toBe('migrate error') + }, + }) + ) + + expect(useStore.getState()).toEqual({ count: 0 }) + expect(postRehydrationCallbackCallCount).toBe(1) + }) +})