From 3f263987bcb7a53e8fc9ddbe04b03a9cd1b91844 Mon Sep 17 00:00:00 2001 From: Vadim Voitenko Date: Tue, 20 Aug 2024 14:47:22 +0300 Subject: [PATCH] fixed: fixed case when overridden type of column does not work --- internal/db/postgres/context/pg_catalog.go | 4 ++++ internal/db/postgres/context/table.go | 6 ++++++ pkg/toolkit/column.go | 7 +++++++ pkg/toolkit/driver.go | 6 +++--- pkg/toolkit/meta.go | 2 +- pkg/toolkit/static_parameter.go | 2 +- 6 files changed, 22 insertions(+), 5 deletions(-) diff --git a/internal/db/postgres/context/pg_catalog.go b/internal/db/postgres/context/pg_catalog.go index c7a68c6a..bf43aed2 100644 --- a/internal/db/postgres/context/pg_catalog.go +++ b/internal/db/postgres/context/pg_catalog.go @@ -154,6 +154,10 @@ func getTables( // Assigning columns, pk and fk for each table for _, t := range tables { + if len(t.Columns) > 0 { + // Columns were already initialized during the transformer initialization + continue + } columns, err := getColumnsConfig(ctx, tx, t.Oid, version) if err != nil { return nil, nil, fmt.Errorf("unable to collect table columns: %w", err) diff --git a/internal/db/postgres/context/table.go b/internal/db/postgres/context/table.go index 5096dfb0..ad14a4b9 100644 --- a/internal/db/postgres/context/table.go +++ b/internal/db/postgres/context/table.go @@ -89,6 +89,12 @@ func validateAndBuildTablesConfig( } table.Columns = columns + pkColumns, err := getPrimaryKeyColumns(ctx, tx, table.Oid) + if err != nil { + return nil, nil, fmt.Errorf("unable to collect primary key columns: %w", err) + } + table.PrimaryKey = pkColumns + // Assigning overridden column types for driver initialization if tableCfg.ColumnsTypeOverride != nil { for _, c := range table.Columns { diff --git a/pkg/toolkit/column.go b/pkg/toolkit/column.go index 51cae52b..ed017c2c 100644 --- a/pkg/toolkit/column.go +++ b/pkg/toolkit/column.go @@ -47,3 +47,10 @@ func (c *Column) GetType() (string, Oid) { } return c.TypeName, c.TypeOid } + +func (c *Column) GetTypeOid() Oid { + if c.OverriddenTypeName != "" { + return c.OverriddenTypeOid + } + return c.TypeOid +} diff --git a/pkg/toolkit/driver.go b/pkg/toolkit/driver.go index c55dfbdf..8e89cd88 100644 --- a/pkg/toolkit/driver.go +++ b/pkg/toolkit/driver.go @@ -130,7 +130,7 @@ func (d *Driver) EncodeValueByColumnIdx(idx int, src any, buf []byte) ([]byte, e return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } @@ -158,7 +158,7 @@ func (d *Driver) ScanValueByColumnIdx(idx int, src []byte, dest any) error { return fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } @@ -189,7 +189,7 @@ func (d *Driver) DecodeValueByColumnIdx(idx int, src []byte) (any, error) { return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx) } c := d.Table.Columns[idx] - oid := uint32(c.TypeOid) + oid := uint32(c.GetTypeOid()) if c.OverriddenTypeOid != 0 { oid = uint32(c.OverriddenTypeOid) } diff --git a/pkg/toolkit/meta.go b/pkg/toolkit/meta.go index 4ab19739..c5ce5417 100644 --- a/pkg/toolkit/meta.go +++ b/pkg/toolkit/meta.go @@ -18,7 +18,7 @@ type Meta struct { Table *Table `json:"table"` Parameters *Parameters `json:"parameters"` Types []*Type `json:"types"` - ColumnTypeOverrides map[string]string `json:"column_type_overrides"` + ColumnsTypeOverride map[string]string `json:"columns_type_override"` } type Parameters struct { diff --git a/pkg/toolkit/static_parameter.go b/pkg/toolkit/static_parameter.go index a2b4976d..7bc28eab 100644 --- a/pkg/toolkit/static_parameter.go +++ b/pkg/toolkit/static_parameter.go @@ -301,7 +301,7 @@ func scanValue(driver *Driver, definition *ParameterDefinition, rawValue ParamsV var typeOid uint32 if linkedColumnParameter != nil { - typeOid = uint32(linkedColumnParameter.Column.TypeOid) + typeOid = uint32(linkedColumnParameter.Column.GetTypeOid()) } else { t, ok := driver.GetTypeMap().TypeForName(definition.CastDbType) if !ok {