diff --git a/ci/run.sh b/ci/run.sh index 4dfd9a4..9c86a51 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -26,6 +26,12 @@ test_pgaudit_zig() { run_unit_tests ./examples/pgaudit_zig } +test_spi_sql() { + local rc=0 + run_regression_tests ./examples/spi_sql || rc=1 + return $rc +} + extension_build() { cwd=$(pwd) cd "$1" || return 1 @@ -63,6 +69,7 @@ run_unit_tests() { run_test_suites() { for t in "$@"; do + echo "" echo "# Run $t" if ! $t; then return 1 @@ -87,13 +94,14 @@ main() { extension_build ./examples/char_count_zig || fail "Failed to build char_count_zig" extension_build ./examples/pgaudit_zig || fail "Failed to build pgaudit_zig" + extension_build ./examples/spi_sql || fail "Failed to build spi_sql" echo "Start PostgreSQL" pgstart || fail "Failed to start PostgreSQL" trap pgstop TERM INT EXIT ok=true - run_test_suites test_pgzx test_char_count_zig test_pgaudit_zig || ok=false + run_test_suites test_pgzx test_char_count_zig test_pgaudit_zig test_spi_sql || ok=false if ! $ok; then printf "\n\nServer log:" diff --git a/examples/spi_sql/README.md b/examples/spi_sql/README.md new file mode 100644 index 0000000..6c2f256 --- /dev/null +++ b/examples/spi_sql/README.md @@ -0,0 +1,19 @@ +# spi_sql - Sample extension using SPI to execute SQL statements. + +This is a sample PostgreSQL extension to test SPI (Server Programming Interface) SQL execution in Zig. The extension provides a number of methods used by the test suite to verify that SPI access if functional. + +## Testing + +The extension uses PostgreSQL regression testing suite, which calls some of the exported functions in the extension itself. + +The extension sets up a sample table with entries that are used by the tests. + +``` +zig build -freference-trace -p $PG_HOME +``` + +Run regression tests: + +``` +zig build -freference-trace -p $PG_HOME pg_regress +``` diff --git a/examples/spi_sql/build.zig b/examples/spi_sql/build.zig new file mode 100644 index 0000000..6e5b8ea --- /dev/null +++ b/examples/spi_sql/build.zig @@ -0,0 +1,55 @@ +const std = @import("std"); + +// Load pgzx build support. The build utilities use pg_config to find all dependencies +// and provide functions go create and test extensions. +const PGBuild = @import("pgzx").Build; + +pub fn build(b: *std.Build) void { + // Project meta data + const name = "spi_sql"; + const version = .{ .major = 0, .minor = 1 }; + + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + // Load the pgzx module and initialize the build utilities + const dep_pgzx = b.dependency("pgzx", .{ .target = target, .optimize = optimize }); + const pgzx = dep_pgzx.module("pgzx"); + var pgbuild = PGBuild.create(b, .{ .target = target, .optimize = optimize }); + + const build_options = b.addOptions(); + build_options.addOption(bool, "testfn", b.option(bool, "testfn", "Register test function") orelse false); + + // Register the dependency with the build system + // and add pgzx as module dependency. + { + const ext = pgbuild.addInstallExtension(.{ + .name = name, + .version = version, + .root_source_file = .{ + .path = "src/main.zig", + }, + .root_dir = ".", + }); + ext.lib.root_module.addImport("pgzx", pgzx); + ext.lib.root_module.addOptions("build_options", build_options); + + b.getInstallStep().dependOn(&ext.step); + } + + // Configure pg_regress based testing for the current extension. + { + const extest = pgbuild.addRegress(.{ + .db_user = "postgres", + .db_port = 5432, + .root_dir = ".", + .scripts = &[_][]const u8{ + "spi_sql_test", + }, + }); + + // Make regression tests available to `zig build` + var regress = b.step("pg_regress", "Run regression tests"); + regress.dependOn(&extest.step); + } +} diff --git a/examples/spi_sql/build.zig.zon b/examples/spi_sql/build.zig.zon new file mode 100644 index 0000000..d36293d --- /dev/null +++ b/examples/spi_sql/build.zig.zon @@ -0,0 +1,13 @@ +.{ + .name = "spi_sql", + .version = "0.1.0", + .paths = .{ + "extension", + "src", + }, + .dependencies = .{ + .pgzx = .{ + .path = "./../..", + }, + }, +} diff --git a/examples/spi_sql/expected/spi_sql_test.out b/examples/spi_sql/expected/spi_sql_test.out new file mode 100644 index 0000000..3e4dbd5 --- /dev/null +++ b/examples/spi_sql/expected/spi_sql_test.out @@ -0,0 +1,66 @@ +CREATE EXTENSION spi_sql; +SELECT spi_sql.query_by_id(0); -- return 'Hello' + query_by_id +------------- + Hello +(1 row) + +SELECT spi_sql.query_by_id(1); -- return 'World' + query_by_id +------------- + World +(1 row) + +SELECT spi_sql.query_by_id(2); -- Fail +ERROR: Unknown id: 2 +SELECT spi_sql.query_by_value('Hello'); -- return 0 + query_by_value +---------------- + 0 +(1 row) + +SELECT spi_sql.query_by_value('World'); -- return 1 + query_by_value +---------------- + 1 +(1 row) + +SELECT spi_sql.query_by_value('test'); -- FAIL +ERROR: Value 'test' not found +SELECT spi_sql.ins_value(2, 'test'); + ins_value +----------- + 2 +(1 row) + +SELECT spi_sql.query_by_id(2); -- return 'test' + query_by_id +------------- + test +(1 row) + +SELECT spi_sql.query_by_value('test'); -- return 2 + query_by_value +---------------- + 2 +(1 row) + +SELECT spi_sql.test_iter(); +INFO: id: 0, value: Hello +INFO: id: 1, value: World +INFO: id: 2, value: test + test_iter +----------- + +(1 row) + +SELECT spi_sql.test_rows_of(); +INFO: id: 0, value: Hello +INFO: id: 1, value: World +INFO: id: 2, value: test + test_rows_of +-------------- + +(1 row) + +DROP EXTENSION spi_sql; diff --git a/examples/spi_sql/extension/spi_sql--0.1.sql b/examples/spi_sql/extension/spi_sql--0.1.sql new file mode 100644 index 0000000..d4561df --- /dev/null +++ b/examples/spi_sql/extension/spi_sql--0.1.sql @@ -0,0 +1,24 @@ +CREATE TABLE tbl ( + id serial not null primary key, + value text +); + +INSERT INTO tbl (id, value) VALUES + (0, 'Hello'), + (1, 'World') +; + +CREATE FUNCTION query_by_id(int4) RETURNS TEXT +AS '$libdir/spi_sql' LANGUAGE C VOLATILE; + +CREATE FUNCTION query_by_value(TEXT) RETURNS INT4 +AS '$libdir/spi_sql' LANGUAGE C VOLATILE; + +CREATE FUNCTION ins_value(INT4, TEXT) RETURNS INT4 +AS '$libdir/spi_sql' LANGUAGE C VOLATILE; + +CREATE FUNCTION test_iter() RETURNS VOID +AS '$libdir/spi_sql' LANGUAGE C VOLATILE; + +CREATE FUNCTION test_rows_of() RETURNS VOID +AS '$libdir/spi_sql' LANGUAGE C VOLATILE; diff --git a/examples/spi_sql/extension/spi_sql.control b/examples/spi_sql/extension/spi_sql.control new file mode 100644 index 0000000..318ab48 --- /dev/null +++ b/examples/spi_sql/extension/spi_sql.control @@ -0,0 +1,6 @@ +comment = 'pgzx: SPI SQL test extension' +default_version = '0.1' +module_pathname = '$libdir/spi_sql' +relocatable = false +superuser = false +schema = 'spi_sql' diff --git a/examples/spi_sql/sql/spi_sql_test.sql b/examples/spi_sql/sql/spi_sql_test.sql new file mode 100644 index 0000000..4b60841 --- /dev/null +++ b/examples/spi_sql/sql/spi_sql_test.sql @@ -0,0 +1,19 @@ +CREATE EXTENSION spi_sql; + +SELECT spi_sql.query_by_id(0); -- return 'Hello' +SELECT spi_sql.query_by_id(1); -- return 'World' +SELECT spi_sql.query_by_id(2); -- Fail + +SELECT spi_sql.query_by_value('Hello'); -- return 0 +SELECT spi_sql.query_by_value('World'); -- return 1 +SELECT spi_sql.query_by_value('test'); -- FAIL + +SELECT spi_sql.ins_value(2, 'test'); +SELECT spi_sql.query_by_id(2); -- return 'test' +SELECT spi_sql.query_by_value('test'); -- return 2 + +SELECT spi_sql.test_iter(); + +SELECT spi_sql.test_rows_of(); + +DROP EXTENSION spi_sql; diff --git a/examples/spi_sql/src/main.zig b/examples/spi_sql/src/main.zig new file mode 100644 index 0000000..89d09ce --- /dev/null +++ b/examples/spi_sql/src/main.zig @@ -0,0 +1,128 @@ +const std = @import("std"); +const pgzx = @import("pgzx"); +const pg = pgzx.c; + +comptime { + pgzx.PG_MODULE_MAGIC(); + + pgzx.PG_FUNCTION_V1("query_by_id", query_by_id); + pgzx.PG_FUNCTION_V1("query_by_value", query_by_value); + pgzx.PG_FUNCTION_V1("ins_value", ins_value); + pgzx.PG_FUNCTION_V1("test_iter", test_iter); + pgzx.PG_FUNCTION_V1("test_rows_of", test_rows_of); +} + +const SCHEMA_NAME = "spi_sql"; +const TABLE_NAME = SCHEMA_NAME ++ ".tbl"; + +fn query_by_id(id: u32) ![]const u8 { + const QUERY = "SELECT value FROM " ++ TABLE_NAME ++ " WHERE id = $1"; + + try pgzx.spi.connect(); + defer pgzx.spi.finish(); + + var rows = try pgzx.spi.query(QUERY, .{ + .limit = 1, + .args = .{ + .types = &[_]pg.Oid{pg.INT4OID}, + .values = &[_]pg.NullableDatum{try pgzx.datum.toNullableDatum(id)}, + }, + }); + defer rows.deinit(); + + if (!rows.next()) { + return pgzx.elog.Error(@src(), "Unknown id: {d}", .{id}); + } + + var value: []const u8 = undefined; + try rows.scan(.{&value}); + return value; +} + +fn query_by_value(value: []const u8) !u32 { + const QUERY = "SELECT id FROM " ++ TABLE_NAME ++ " WHERE value = $1"; + + try pgzx.spi.connect(); + defer pgzx.spi.finish(); + + // Use `RowsOf` to implicitey scan the result without having to declare temporary variables. + + var rows = pgzx.spi.RowsOf(u32).init(try pgzx.spi.query(QUERY, .{ + .limit = 1, + .args = .{ + .types = &[_]pg.Oid{pg.TEXTOID}, + .values = &[_]pg.NullableDatum{try pgzx.datum.toNullableDatum(value)}, + }, + })); + defer rows.deinit(); + + if (try rows.next()) |id| { + return id; + } + return pgzx.elog.Error(@src(), "Value '{s}' not found", .{value}); +} + +fn ins_value(id: u32, value: []const u8) !u32 { + const STMT = "INSERT INTO " ++ TABLE_NAME ++ " (id, value) VALUES ($1, $2) RETURNING id"; + + try pgzx.spi.connect(); + defer pgzx.spi.finish(); + + var rows = pgzx.spi.RowsOf(u32).init(try pgzx.spi.query(STMT, .{ + .args = .{ + .types = &[_]pg.Oid{ + pg.INT4OID, + pg.TEXTOID, + }, + .values = &[_]pg.NullableDatum{ + try pgzx.datum.toNullableDatum(id), + try pgzx.datum.toNullableDatum(value), + }, + }, + })); + defer rows.deinit(); + + if (try rows.next()) |ret_id| { + return ret_id; + } + unreachable; +} + +fn test_iter() !void { + const QUERY = "SELECT id, value FROM " ++ TABLE_NAME; + const Record = struct { + id: u32, + value: []const u8, + }; + + try pgzx.spi.connect(); + defer pgzx.spi.finish(); + + var rows = try pgzx.spi.query(QUERY, .{}); + defer rows.deinit(); + + while (rows.next()) { + var rec: Record = undefined; + + try rows.scan(.{&rec}); + pgzx.elog.Info(@src(), "id: {d}, value: {s}", .{ rec.id, rec.value }); + } +} + +fn test_rows_of() !void { + const QUERY = "SELECT id, value FROM " ++ TABLE_NAME; + const Record = struct { + id: u32, + value: []const u8, + }; + + try pgzx.spi.connect(); + defer pgzx.spi.finish(); + + var rows = pgzx.spi.RowsOf(Record).init(try pgzx.spi.query(QUERY, .{})); + defer rows.deinit(); + + while (try rows.next()) |rec| { + pgzx.elog.Info(@src(), "id: {d}, value: {s}", .{ rec.id, rec.value }); + } +} diff --git a/src/pgzx/datum.zig b/src/pgzx/datum.zig index 0755634..5d1bf6e 100644 --- a/src/pgzx/datum.zig +++ b/src/pgzx/datum.zig @@ -168,7 +168,7 @@ pub const Float32 = ConvNoFail(f32, c.DatumGetFloat4, c.Float4GetDatum); pub const Float64 = ConvNoFail(f64, c.DatumGetFloat8, c.Float8GetDatum); pub const SliceU8 = Conv([]const u8, getDatumTextSlice, sliceToDatumText); -pub const SliceU8Z = Conv([:0]const u8, getDatumTextSliceZ, sliceToDatumText); +pub const SliceU8Z = Conv([:0]const u8, getDatumTextSliceZ, sliceToDatumTextZ); pub const PGDatum = ConvNoFail(c.Datum, idDatum, idDatum); const PGNullableDatum = struct { diff --git a/src/pgzx/spi.zig b/src/pgzx/spi.zig index 46bf48f..f7de0d3 100644 --- a/src/pgzx/spi.zig +++ b/src/pgzx/spi.zig @@ -2,7 +2,7 @@ const std = @import("std"); const mem = @import("mem.zig"); const err = @import("err.zig"); const c = @import("c.zig"); -const fmgr = @import("fmgr.zig"); +const datum = @import("datum.zig"); pub fn connect() err.PGError!void { const status = c.SPI_connect(); @@ -40,7 +40,9 @@ pub const ExecOptions = struct { args: ?Args = null, }; -pub fn exec(sql: [:0]const u8, options: ExecOptions) !c_int { +pub const SPIError = err.PGError || std.mem.Allocator.Error; + +pub fn exec(sql: [:0]const u8, options: ExecOptions) SPIError!c_int { if (options.args) |args| { if (args.types.len != args.values.len) { return err.PGError.SPIArgument; @@ -89,7 +91,7 @@ pub fn exec(sql: [:0]const u8, options: ExecOptions) !c_int { } } -pub fn query(sql: [:0]const u8, options: ExecOptions) !Rows { +pub fn query(sql: [:0]const u8, options: ExecOptions) SPIError!Rows { _ = try exec(sql, options); return Rows.init(); } @@ -195,10 +197,10 @@ pub fn convProcessed(comptime T: type, row: c_int, col: c_int) !T { pub fn convBinValue(comptime T: type, table: *c.SPITupleTable, row: usize, col: c_int) !T { // TODO: check index? - var datum: c.NullableDatum = undefined; - datum.value = c.SPI_getbinval(table.*.vals[row], table.*.tupdesc, col, @ptrCast(&datum.isnull)); + var nd: c.NullableDatum = undefined; + nd.value = c.SPI_getbinval(table.*.vals[row], table.*.tupdesc, col, @ptrCast(&nd.isnull)); try checkStatus(c.SPI_result); - return try fmgr.conv.fromNullableDatum(T, datum); + return try datum.fromNullableDatum(T, nd); } fn checkStatus(st: c_int) err.PGError!void {