diff --git a/.gitignore b/.gitignore index 5e82a3dd..0c9302eb 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,3 @@ venv .cache # Binaries cmd/greenmask/greenmask -pkg/toolkit/test/test diff --git a/internal/db/postgres/cmd/validate_utils/json_document_test.go b/internal/db/postgres/cmd/validate_utils/json_document_test.go index 0f74db51..d3987716 100644 --- a/internal/db/postgres/cmd/validate_utils/json_document_test.go +++ b/internal/db/postgres/cmd/validate_utils/json_document_test.go @@ -15,26 +15,6 @@ import ( "github.com/greenmaskio/greenmask/pkg/toolkit" ) -type testTransformer struct{} - -func (tt *testTransformer) Init(ctx context.Context) error { - return nil -} - -func (tt *testTransformer) Done(ctx context.Context) error { - return nil -} - -func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { - return nil, nil -} - -func (tt *testTransformer) GetAffectedColumns() map[int]string { - return map[int]string{ - 1: "name", - } -} - func TestJsonDocument_GetAffectedColumns(t *testing.T) { tab, _, _ := getTableAndRows() jd := NewJsonDocument(tab, true, true) @@ -87,6 +67,26 @@ func TestJsonDocument_GetRecords(t *testing.T) { //r.SetRow(row) } +type testTransformer struct{} + +func (tt *testTransformer) Init(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Done(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { + return nil, nil +} + +func (tt *testTransformer) GetAffectedColumns() map[int]string { + return map[int]string{ + 1: "name", + } +} + func getTableAndRows() (table *entries.Table, original, transformed [][]byte) { tableDef := ` diff --git a/internal/db/postgres/dumpers/transformation_pipeline_test.go b/internal/db/postgres/dumpers/transformation_pipeline_test.go new file mode 100644 index 00000000..77b35678 --- /dev/null +++ b/internal/db/postgres/dumpers/transformation_pipeline_test.go @@ -0,0 +1,108 @@ +package dumpers + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" + "github.com/greenmaskio/greenmask/pkg/toolkit" +) + +func TestTransformationPipeline_Dump(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("", driver, nil) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 1) + require.Equal(t, buf.String(), "2\t2023-08-27 00:00:00.00000\n\\.\n\n") +} + +func TestTransformationPipeline_Dump_with_transformer_cond(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any)) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 0) + require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n") +} + +func TestTransformationPipeline_Dump_with_table_cond(t *testing.T) { + termCtx, termCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer termCancel() + table := getTable() + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + driver := getDriver(table.Table) + table.Driver = driver + when, warns := toolkit.NewWhenCond("", driver, make(map[string]any)) + require.Empty(t, warns) + tt := &testTransformer{} + tc := &utils.TransformerContext{ + Transformer: tt, + When: when, + } + table.TransformersContext = []*utils.TransformerContext{tc} + table.When = "record.id != 1" + + buf := bytes.NewBuffer(nil) + + pipeline, err := NewTransformationPipeline(gtx, eg, table, buf) + require.NoError(t, err) + require.NoError(t, pipeline.Init(termCtx)) + data := []byte("1\t2023-08-27 00:00:00.000000") + err = pipeline.Dump(ctx, data) + require.NoError(t, err) + require.NoError(t, pipeline.Done(termCtx)) + require.NoError(t, pipeline.CompleteDump()) + require.Equal(t, tt.callsCount, 0) + require.Equal(t, buf.String(), "1\t2023-08-27 00:00:00.00000\n\\.\n\n") +} diff --git a/internal/db/postgres/dumpers/transformation_window.go b/internal/db/postgres/dumpers/transformation_window.go index 6424c201..2fd72b10 100644 --- a/internal/db/postgres/dumpers/transformation_window.go +++ b/internal/db/postgres/dumpers/transformation_window.go @@ -81,6 +81,7 @@ func (tw *transformationWindow) tryAdd(table *entries.Table, t *utils.Transforme return true } +// init - runs all transformers in the goroutines and waits for the ac.ch signal to run the transformer func (tw *transformationWindow) init() { for _, ac := range tw.window { func(ac *asyncContext) { @@ -105,10 +106,13 @@ func (tw *transformationWindow) init() { } } +// close - closes the done channel to stop the transformers goroutines func (tw *transformationWindow) close() { close(tw.done) } +// Transform - runs the transformation for the record in the window. This function checks the when +// condition of the transformer and if true sends a signal to the transformer goroutine to run the transformation func (tw *transformationWindow) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { tw.r = r for _, ac := range tw.window { diff --git a/internal/db/postgres/dumpers/transformation_window_test.go b/internal/db/postgres/dumpers/transformation_window_test.go new file mode 100644 index 00000000..1b411c9f --- /dev/null +++ b/internal/db/postgres/dumpers/transformation_window_test.go @@ -0,0 +1,158 @@ +package dumpers + +import ( + "context" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + "github.com/greenmaskio/greenmask/internal/db/postgres/entries" + "github.com/greenmaskio/greenmask/internal/db/postgres/transformers/utils" + "github.com/greenmaskio/greenmask/pkg/toolkit" + "github.com/greenmaskio/greenmask/pkg/toolkit/testutils" +) + +func TestTransformationWindow_tryAdd(t *testing.T) { + ctx := context.Background() + eg, gtx := errgroup.WithContext(ctx) + tw := newTransformationWindow(gtx, eg) + tc := utils.TransformerContext{ + Transformer: &testTransformer{}, + } + table := getTable() + require.True(t, tw.tryAdd(table, &tc)) + require.False(t, tw.tryAdd(table, &tc)) +} + +func TestTransformationWindow_Transform(t *testing.T) { + mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer mainCancel() + eg, gtx := errgroup.WithContext(mainCtx) + tw := newTransformationWindow(gtx, eg) + when, warns := toolkit.NewWhenCond("", nil, nil) + require.Empty(t, warns) + tc := utils.TransformerContext{ + Transformer: &testTransformer{}, + When: when, + } + table := getTable() + require.True(t, tw.tryAdd(table, &tc)) + + driver := getDriver(table.Table) + record := toolkit.NewRecord(driver) + row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"}) + record.SetRow(row) + tw.init() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := tw.Transform(ctx, record) + require.NoError(t, err) + v, err := record.GetRawColumnValueByName("id") + require.NoError(t, err) + require.False(t, v.IsNull) + require.Equal(t, []byte("2"), v.Data) + tw.close() + require.NoError(t, eg.Wait()) +} + +func TestTransformationWindow_Transform_with_cond(t *testing.T) { + table := getTable() + driver := getDriver(table.Table) + record := toolkit.NewRecord(driver) + when, warns := toolkit.NewWhenCond("record.id != 1", driver, make(map[string]any)) + require.Empty(t, warns) + mainCtx, mainCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer mainCancel() + eg, gtx := errgroup.WithContext(mainCtx) + tw := newTransformationWindow(gtx, eg) + tt := &testTransformer{} + tc := utils.TransformerContext{ + Transformer: tt, + When: when, + } + require.True(t, tw.tryAdd(table, &tc)) + + row := testutils.NewTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000"}) + record.SetRow(row) + tw.init() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := tw.Transform(ctx, record) + require.NoError(t, err) + require.Equal(t, 0, tt.callsCount) + v, err := record.GetRawColumnValueByName("id") + require.NoError(t, err) + require.False(t, v.IsNull) + require.Equal(t, []byte("1"), v.Data) + tw.close() + require.NoError(t, eg.Wait()) +} + +type testTransformer struct { + callsCount int +} + +func (tt *testTransformer) Init(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Done(ctx context.Context) error { + return nil +} + +func (tt *testTransformer) Transform(ctx context.Context, r *toolkit.Record) (*toolkit.Record, error) { + tt.callsCount++ + err := r.SetColumnValueByName("id", 2) + if err != nil { + return nil, err + } + return r, nil +} + +func (tt *testTransformer) GetAffectedColumns() map[int]string { + return map[int]string{ + 1: "name", + } +} + +func getDriver(table *toolkit.Table) *toolkit.Driver { + driver, _, err := toolkit.NewDriver(table, nil) + if err != nil { + panic(err.Error()) + } + return driver +} + +func getTable() *entries.Table { + return &entries.Table{ + Table: &toolkit.Table{ + Schema: "public", + Name: "test", + Oid: 1224, + Columns: []*toolkit.Column{ + { + Name: "id", + TypeName: "int2", + TypeOid: pgtype.Int2OID, + Num: 1, + NotNull: true, + Length: -1, + }, + { + Name: "created_at", + TypeName: "timestamp", + TypeOid: pgtype.TimestampOID, + Num: 1, + NotNull: true, + Length: -1, + }, + }, + Constraints: []toolkit.Constraint{}, + }, + } +} diff --git a/internal/db/postgres/transformers/utils/definition.go b/internal/db/postgres/transformers/utils/definition.go index e512c17c..62556c45 100644 --- a/internal/db/postgres/transformers/utils/definition.go +++ b/internal/db/postgres/transformers/utils/definition.go @@ -111,7 +111,7 @@ func (d *TransformerDefinition) Instance( Transformer: t, StaticParameters: staticParams, DynamicParameters: dynamicParams, - when: when, + When: when, }, res, nil } @@ -119,9 +119,9 @@ type TransformerContext struct { Transformer Transformer StaticParameters map[string]*toolkit.StaticParameter DynamicParameters map[string]*toolkit.DynamicParameter - when *toolkit.WhenCond + When *toolkit.WhenCond } func (tc *TransformerContext) EvaluateWhen(r *toolkit.Record) (bool, error) { - return tc.when.Evaluate(r) + return tc.When.Evaluate(r) } diff --git a/pkg/toolkit/expt_test.go b/pkg/toolkit/expt_test.go index b5d26572..c90ed8d0 100644 --- a/pkg/toolkit/expt_test.go +++ b/pkg/toolkit/expt_test.go @@ -9,9 +9,7 @@ import ( func TestWhenCond_Evaluate(t *testing.T) { driver := getDriver() record := NewRecord(driver) - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", testNullSeq, `{"a": 1}`, "123.0"}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", testNullSeq, `{"a": 1}`, "123.0"}) record.SetRow(row) type test struct { diff --git a/pkg/toolkit/record_test.go b/pkg/toolkit/record_test.go index 87292270..a17134a8 100644 --- a/pkg/toolkit/record_test.go +++ b/pkg/toolkit/record_test.go @@ -80,9 +80,7 @@ func getDriver() *Driver { } func TestRecord_ScanAttribute(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", ""}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -94,9 +92,7 @@ func TestRecord_ScanAttribute(t *testing.T) { } func TestRecord_GetAttribute_date(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", "1234", ""}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -108,9 +104,7 @@ func TestRecord_GetAttribute_date(t *testing.T) { } func TestRecord_GetAttribute_text(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000000", "1234", ""}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000000", "1234", ""}) driver := getDriver() r := NewRecord(driver) r.SetRow(row) @@ -142,9 +136,7 @@ func TestRecord_GetAttribute_text(t *testing.T) { //} func TestRecord_Encode(t *testing.T) { - row := &TestRowDriver{ - row: []string{"1", "2023-08-27 00:00:00.000001", "test", "", ""}, - } + row := newTestRowDriver([]string{"1", "2023-08-27 00:00:00.000001", "test", "", ""}) expected := []byte("2\t2023-08-29 00:00:00.000002\t\\N\t\t") driver := getDriver() r := NewRecord(driver) diff --git a/pkg/toolkit/testutils_test.go b/pkg/toolkit/testutils.go similarity index 94% rename from pkg/toolkit/testutils_test.go rename to pkg/toolkit/testutils.go index d74441ff..31a26bae 100644 --- a/pkg/toolkit/testutils_test.go +++ b/pkg/toolkit/testutils.go @@ -21,6 +21,10 @@ type TestRowDriver struct { row []string } +func newTestRowDriver(row []string) *TestRowDriver { + return &TestRowDriver{row: row} +} + func (trd *TestRowDriver) GetColumn(idx int) (*RawValue, error) { val := trd.row[idx] if val == testNullSeq { diff --git a/pkg/toolkit/testutils/testutils.go b/pkg/toolkit/testutils/testutils.go new file mode 100644 index 00000000..d06d624a --- /dev/null +++ b/pkg/toolkit/testutils/testutils.go @@ -0,0 +1,68 @@ +// Copyright 2023 Greenmask +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutils + +import "github.com/greenmaskio/greenmask/pkg/toolkit" + +var NullSeq = "\\N" +var Delim byte = '\t' + +type TestRowDriver struct { + row []string +} + +func NewTestRowDriver(row []string) *TestRowDriver { + return &TestRowDriver{row: row} +} + +func (trd *TestRowDriver) GetColumn(idx int) (*toolkit.RawValue, error) { + val := trd.row[idx] + if val == NullSeq { + return toolkit.NewRawValue(nil, true), nil + } + return toolkit.NewRawValue([]byte(val), false), nil +} + +func (trd *TestRowDriver) SetColumn(idx int, v *toolkit.RawValue) error { + if v.IsNull { + trd.row[idx] = NullSeq + } else { + trd.row[idx] = string(v.Data) + } + return nil +} + +func (trd *TestRowDriver) Encode() ([]byte, error) { + var res []byte + for idx, v := range trd.row { + res = append(res, []byte(v)...) + if idx != len(trd.row)-1 { + res = append(res, Delim) + } + } + return res, nil +} + +func (trd *TestRowDriver) Decode([]byte) error { + panic("is not implemented") +} + +func (trd *TestRowDriver) Length() int { + return len(trd.row) +} + +func (trd *TestRowDriver) Clean() { + +} diff --git a/pkg/toolkit/test/test.go b/tests/external_transformer/test.go similarity index 99% rename from pkg/toolkit/test/test.go rename to tests/external_transformer/test.go index 6929e339..dfed5ba4 100644 --- a/pkg/toolkit/test/test.go +++ b/tests/external_transformer/test.go @@ -1,4 +1,5 @@ // Copyright 2023 Greenmask + // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License.