Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/query-core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export {
noop,
partialMatchKey,
replaceEqualDeep,
resolveEnabled,
shouldThrowError,
skipToken,
} from './utils'
Expand Down
29 changes: 26 additions & 3 deletions packages/query-core/src/queryObserver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ export class QueryObserver<
isRefetchError: isError && hasData,
isStale: isStale(query, options),
refetch: this.refetch,
promise: this.#currentThenable,
promise: tagThenable(this.#currentThenable, query.queryHash),
isEnabled: resolveEnabled(options.enabled, query) !== false,
}

Expand All @@ -612,7 +612,7 @@ export class QueryObserver<
const pending =
(this.#currentThenable =
nextResult.promise =
pendingThenable())
tagThenable(pendingThenable<TData>(), query.queryHash))

finalizeThenableIfPossible(pending)
}
Expand All @@ -632,7 +632,11 @@ export class QueryObserver<
}
break
case 'rejected':
if (!isErrorWithoutData || nextResult.error !== prevThenable.reason) {
if (
!isErrorWithoutData ||
nextResult.error !== prevThenable.reason ||
nextResult.fetchStatus === 'fetching'
) {
recreateThenable()
}
break
Expand Down Expand Up @@ -830,3 +834,22 @@ function shouldAssignObserverCurrentProperties<
// basically, just keep previous properties if nothing changed
return false
}

function tagThenable<TThenable extends Thenable<any>>(
thenable: TThenable,
queryHash: string,
): TThenable {
if (!Object.prototype.hasOwnProperty.call(thenable, 'queryHash')) {
Object.defineProperty(thenable, 'queryHash', {
value: queryHash,
enumerable: false,
configurable: true,
})
}
return thenable
}

/**
* @internal
*/
export type PromiseWithHash<T> = Promise<T> & { queryHash?: string }
55 changes: 54 additions & 1 deletion packages/react-query/src/QueryErrorResetBoundary.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
'use client'
import * as React from 'react'

import { useQueryClient } from './QueryClientProvider'

// CONTEXT
export type QueryErrorResetFunction = () => void
export type QueryErrorIsResetFunction = () => boolean
Expand All @@ -10,6 +12,7 @@ export interface QueryErrorResetBoundaryValue {
clearReset: QueryErrorClearResetFunction
isReset: QueryErrorIsResetFunction
reset: QueryErrorResetFunction
register: (queryHash: string) => void
}

function createValue(): QueryErrorResetBoundaryValue {
Expand All @@ -24,6 +27,7 @@ function createValue(): QueryErrorResetBoundaryValue {
isReset: () => {
return isReset
},
register: () => {},
}
}

Expand All @@ -47,10 +51,59 @@ export interface QueryErrorResetBoundaryProps {
export const QueryErrorResetBoundary = ({
children,
}: QueryErrorResetBoundaryProps) => {
const [value] = React.useState(() => createValue())
const client = useQueryClient()
const registeredQueries = React.useRef(new Set<string>())
const [value] = React.useState(() => {
const boundary = createValue()
return {
...boundary,
reset: () => {
boundary.reset()
const queryHashes = new Set(registeredQueries.current)
registeredQueries.current.clear()

void client.refetchQueries({
predicate: (query) =>
queryHashes.has(query.queryHash) && query.state.status === 'error',
type: 'active',
})
},
register: (queryHash: string) => {
registeredQueries.current.add(queryHash)
},
}
})
return (
<QueryErrorResetBoundaryContext.Provider value={value}>
{typeof children === 'function' ? children(value) : children}
</QueryErrorResetBoundaryContext.Provider>
)
}

/**
* @internal
*/
export function getQueryHash(query: any): string | undefined {
if (typeof query === 'object' && query !== null) {
if ('queryHash' in query) {
return query.queryHash
}
if (
'promise' in query &&
query.promise &&
typeof query.promise === 'object' &&
'queryHash' in query.promise
) {
return query.promise.queryHash
}
}
return undefined
}

export function useTrackQueryHash(query: any) {
const { register } = useQueryErrorResetBoundary()
const hash = getQueryHash(query)
if (hash) {
register(hash)
}
}
236 changes: 236 additions & 0 deletions packages/react-query/src/__tests__/QueryResetErrorBoundary.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
useQuery,
useSuspenseQueries,
useSuspenseQuery,
useTrackQueryHash,
} from '..'
import { renderWithClient } from './utils'

