From ab6bbf5ec04edc83e4f7e2f6d697d9a67dc17f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Nordstr=C3=B6m?= Date: Wed, 16 Mar 2022 11:10:29 +0100 Subject: [PATCH] Ensure scan functions use long-lived memory context PostgreSQL scan functions might allocate memory that needs to live for the duration of the scan. This applies also to functions that are called during the scan, such as getting the next tuple. To avoid situations when such functions are accidentally called on, e.g., a short-lived per-tuple context, add a explicit scan memory context to the Scanner interface that wraps the PostgreSQL scan API. --- src/scan_iterator.c | 4 +--- src/scan_iterator.h | 7 ++---- src/scanner.c | 50 ++++++++++++++++++++++++++++++++++++++++--- src/scanner.h | 6 ++++++ src/telemetry/stats.c | 3 --- 5 files changed, 56 insertions(+), 14 deletions(-) diff --git a/src/scan_iterator.c b/src/scan_iterator.c index f2a1d6e06dc..2b2dd45e7b1 100644 --- a/src/scan_iterator.c +++ b/src/scan_iterator.c @@ -44,7 +44,7 @@ ts_scan_iterator_scan_key_init(ScanIterator *iterator, AttrNumber attributeNumbe * sure the scan key is initialized on the long-lived scankey memory * context. */ - oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt); + oldmcxt = MemoryContextSwitchTo(iterator->ctx.internal.scan_mcxt); ScanKeyInit(&iterator->scankey[iterator->ctx.nkeys++], attributeNumber, strategy, @@ -56,7 +56,5 @@ ts_scan_iterator_scan_key_init(ScanIterator *iterator, AttrNumber attributeNumbe TSDLLEXPORT void ts_scan_iterator_rescan(ScanIterator *iterator) { - MemoryContext oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt); ts_scanner_rescan(&iterator->ctx, NULL); - MemoryContextSwitchTo(oldmcxt); } diff --git a/src/scan_iterator.h b/src/scan_iterator.h index 099e781f94c..42ba300b07f 100644 --- a/src/scan_iterator.h +++ b/src/scan_iterator.h @@ -18,7 +18,6 @@ typedef struct ScanIterator { ScannerCtx ctx; TupleInfo *tinfo; - MemoryContext scankey_mcxt; ScanKeyData scankey[EMBEDDED_SCAN_KEY_SIZE]; } ScanIterator; @@ -28,6 +27,7 @@ typedef struct ScanIterator .ctx = { \ .internal = { \ .ended = true, \ + .scan_mcxt = CurrentMemoryContext, \ }, \ .table = catalog_get_table_id(ts_catalog_get(), catalog_table_id), \ .nkeys = 0, \ @@ -35,8 +35,7 @@ typedef struct ScanIterator .lockmode = lock_mode, \ .result_mctx = mctx, \ .flags = SCANNER_F_NOFLAGS, \ - }, \ - .scankey_mcxt = CurrentMemoryContext, \ + }, \ } static inline TupleInfo * @@ -78,9 +77,7 @@ ts_scan_iterator_alloc_result(const ScanIterator *iterator, Size size) static inline void ts_scan_iterator_start_scan(ScanIterator *iterator) { - MemoryContext oldmcxt = MemoryContextSwitchTo(iterator->scankey_mcxt); ts_scanner_start_scan(&(iterator)->ctx); - MemoryContextSwitchTo(oldmcxt); } static inline TupleInfo * diff --git a/src/scanner.c b/src/scanner.c index 84e98561751..e04241e7270 100644 --- a/src/scanner.c +++ b/src/scanner.c @@ -172,13 +172,16 @@ TSDLLEXPORT void ts_scanner_rescan(ScannerCtx *ctx, const ScanKey scankey) { Scanner *scanner = scanner_ctx_get_scanner(ctx); + MemoryContext oldmcxt; /* If scankey is NULL, the existing scan key was already updated or the * old should be reused */ if (NULL != scankey) memcpy(ctx->scankey, scankey, sizeof(*ctx->scankey)); + oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); scanner->rescan(ctx); + MemoryContextSwitchTo(oldmcxt); } static void @@ -187,6 +190,9 @@ prepare_scan(ScannerCtx *ctx) ctx->internal.ended = false; ctx->internal.registered_snapshot = false; + if (ctx->internal.scan_mcxt == NULL) + ctx->internal.scan_mcxt = CurrentMemoryContext; + if (ctx->snapshot == NULL) { /* @@ -213,8 +219,10 @@ prepare_scan(ScannerCtx *ctx) * hypertables compared to regular tables under SERIALIZABLE * mode. */ + MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); ctx->snapshot = RegisterSnapshot(GetSnapshotData(SnapshotSelf)); ctx->internal.registered_snapshot = true; + MemoryContextSwitchTo(oldmcxt); } } @@ -222,11 +230,18 @@ TSDLLEXPORT Relation ts_scanner_open(ScannerCtx *ctx) { Scanner *scanner = scanner_ctx_get_scanner(ctx); + MemoryContext oldmcxt; + Relation rel; Assert(NULL == ctx->tablerel); + prepare_scan(ctx); + Assert(ctx->internal.scan_mcxt != NULL); + oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); + rel = scanner->openscan(ctx); + MemoryContextSwitchTo(oldmcxt); - return scanner->openscan(ctx); + return rel; } /* @@ -240,6 +255,7 @@ ts_scanner_start_scan(ScannerCtx *ctx) InternalScannerCtx *ictx = &ctx->internal; Scanner *scanner; TupleDesc tuple_desc; + MemoryContext oldmcxt; if (ictx->started) { @@ -269,6 +285,9 @@ ts_scanner_start_scan(ScannerCtx *ctx) ctx->index = RelationGetRelid(ctx->indexrel); } + Assert(ctx->internal.scan_mcxt != NULL); + oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); + scanner = scanner_ctx_get_scanner(ctx); scanner->beginscan(ctx); @@ -277,6 +296,7 @@ ts_scanner_start_scan(ScannerCtx *ctx) ictx->tinfo.scanrel = ctx->tablerel; ictx->tinfo.mctx = ctx->result_mctx == NULL ? CurrentMemoryContext : ctx->result_mctx; ictx->tinfo.slot = MakeSingleTupleTableSlot(tuple_desc, table_slot_callbacks(ctx->tablerel)); + MemoryContextSwitchTo(oldmcxt); /* Call pre-scan handler, if any. */ if (ctx->prescan != NULL) @@ -307,6 +327,11 @@ scanner_cleanup(ScannerCtx *ctx) ExecDropSingleTupleTableSlot(ictx->tinfo.slot); ictx->tinfo.slot = NULL; } + + if (NULL != ictx->scan_mcxt) + { + ictx->scan_mcxt = NULL; + } } TSDLLEXPORT void @@ -314,6 +339,7 @@ ts_scanner_end_scan(ScannerCtx *ctx) { InternalScannerCtx *ictx = &ctx->internal; Scanner *scanner = scanner_ctx_get_scanner(ctx); + MemoryContext oldmcxt; if (ictx->ended) return; @@ -322,7 +348,10 @@ ts_scanner_end_scan(ScannerCtx *ctx) if (ctx->postscan != NULL) ctx->postscan(ictx->tinfo.count, ctx->data); + oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); scanner->endscan(ctx); + MemoryContextSwitchTo(oldmcxt); + scanner_cleanup(ctx); ictx->ended = true; ictx->started = false; @@ -348,7 +377,14 @@ ts_scanner_next(ScannerCtx *ctx) { InternalScannerCtx *ictx = &ctx->internal; Scanner *scanner = scanner_ctx_get_scanner(ctx); - bool is_valid = ts_scanner_limit_reached(ctx) ? false : scanner->getnext(ctx); + bool is_valid = false; + + if (!ts_scanner_limit_reached(ctx)) + { + MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); + is_valid = scanner->getnext(ctx); + MemoryContextSwitchTo(oldmcxt); + } while (is_valid) { @@ -375,7 +411,15 @@ ts_scanner_next(ScannerCtx *ctx) /* stop at a valid tuple */ return &ictx->tinfo; } - is_valid = ts_scanner_limit_reached(ctx) ? false : scanner->getnext(ctx); + + if (ts_scanner_limit_reached(ctx)) + is_valid = false; + else + { + MemoryContext oldmcxt = MemoryContextSwitchTo(ctx->internal.scan_mcxt); + is_valid = scanner->getnext(ctx); + MemoryContextSwitchTo(oldmcxt); + } } if (!(ctx->flags & SCANNER_F_NOEND)) diff --git a/src/scanner.h b/src/scanner.h index a74e52023ef..6cbf5d90057 100644 --- a/src/scanner.h +++ b/src/scanner.h @@ -85,6 +85,12 @@ typedef struct InternalScannerCtx { TupleInfo tinfo; ScanDesc scan; + /* + * PG scan functions must be called on a memory context that lives + * throughout the entire scan. Use the scan_mcxt to ensure that + * functions aren't called on, e.g., a per-tuple context. + */ + MemoryContext scan_mcxt; bool registered_snapshot; bool started; bool ended; diff --git a/src/telemetry/stats.c b/src/telemetry/stats.c index 35084e16697..e5c5efdcf5a 100644 --- a/src/telemetry/stats.c +++ b/src/telemetry/stats.c @@ -338,13 +338,11 @@ get_chunk_compression_stats(StatsContext *statsctx, const Chunk *chunk, Form_compression_chunk_size compr_stats) { TupleInfo *ti; - MemoryContext oldmcxt; if (!ts_chunk_is_compressed(chunk)) return false; /* Need to execute the scan functions on the long-lived memory context */ - oldmcxt = MemoryContextSwitchTo(statsctx->compressed_chunk_stats_iterator.scankey_mcxt); ts_scan_iterator_scan_key_reset(&statsctx->compressed_chunk_stats_iterator); ts_scan_iterator_scan_key_init(&statsctx->compressed_chunk_stats_iterator, Anum_compression_chunk_size_pkey_chunk_id, @@ -353,7 +351,6 @@ get_chunk_compression_stats(StatsContext *statsctx, const Chunk *chunk, Int32GetDatum(chunk->fd.id)); ts_scan_iterator_start_or_restart_scan(&statsctx->compressed_chunk_stats_iterator); ti = ts_scan_iterator_next(&statsctx->compressed_chunk_stats_iterator); - MemoryContextSwitchTo(oldmcxt); if (ti) {