diff --git a/wren/compute/compute_manager.go b/wren/compute/compute_manager.go index 3a7e842be..34c761489 100644 --- a/wren/compute/compute_manager.go +++ b/wren/compute/compute_manager.go @@ -22,6 +22,7 @@ import ( "github.com/kaskada-ai/kaskada/wren/client" "github.com/kaskada-ai/kaskada/wren/customerrors" "github.com/kaskada-ai/kaskada/wren/ent" + "github.com/kaskada-ai/kaskada/wren/ent/materialization" "github.com/kaskada-ai/kaskada/wren/internal" "github.com/kaskada-ai/kaskada/wren/utils" ) @@ -39,7 +40,9 @@ type ComputeManager interface { InitiateQuery(queryContext *QueryContext) (client.ComputeServiceClient, v1alpha.ComputeService_ExecuteClient, error) SaveComputeSnapshots(queryContext *QueryContext, computeSnapshots []*v1alpha.ComputeSnapshot) - // materialization related + // Runs all existing file-based materializations for the given owner + // Note: this exists in the ComputeManager interface instead of the MaterializationManager interface because + // it runs materializations in a similar way to InitiateQuery RunMaterializations(ctx context.Context, owner *ent.Owner) } @@ -201,6 +204,8 @@ func (m *computeManager) SaveComputeSnapshots(queryContext *QueryContext, comput } // Runs all saved materializations on current data inside a go-routine that attempts to finish before shutdown +// TODO: After sparrow supports long-running materializations from file-based sources +// remove all the code related to this method func (m *computeManager) RunMaterializations(requestCtx context.Context, owner *ent.Owner) { m.errGroup.Go(func() error { return m.processMaterializations(requestCtx, owner) }) } @@ -229,7 +234,7 @@ func (m *computeManager) processMaterializations(requestCtx context.Context, own subLogger.Error().Err(err).Msg("issue getting current prepare cache buster") } - materializations, err := m.materializationClient.GetAllMaterializations(ctx, owner) + materializations, err := m.materializationClient.GetMaterializationsBySourceType(ctx, owner, materialization.SourceTypeFiles) if err != nil { subLogger.Error().Err(err).Msg("error listing materializations") return nil diff --git a/wren/compute/file_manager.go b/wren/compute/file_manager.go index bec9ada9f..7740af67e 100644 --- a/wren/compute/file_manager.go +++ b/wren/compute/file_manager.go @@ -12,8 +12,12 @@ import ( ) type FileManager interface { - // metadata related + + // GetFileSchema returns the schema of the file at the given URI. GetFileSchema(ctx context.Context, fileInput internal.FileInput) (*v1alpha.Schema, error) + + // GetPulsarSchema returns the schema of the pulsar topic + GetPulsarSchema(ctx context.Context, pulsarConfig *v1alpha.PulsarConfig) (*v1alpha.Schema, error) } type fileManager struct { @@ -27,7 +31,7 @@ func NewFileManager(computeClients *client.ComputeClients) FileManager { } func (m *fileManager) GetFileSchema(ctx context.Context, fileInput internal.FileInput) (*v1alpha.Schema, error) { - subLogger := log.Ctx(ctx).With().Str("method", "manager.GetFileSchema").Str("uri", fileInput.GetURI()).Str("type", fileInput.GetExtension()).Logger() + subLogger := log.Ctx(ctx).With().Str("method", "fileManager.GetFileSchema").Str("uri", fileInput.GetURI()).Str("type", fileInput.GetExtension()).Logger() // Send the metadata request to the FileService var sourceData *v1alpha.SourceData @@ -42,6 +46,24 @@ func (m *fileManager) GetFileSchema(ctx context.Context, fileInput internal.File sourceData = &v1alpha.SourceData{Source: &v1alpha.SourceData_ParquetPath{ParquetPath: fileInput.GetURI()}} } + return m.getSchema(ctx, sourceData) +} + +func (m *fileManager) GetPulsarSchema(ctx context.Context, pulsarConfig *v1alpha.PulsarConfig) (*v1alpha.Schema, error) { + sourceData := &v1alpha.SourceData{ + Source: &v1alpha.SourceData_PulsarSubscription{ + PulsarSubscription: &v1alpha.PulsarSubscription{ + Config: pulsarConfig, + }, + }, + } + + return m.getSchema(ctx, sourceData) +} + +func (m *fileManager) getSchema(ctx context.Context, sourceData *v1alpha.SourceData) (*v1alpha.Schema, error) { + subLogger := log.Ctx(ctx).With().Str("method", "fileManager.getSchema").Logger() + fileClient := m.computeClients.NewFileServiceClient(ctx) defer fileClient.Close() diff --git a/wren/compute/materialization_manager.go b/wren/compute/materialization_manager.go index c86fd3835..697e682c1 100644 --- a/wren/compute/materialization_manager.go +++ b/wren/compute/materialization_manager.go @@ -7,6 +7,7 @@ import ( "github.com/kaskada-ai/kaskada/wren/client" "github.com/kaskada-ai/kaskada/wren/customerrors" "github.com/kaskada-ai/kaskada/wren/ent" + "github.com/kaskada-ai/kaskada/wren/ent/materialization" "github.com/kaskada-ai/kaskada/wren/internal" "github.com/rs/zerolog/log" ) @@ -22,6 +23,9 @@ type MaterializationManager interface { // GetMaterializationStatus gets the status of a materialization on the compute backend GetMaterializationStatus(ctx context.Context, materializationID string) (*v1alpha.ProgressInformation, error) + + // ReconcileMaterializations reconciles the materializations in the database with the materializations on the compute backend + ReconcileMaterializations(ctx context.Context) error } type materializationManager struct { @@ -30,14 +34,19 @@ type materializationManager struct { computeClients client.ComputeClients kaskadaTableClient internal.KaskadaTableClient materializationClient internal.MaterializationClient + + // this is used to keep track of which materializations are currently running on the compute backend + // so that if a materialization is deleted from the database, we can stop it the next time we reconcile + runningMaterializations map[string]interface{} } func NewMaterializationManager(compileManager *CompileManager, computeClients *client.ComputeClients, kaskadaTableClient *internal.KaskadaTableClient, materializationClient *internal.MaterializationClient) MaterializationManager { return &materializationManager{ - CompileManager: *compileManager, - computeClients: *computeClients, - kaskadaTableClient: *kaskadaTableClient, - materializationClient: *materializationClient, + CompileManager: *compileManager, + computeClients: *computeClients, + kaskadaTableClient: *kaskadaTableClient, + materializationClient: *materializationClient, + runningMaterializations: map[string]interface{}{}, } } @@ -110,6 +119,77 @@ func (m *materializationManager) GetMaterializationStatus(ctx context.Context, m return statusResponse.Progress, nil } +// ReconcileMaterializations reconciles the materializations in the database with the materializations on the compute backend +// After running this function, all materializations in the database will be running on the compute backend +// and all deleted materializations will be stopped +func (m *materializationManager) ReconcileMaterializations(ctx context.Context) error { + subLogger := log.Ctx(ctx).With().Str("method", "manager.ReconcileMaterializations").Logger() + + allStreamMaterializations, err := m.materializationClient.GetAllMaterializationsBySourceType(ctx, materialization.SourceTypeStreams) + if err != nil { + subLogger.Error().Err(err).Msg("failed to get all stream materializations") + return err + } + + // find all materializations in the database and start any that are not running + // we keep a map of materialization_name=>nil to keep track of which materializations are running + newRunningMaterializations := make(map[string]interface{}) + for _, streamMaterialization := range allStreamMaterializations { + materializationID := streamMaterialization.ID.String() + owner := streamMaterialization.Edges.Owner + + isRunning := false + // check to see if the materialization was running in the previous iteration + if _, found := m.runningMaterializations[materializationID]; found { + //verify that the materialization is still running + progressInfo, err := m.GetMaterializationStatus(ctx, materializationID) + if err != nil { + log.Error().Err(err).Str("id", materializationID).Msg("failed to get materialization status") + } + isRunning = progressInfo != nil + } + + if isRunning { + newRunningMaterializations[materializationID] = nil + } else { + log.Debug().Str("id", materializationID).Msg("found materialization that is not running, attempting to start it") + + compileResp, _, err := m.CompileEntMaterialization(ctx, owner, streamMaterialization) + if err != nil { + log.Error().Err(err).Str("id", materializationID).Msg("issue compiling materialization") + } else { + err = m.StartMaterialization(ctx, owner, materializationID, compileResp, streamMaterialization.Destination) + if err != nil { + log.Error().Err(err).Str("id", materializationID).Msg("failed to start materialization") + } else { + log.Debug().Str("id", materializationID).Msg("started materialization") + newRunningMaterializations[materializationID] = nil + } + } + } + } + + // find all materializations that were running the previous time this method was called + // but no longer exist in the database. stop any that are found. this can happen due to a race + // condition where a materialization is deleted from the database after this method has started + // but before it has finished. this method is called periodically so it will eventually stop + // the materialization. + for materializationID := range m.runningMaterializations { + if _, found := newRunningMaterializations[materializationID]; !found { + log.Debug().Str("id", materializationID).Msg("found materialization that no longer exists, attempting to stop it") + err := m.StopMaterialization(ctx, materializationID) + if err != nil { + log.Error().Err(err).Str("id", materializationID).Msg("failed to stop materialization") + newRunningMaterializations[materializationID] = nil + } else { + log.Debug().Str("id", materializationID).Msg("stopped materialization") + } + } + } + m.runningMaterializations = newRunningMaterializations + return nil +} + func (m *materializationManager) getMaterializationTables(ctx context.Context, owner *ent.Owner, compileResp *v1alpha.CompileResponse) ([]*v1alpha.ComputeTable, error) { subLogger := log.Ctx(ctx).With().Str("method", "materializationManager.getMaterializationTables").Logger() diff --git a/wren/internal/interface.go b/wren/internal/interface.go index 9517855a7..f201f9eea 100644 --- a/wren/internal/interface.go +++ b/wren/internal/interface.go @@ -9,6 +9,7 @@ import ( v1alpha "github.com/kaskada-ai/kaskada/gen/proto/go/kaskada/kaskada/v1alpha" "github.com/kaskada-ai/kaskada/wren/ent" "github.com/kaskada-ai/kaskada/wren/ent/kaskadafile" + "github.com/kaskada-ai/kaskada/wren/ent/materialization" "github.com/kaskada-ai/kaskada/wren/ent/predicate" "github.com/kaskada-ai/kaskada/wren/ent/schema" "github.com/kaskada-ai/kaskada/wren/property" @@ -80,10 +81,11 @@ type MaterializationClient interface { CreateMaterialization(ctx context.Context, owner *ent.Owner, newMaterialization *ent.Materialization, dependencies []*ent.MaterializationDependency) (*ent.Materialization, error) DeleteMaterialization(ctx context.Context, owner *ent.Owner, view *ent.Materialization) error GetAllMaterializations(ctx context.Context, owner *ent.Owner) ([]*ent.Materialization, error) + GetAllMaterializationsBySourceType(ctx context.Context, sourceType materialization.SourceType) ([]*ent.Materialization, error) GetMaterialization(ctx context.Context, owner *ent.Owner, id uuid.UUID) (*ent.Materialization, error) GetMaterializationByName(ctx context.Context, owner *ent.Owner, name string) (*ent.Materialization, error) - GetMaterializationsFromNames(ctx context.Context, owner *ent.Owner, names []string) (map[string]*ent.Materialization, error) GetMaterializationsWithDependency(ctx context.Context, owner *ent.Owner, name string, dependencyType schema.DependencyType) ([]*ent.Materialization, error) + GetMaterializationsBySourceType(ctx context.Context, owner *ent.Owner, sourceType materialization.SourceType) ([]*ent.Materialization, error) ListMaterializations(ctx context.Context, owner *ent.Owner, searchTerm string, pageSize int, offset int) ([]*ent.Materialization, error) UpdateDataVersion(ctx context.Context, materialization *ent.Materialization, newDataVersion int64) (*ent.Materialization, error) IncrementVersion(ctx context.Context, materialization *ent.Materialization) (*ent.Materialization, error) diff --git a/wren/internal/materialization_client.go b/wren/internal/materialization_client.go index cdd24ba83..64bde7cc1 100644 --- a/wren/internal/materialization_client.go +++ b/wren/internal/materialization_client.go @@ -10,7 +10,6 @@ import ( "github.com/kaskada-ai/kaskada/wren/ent" "github.com/kaskada-ai/kaskada/wren/ent/materialization" "github.com/kaskada-ai/kaskada/wren/ent/materializationdependency" - "github.com/kaskada-ai/kaskada/wren/ent/predicate" "github.com/kaskada-ai/kaskada/wren/ent/schema" ) @@ -57,6 +56,7 @@ func (c *materializationClient) CreateMaterialization(ctx context.Context, owner SetAnalysis(newMaterialization.Analysis). SetDataVersionID(newMaterialization.DataVersionID). SetVersion(newMaterialization.Version). + SetSourceType(newMaterialization.SourceType). Save(ctx) if err != nil { @@ -155,31 +155,6 @@ func (c *materializationClient) GetMaterializationByName(ctx context.Context, ow return materialization, nil } -func (c *materializationClient) GetMaterializationsFromNames(ctx context.Context, owner *ent.Owner, names []string) (map[string]*ent.Materialization, error) { - subLogger := log.Ctx(ctx).With(). - Str("method", "materializationClient.GetMaterializationsFromNames"). - Logger() - - predicates := make([]predicate.Materialization, 0, len(names)) - - for _, name := range names { - predicates = append(predicates, materialization.Name(name)) - } - - materializations, err := owner.QueryMaterializations().Where(materialization.Or(predicates...)).All(ctx) - if err != nil { - subLogger.Error().Err(err).Msg("issue getting materializations") - return nil, err - } - - materializationMap := map[string]*ent.Materialization{} - - for _, materialization := range materializations { - materializationMap[materialization.Name] = materialization - } - - return materializationMap, nil -} func (c *materializationClient) GetMaterializationsWithDependency(ctx context.Context, owner *ent.Owner, name string, dependencyType schema.DependencyType) ([]*ent.Materialization, error) { subLogger := log.Ctx(ctx).With(). @@ -206,6 +181,34 @@ func (c *materializationClient) GetMaterializationsWithDependency(ctx context.Co return materializations, nil } +func (c *materializationClient) GetMaterializationsBySourceType(ctx context.Context, owner *ent.Owner, sourceType materialization.SourceType) ([]*ent.Materialization, error) { + subLogger := log.Ctx(ctx).With(). + Str("method", "materializationClient.GetMaterializationsBySourceType"). + Str("source_type", string(sourceType)). + Logger() + + materializations, err := owner.QueryMaterializations().Where(materialization.SourceTypeEQ(sourceType)).WithOwner().All(ctx) + if err != nil { + subLogger.Error().Err(err).Msg("issue listing materializations") + return nil, err + } + return materializations, nil +} + +func (c *materializationClient) GetAllMaterializationsBySourceType(ctx context.Context, sourceType materialization.SourceType) ([]*ent.Materialization, error) { + subLogger := log.Ctx(ctx).With(). + Str("method", "materializationClient.GetAllMaterializationsBySourceType"). + Str("source_type", string(sourceType)). + Logger() + + materializations, err := c.entClient.Materialization.Query().Where(materialization.SourceTypeEQ(sourceType)).All(ctx) + if err != nil { + subLogger.Error().Err(err).Msg("issue listing materializations") + return nil, err + } + return materializations, nil +} + func (c *materializationClient) ListMaterializations(ctx context.Context, owner *ent.Owner, searchTerm string, pageSize int, offset int) ([]*ent.Materialization, error) { subLogger := log.Ctx(ctx).With(). Str("method", "materializationClient.ListMaterializations"). diff --git a/wren/main.go b/wren/main.go index 451f9092b..f34fb9bce 100644 --- a/wren/main.go +++ b/wren/main.go @@ -351,6 +351,18 @@ func main() { return nil }) + // peridocally reconcile materializations to ensure the ones that are supposed to be running are running + g.Go(func() error { + for { + time.Sleep(5 * time.Second) + + err := materializationManager.ReconcileMaterializations(ctx) + if err != nil { + return err + } + } + }) + // wait until shutdown signal occurs select { case <-interrupt: diff --git a/wren/service/materialization.go b/wren/service/materialization.go index e0e2c828b..74f5289dd 100644 --- a/wren/service/materialization.go +++ b/wren/service/materialization.go @@ -14,6 +14,7 @@ import ( "github.com/kaskada-ai/kaskada/wren/compute" "github.com/kaskada-ai/kaskada/wren/customerrors" "github.com/kaskada-ai/kaskada/wren/ent" + "github.com/kaskada-ai/kaskada/wren/ent/materialization" "github.com/kaskada-ai/kaskada/wren/ent/schema" "github.com/kaskada-ai/kaskada/wren/internal" ) @@ -155,12 +156,19 @@ func (s *materializationService) CreateMaterialization(ctx context.Context, requ } func (s *materializationService) createMaterialization(ctx context.Context, owner *ent.Owner, request *v1alpha.CreateMaterializationRequest) (*v1alpha.CreateMaterializationResponse, error) { - subLogger := log.Ctx(ctx).With().Str("method", "materializationService.createMaterialization").Str("expression", request.Materialization.Expression).Logger() + if request.Materialization == nil { + return nil, customerrors.NewInvalidArgumentErrorWithCustomText("missing materialization definition") + } + + if request.Materialization.Expression == "" { + return nil, customerrors.NewInvalidArgumentErrorWithCustomText("missing materialization expression") + } if request.Materialization.Destination == nil { return nil, customerrors.NewInvalidArgumentErrorWithCustomText("missing materialization destination") } + subLogger := log.Ctx(ctx).With().Str("method", "materializationService.createMaterialization").Str("expression", request.Materialization.Expression).Logger() compileResp, _, err := s.materializationManager.CompileV1Materialization(ctx, owner, request.Materialization) if err != nil { subLogger.Error().Err(err).Msg("issue compiling materialization") @@ -181,6 +189,26 @@ func (s *materializationService) createMaterialization(ctx context.Context, owne if err != nil { return nil, err } + + sourceType := materialization.SourceTypeUnspecified + for _, table := range tableMap { + var newSourceType materialization.SourceType + switch table.Source.Source.(type) { + case *v1alpha.Source_Kaskada: + newSourceType = materialization.SourceTypeFiles + case *v1alpha.Source_Pulsar: + newSourceType = materialization.SourceTypeStreams + default: + log.Error().Msgf("unknown source type %T", table.Source.Source) + return nil, customerrors.NewInternalError("unknown table source type") + } + if sourceType == materialization.SourceTypeUnspecified { + sourceType = newSourceType + } else if sourceType != newSourceType { + return nil, customerrors.NewInvalidArgumentErrorWithCustomText("cannot materialize tables from different source types") + } + } + viewMap, err := s.kaskadaViewClient.GetKaskadaViewsFromNames(ctx, owner, compileResp.FreeNames) if err != nil { return nil, err @@ -229,15 +257,28 @@ func (s *materializationService) createMaterialization(ctx context.Context, owne SliceRequest: sliceRequest, Analysis: getAnalysisFromCompileResponse(compileResp), DataVersionID: dataVersionID, + SourceType: sourceType, } - materialization, err := s.materializationClient.CreateMaterialization(ctx, owner, newMaterialization, dependencies) + createdMaterialization, err := s.materializationClient.CreateMaterialization(ctx, owner, newMaterialization, dependencies) if err != nil { return nil, err } - subLogger.Debug().Msg("running materializations") - s.computeManager.RunMaterializations(ctx, owner) + switch sourceType { + case materialization.SourceTypeFiles: + subLogger.Debug().Msg("running materializations") + s.computeManager.RunMaterializations(ctx, owner) + case materialization.SourceTypeStreams: + subLogger.Debug().Msg("adding materialization to compute") + err := s.materializationManager.StartMaterialization(ctx, owner, createdMaterialization.ID.String(), compileResp, createdMaterialization.Destination) + if err != nil { + return nil, err + } + default: + log.Error().Msgf("unknown source type %T", sourceType) + return nil, customerrors.NewInternalError("unknown table source type") + } // Get the newly computed materialization and its associated data token. // @@ -246,7 +287,7 @@ func (s *materializationService) createMaterialization(ctx context.Context, owne // // We could also store the `data_token_id` (or DataToken itself) on the materialization, // which would allow us to skip the secondary lookup of the token from version. - computedMaterialization, err := s.materializationClient.GetMaterialization(ctx, owner, materialization.ID) + computedMaterialization, err := s.materializationClient.GetMaterialization(ctx, owner, createdMaterialization.ID) if err != nil { return nil, err } @@ -254,7 +295,7 @@ func (s *materializationService) createMaterialization(ctx context.Context, owne if err != nil { return nil, err } - return &v1alpha.CreateMaterializationResponse{Materialization: materializationProto, Analysis: materialization.Analysis}, nil + return &v1alpha.CreateMaterializationResponse{Materialization: materializationProto, Analysis: createdMaterialization.Analysis}, nil } func (s *materializationService) DeleteMaterialization(ctx context.Context, request *v1alpha.DeleteMaterializationRequest) (*v1alpha.DeleteMaterializationResponse, error) { @@ -267,12 +308,20 @@ func (s *materializationService) DeleteMaterialization(ctx context.Context, requ } func (s *materializationService) deleteMaterialization(ctx context.Context, owner *ent.Owner, request *v1alpha.DeleteMaterializationRequest) (*v1alpha.DeleteMaterializationResponse, error) { - materialization, err := s.materializationClient.GetMaterializationByName(ctx, owner, request.MaterializationName) + foundMaterialization, err := s.materializationClient.GetMaterializationByName(ctx, owner, request.MaterializationName) if err != nil { return nil, err } - err = s.materializationClient.DeleteMaterialization(ctx, owner, materialization) + if foundMaterialization.SourceType == materialization.SourceTypeStreams { + err := s.materializationManager.StopMaterialization(ctx, foundMaterialization.ID.String()) + if err != nil { + subLogger := log.Ctx(ctx).With().Str("method", "materializationService.deleteMaterialization").Logger() + subLogger.Warn().Err(err).Str("materialization_id", foundMaterialization.ID.String()).Msg("unable to stop materialization on engine") + } + } + + err = s.materializationClient.DeleteMaterialization(ctx, owner, foundMaterialization) if err != nil { return nil, err } diff --git a/wren/service/materialization_test.go b/wren/service/materialization_test.go new file mode 100644 index 000000000..64fb26a09 --- /dev/null +++ b/wren/service/materialization_test.go @@ -0,0 +1,304 @@ +package service + +import ( + "context" + + "github.com/google/uuid" + v1alpha "github.com/kaskada-ai/kaskada/gen/proto/go/kaskada/kaskada/v1alpha" + "github.com/kaskada-ai/kaskada/wren/compute" + "github.com/kaskada-ai/kaskada/wren/ent" + "github.com/kaskada-ai/kaskada/wren/ent/materialization" + "github.com/kaskada-ai/kaskada/wren/internal" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + mock "github.com/stretchr/testify/mock" +) + +var _ = Describe("MaterializationService", func() { + var ( + owner *ent.Owner + objectStoreDestination = &v1alpha.Destination{ + Destination: &v1alpha.Destination_ObjectStore{ + ObjectStore: &v1alpha.ObjectStoreDestination{ + FileType: v1alpha.FileType_FILE_TYPE_CSV, + OutputPrefixUri: "gs://some-bucket/some-prefix", + }, + }, + } + pulsarDestination = &v1alpha.Destination{ + Destination: &v1alpha.Destination_Pulsar{ + Pulsar: &v1alpha.PulsarDestination{ + Config: &v1alpha.PulsarConfig{}, + }, + }, + } + ) + + BeforeEach(func() { + owner = &ent.Owner{} + }) + + Context("CreateMaterialization", func() { + Context("missing materialization", func() { + + It("should throw error", func() { + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + } + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{}) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(Equal("missing materialization definition")) + Expect(response).Should(BeNil()) + }) + }) + + Context("missing expression", func() { + It("should throw error", func() { + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + } + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{ + Materialization: &v1alpha.Materialization{ + Expression: "", + Destination: objectStoreDestination, + }, + }) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(Equal("missing materialization expression")) + Expect(response).Should(BeNil()) + }) + }) + + Context("missing destination", func() { + It("should throw error", func() { + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + } + + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{ + Materialization: &v1alpha.Materialization{ + Expression: "nachos", + }, + }) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(Equal("missing materialization destination")) + Expect(response).Should(BeNil()) + }) + }) + + Context("the compiled materialization includes both pulsar and object store sources", func() { + It("should throw error", func() { + freeNames := []string{"file_backed_source", "pulsar_backed_source"} + + mockMaterializationManager := compute.NewMockMaterializationManager(GinkgoT()) + newMaterialization := &v1alpha.Materialization{ + Expression: "nachos", + Destination: pulsarDestination, + } + + expectedCompileResponse := &v1alpha.CompileResponse{ + MissingNames: []string{}, + FenlDiagnostics: nil, + Plan: &v1alpha.ComputePlan{}, + FreeNames: freeNames, + } + mockMaterializationManager.EXPECT().CompileV1Materialization(mock.Anything, owner, newMaterialization).Return(expectedCompileResponse, nil, nil) + + mockKaskadaTableClient := internal.NewMockKaskadaTableClient(GinkgoT()) + tablesResponse := map[string]*ent.KaskadaTable{ + "file_backed_source": { + Name: "file_backed_source", + Source: &v1alpha.Source{Source: &v1alpha.Source_Kaskada{Kaskada: &v1alpha.KaskadaSource{}}}, + }, + "pulsar_backed_source": { + Name: "pulsar_backed_source", + Source: &v1alpha.Source{Source: &v1alpha.Source_Pulsar{Pulsar: &v1alpha.PulsarSource{}}}, + }, + } + + mockKaskadaTableClient.EXPECT().GetKaskadaTablesFromNames(mock.Anything, owner, freeNames).Return(tablesResponse, nil) + + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + computeManager: nil, + materializationManager: mockMaterializationManager, + kaskadaTableClient: mockKaskadaTableClient, + } + + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{ + Materialization: newMaterialization, + }) + Expect(err).Should(HaveOccurred()) + Expect(err.Error()).Should(Equal("cannot materialize tables from different source types")) + Expect(response).Should(BeNil()) + }) + }) + + Context("the compiled query includes only object store sources", func() { + It("creates a file-based materialization", func() { + freeNames := []string{"file_backed_source1", "file_backed_source2"} + expression := "file_backed_to_pulsar" + materializationName := "NAME_" + expression + + newMaterialization := &v1alpha.Materialization{ + MaterializationName: materializationName, + Expression: expression, + Destination: pulsarDestination, + } + + mockMaterializationManager := compute.NewMockMaterializationManager(GinkgoT()) + expectedCompileResponse := &v1alpha.CompileResponse{ + MissingNames: []string{}, + FenlDiagnostics: nil, + Plan: &v1alpha.ComputePlan{}, + FreeNames: freeNames, + } + mockMaterializationManager.EXPECT().CompileV1Materialization(mock.Anything, owner, newMaterialization).Return(expectedCompileResponse, nil, nil) + + mockComputeManager := compute.NewMockComputeManager(GinkgoT()) + mockComputeManager.EXPECT().RunMaterializations(mock.Anything, owner) + + mockKaskadaTableClient := internal.NewMockKaskadaTableClient(GinkgoT()) + tablesResponse := map[string]*ent.KaskadaTable{ + "file_backed_source1": { + Name: "file_backed_source1", + Source: &v1alpha.Source{Source: &v1alpha.Source_Kaskada{Kaskada: &v1alpha.KaskadaSource{}}}, + }, + "file_backed_source2": { + Name: "file_backed_source2", + Source: &v1alpha.Source{Source: &v1alpha.Source_Kaskada{Kaskada: &v1alpha.KaskadaSource{}}}, + }, + } + + mockKaskadaTableClient.EXPECT().GetKaskadaTablesFromNames(mock.Anything, owner, freeNames).Return(tablesResponse, nil) + + mockKaskadaViewClient := internal.NewMockKaskadaViewClient(GinkgoT()) + mockKaskadaViewClient.EXPECT().GetKaskadaViewsFromNames(mock.Anything, owner, freeNames).Return(map[string]*ent.KaskadaView{}, nil) + + mockMaterializationClient := internal.NewMockMaterializationClient(GinkgoT()) + newMaterializationResponse := &ent.Materialization{ + ID: uuid.New(), + Name: materializationName, + Expression: expression, + Version: int64(0), + WithViews: &v1alpha.WithViews{Views: []*v1alpha.WithView{}}, + Destination: pulsarDestination, + SliceRequest: &v1alpha.SliceRequest{}, + Analysis: getAnalysisFromCompileResponse(expectedCompileResponse), + DataVersionID: int64(0), + SourceType: materialization.SourceTypeFiles, + } + + mockMaterializationClient.EXPECT().CreateMaterialization(mock.Anything, owner, mock.Anything, mock.Anything).Return(newMaterializationResponse, nil) + mockMaterializationClient.EXPECT().GetMaterialization(mock.Anything, owner, newMaterializationResponse.ID).Return(newMaterializationResponse, nil) + + mockDataTokenClient := internal.NewMockDataTokenClient(GinkgoT()) + dataTokenFromVersionResponse := &ent.DataToken{ID: uuid.New()} + mockDataTokenClient.EXPECT().GetDataTokenFromVersion(mock.Anything, owner, newMaterializationResponse.DataVersionID).Return(dataTokenFromVersionResponse, nil) + + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + computeManager: mockComputeManager, + materializationManager: mockMaterializationManager, + dataTokenClient: mockDataTokenClient, + materializationClient: mockMaterializationClient, + kaskadaTableClient: mockKaskadaTableClient, + kaskadaViewClient: mockKaskadaViewClient, + } + + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{ + Materialization: newMaterialization, + }) + Expect(err).ShouldNot(HaveOccurred()) + Expect(response).ShouldNot(BeNil()) + Expect(response.Materialization).ShouldNot(BeNil()) + Expect(response.Materialization.DataTokenId).Should(Equal(dataTokenFromVersionResponse.ID.String())) + Expect(response.Analysis).ShouldNot(BeNil()) + }) + }) + + Context("the compiled query includes only stream sources", func() { + It("creates a stream-based materialization", func() { + freeNames := []string{"pulsar_backed_source1", "pulsar_backed_source2"} + expression := "puslar_backed_to_object_store" + materializationName := "NAME_" + expression + + newMaterialization := &v1alpha.Materialization{ + MaterializationName: materializationName, + Expression: expression, + Destination: pulsarDestination, + } + + mockMaterializationManager := compute.NewMockMaterializationManager(GinkgoT()) + expectedCompileResponse := &v1alpha.CompileResponse{ + MissingNames: []string{}, + FenlDiagnostics: nil, + Plan: &v1alpha.ComputePlan{}, + FreeNames: freeNames, + } + mockMaterializationManager.EXPECT().CompileV1Materialization(mock.Anything, owner, newMaterialization).Return(expectedCompileResponse, nil, nil) + + mockKaskadaTableClient := internal.NewMockKaskadaTableClient(GinkgoT()) + tablesResponse := map[string]*ent.KaskadaTable{ + "pulsar_backed_source1": { + Name: "pulsar_backed_source1", + Source: &v1alpha.Source{Source: &v1alpha.Source_Pulsar{Pulsar: &v1alpha.PulsarSource{}}}, + }, + "pulsar_backed_source2": { + Name: "pulsar_backed_source2", + Source: &v1alpha.Source{Source: &v1alpha.Source_Pulsar{Pulsar: &v1alpha.PulsarSource{}}}, + }, + } + + mockKaskadaTableClient.EXPECT().GetKaskadaTablesFromNames(mock.Anything, owner, freeNames).Return(tablesResponse, nil) + + mockKaskadaViewClient := internal.NewMockKaskadaViewClient(GinkgoT()) + mockKaskadaViewClient.EXPECT().GetKaskadaViewsFromNames(mock.Anything, owner, freeNames).Return(map[string]*ent.KaskadaView{}, nil) + + mockMaterializationClient := internal.NewMockMaterializationClient(GinkgoT()) + newMaterializationResponse := &ent.Materialization{ + ID: uuid.New(), + Name: materializationName, + Expression: expression, + Version: int64(0), + WithViews: &v1alpha.WithViews{Views: []*v1alpha.WithView{}}, + Destination: objectStoreDestination, + SliceRequest: &v1alpha.SliceRequest{}, + Analysis: getAnalysisFromCompileResponse(expectedCompileResponse), + DataVersionID: int64(0), + SourceType: materialization.SourceTypeStreams, + } + + mockMaterializationClient.EXPECT().CreateMaterialization(mock.Anything, owner, mock.Anything, mock.Anything).Return(newMaterializationResponse, nil) + mockMaterializationClient.EXPECT().GetMaterialization(mock.Anything, owner, newMaterializationResponse.ID).Return(newMaterializationResponse, nil) + + mockMaterializationManager.EXPECT().StartMaterialization(mock.Anything, owner, newMaterializationResponse.ID.String(), expectedCompileResponse, objectStoreDestination).Return(nil) + + mockDataTokenClient := internal.NewMockDataTokenClient(GinkgoT()) + dataTokenFromVersionResponse := &ent.DataToken{ID: uuid.New()} + mockDataTokenClient.EXPECT().GetDataTokenFromVersion(mock.Anything, owner, newMaterializationResponse.DataVersionID).Return(dataTokenFromVersionResponse, nil) + + materializationService := &materializationService{ + UnimplementedMaterializationServiceServer: v1alpha.UnimplementedMaterializationServiceServer{}, + computeManager: nil, + materializationManager: mockMaterializationManager, + dataTokenClient: mockDataTokenClient, + materializationClient: mockMaterializationClient, + kaskadaTableClient: mockKaskadaTableClient, + kaskadaViewClient: mockKaskadaViewClient, + } + + response, err := materializationService.createMaterialization(context.Background(), owner, &v1alpha.CreateMaterializationRequest{ + Materialization: newMaterialization, + }) + Expect(err).ShouldNot(HaveOccurred()) + Expect(response).ShouldNot(BeNil()) + Expect(response.Materialization).ShouldNot(BeNil()) + Expect(response.Materialization.DataTokenId).Should(Equal(dataTokenFromVersionResponse.ID.String())) + Expect(response.Analysis).ShouldNot(BeNil()) + }) + }) + }) +}) diff --git a/wren/service/service_suite_test.go b/wren/service/service_suite_test.go new file mode 100644 index 000000000..b3d922b57 --- /dev/null +++ b/wren/service/service_suite_test.go @@ -0,0 +1,13 @@ +package service_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestService(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Service Suite") +} diff --git a/wren/service/table.go b/wren/service/table.go index eaa7c2033..81b93e90e 100644 --- a/wren/service/table.go +++ b/wren/service/table.go @@ -157,6 +157,7 @@ func (t *tableService) CreateTable(ctx context.Context, request *v1alpha.CreateT } func (t *tableService) createTable(ctx context.Context, owner *ent.Owner, request *v1alpha.CreateTableRequest) (*v1alpha.CreateTableResponse, error) { + subLogger := log.Ctx(ctx).With().Str("method", "tableService.createTable").Logger() table := request.Table // if no table source passed in request, set it to Kaskada source @@ -178,6 +179,24 @@ func (t *tableService) createTable(ctx context.Context, owner *ent.Owner, reques newTable.SubsortColumnName = &request.Table.SubsortColumnName.Value } + switch s := table.Source.Source.(type) { + case *v1alpha.Source_Kaskada: // if the table soure is kaskada, do nothing + case *v1alpha.Source_Pulsar: // if the table soure is pulsar, validate the schema + streamSchema, err := t.fileManager.GetPulsarSchema(ctx, s.Pulsar.Config) + if err != nil { + subLogger.Error().Err(err).Msg("issue getting schema for file") + return nil, reMapSparrowError(ctx, err) + } + + err = t.validateSchema(ctx, *newTable, streamSchema) + if err != nil { + return nil, err + } + newTable.MergedSchema = streamSchema + default: + return nil, customerrors.NewInvalidArgumentError("invalid source type") + } + kaskadaTable, err := t.kaskadaTableClient.CreateKaskadaTable(ctx, owner, newTable) if err != nil { return nil, err @@ -293,7 +312,13 @@ func (t *tableService) loadFileIntoTable(ctx context.Context, owner *ent.Owner, return nil, customerrors.NewNotFoundErrorWithCustomText(fmt.Sprintf("file: %s not found by the kaskada service", fileInput.GetURI())) } - fileSchema, err := t.validateFileSchema(ctx, *kaskadaTable, fileInput) + fileSchema, err := t.fileManager.GetFileSchema(ctx, fileInput) + if err != nil { + subLogger.Error().Err(err).Msg("issue getting schema for file") + return nil, reMapSparrowError(ctx, err) + } + + err = t.validateSchema(ctx, *kaskadaTable, fileSchema) if err != nil { return nil, err } @@ -331,41 +356,35 @@ func (t *tableService) loadFileIntoTable(ctx context.Context, owner *ent.Owner, return newDataToken, nil } -// validateFileSchema performs the validation of the a file vs the desired table +// validateSchema performs the validation of the a file vs the desired table // assumes that the provided table and schema are up to date and valid // TODO: return all the return all the potential issues in a single response. Similar to how the request validation works. -func (t *tableService) validateFileSchema(ctx context.Context, kaskadaTable ent.KaskadaTable, fileInput internal.FileInput) (*v1alpha.Schema, error) { - subLogger := log.Ctx(ctx).With().Str("method", "table.validateFileSchema").Logger() - - fileSchema, err := t.fileManager.GetFileSchema(ctx, fileInput) - if err != nil { - subLogger.Error().Err(err).Msg("issue getting schema for file") - return nil, reMapSparrowError(ctx, err) - } +func (t *tableService) validateSchema(ctx context.Context, kaskadaTable ent.KaskadaTable, schema *v1alpha.Schema) error { + subLogger := log.Ctx(ctx).With().Str("method", "table.validateSchema").Logger() - if kaskadaTable.MergedSchema != nil && !proto.Equal(kaskadaTable.MergedSchema, fileSchema) { - subLogger.Warn().Interface("table_schema", kaskadaTable.MergedSchema).Interface("file_schema", fileSchema).Str("file_uri", fileInput.GetURI()).Msg("new file doesn't match schema of table") - return nil, customerrors.NewFailedPreconditionError("file schema does not match previous files") + if kaskadaTable.MergedSchema != nil && !proto.Equal(kaskadaTable.MergedSchema, schema) { + subLogger.Warn().Interface("table_schema", kaskadaTable.MergedSchema).Interface("schema", schema).Msg("new schema doesn't match schema of table") + return customerrors.NewFailedPreconditionError("schema does not match previous schema") } columnNames := map[string]interface{}{} - for _, field := range fileSchema.Fields { + for _, field := range schema.Fields { columnNames[field.Name] = nil } if _, ok := columnNames[kaskadaTable.EntityKeyColumnName]; !ok { - return nil, customerrors.NewFailedPreconditionError(fmt.Sprintf("file does not contain entity key column: '%s'", kaskadaTable.EntityKeyColumnName)) + return customerrors.NewFailedPreconditionError(fmt.Sprintf("schema does not contain entity key column: '%s'", kaskadaTable.EntityKeyColumnName)) } if kaskadaTable.SubsortColumnName != nil { if _, ok := columnNames[*kaskadaTable.SubsortColumnName]; !ok { - return nil, customerrors.NewFailedPreconditionError(fmt.Sprintf("file does not contain subsort column: '%s'", *kaskadaTable.SubsortColumnName)) + return customerrors.NewFailedPreconditionError(fmt.Sprintf("schema does not contain subsort column: '%s'", *kaskadaTable.SubsortColumnName)) } } if _, ok := columnNames[kaskadaTable.TimeColumnName]; !ok { - return nil, customerrors.NewFailedPreconditionError(fmt.Sprintf("file does not contain time column: '%s'", kaskadaTable.TimeColumnName)) + return customerrors.NewFailedPreconditionError(fmt.Sprintf("schema does not contain time column: '%s'", kaskadaTable.TimeColumnName)) } - return fileSchema, nil + return nil } func reMapSparrowError(ctx context.Context, err error) error {