Expand Down Expand Up @@ -941,4 +942,239 @@ describe('QueryErrorResetBoundary', () => {
consoleMock.mockRestore()
})
})

describe('Scoped Registry', () => {
it('should isolate resets between different boundaries', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key1 = queryKey()
const key2 = queryKey()
let count1 = 0
let count2 = 0

function Comp1() {
useQuery({
queryKey: key1,
queryFn: async () => {
await sleep(10)
count1++
throw new Error('fail1')
},
retry: false,
throwOnError: true,
})
return null
}

function Comp2() {
useQuery({
queryKey: key2,
queryFn: async () => {
await sleep(10)
count2++
throw new Error('fail2')
},
retry: false,
throwOnError: true,
})
return null
}

const rendered = renderWithClient(
queryClient,
<>
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset1</button>
</div>
)}
>
<React.Suspense fallback="loading1">
<Comp1 />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset2</button>
</div>
)}
>
<React.Suspense fallback="loading2">
<Comp2 />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>
</>,
)

await vi.advanceTimersByTimeAsync(11)
expect(rendered.getByText('reset1')).toBeInTheDocument()
expect(rendered.getByText('reset2')).toBeInTheDocument()
expect(count1).toBe(1)
expect(count2).toBe(1)

fireEvent.click(rendered.getByText('reset1'))

await vi.advanceTimersByTimeAsync(11)
expect(count1).toBe(2)
expect(count2).toBe(1)

consoleMock.mockRestore()
})

it('should clear registry after reset', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key = queryKey()
let count = 0

function Comp() {
useQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
throw new Error('fail')
},
retry: false,
throwOnError: true,
})
return null
}

const rendered = renderWithClient(
queryClient,
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<div>
<button onClick={resetErrorBoundary}>reset</button>
</div>
)}
>
<React.Suspense fallback="loading">
<Comp />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>,
)

await vi.advanceTimersByTimeAsync(11)
expect(rendered.getByText('reset')).toBeInTheDocument()
expect(count).toBe(1)

fireEvent.click(rendered.getByText('reset'))
await vi.advanceTimersByTimeAsync(11)
expect(count).toBe(2)

consoleMock.mockRestore()
})

it('should handle StrictMode double registration gracefully', async () => {
const key = queryKey()
let count = 0

function Comp() {
useQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
return 'ok'
},
})
return null
}

renderWithClient(
queryClient,
<React.StrictMode>
<QueryErrorResetBoundary>
<Comp />
</QueryErrorResetBoundary>
</React.StrictMode>,
)

await vi.advanceTimersByTimeAsync(11)
expect(count).toBeGreaterThanOrEqual(1)
})

it('should support tracking queries outside the boundary via useTrackQueryHash', async () => {
const consoleMock = vi
.spyOn(console, 'error')
.mockImplementation(() => undefined)
const key = queryKey()
let count = 0

function Child() {
const { data } = useSuspenseQuery({
queryKey: key,
queryFn: async () => {
await sleep(10)
count++
if (count === 1) {
throw new Error('fail')
}
return 'ok'
},
retry: false,
})
return <div>{data}</div>
}

function TrackedChild() {
const hash = queryClient
.getQueryCache()
.build(queryClient, { queryKey: key }).queryHash
useTrackQueryHash({ queryHash: hash })
return null
}

const rendered = renderWithClient(
queryClient,
<QueryErrorResetBoundary>
{({ reset }) => (
<ErrorBoundary
onReset={reset}
fallbackRender={({ resetErrorBoundary }) => (
<button onClick={resetErrorBoundary}>retry</button>
)}
>
<React.Suspense fallback="loading">
<TrackedChild />
<Child />
</React.Suspense>
</ErrorBoundary>
)}
</QueryErrorResetBoundary>,
)

await act(() => vi.advanceTimersByTimeAsync(11))
expect(rendered.getByText('retry')).toBeInTheDocument()
expect(count).toBe(1)

fireEvent.click(rendered.getByText('retry'))
await act(() => vi.advanceTimersByTimeAsync(11))
expect(count).toBe(2)
expect(rendered.getByText('ok')).toBeInTheDocument()

consoleMock.mockRestore()
})
})
})
Loading
Loading