Skip to content

Commit

Permalink
Ensure scan functions use long-lived memory context
Browse files Browse the repository at this point in the history
PostgreSQL scan functions might allocate memory that needs to live for
the duration of the scan. This applies also to functions that are
called during the scan, such as getting the next tuple. To avoid
situations when such functions are accidentally called on, e.g., a
short-lived per-tuple context, add a explicit scan memory context to
the Scanner interface that wraps the PostgreSQL scan API.
  • Loading branch information
erimatnor committed Mar 16, 2022
1 parent d1e02df commit ab6bbf5
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
4 changes: 1 addition & 3 deletions src/scan_iterator.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ts_scan_iterator_scan_key_init(ScanIterator *iterator, AttrNumber attributeNumbe
* sure the scan key is initialized on the long-lived scankey memory
* context.
*/
oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt);
oldmcxt = MemoryContextSwitchTo(iterator->ctx.internal.scan_mcxt);
ScanKeyInit(&iterator->scankey[iterator->ctx.nkeys++],
attributeNumber,
strategy,
Expand All @@ -56,7 +56,5 @@ ts_scan_iterator_scan_key_init(ScanIterator *iterator, AttrNumber attributeNumbe
TSDLLEXPORT void
ts_scan_iterator_rescan(ScanIterator *iterator)
{
MemoryContext oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt);
ts_scanner_rescan(&iterator->ctx, NULL);
MemoryContextSwitchTo(oldmcxt);
}
7 changes: 2 additions & 5 deletions src/scan_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ typedef struct ScanIterator
{
ScannerCtx ctx;
TupleInfo *tinfo;
MemoryContext scankey_mcxt;
ScanKeyData scankey[EMBEDDED_SCAN_KEY_SIZE];
} ScanIterator;

Expand All @@ -28,15 +27,15 @@ typedef struct ScanIterator
.ctx = { \
.internal = { \
.ended = true, \
.scan_mcxt = CurrentMemoryContext, \
}, \
.table = catalog_get_table_id(ts_catalog_get(), catalog_table_id), \
.nkeys = 0, \
.scandirection = ForwardScanDirection, \
.lockmode = lock_mode, \
.result_mctx = mctx, \
.flags = SCANNER_F_NOFLAGS, \
}, \
.scankey_mcxt = CurrentMemoryContext, \
}, \
}

static inline TupleInfo *
Expand Down Expand Up @@ -78,9 +77,7 @@ ts_scan_iterator_alloc_result(const ScanIterator *iterator, Size size)
static inline void
ts_scan_iterator_start_scan(ScanIterator *iterator)
{
MemoryContext oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt);
ts_scanner_start_scan(&(iterator)->ctx);
MemoryContextSwitchTo(oldmcxt);
}

static inline TupleInfo *
Expand Down
50 changes: 47 additions & 3 deletions src/scanner.c
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,16 @@ TSDLLEXPORT void
ts_scanner_rescan(ScannerCtx *ctx, const ScanKey scankey)
{
Scanner *scanner = scanner_ctx_get_scanner(ctx);
MemoryContext oldmcxt;

/* If scankey is NULL, the existing scan key was already updated or the
* old should be reused */
if (NULL != scankey)
memcpy(ctx->scankey, scankey, sizeof(*ctx->scankey));

oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
scanner->rescan(ctx);
MemoryContextSwitchTo(oldmcxt);
}

static void
Expand All @@ -187,6 +190,9 @@ prepare_scan(ScannerCtx *ctx)
ctx->internal.ended = false;
ctx->internal.registered_snapshot = false;

if (ctx->internal.scan_mcxt == NULL)
ctx->internal.scan_mcxt = CurrentMemoryContext;

if (ctx->snapshot == NULL)
{
/*
Expand All @@ -213,20 +219,29 @@ prepare_scan(ScannerCtx *ctx)
* hypertables compared to regular tables under SERIALIZABLE
* mode.
*/
MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
ctx->snapshot = RegisterSnapshot(GetSnapshotData(SnapshotSelf));
ctx->internal.registered_snapshot = true;
MemoryContextSwitchTo(oldmcxt);
}
}

TSDLLEXPORT Relation
ts_scanner_open(ScannerCtx *ctx)
{
Scanner *scanner = scanner_ctx_get_scanner(ctx);
MemoryContext oldmcxt;
Relation rel;

Assert(NULL == ctx->tablerel);

prepare_scan(ctx);
Assert(ctx->internal.scan_mcxt != NULL);
oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
rel = scanner->openscan(ctx);
MemoryContextSwitchTo(oldmcxt);

return scanner->openscan(ctx);
return rel;
}

