Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed case when the overridden type of column does not work #175

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/db/postgres/context/pg_catalog.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ func getTables(

// Assigning columns, pk and fk for each table
for _, t := range tables {
if len(t.Columns) > 0 {
// Columns were already initialized during the transformer initialization
continue
}
columns, err := getColumnsConfig(ctx, tx, t.Oid, version)
if err != nil {
return nil, nil, fmt.Errorf("unable to collect table columns: %w", err)
Expand Down
6 changes: 6 additions & 0 deletions internal/db/postgres/context/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ func validateAndBuildTablesConfig(
}
table.Columns = columns

pkColumns, err := getPrimaryKeyColumns(ctx, tx, table.Oid)
if err != nil {
return nil, nil, fmt.Errorf("unable to collect primary key columns: %w", err)
}
table.PrimaryKey = pkColumns

// Assigning overridden column types for driver initialization
if tableCfg.ColumnsTypeOverride != nil {
for _, c := range table.Columns {
Expand Down
7 changes: 7 additions & 0 deletions pkg/toolkit/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,10 @@ func (c *Column) GetType() (string, Oid) {
}
return c.TypeName, c.TypeOid
}

func (c *Column) GetTypeOid() Oid {
if c.OverriddenTypeName != "" {
return c.OverriddenTypeOid
}
return c.TypeOid
}
6 changes: 3 additions & 3 deletions pkg/toolkit/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func (d *Driver) EncodeValueByColumnIdx(idx int, src any, buf []byte) ([]byte, e
return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down Expand Up @@ -158,7 +158,7 @@ func (d *Driver) ScanValueByColumnIdx(idx int, src []byte, dest any) error {
return fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down Expand Up @@ -189,7 +189,7 @@ func (d *Driver) DecodeValueByColumnIdx(idx int, src []byte) (any, error) {
return nil, fmt.Errorf("index out ouf range: must be between 0 and %d received %d", d.maxIdx, idx)
}
c := d.Table.Columns[idx]
oid := uint32(c.TypeOid)
oid := uint32(c.GetTypeOid())
if c.OverriddenTypeOid != 0 {
oid = uint32(c.OverriddenTypeOid)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/toolkit/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Meta struct {
Table *Table `json:"table"`
Parameters *Parameters `json:"parameters"`
Types []*Type `json:"types"`
ColumnTypeOverrides map[string]string `json:"column_type_overrides"`
ColumnsTypeOverride map[string]string `json:"columns_type_override"`
}

type Parameters struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/toolkit/static_parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ func scanValue(driver *Driver, definition *ParameterDefinition, rawValue ParamsV

var typeOid uint32
if linkedColumnParameter != nil {
typeOid = uint32(linkedColumnParameter.Column.TypeOid)
typeOid = uint32(linkedColumnParameter.Column.GetTypeOid())
} else {
t, ok := driver.GetTypeMap().TypeForName(definition.CastDbType)
if !ok {
Expand Down