Skip to content

Commit

Permalink
Address 3rd round comments & add UTs
Browse files Browse the repository at this point in the history
Signed-off-by: Yongming Ding <dyongming@vmware.com>
  • Loading branch information
dreamtalen committed Dec 6, 2022
1 parent 6d3889a commit 0756809
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 31 deletions.
7 changes: 5 additions & 2 deletions snowflake/Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
GO ?= go
BINDIR := $(CURDIR)/bin

all: bin
all: udfs bin

.PHONY: udfs
udfs:
make -C udfs/udfs/

.PHONY: bin
bin:
make -C udfs/udfs/
$(GO) build -o $(BINDIR)/theia-sf antrea.io/theia/snowflake

.PHONY: .coverage
Expand Down
14 changes: 7 additions & 7 deletions snowflake/pkg/infra/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (

"antrea.io/theia/snowflake/database"
sf "antrea.io/theia/snowflake/pkg/snowflake"
utils "antrea.io/theia/snowflake/pkg/utils"
fileutils "antrea.io/theia/snowflake/pkg/utils/file"
"antrea.io/theia/snowflake/udfs"
)

Expand Down Expand Up @@ -85,7 +85,7 @@ func installPulumiCLI(ctx context.Context, logger logr.Logger, dir string) error
if err := os.MkdirAll(filepath.Join(dir, "pulumi"), 0755); err != nil {
return err
}
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -118,7 +118,7 @@ func installMigrateSnowflakeCLI(ctx context.Context, logger logr.Logger, dir str
return fmt.Errorf("OS / arch combination is not supported: %s / %s", operatingSystem, arch)
}
url := fmt.Sprintf("https://github.com/antoninbas/migrate-snowflake/releases/download/%s/migrate-snowflake_%s_%s.tar.gz", migrateSnowflakeVersion, migrateSnowflakeVersion, target)
if err := utils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
if err := fileutils.DownloadAndUntar(ctx, logger, url, dir); err != nil {
return err
}

Expand Down Expand Up @@ -282,7 +282,7 @@ func (m *Manager) run(ctx context.Context, destroy bool) (*Result, error) {
warehouseName := m.warehouseName
if !destroy {
logger.Info("Copying database migrations to disk")
if err := utils.WriteEmbedDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
if err := fileutils.WriteFSDirToDisk(ctx, logger, database.Migrations, database.MigrationsPath, filepath.Join(workdir, migrationsDir)); err != nil {
return nil, err
}
logger.Info("Copied database migrations to disk")
Expand Down Expand Up @@ -445,7 +445,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
}

// Download and stage Kubernetes python client for policy recommendation udf
k8sPythonClientFilePath, err := utils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName)
k8sPythonClientFilePath, err := fileutils.Download(ctx, logger, k8sPythonClientUrl, workdir, k8sPythonClientFileName)
if err != nil {
return err
}
Expand All @@ -463,7 +463,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa

logger.Info("Copying UDFs to disk")
udfsDirPath := filepath.Join(workdir, udfsDir)
if err := utils.WriteEmbedDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil {
if err := fileutils.WriteFSDirToDisk(ctx, logger, udfs.UdfsFs, udfs.UdfsPath, udfsDirPath); err != nil {
return err
}
logger.Info("Copied UDFs to disk")
Expand Down Expand Up @@ -521,7 +521,7 @@ func createUdfs(ctx context.Context, logger logr.Logger, databaseName string, wa
return fmt.Errorf("version placeholder '%s' not found in SQL file", udfVersionPlaceholder)
}
query = strings.ReplaceAll(query, udfVersionPlaceholder, version)
_, err = sfClient.ExecMultiStatementQuery(ctx, query, false)
err = sfClient.ExecMultiStatement(ctx, query)
if err != nil {
return fmt.Errorf("error when creating UDF: %w", err)
}
Expand Down
24 changes: 13 additions & 11 deletions snowflake/pkg/snowflake/snowflake.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,28 @@ func (c *client) UseDatabase(ctx context.Context, name string) error {

func (c *client) UseSchema(ctx context.Context, name string) error {
query := fmt.Sprintf("USE SCHEMA %s", name)
c.logger.Info("Snowflake query", "query", query)
c.logger.V(2).Info("Snowflake query", "query", query)
_, err := c.db.ExecContext(ctx, query)
return err
}

func (c *client) StageFile(ctx context.Context, path string, stage string) error {
query := fmt.Sprintf("PUT file://%s @%s AUTO_COMPRESS = FALSE OVERWRITE = TRUE", path, stage)
c.logger.Info("Snowflake query", "query", query)
c.logger.V(2).Info("Snowflake query", "query", query)
_, err := c.db.ExecContext(ctx, query)
return err
}

func (c *client) ExecMultiStatementQuery(ctx context.Context, query string, result bool) (*sql.Rows, error) {
func (c *client) ExecMultiStatement(ctx context.Context, query string) error {
multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0)
c.logger.Info("Snowflake query", "query", query)
if !result {
_, err := c.db.ExecContext(multi_statement_context, query)
return nil, err
} else {
rows, err := c.db.QueryContext(multi_statement_context, query)
return rows, err
}
c.logger.V(2).Info("Snowflake query", "query", query)
_, err := c.db.ExecContext(multi_statement_context, query)
return err
}

func (c *client) QueryMultiStatement(ctx context.Context, query string) (*sql.Rows, error) {
multi_statement_context, _ := gosnowflake.WithMultiStatement(ctx, 0)
c.logger.V(2).Info("Snowflake query", "query", query)
rows, err := c.db.QueryContext(multi_statement_context, query)
return rows, err
}
2 changes: 1 addition & 1 deletion snowflake/pkg/udfs/udfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func RunUdf(ctx context.Context, logger logr.Logger, query string, databaseName
return nil, err
}

rows, err := sfClient.ExecMultiStatementQuery(ctx, query, true)
rows, err := sfClient.QueryMultiStatement(ctx, query)
if err != nil {
return nil, fmt.Errorf("error when running UDF: %w", err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package utils
package file

import (
"archive/tar"
Expand Down Expand Up @@ -102,17 +102,17 @@ func DownloadAndUntar(ctx context.Context, logger logr.Logger, url string, dir s
return nil
}

func WriteEmbedDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, embedPath string, dest string) error {
func WriteFSDirToDisk(ctx context.Context, logger logr.Logger, fsys fs.FS, fsysPath string, dest string) error {
if err := os.MkdirAll(dest, 0755); err != nil {
return err
}

return fs.WalkDir(fsys, embedPath, func(path string, d fs.DirEntry, err error) error {
return fs.WalkDir(fsys, fsysPath, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}

outpath := filepath.Join(dest, strings.TrimPrefix(path, embedPath))
outpath := filepath.Join(dest, strings.TrimPrefix(path, fsysPath))

if d.IsDir() {
os.MkdirAll(outpath, 0755)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,32 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package infra
package file

import (
"context"
"os"
"path/filepath"
"testing"

"antrea.io/theia/snowflake/database"
"github.com/go-logr/logr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"antrea.io/theia/snowflake/database"
)

func TestWriteMigrationsToDisk(t *testing.T) {
func TestWriteFSDirToDisk(t *testing.T) {
var logger logr.Logger
tempDir, err := os.MkdirTemp("", "antrea-pulumi-test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)
err = writeMigrationsToDisk(database.Migrations, database.MigrationsPath, filepath.Join(tempDir, migrationsDir))
err = WriteFSDirToDisk(context.TODO(), logger, database.Migrations, database.MigrationsPath, filepath.Join(tempDir, database.MigrationsPath))
require.NoError(t, err)
entries, err := database.Migrations.ReadDir(database.MigrationsPath)
require.NoError(t, err)
for _, entry := range entries {
_, err := os.Stat(filepath.Join(tempDir, migrationsDir, entry.Name()))
_, err := os.Stat(filepath.Join(tempDir, database.MigrationsPath, entry.Name()))
assert.NoErrorf(t, err, "Migration file %s not exist", entry.Name())
}
}
2 changes: 1 addition & 1 deletion snowflake/pkg/utils/timestamps/timestamps.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func ParseTimestamp(t string, now time.Time) (string, error) {
return now, fmt.Errorf("bad timestamp: %s", t)
}()
if err != nil {
return "", nil
return "", err
}
return ts.UTC().Format(time.RFC3339), nil
}
65 changes: 65 additions & 0 deletions snowflake/pkg/utils/timestamps/timestamps_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2022 Antrea Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package timestamps

import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestParseTimestamp(t *testing.T) {
now := time.Now()
nowTimestamp := now.UTC().Format(time.RFC3339)
for _, tc := range []struct {
name string
inputTimestamp string
expectedTimestamp string
expectedError error
}{
{
name: "Successful case 1",
inputTimestamp: "now-1h",
expectedTimestamp: now.Add(-time.Hour).UTC().Format(time.RFC3339),
expectedError: nil,
},
{
name: "Successful case 2",
inputTimestamp: "now",
expectedTimestamp: nowTimestamp,
expectedError: nil,
},
{
name: "Successful case 3",
inputTimestamp: "",
expectedTimestamp: nowTimestamp,
expectedError: nil,
},
{
name: "Failed case",
inputTimestamp: "now-1c",
expectedTimestamp: "",
expectedError: fmt.Errorf("bad timestamp: now-1c"),
},
} {
t.Run(tc.name, func(t *testing.T) {
timestamp, err := ParseTimestamp(tc.inputTimestamp, now)
assert.Equal(t, tc.expectedTimestamp, timestamp)
assert.Equal(t, tc.expectedError, err)
})
}
}

0 comments on commit 0756809

Please sign in to comment.