/*
Expand All @@ -240,6 +255,7 @@ ts_scanner_start_scan(ScannerCtx *ctx)
InternalScannerCtx *ictx = &ctx->internal;
Scanner *scanner;
TupleDesc tuple_desc;
MemoryContext oldmcxt;

if (ictx->started)
{
Expand Down Expand Up @@ -269,6 +285,9 @@ ts_scanner_start_scan(ScannerCtx *ctx)
ctx->index = RelationGetRelid(ctx->indexrel);
}

Assert(ctx->internal.scan_mcxt != NULL);
oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);

scanner = scanner_ctx_get_scanner(ctx);
scanner->beginscan(ctx);

Expand All @@ -277,6 +296,7 @@ ts_scanner_start_scan(ScannerCtx *ctx)
ictx->tinfo.scanrel = ctx->tablerel;
ictx->tinfo.mctx = ctx->result_mctx == NULL ? CurrentMemoryContext : ctx->result_mctx;
ictx->tinfo.slot = MakeSingleTupleTableSlot(tuple_desc, table_slot_callbacks(ctx->tablerel));
MemoryContextSwitchTo(oldmcxt);

/* Call pre-scan handler, if any. */
if (ctx->prescan != NULL)
Expand Down Expand Up @@ -307,13 +327,19 @@ scanner_cleanup(ScannerCtx *ctx)
ExecDropSingleTupleTableSlot(ictx->tinfo.slot);
ictx->tinfo.slot = NULL;
}

if (NULL != ictx->scan_mcxt)
{
ictx->scan_mcxt = NULL;
}
}

TSDLLEXPORT void
ts_scanner_end_scan(ScannerCtx *ctx)
{
InternalScannerCtx *ictx = &ctx->internal;
Scanner *scanner = scanner_ctx_get_scanner(ctx);
MemoryContext oldmcxt;

if (ictx->ended)
return;
Expand All @@ -322,7 +348,10 @@ ts_scanner_end_scan(ScannerCtx *ctx)
if (ctx->postscan != NULL)
ctx->postscan(ictx->tinfo.count, ctx->data);

oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
scanner->endscan(ctx);
MemoryContextSwitchTo(oldmcxt);

scanner_cleanup(ctx);
ictx->ended = true;
ictx->started = false;
Expand All @@ -348,7 +377,14 @@ ts_scanner_next(ScannerCtx *ctx)
{
InternalScannerCtx *ictx = &ctx->internal;
Scanner *scanner = scanner_ctx_get_scanner(ctx);
bool is_valid = ts_scanner_limit_reached(ctx) ? false : scanner->getnext(ctx);
bool is_valid = false;

if (!ts_scanner_limit_reached(ctx))
{
MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
is_valid = scanner->getnext(ctx);
MemoryContextSwitchTo(oldmcxt);
}

while (is_valid)
{
Expand All @@ -375,7 +411,15 @@ ts_scanner_next(ScannerCtx *ctx)
/* stop at a valid tuple */
return &ictx->tinfo;
}
is_valid = ts_scanner_limit_reached(ctx) ? false : scanner->getnext(ctx);

if (ts_scanner_limit_reached(ctx))
is_valid = false;
else
{
MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt);
is_valid = scanner->getnext(ctx);
MemoryContextSwitchTo(oldmcxt);
}
}

if (!(ctx->flags & SCANNER_F_NOEND))
Expand Down
6 changes: 6 additions & 0 deletions src/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ typedef struct InternalScannerCtx
{
TupleInfo tinfo;
ScanDesc scan;
/*
* PG scan functions must be called on a memory context that lives
* throughout the entire scan. Use the scan_mcxt to ensure that
* functions aren't called on, e.g., a per-tuple context.
*/
MemoryContext scan_mcxt;
bool registered_snapshot;
bool started;
bool ended;
Expand Down
3 changes: 0 additions & 3 deletions src/telemetry/stats.c
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,11 @@ get_chunk_compression_stats(StatsContext *statsctx, const Chunk *chunk,
Form_compression_chunk_size compr_stats)
{
TupleInfo *ti;
MemoryContext oldmcxt;

if (!ts_chunk_is_compressed(chunk))
return false;

/* Need to execute the scan functions on the long-lived memory context */
oldmcxt = MemoryContextSwitchTo(statsctx->compressed_chunk_stats_iterator.scankey_mcxt);
ts_scan_iterator_scan_key_reset(&statsctx->compressed_chunk_stats_iterator);
ts_scan_iterator_scan_key_init(&statsctx->compressed_chunk_stats_iterator,
Anum_compression_chunk_size_pkey_chunk_id,
Expand All @@ -353,7 +351,6 @@ get_chunk_compression_stats(StatsContext *statsctx, const Chunk *chunk,
Int32GetDatum(chunk->fd.id));
ts_scan_iterator_start_or_restart_scan(&statsctx->compressed_chunk_stats_iterator);
ti = ts_scan_iterator_next(&statsctx->compressed_chunk_stats_iterator);
MemoryContextSwitchTo(oldmcxt);

if (ti)
{
Expand Down

0 comments on commit ab6bbf5

Please sign in to comment.