diff --git a/internal/langserver/handlers/hooks_module.go b/internal/langserver/handlers/hooks_module.go index 148384422..d6fa15981 100644 --- a/internal/langserver/handlers/hooks_module.go +++ b/internal/langserver/handlers/hooks_module.go @@ -3,51 +3,50 @@ package handlers import ( "context" "fmt" - "log" - "github.com/creachadair/jrpc2" "github.com/hashicorp/terraform-ls/internal/langserver/diagnostics" + "github.com/hashicorp/terraform-ls/internal/langserver/notifier" "github.com/hashicorp/terraform-ls/internal/langserver/session" "github.com/hashicorp/terraform-ls/internal/state" "github.com/hashicorp/terraform-ls/internal/telemetry" "github.com/hashicorp/terraform-schema/backend" ) -func sendModuleTelemetry(ctx context.Context, store *state.StateStore, telemetrySender telemetry.Sender) state.ModuleChangeHook { - return func(oldMod, newMod *state.Module) { - if newMod == nil { - // module is being removed - // TODO: Track module removal as an event - return +func sendModuleTelemetry(store *state.StateStore, telemetrySender telemetry.Sender) notifier.Hook { + return func(ctx context.Context, changes state.ModuleChanges) error { + if changes.IsRemoval { + // we ignore removed modules for now + return nil } - properties, hasChanged := moduleTelemetryData(oldMod, newMod, store) - if !hasChanged { - // avoid sending telemetry if nothing has changed - return + mod, err := notifier.ModuleFromContext(ctx) + if err != nil { + return err } - telemetrySender.SendEvent(ctx, "moduleData", properties) + properties, hasChanged := moduleTelemetryData(mod, changes, store) + if hasChanged { + telemetrySender.SendEvent(ctx, "moduleData", properties) + } + return nil } } -func moduleTelemetryData(oldMod, newMod *state.Module, store *state.StateStore) (map[string]interface{}, bool) { +func moduleTelemetryData(mod *state.Module, ch state.ModuleChanges, store *state.StateStore) (map[string]interface{}, bool) { properties := make(map[string]interface{}) - hasChanged := false + hasChanged := ch.CoreRequirements || ch.Backend || ch.ProviderRequirements || + ch.TerraformVersion || ch.InstalledProviders - if oldMod == nil || !oldMod.Meta.CoreRequirements.Equals(newMod.Meta.CoreRequirements) { - hasChanged = true - } - if len(newMod.Meta.CoreRequirements) > 0 { - properties["tfRequirements"] = newMod.Meta.CoreRequirements.String() + if !hasChanged { + return properties, false } - if oldMod == nil || !oldMod.Meta.Backend.Equals(newMod.Meta.Backend) { - hasChanged = true + if len(mod.Meta.CoreRequirements) > 0 { + properties["tfRequirements"] = mod.Meta.CoreRequirements.String() } - if newMod.Meta.Backend != nil { - properties["backend"] = newMod.Meta.Backend.Type - if data, ok := newMod.Meta.Backend.Data.(*backend.Remote); ok { + if mod.Meta.Backend != nil { + properties["backend"] = mod.Meta.Backend.Type + if data, ok := mod.Meta.Backend.Data.(*backend.Remote); ok { hostname := data.Hostname // anonymize any non-default hostnames @@ -58,13 +57,9 @@ func moduleTelemetryData(oldMod, newMod *state.Module, store *state.StateStore) properties["backend.remote.hostname"] = hostname } } - - if oldMod == nil || !oldMod.Meta.ProviderRequirements.Equals(newMod.Meta.ProviderRequirements) { - hasChanged = true - } - if len(newMod.Meta.ProviderRequirements) > 0 { + if len(mod.Meta.ProviderRequirements) > 0 { reqs := make(map[string]string, 0) - for pAddr, cons := range newMod.Meta.ProviderRequirements { + for pAddr, cons := range mod.Meta.ProviderRequirements { if telemetry.IsPublicProvider(pAddr) { reqs[pAddr.String()] = cons.String() continue @@ -80,20 +75,12 @@ func moduleTelemetryData(oldMod, newMod *state.Module, store *state.StateStore) } properties["providerRequirements"] = reqs } - - if oldMod == nil || !oldMod.TerraformVersion.Equal(newMod.TerraformVersion) { - hasChanged = true - } - if newMod.TerraformVersion != nil { - properties["tfVersion"] = newMod.TerraformVersion.String() - } - - if oldMod == nil || !oldMod.InstalledProviders.Equals(newMod.InstalledProviders) { - hasChanged = true + if mod.TerraformVersion != nil { + properties["tfVersion"] = mod.TerraformVersion.String() } - if len(newMod.InstalledProviders) > 0 { + if len(mod.InstalledProviders) > 0 { installedProviders := make(map[string]string, 0) - for pAddr, pv := range newMod.InstalledProviders { + for pAddr, pv := range mod.InstalledProviders { if telemetry.IsPublicProvider(pAddr) { versionString := "" if pv != nil { @@ -118,7 +105,7 @@ func moduleTelemetryData(oldMod, newMod *state.Module, store *state.StateStore) return nil, false } - modId, err := store.GetModuleID(newMod.Path) + modId, err := store.GetModuleID(mod.Path) if err != nil { return nil, false } @@ -127,73 +114,87 @@ func moduleTelemetryData(oldMod, newMod *state.Module, store *state.StateStore) return properties, true } -func updateDiagnostics(ctx context.Context, notifier *diagnostics.Notifier) state.ModuleChangeHook { - return func(oldMod, newMod *state.Module) { - oldDiags, newDiags := 0, 0 - if oldMod != nil { - oldDiags = oldMod.ModuleDiagnostics.Count() + oldMod.VarsDiagnostics.Count() - } - if newMod != nil { - newDiags = newMod.ModuleDiagnostics.Count() + newMod.VarsDiagnostics.Count() - } - - if oldDiags == 0 && newDiags == 0 { - return - } +func updateDiagnostics(dNotifier *diagnostics.Notifier) notifier.Hook { + return func(ctx context.Context, changes state.ModuleChanges) error { + if changes.Diagnostics { + mod, err := notifier.ModuleFromContext(ctx) + if err != nil { + return err + } - diags := diagnostics.NewDiagnostics() - diags.EmptyRootDiagnostic() + diags := diagnostics.NewDiagnostics() + diags.EmptyRootDiagnostic() - defer notifier.PublishHCLDiags(ctx, newMod.Path, diags) + defer dNotifier.PublishHCLDiags(ctx, mod.Path, diags) - if newMod != nil { - diags.Append("HCL", newMod.ModuleDiagnostics.AsMap()) - diags.Append("HCL", newMod.VarsDiagnostics.AutoloadedOnly().AsMap()) + if mod != nil { + diags.Append("HCL", mod.ModuleDiagnostics.AsMap()) + diags.Append("HCL", mod.VarsDiagnostics.AutoloadedOnly().AsMap()) + } } + return nil } } -func callClientCommand(ctx context.Context, clientRequester session.ClientCaller, logger *log.Logger, commandId string) state.ModuleChangeHook { - return func(oldMod, newMod *state.Module) { - var modPath string - if oldMod != nil { - modPath = oldMod.Path - } else { - modPath = newMod.Path +func callRefreshClientCommand(clientRequester session.ClientCaller, commandId string) notifier.Hook { + return func(ctx context.Context, changes state.ModuleChanges) error { + // TODO: avoid triggering if module calls/providers did not change + isOpen, err := notifier.ModuleIsOpen(ctx) + if err != nil { + return err } - _, err := clientRequester.Callback(ctx, commandId, nil) - if err != nil { - logger.Printf("Error calling %s for %s: %s", commandId, modPath, err) + if isOpen { + mod, err := notifier.ModuleFromContext(ctx) + if err != nil { + return err + } + + _, err = clientRequester.Callback(ctx, commandId, nil) + if err != nil { + return fmt.Errorf("Error calling %s for %s: %s", commandId, mod.Path, err) + } } + + return nil } } -func refreshCodeLens(ctx context.Context, clientRequester session.ClientCaller) state.ModuleChangeHook { - return func(oldMod, newMod *state.Module) { - oldOrigins, oldTargets := 0, 0 - if oldMod != nil { - oldOrigins = len(oldMod.RefOrigins) - oldTargets = len(oldMod.RefTargets) - } - newOrigins, newTargets := 0, 0 - if newMod != nil { - newOrigins = len(newMod.RefOrigins) - newTargets = len(newMod.RefTargets) - } - - if oldOrigins != newOrigins || oldTargets != newTargets { - clientRequester.Callback(ctx, "workspace/codeLens/refresh", nil) +func refreshCodeLens(clientRequester session.ClientCaller) notifier.Hook { + return func(ctx context.Context, changes state.ModuleChanges) error { + // TODO: avoid triggering for new targets outside of open module + if changes.ReferenceOrigins || changes.ReferenceTargets { + _, err := clientRequester.Callback(ctx, "workspace/codeLens/refresh", nil) + if err != nil { + return err + } } + return nil } } -func refreshSemanticTokens(ctx context.Context, svrCtx context.Context, logger *log.Logger) state.ModuleChangeHook { - return func(_, newMod *state.Module) { - jrpcsvc := jrpc2.ServerFromContext(ctx) - _, err := jrpcsvc.Callback(svrCtx, "workspace/semanticTokens/refresh", nil) +func refreshSemanticTokens(clientRequester session.ClientCaller) notifier.Hook { + return func(ctx context.Context, changes state.ModuleChanges) error { + isOpen, err := notifier.ModuleIsOpen(ctx) if err != nil { - logger.Printf("Error refreshing %s: %s", newMod.Path, err) + return err } + + localChanges := isOpen && (changes.TerraformVersion || changes.CoreRequirements || + changes.InstalledProviders || changes.ProviderRequirements) + + if localChanges || changes.ReferenceOrigins || changes.ReferenceTargets { + mod, err := notifier.ModuleFromContext(ctx) + if err != nil { + return err + } + + _, err = clientRequester.Callback(ctx, "workspace/semanticTokens/refresh", nil) + if err != nil { + return fmt.Errorf("Error refreshing %s: %s", mod.Path, err) + } + } + + return nil } } diff --git a/internal/langserver/handlers/service.go b/internal/langserver/handlers/service.go index 770f02cd1..88bd28e1c 100644 --- a/internal/langserver/handlers/service.go +++ b/internal/langserver/handlers/service.go @@ -18,6 +18,7 @@ import ( "github.com/hashicorp/terraform-ls/internal/document" "github.com/hashicorp/terraform-ls/internal/filesystem" "github.com/hashicorp/terraform-ls/internal/langserver/diagnostics" + "github.com/hashicorp/terraform-ls/internal/langserver/notifier" "github.com/hashicorp/terraform-ls/internal/langserver/session" ilsp "github.com/hashicorp/terraform-ls/internal/lsp" lsp "github.com/hashicorp/terraform-ls/internal/protocol" @@ -58,6 +59,7 @@ type service struct { stateStore *state.StateStore server session.Server diagsNotifier *diagnostics.Notifier + notifier *notifier.Notifier walkerCollector *module.WalkerCollector additionalHandlers map[string]rpch.Func @@ -420,9 +422,10 @@ func (svc *service) configureSessionDependencies(ctx context.Context, cfgOpts *s } svc.stateStore.SetLogger(svc.logger) - svc.stateStore.Modules.ChangeHooks = state.ModuleChangeHooks{ - updateDiagnostics(svc.sessCtx, svc.diagsNotifier), - sendModuleTelemetry(svc.sessCtx, svc.stateStore, svc.telemetry), + + moduleHooks := []notifier.Hook{ + updateDiagnostics(svc.diagsNotifier), + sendModuleTelemetry(svc.stateStore, svc.telemetry), } svc.closedDirIndexer = scheduler.NewScheduler(&closedDirJobStore{svc.stateStore.JobStore}, 1) @@ -438,26 +441,26 @@ func (svc *service) configureSessionDependencies(ctx context.Context, cfgOpts *s cc, err := ilsp.ClientCapabilities(ctx) if err == nil { if _, ok = lsp.ExperimentalClientCapabilities(cc.Experimental).ShowReferencesCommandId(); ok { - svc.stateStore.Modules.ChangeHooks = append(svc.stateStore.Modules.ChangeHooks, - refreshCodeLens(svc.sessCtx, svc.server)) + moduleHooks = append(moduleHooks, refreshCodeLens(svc.server)) } if commandId, ok := lsp.ExperimentalClientCapabilities(cc.Experimental).RefreshModuleProvidersCommandId(); ok { - svc.stateStore.Modules.ChangeHooks = append(svc.stateStore.Modules.ChangeHooks, - callClientCommand(svc.sessCtx, svc.server, svc.logger, commandId)) + moduleHooks = append(moduleHooks, callRefreshClientCommand(svc.server, commandId)) } if commandId, ok := lsp.ExperimentalClientCapabilities(cc.Experimental).RefreshModuleCallsCommandId(); ok { - svc.stateStore.Modules.ChangeHooks = append(svc.stateStore.Modules.ChangeHooks, - callClientCommand(svc.sessCtx, svc.server, svc.logger, commandId)) + moduleHooks = append(moduleHooks, callRefreshClientCommand(svc.server, commandId)) } if cc.Workspace.SemanticTokens.RefreshSupport { - svc.stateStore.Modules.ChangeHooks = append(svc.stateStore.Modules.ChangeHooks, - refreshSemanticTokens(ctx, svc.srvCtx, svc.logger)) + moduleHooks = append(moduleHooks, refreshSemanticTokens(svc.server)) } } + svc.notifier = notifier.NewNotifier(svc.stateStore.Modules, moduleHooks) + svc.notifier.SetLogger(svc.logger) + svc.notifier.Start(svc.sessCtx) + svc.modStore = svc.stateStore.Modules svc.schemaStore = svc.stateStore.ProviderSchemas diff --git a/internal/langserver/notifier/notifier.go b/internal/langserver/notifier/notifier.go new file mode 100644 index 000000000..51d545223 --- /dev/null +++ b/internal/langserver/notifier/notifier.go @@ -0,0 +1,109 @@ +package notifier + +import ( + "context" + "errors" + "io/ioutil" + "log" + + "github.com/hashicorp/terraform-ls/internal/state" +) + +type moduleCtxKey struct{} +type moduleIsOpenCtxKey struct{} + +type Notifier struct { + modStore ModuleStore + hooks []Hook + logger *log.Logger +} + +type ModuleStore interface { + AwaitNextChangeBatch(ctx context.Context) (state.ModuleChangeBatch, error) + ModuleByPath(path string) (*state.Module, error) +} + +type Hook func(ctx context.Context, changes state.ModuleChanges) error + +func NewNotifier(modStore ModuleStore, hooks []Hook) *Notifier { + return &Notifier{ + modStore: modStore, + hooks: hooks, + logger: defaultLogger, + } +} + +func (n *Notifier) SetLogger(logger *log.Logger) { + n.logger = logger +} + +func (n *Notifier) Start(ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + n.logger.Printf("stopping notifier: %s", ctx.Err()) + return + default: + } + + err := n.notify(ctx) + if err != nil { + n.logger.Printf("failed to notify a change batch: %s", err) + } + } + }() +} + +func (n *Notifier) notify(ctx context.Context) error { + changeBatch, err := n.modStore.AwaitNextChangeBatch(ctx) + if err != nil { + return err + } + + mod, err := n.modStore.ModuleByPath(changeBatch.DirHandle.Path()) + if err != nil { + return err + } + ctx = withModule(ctx, mod) + + ctx = withModuleIsOpen(ctx, changeBatch.IsDirOpen) + + for i, h := range n.hooks { + err = h(ctx, changeBatch.Changes) + if err != nil { + n.logger.Printf("hook error (%d): %s", i, err) + continue + } + } + + return nil +} + +func withModule(ctx context.Context, mod *state.Module) context.Context { + return context.WithValue(ctx, moduleCtxKey{}, mod) +} + +func ModuleFromContext(ctx context.Context) (*state.Module, error) { + mod, ok := ctx.Value(moduleCtxKey{}).(*state.Module) + if !ok { + return nil, errors.New("module data not found") + } + + return mod, nil +} + +func withModuleIsOpen(ctx context.Context, isOpen bool) context.Context { + return context.WithValue(ctx, moduleIsOpenCtxKey{}, isOpen) +} + +func ModuleIsOpen(ctx context.Context) (bool, error) { + isOpen, ok := ctx.Value(moduleIsOpenCtxKey{}).(bool) + if !ok { + return false, errors.New("module open state not found") + } + + return isOpen, nil +} + +var defaultLogger = log.New(ioutil.Discard, "", 0) diff --git a/internal/langserver/notifier/notifier_test.go b/internal/langserver/notifier/notifier_test.go new file mode 100644 index 000000000..75eb83ea8 --- /dev/null +++ b/internal/langserver/notifier/notifier_test.go @@ -0,0 +1,70 @@ +package notifier + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "sync" + "testing" + "time" + + "github.com/hashicorp/terraform-ls/internal/document" + "github.com/hashicorp/terraform-ls/internal/state" +) + +func TestNotifier(t *testing.T) { + ctx, cancelFunc := context.WithCancel(context.Background()) + + var wg sync.WaitGroup + wg.Add(2) + + hookFunc := func(ctx context.Context, changes state.ModuleChanges) error { + wg.Done() + cancelFunc() + return nil + } + notifier := NewNotifier(mockModuleStore{modPath: t.TempDir()}, []Hook{ + hookFunc, + hookFunc, + }) + notifier.SetLogger(testLogger()) + + notifier.Start(ctx) + + wg.Wait() +} + +type mockModuleStore struct { + returned bool + modPath string +} + +func (mms mockModuleStore) AwaitNextChangeBatch(ctx context.Context) (state.ModuleChangeBatch, error) { + if mms.returned { + return state.ModuleChangeBatch{}, fmt.Errorf("no more batches") + } + defer func() { mms.returned = true }() + + return state.ModuleChangeBatch{ + DirHandle: document.DirHandleFromPath(mms.modPath), + FirstChangeTime: time.Date(2022, 5, 26, 0, 0, 0, 0, time.UTC), + }, nil +} + +func (mms mockModuleStore) ModuleByPath(path string) (*state.Module, error) { + if path != mms.modPath { + return nil, fmt.Errorf("unexpected path: %q", path) + } + + return &state.Module{ + Path: path, + }, nil +} + +func testLogger() *log.Logger { + if testing.Verbose() { + return log.Default() + } + return log.New(ioutil.Discard, "", 0) +} diff --git a/internal/state/documents.go b/internal/state/documents.go index cfcd21481..380403138 100644 --- a/internal/state/documents.go +++ b/internal/state/documents.go @@ -56,6 +56,10 @@ func (s *DocumentStore) OpenDocument(dh document.Handle, langId string, version if err != nil { return err } + err = updateModuleChangeDirOpenMark(txn, dh.Dir, true) + if err != nil { + return err + } txn.Commit() return nil @@ -142,6 +146,11 @@ func (s *DocumentStore) CloseDocument(dh document.Handle) error { return err } + err = updateModuleChangeDirOpenMark(txn, dh.Dir, false) + if err != nil { + return err + } + txn.Commit() return nil } @@ -175,8 +184,11 @@ func (s *DocumentStore) IsDocumentOpen(dh document.Handle) (bool, error) { func (s *DocumentStore) HasOpenDocuments(dirHandle document.DirHandle) (bool, error) { txn := s.db.Txn(false) + return dirHasOpenDocuments(txn, dirHandle) +} - obj, err := txn.First(s.tableName, "dir", dirHandle) +func dirHasOpenDocuments(txn *memdb.Txn, dirHandle document.DirHandle) (bool, error) { + obj, err := txn.First(documentsTableName, "dir", dirHandle) if err != nil { return false, err } diff --git a/internal/state/hooks.go b/internal/state/hooks.go deleted file mode 100644 index 2c6c03251..000000000 --- a/internal/state/hooks.go +++ /dev/null @@ -1,11 +0,0 @@ -package state - -type ModuleChangeHook func(oldMod, newMod *Module) - -type ModuleChangeHooks []ModuleChangeHook - -func (mh ModuleChangeHooks) notifyModuleChange(oldMod, newMod *Module) { - for _, h := range mh { - h(oldMod, newMod) - } -} diff --git a/internal/state/jobs.go b/internal/state/jobs.go index 5e7429c27..1648bcce5 100644 --- a/internal/state/jobs.go +++ b/internal/state/jobs.go @@ -125,6 +125,26 @@ func (js *JobStore) DequeueJobsForDir(dir document.DirHandle) error { return nil } +func jobsExistForDirHandle(txn *memdb.Txn, dir document.DirHandle) (<-chan struct{}, bool, error) { + wCh, runningObj, err := txn.FirstWatch(jobsTableName, "dir_state", dir, StateRunning) + if err != nil { + return nil, false, err + } + if runningObj != nil { + return wCh, true, nil + } + + queuedObj, err := txn.First(jobsTableName, "dir_state", dir, StateQueued) + if err != nil { + return nil, false, err + } + if queuedObj != nil { + return wCh, true, nil + } + + return nil, false, nil +} + func updateJobsDirOpenMark(txn *memdb.Txn, dirHandle document.DirHandle, isDirOpen bool) error { it, err := txn.Get(jobsTableName, "dir_state", dirHandle, StateQueued) if err != nil { diff --git a/internal/state/module.go b/internal/state/module.go index 2d4d9665d..4a77861f1 100644 --- a/internal/state/module.go +++ b/internal/state/module.go @@ -245,9 +245,10 @@ func (s *ModuleStore) Add(modPath string) error { return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(nil, mod) - }) + err = s.queueModuleChange(txn, nil, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -267,10 +268,11 @@ func (s *ModuleStore) Remove(modPath string) error { return nil } - txn.Defer(func() { - oldMod := oldObj.(*Module) - go s.ChangeHooks.notifyModuleChange(oldMod, nil) - }) + oldMod := oldObj.(*Module) + err = s.queueModuleChange(txn, oldMod, nil) + if err != nil { + return err + } _, err = txn.DeleteAll(s.tableName, "id", modPath) if err != nil { @@ -419,9 +421,10 @@ func (s *ModuleStore) UpdateInstalledProviders(path string, pvs map[tfaddr.Provi return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -484,10 +487,10 @@ func (s *ModuleStore) UpdateModManifest(path string, manifest *datadir.ModuleMan return err } - txn.Defer(func() { - s.logger.Printf("Queuing refresh for %s", path) - go s.ChangeHooks.notifyModuleChange(nil, mod) - }) + err = s.queueModuleChange(txn, nil, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -508,9 +511,10 @@ func (s *ModuleStore) SetTerraformVersionState(path string, state op.OpState) er return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(nil, mod) - }) + err = s.queueModuleChange(txn, nil, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -555,9 +559,10 @@ func (s *ModuleStore) FinishProviderSchemaLoading(path string, psErr error) erro return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -584,19 +589,16 @@ func (s *ModuleStore) UpdateTerraformVersion(modPath string, tfVer *version.Vers return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } err = updateProviderVersions(txn, modPath, pv) if err != nil { return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(nil, mod) - }) - txn.Commit() return nil } @@ -738,9 +740,10 @@ func (s *ModuleStore) UpdateMetadata(path string, meta *tfmod.Meta, mErr error) return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -763,9 +766,10 @@ func (s *ModuleStore) UpdateModuleDiagnostics(path string, diags ast.ModDiags) e return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } txn.Commit() return nil @@ -788,9 +792,10 @@ func (s *ModuleStore) UpdateVarsDiagnostics(path string, diags ast.VarsDiags) er return err } - txn.Defer(func() { - go s.ChangeHooks.notifyModuleChange(oldMod, mod) - }) + err = s.queueModuleChange(txn, oldMod, mod) + if err != nil { + return err + } txn.Commit() return nil diff --git a/internal/state/module_changes.go b/internal/state/module_changes.go new file mode 100644 index 000000000..f2a8013b6 --- /dev/null +++ b/internal/state/module_changes.go @@ -0,0 +1,256 @@ +package state + +import ( + "context" + "fmt" + "time" + + "github.com/hashicorp/go-memdb" + "github.com/hashicorp/terraform-ls/internal/document" +) + +type ModuleChangeBatch struct { + DirHandle document.DirHandle + FirstChangeTime time.Time + IsDirOpen bool + Changes ModuleChanges +} + +func (mcb ModuleChangeBatch) Copy() ModuleChangeBatch { + return ModuleChangeBatch{ + DirHandle: mcb.DirHandle, + FirstChangeTime: mcb.FirstChangeTime, + IsDirOpen: mcb.IsDirOpen, + Changes: mcb.Changes, + } +} + +type ModuleChanges struct { + // IsRemoval indicates whether this batch represents removal of a module + IsRemoval bool + + CoreRequirements bool + Backend bool + ProviderRequirements bool + TerraformVersion bool + InstalledProviders bool + Diagnostics bool + ReferenceOrigins bool + ReferenceTargets bool +} + +const maxTimespan = 1 * time.Second + +func (s *ModuleStore) queueModuleChange(txn *memdb.Txn, oldMod, newMod *Module) error { + var modHandle document.DirHandle + if oldMod != nil { + modHandle = document.DirHandleFromPath(oldMod.Path) + } else { + modHandle = document.DirHandleFromPath(newMod.Path) + } + obj, err := txn.First(moduleChangesTableName, "id", modHandle) + if err != nil { + return err + } + + var cb ModuleChangeBatch + if obj != nil { + batch := obj.(ModuleChangeBatch) + cb = batch.Copy() + } else { + // create new change batch + isDirOpen, err := dirHasOpenDocuments(txn, modHandle) + if err != nil { + return err + } + cb = ModuleChangeBatch{ + DirHandle: modHandle, + FirstChangeTime: s.TimeProvider(), + Changes: ModuleChanges{}, + IsDirOpen: isDirOpen, + } + } + + switch { + // new module added + case oldMod == nil && newMod != nil: + if len(newMod.Meta.CoreRequirements) > 0 { + cb.Changes.CoreRequirements = true + } + if newMod.Meta.Backend != nil { + cb.Changes.Backend = true + } + if len(newMod.Meta.ProviderRequirements) > 0 { + cb.Changes.ProviderRequirements = true + } + if newMod.TerraformVersion != nil { + cb.Changes.TerraformVersion = true + } + if len(newMod.InstalledProviders) > 0 { + cb.Changes.InstalledProviders = true + } + // module removed + case oldMod != nil && newMod == nil: + cb.Changes.IsRemoval = true + + if len(oldMod.Meta.CoreRequirements) > 0 { + cb.Changes.CoreRequirements = true + } + if oldMod.Meta.Backend != nil { + cb.Changes.Backend = true + } + if len(oldMod.Meta.ProviderRequirements) > 0 { + cb.Changes.ProviderRequirements = true + } + if oldMod.TerraformVersion != nil { + cb.Changes.TerraformVersion = true + } + if len(oldMod.InstalledProviders) > 0 { + cb.Changes.InstalledProviders = true + } + // module changed + default: + if !oldMod.Meta.CoreRequirements.Equals(newMod.Meta.CoreRequirements) { + cb.Changes.CoreRequirements = true + } + if !oldMod.Meta.Backend.Equals(newMod.Meta.Backend) { + cb.Changes.Backend = true + } + if !oldMod.Meta.ProviderRequirements.Equals(newMod.Meta.ProviderRequirements) { + cb.Changes.ProviderRequirements = true + } + if !oldMod.TerraformVersion.Equal(newMod.TerraformVersion) { + cb.Changes.TerraformVersion = true + } + if !oldMod.InstalledProviders.Equals(newMod.InstalledProviders) { + cb.Changes.InstalledProviders = true + } + } + + oldDiags, newDiags := 0, 0 + if oldMod != nil { + oldDiags = oldMod.ModuleDiagnostics.Count() + oldMod.VarsDiagnostics.Count() + } + if newMod != nil { + newDiags = newMod.ModuleDiagnostics.Count() + newMod.VarsDiagnostics.Count() + } + // Comparing diagnostics accurately could be expensive + // so we just treat any non-empty diags as a change + if oldDiags > 0 || newDiags > 0 { + cb.Changes.Diagnostics = true + } + + oldOrigins, oldTargets := 0, 0 + if oldMod != nil { + oldOrigins = len(oldMod.RefOrigins) + oldTargets = len(oldMod.RefTargets) + } + newOrigins, newTargets := 0, 0 + if newMod != nil { + newOrigins = len(newMod.RefOrigins) + newTargets = len(newMod.RefTargets) + } + if oldOrigins != newOrigins { + cb.Changes.ReferenceOrigins = true + } + if oldTargets != newTargets { + cb.Changes.ReferenceTargets = true + } + + // update change batch + _, err = txn.DeleteAll(moduleChangesTableName, "id", modHandle) + if err != nil { + return err + } + return txn.Insert(moduleChangesTableName, cb) +} + +func updateModuleChangeDirOpenMark(txn *memdb.Txn, dirHandle document.DirHandle, isDirOpen bool) error { + it, err := txn.Get(moduleChangesTableName, "id", dirHandle) + if err != nil { + return fmt.Errorf("failed to find module changes for %q: %w", dirHandle, err) + } + + for obj := it.Next(); obj != nil; obj = it.Next() { + batch := obj.(ModuleChangeBatch) + mcb := batch.Copy() + + _, err = txn.DeleteAll(moduleChangesTableName, "id", batch.DirHandle) + if err != nil { + return err + } + + mcb.IsDirOpen = isDirOpen + + err = txn.Insert(moduleChangesTableName, mcb) + if err != nil { + return err + } + } + + return nil +} + +func (ms *ModuleStore) AwaitNextChangeBatch(ctx context.Context) (ModuleChangeBatch, error) { + rTxn := ms.db.Txn(false) + wCh, obj, err := rTxn.FirstWatch(moduleChangesTableName, "time") + if err != nil { + return ModuleChangeBatch{}, err + } + + if obj == nil { + select { + case <-wCh: + case <-ctx.Done(): + return ModuleChangeBatch{}, ctx.Err() + } + + return ms.AwaitNextChangeBatch(ctx) + } + + batch := obj.(ModuleChangeBatch) + + timeout := batch.FirstChangeTime.Add(maxTimespan) + if time.Now().After(timeout) { + err := ms.deleteChangeBatch(batch) + if err != nil { + return ModuleChangeBatch{}, err + } + return batch, nil + } + + wCh, jobsExist, err := jobsExistForDirHandle(rTxn, batch.DirHandle) + if err != nil { + return ModuleChangeBatch{}, err + } + if !jobsExist { + err := ms.deleteChangeBatch(batch) + if err != nil { + return ModuleChangeBatch{}, err + } + return batch, nil + } + + select { + // wait for another job to get processed + case <-wCh: + // or for the remaining time to pass + case <-time.After(timeout.Sub(time.Now())): + // or context cancellation + case <-ctx.Done(): + return ModuleChangeBatch{}, ctx.Err() + } + + return ms.AwaitNextChangeBatch(ctx) +} + +func (ms *ModuleStore) deleteChangeBatch(batch ModuleChangeBatch) error { + txn := ms.db.Txn(true) + defer txn.Abort() + err := txn.Delete(moduleChangesTableName, batch) + if err != nil { + return err + } + txn.Commit() + return nil +} diff --git a/internal/state/module_changes_test.go b/internal/state/module_changes_test.go new file mode 100644 index 000000000..46500d5dd --- /dev/null +++ b/internal/state/module_changes_test.go @@ -0,0 +1,223 @@ +package state + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-version" + "github.com/hashicorp/terraform-ls/internal/document" + "github.com/hashicorp/terraform-ls/internal/job" + tfaddr "github.com/hashicorp/terraform-registry-address" +) + +func TestModuleChanges_dirOpenMark_openBeforeChange(t *testing.T) { + ss, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + modPath := t.TempDir() + modHandle := document.DirHandleFromPath(modPath) + docHandle := document.Handle{ + Dir: modHandle, + Filename: "main.tf", + } + err = ss.DocumentStore.OpenDocument(docHandle, "terraform", 0, []byte{}) + if err != nil { + t.Fatal(err) + } + + err = ss.Modules.Add(modPath) + if err != nil { + t.Fatal(err) + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + batch, err := ss.Modules.AwaitNextChangeBatch(ctx) + if err != nil { + t.Fatal(err) + } + + if !batch.IsDirOpen { + t.Fatalf("expected dir to be open for change batch, given: %#v", batch) + } +} + +func TestModuleChanges_dirOpenMark_openAfterChange(t *testing.T) { + ss, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + modPath := t.TempDir() + + err = ss.Modules.Add(modPath) + if err != nil { + t.Fatal(err) + } + + modHandle := document.DirHandleFromPath(modPath) + docHandle := document.Handle{ + Dir: modHandle, + Filename: "main.tf", + } + err = ss.DocumentStore.OpenDocument(docHandle, "terraform", 0, []byte{}) + if err != nil { + t.Fatal(err) + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + batch, err := ss.Modules.AwaitNextChangeBatch(ctx) + if err != nil { + t.Fatal(err) + } + + if !batch.IsDirOpen { + t.Fatalf("expected dir to be open for change batch, given: %#v", batch) + } +} + +func TestModuleChanges_AwaitNextChangeBatch_maxTimespan(t *testing.T) { + ss, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + modPath := t.TempDir() + modHandle := document.DirHandleFromPath(modPath) + + _, err = ss.JobStore.EnqueueJob(job.Job{ + Func: func(ctx context.Context) error { + return nil + }, + Dir: modHandle, + Type: "test", + }) + if err != nil { + t.Fatal(err) + } + + err = ss.Modules.Add(modPath) + if err != nil { + t.Fatal(err) + } + + // confirm the method gets cancelled with pending job + // and less than maximum timespan to wait + ctx, cancelFunc := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelFunc() + + _, err = ss.Modules.AwaitNextChangeBatch(ctx) + if err == nil { + t.Fatal("expected timeout") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded error, given: %#v", err) + } + +} + +func TestModuleChanges_AwaitNextChangeBatch_multipleChanges(t *testing.T) { + ss, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + ss.Modules.TimeProvider = testTimeProvider + + modPath := t.TempDir() + + err = ss.Modules.Add(modPath) + if err != nil { + t.Fatal(err) + } + + err = ss.Modules.UpdateTerraformVersion(modPath, testVersion(t, "1.0.0"), map[tfaddr.Provider]*version.Version{}, nil) + if err != nil { + t.Fatal(err) + } + + ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second) + defer cancelFunc() + batch, err := ss.Modules.AwaitNextChangeBatch(ctx) + if err != nil { + t.Fatal(err) + } + expectedBatch := ModuleChangeBatch{ + DirHandle: document.DirHandleFromPath(modPath), + FirstChangeTime: testTimeProvider(), + IsDirOpen: false, + Changes: ModuleChanges{ + TerraformVersion: true, + }, + } + if diff := cmp.Diff(expectedBatch, batch); diff != "" { + t.Fatalf("unexpected change batch: %s", diff) + } + + // verify that no more batches are available + ctx, cancelFunc = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelFunc() + _, err = ss.Modules.AwaitNextChangeBatch(ctx) + if err == nil { + t.Fatal("expected error on next batch read") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded error, given: %#v", err) + } +} + +func TestModuleChanges_AwaitNextChangeBatch_removal(t *testing.T) { + ss, err := NewStateStore() + if err != nil { + t.Fatal(err) + } + + ss.Modules.TimeProvider = testTimeProvider + + modPath := t.TempDir() + + err = ss.Modules.Add(modPath) + if err != nil { + t.Fatal(err) + } + err = ss.Modules.Remove(modPath) + if err != nil { + t.Fatal(err) + } + + ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second) + defer cancelFunc() + batch, err := ss.Modules.AwaitNextChangeBatch(ctx) + if err != nil { + t.Fatal(err) + } + expectedBatch := ModuleChangeBatch{ + DirHandle: document.DirHandleFromPath(modPath), + FirstChangeTime: testTimeProvider(), + IsDirOpen: false, + Changes: ModuleChanges{ + IsRemoval: true, + }, + } + if diff := cmp.Diff(expectedBatch, batch); diff != "" { + t.Fatalf("unexpected change batch: %s", diff) + } + + // verify that no more batches are available + ctx, cancelFunc = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelFunc() + _, err = ss.Modules.AwaitNextChangeBatch(ctx) + if err == nil { + t.Fatal("expected error on next batch read") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context deadline exceeded error, given: %#v", err) + } +} diff --git a/internal/state/state.go b/internal/state/state.go index 0626ef150..a31960220 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -18,6 +18,7 @@ const ( jobsTableName = "jobs" moduleTableName = "module" moduleIdsTableName = "module_ids" + moduleChangesTableName = "module_changes" providerSchemaTableName = "provider_schema" providerIdsTableName = "provider_ids" walkerPathsTableName = "walker_paths" @@ -146,6 +147,20 @@ var dbSchema = &memdb.DBSchema{ }, }, }, + moduleChangesTableName: { + Name: moduleChangesTableName, + Indexes: map[string]*memdb.IndexSchema{ + "id": { + Name: "id", + Unique: true, + Indexer: &DirHandleFieldIndexer{Field: "DirHandle"}, + }, + "time": { + Name: "time", + Indexer: &TimeFieldIndex{Field: "FirstChangeTime"}, + }, + }, + }, walkerPathsTableName: { Name: walkerPathsTableName, Indexes: map[string]*memdb.IndexSchema{ @@ -179,10 +194,17 @@ type StateStore struct { } type ModuleStore struct { - db *memdb.MemDB - ChangeHooks ModuleChangeHooks - tableName string - logger *log.Logger + db *memdb.MemDB + Changes *ModuleChangeStore + tableName string + logger *log.Logger + + // TimeProvider provides current time (for mocking time.Now in tests) + TimeProvider func() time.Time +} + +type ModuleChangeStore struct { + db *memdb.MemDB } type ModuleReader interface { @@ -228,10 +250,10 @@ func NewStateStore() (*StateStore, error) { nextJobClosedDirMu: &sync.Mutex{}, }, Modules: &ModuleStore{ - db: db, - ChangeHooks: make(ModuleChangeHooks, 0), - tableName: moduleTableName, - logger: defaultLogger, + db: db, + tableName: moduleTableName, + logger: defaultLogger, + TimeProvider: time.Now, }, ProviderSchemas: &ProviderSchemaStore{ db: db, diff --git a/internal/state/time_field_index.go b/internal/state/time_field_index.go new file mode 100644 index 000000000..8540131b2 --- /dev/null +++ b/internal/state/time_field_index.go @@ -0,0 +1,102 @@ +package state + +import ( + "encoding/binary" + "fmt" + "reflect" + "time" +) + +// See https://github.com/hashicorp/go-memdb/pull/117 + +// TimeFieldIndex is used to extract a time.Time field from an object using +// reflection and builds an index on that field. +type TimeFieldIndex struct { + Field string +} + +func (u *TimeFieldIndex) FromObject(obj interface{}) (bool, []byte, error) { + v := reflect.ValueOf(obj) + v = reflect.Indirect(v) // Dereference the pointer if any + + fv := v.FieldByName(u.Field) + if !fv.IsValid() { + return false, nil, + fmt.Errorf("field '%s' for %#v is invalid", u.Field, obj) + } + + // Check the type + k := fv.Kind() + + if ok := IsTimeType(k); !ok { + return false, nil, fmt.Errorf("field %q is of type %v; want a time."+ + "Time", u.Field, k) + } + + // Get the value and encode it + val := fv.Interface().(time.Time) + bufUnix := encodeInt(val.Unix(), 8) + bufNano := encodeInt(int64(val.Nanosecond()), 4) + buf := append(bufUnix, bufNano...) + + return true, buf, nil +} + +func (u *TimeFieldIndex) FromArgs(args ...interface{}) ([]byte, error) { + if len(args) != 1 { + return nil, fmt.Errorf("must provide only a single argument") + } + + v := reflect.ValueOf(args[0]) + if !v.IsValid() { + return nil, fmt.Errorf("%#v is invalid", args[0]) + } + + k := v.Kind() + + if ok := IsTimeType(k); !ok { + return nil, fmt.Errorf("arg is of type %v; want a time.Time", k) + } + + val := v.Interface().(time.Time) + bufUnix := encodeInt(val.Unix(), 8) + bufNano := encodeInt(int64(val.Nanosecond()), 4) + buf := append(bufUnix, bufNano...) + + return buf, nil +} + +func encodeInt(val int64, size int) []byte { + buf := make([]byte, size) + + // This bit flips the sign bit on any sized signed twos-complement integer, + // which when truncated to a uint of the same size will bias the value such + // that the maximum negative int becomes 0, and the maximum positive int + // becomes the maximum positive uint. + scaled := val ^ int64(-1<<(size*8-1)) + + switch size { + case 1: + buf[0] = uint8(scaled) + case 2: + binary.BigEndian.PutUint16(buf, uint16(scaled)) + case 4: + binary.BigEndian.PutUint32(buf, uint32(scaled)) + case 8: + binary.BigEndian.PutUint64(buf, uint64(scaled)) + default: + panic(fmt.Sprintf("unsupported int size parameter: %d", size)) + } + + return buf +} + +// IsTimeType returns whether the passed type is a type of time.Time. +func IsTimeType(k reflect.Kind) (okay bool) { + switch k { + case reflect.TypeOf(time.Time{}).Kind(): + return true + default: + return false + } +}