Skip to content

Commit

Permalink
added initial compile unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed May 25, 2023
1 parent 50b1c63 commit 7092c0e
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 24 deletions.
7 changes: 7 additions & 0 deletions wren/.mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,10 @@ packages:
FileManager:
MaterializationManager:
PrepareManager:
github.com/kaskada-ai/kaskada/gen/proto/go/kaskada/kaskada/v1alpha:
config:
dir: "{{.InterfaceDir}}"
interfaces:
ComputeServiceClient:
FileServiceClient:
PreparationServiceClient:
18 changes: 12 additions & 6 deletions wren/client/compute_clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,30 @@ import (
"github.com/rs/zerolog/log"
)

type ComputeClients interface {
NewFileServiceClient(ctx context.Context) FileServiceClient
NewPrepareServiceClient(ctx context.Context) PrepareServiceClient
NewComputeServiceClient(ctx context.Context) ComputeServiceClient
}

// ComputeClients is the container to hold client for communicating with compute services
type ComputeClients struct {
type computeClients struct {
fileServiceConfig *HostConfig
prepareServiceConfig *HostConfig
computeServiceConfig *HostConfig
}

// CreateComputeClients initializes the computeClients
func CreateComputeClients(fileServiceConfig *HostConfig, prepareServiceConfig *HostConfig, computeServiceConfig *HostConfig) *ComputeClients {
return &ComputeClients{
func CreateComputeClients(fileServiceConfig *HostConfig, prepareServiceConfig *HostConfig, computeServiceConfig *HostConfig) ComputeClients {
return &computeClients{
fileServiceConfig: fileServiceConfig,
prepareServiceConfig: prepareServiceConfig,
computeServiceConfig: computeServiceConfig,
}
}

// FileServiceClient creates a new FileServiceClient from the configuration and context
func (c *ComputeClients) FileServiceClient(ctx context.Context) FileServiceClient {
func (c *computeClients) NewFileServiceClient(ctx context.Context) FileServiceClient {
conn, err := connection(ctx, c.fileServiceConfig)
if err != nil {
log.Ctx(ctx).Fatal().Err(err).Interface("host_config", c.fileServiceConfig).Msg("unable to dial FileServiceClient")
Expand All @@ -35,7 +41,7 @@ func (c *ComputeClients) FileServiceClient(ctx context.Context) FileServiceClien
}

// PrepareServiceClient creates a new PrepareServiceClient from the configuration and context
func (c *ComputeClients) PrepareServiceClient(ctx context.Context) PrepareServiceClient {
func (c *computeClients) NewPrepareServiceClient(ctx context.Context) PrepareServiceClient {
conn, err := connection(ctx, c.prepareServiceConfig)
if err != nil {
log.Ctx(ctx).Fatal().Err(err).Interface("host_config", c.prepareServiceConfig).Msg("unable to dial PrepareServiceClient")
Expand All @@ -47,7 +53,7 @@ func (c *ComputeClients) PrepareServiceClient(ctx context.Context) PrepareServic
}

// ComputeServiceClient creates a new ComputeServiceClient from the configuration and context
func (c *ComputeClients) ComputeServiceClient(ctx context.Context) ComputeServiceClient {
func (c *computeClients) NewComputeServiceClient(ctx context.Context) ComputeServiceClient {
conn, err := connection(ctx, c.computeServiceConfig)
if err != nil {
log.Ctx(ctx).Fatal().Err(err).Interface("host_config", c.computeServiceConfig).Msg("unable to dial ComputeServiceClient")
Expand Down
7 changes: 4 additions & 3 deletions wren/compute/compile_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func NewCompileManager(computeClients *client.ComputeClients, kaskadaTableClient
kaskadaViewClient: *kaskadaViewClient,
}
}

func (m *compileManager) CompileEntMaterialization(ctx context.Context, owner *ent.Owner, materialization *ent.Materialization) (*v1alpha.CompileResponse, []*v1alpha.View, error) {
compileRequest := &compileRequest{
Expression: materialization.Expression,
Expand Down Expand Up @@ -95,7 +96,7 @@ func (m *compileManager) CompileV1Query(ctx context.Context, owner *ent.Owner, q
}

func (m *compileManager) CompileV2Query(ctx context.Context, owner *ent.Owner, expression string, views []*v2alpha.QueryView, queryConfig *v2alpha.QueryConfig) (*v1alpha.CompileResponse, []*v1alpha.View, error) {

compileRequest := &compileRequest{
Expression: expression,
Views: make([]*v1alpha.WithView, len(views)),
Expand Down Expand Up @@ -208,7 +209,7 @@ func (m *compileManager) compile(ctx context.Context, owner *ent.Owner, request
compileRequest.ExpressionKind = v1alpha.CompileRequest_EXPRESSION_KIND_COMPLETE
}

computeClient := m.computeClients.ComputeServiceClient(ctx)
computeClient := m.computeClients.NewComputeServiceClient(ctx)
defer computeClient.Close()

subLogger.Info().Interface("request", compileRequest).Msg("sending compile request")
Expand Down Expand Up @@ -265,7 +266,7 @@ func (m *compileManager) getFormulaMap(ctx context.Context, owner *ent.Owner, re
formulaMap[requestView.Name] = &v1alpha.Formula{
Name: requestView.Name,
Formula: requestView.Expression,
SourceLocation: fmt.Sprintf("Requested View %s", requestView.Name),
SourceLocation: fmt.Sprintf("Requested View: %s", requestView.Name),
}
}

Expand Down
195 changes: 195 additions & 0 deletions wren/compute/compile_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package compute

import (
"context"

"github.com/google/uuid"
v1alpha "github.com/kaskada-ai/kaskada/gen/proto/go/kaskada/kaskada/v1alpha"
"github.com/kaskada-ai/kaskada/wren/ent"
"github.com/kaskada-ai/kaskada/wren/internal"
"github.com/stretchr/testify/mock"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("CompileManager", func() {

var (
ctx context.Context
owner *ent.Owner

defaultUUID = uuid.MustParse("00000000-0000-0000-0000-000000000000")

mockComputeServiceClient *v1alpha.MockComputeServiceClient
mockFileServiceClient *v1alpha.MockFileServiceClient
mockPreparationServiceClient *v1alpha.MockPreparationServiceClient
mockKaskadaTableClient *internal.MockKaskadaTableClient
mockKaskadaViewClient *internal.MockKaskadaViewClient

objectStoreDestination = &v1alpha.Destination{
Destination: &v1alpha.Destination_ObjectStore{
ObjectStore: &v1alpha.ObjectStoreDestination{
FileType: v1alpha.FileType_FILE_TYPE_CSV,
OutputPrefixUri: "gs://some-bucket/some-prefix",
},
},
}

sliceRequest = &v1alpha.SliceRequest{
Slice: &v1alpha.SliceRequest_Percent{
Percent: &v1alpha.SliceRequest_PercentSlice{
Percent: 42,
},
},
}

persistedViews = []*ent.KaskadaView{
{
Name: "persisted_view",
Expression: "persisted_view_expression",
},
{
Name: "overwritten_view",
Expression: "overwritten_view_expression",
},
}

persistedTables = []*ent.KaskadaTable{
{
ID: defaultUUID,
Name: "persisted_table1",
MergedSchema: &v1alpha.Schema{},
},
{
ID: defaultUUID,
Name: "persisted_table2",
MergedSchema: &v1alpha.Schema{},
},
}
)

BeforeEach(func() {
ctx = context.Background()
owner = &ent.Owner{}

mockKaskadaTableClient = internal.NewMockKaskadaTableClient(GinkgoT())
mockKaskadaViewClient = internal.NewMockKaskadaViewClient(GinkgoT())
mockComputeServiceClient = v1alpha.NewMockComputeServiceClient(GinkgoT())
mockFileServiceClient = v1alpha.NewMockFileServiceClient(GinkgoT())
mockPreparationServiceClient = v1alpha.NewMockPreparationServiceClient(GinkgoT())

})

Context("CompileEntMaterialization", func() {
It("should compile a materialization", func() {
mockKaskadaViewClient.EXPECT().GetAllKaskadaViews(ctx, owner).Return(persistedViews, nil)

mockKaskadaTableClient.EXPECT().GetAllKaskadaTables(ctx, owner).Return(persistedTables, nil)

entMaterialization := &ent.Materialization{
Name: "ent_materialization",
Expression: "ent_materialization_expression",
Destination: objectStoreDestination,
SliceRequest: sliceRequest,
WithViews: &v1alpha.WithViews{
Views: []*v1alpha.WithView{
{
Name: "with_view",
Expression: "with_view_expression",
},
{
Name: "overwritten_view",
Expression: "overwritten_view_expression2",
},
},
},
}

computeTables := []*v1alpha.ComputeTable{
{
Config: &v1alpha.TableConfig{
Name: "persisted_table1",
Uuid: defaultUUID.String(),
},
Metadata: &v1alpha.TableMetadata{
Schema: &v1alpha.Schema{},
},
FileSets: []*v1alpha.ComputeTable_FileSet{},
},
{
Config: &v1alpha.TableConfig{
Name: "persisted_table2",
Uuid: defaultUUID.String(),
},
Metadata: &v1alpha.TableMetadata{
Schema: &v1alpha.Schema{},
},
FileSets: []*v1alpha.ComputeTable_FileSet{},
},
}

formulas := []*v1alpha.Formula{
{
Name: "persisted_view",
Formula: "persisted_view_expression",
SourceLocation: "Persisted View: persisted_view",
},
{
Name: "overwritten_view",
Formula: "overwritten_view_expression2",
SourceLocation: "Requested View: overwritten_view",
},
{
Name: "with_view",
Formula: "with_view_expression",
SourceLocation: "Requested View: with_view",
},
}

compileRequest := &v1alpha.CompileRequest{
Experimental: false,
FeatureSet: &v1alpha.FeatureSet{
Formulas: formulas,
Query: entMaterialization.Expression,
},
PerEntityBehavior: v1alpha.PerEntityBehavior_PER_ENTITY_BEHAVIOR_FINAL,
SliceRequest: sliceRequest,
Tables: computeTables,
ExpressionKind: v1alpha.CompileRequest_EXPRESSION_KIND_COMPLETE,
}

compileResponse := &v1alpha.CompileResponse{
FreeNames: []string{"with_view", "overwritten_view", "persisted_table1"},
}

mockComputeServiceClient.EXPECT().Compile(mock.Anything, compileRequest).Return(compileResponse, nil)

computeClients := newMockComputeServiceClients(mockFileServiceClient, mockPreparationServiceClient, mockComputeServiceClient)
compManager := &compileManager{
computeClients: computeClients,
kaskadaTableClient: mockKaskadaTableClient,
kaskadaViewClient: mockKaskadaViewClient,
}

compileResponse, views, err := compManager.CompileEntMaterialization(ctx, owner, entMaterialization)
Expect(err).ToNot(HaveOccurred())
Expect(compileResponse).ToNot(BeNil())
Expect(views).ToNot(BeNil())

expectedViews := []*v1alpha.View{
{
ViewName: "with_view",
Expression: "with_view_expression",
},
{
ViewName: "overwritten_view",
Expression: "overwritten_view_expression2",
},
}

Expect(views).To(Equal(expectedViews))

})
})
})
8 changes: 4 additions & 4 deletions wren/compute/compute_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type computeManager struct {
CompileManager

prepareManager PrepareManager
computeClients *client.ComputeClients
computeClients client.ComputeClients
errGroup *errgroup.Group
dataTokenClient internal.DataTokenClient
kaskadaTableClient internal.KaskadaTableClient
Expand All @@ -60,7 +60,7 @@ type computeManager struct {
func NewComputeManager(errGroup *errgroup.Group, compileManager *CompileManager, computeClients *client.ComputeClients, dataTokenClient *internal.DataTokenClient, kaskadaTableClient *internal.KaskadaTableClient, materializationClient *internal.MaterializationClient, objectStoreClient *client.ObjectStoreClient, prepareManager *PrepareManager) ComputeManager {
return &computeManager{
CompileManager: *compileManager,
computeClients: computeClients,
computeClients: *computeClients,
errGroup: errGroup,
dataTokenClient: *dataTokenClient,
kaskadaTableClient: *kaskadaTableClient,
Expand Down Expand Up @@ -102,7 +102,7 @@ func (m *computeManager) InitiateQuery(queryContext *QueryContext) (client.Compu
return nil, nil, err
}

queryClient := m.computeClients.ComputeServiceClient(queryContext.ctx)
queryClient := m.computeClients.NewComputeServiceClient(queryContext.ctx)

subLogger.Info().Bool("incremental_enabled", queryContext.compileResp.IncrementalEnabled).Bool("is_current_data_token", queryContext.isCurrentDataToken).Msg("Populating snapshot config if needed")
if queryContext.compileResp.IncrementalEnabled && queryContext.isCurrentDataToken && queryContext.compileResp.PlanHash != nil {
Expand Down Expand Up @@ -346,7 +346,7 @@ func (m *computeManager) processMaterializations(requestCtx context.Context, own
// gets the current snapshot cache buster
func (m *computeManager) getSnapshotCacheBuster(ctx context.Context) (*int32, error) {
subLogger := log.Ctx(ctx).With().Str("method", "manager.getSnapshotCacheBuster").Logger()
queryClient := m.computeClients.ComputeServiceClient(ctx)
queryClient := m.computeClients.NewComputeServiceClient(ctx)
defer queryClient.Close()

res, err := queryClient.GetCurrentSnapshotVersion(ctx, &v1alpha.GetCurrentSnapshotVersionRequest{})
Expand Down
Loading

0 comments on commit 7092c0e

Please sign in to comment.