Skip to content

Commit

Permalink
fix: prevent crash for legacy provider lookups (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
radeksimko authored Apr 30, 2021
1 parent 1821ce9 commit 9209786
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
1 change: 1 addition & 0 deletions internal/langserver/handlers/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func (svc *service) Assigner() (jrpc2.Assigner, error) {
if err != nil {
return nil, err
}
store.SetLogger(svc.logger)

err = schemas.PreloadSchemasToStore(store.ProviderSchemas)
if err != nil {
Expand Down
8 changes: 6 additions & 2 deletions internal/state/provider_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func updateProviderVersions(txn *memdb.Txn, modPath string, pv map[tfaddr.Provid
}

func (s *ProviderSchemaStore) AddLocalSchema(modPath string, addr tfaddr.Provider, schema *tfschema.ProviderSchema) error {
s.logger.Printf("PSS: adding local schema (%s, %s): %p", modPath, addr, schema)
txn := s.db.Txn(true)
defer txn.Abort()

Expand Down Expand Up @@ -190,6 +191,7 @@ func (s *ProviderSchemaStore) AddPreloadedSchema(addr tfaddr.Provider, pv *versi
}

func (s *ProviderSchemaStore) ProviderSchema(modPath string, addr tfaddr.Provider, vc version.Constraints) (*tfschema.ProviderSchema, error) {
s.logger.Printf("PSS: getting provider schema (%s, %s, %s)", modPath, addr, vc)
txn := s.db.Txn(false)

it, err := txn.Get(s.tableName, "id_prefix", addr)
Expand Down Expand Up @@ -233,7 +235,9 @@ func (s *ProviderSchemaStore) ProviderSchema(modPath string, addr tfaddr.Provide
}
if obj != nil {
ps := obj.(*ProviderSchema)
return ps.Schema, err
if ps.Schema != nil {
return ps.Schema, nil
}
}

// Last we just try to loosely match the provider type
Expand Down Expand Up @@ -264,7 +268,7 @@ func (s *ProviderSchemaStore) ProviderSchema(modPath string, addr tfaddr.Provide

sort.Stable(ss)

return ss.schemas[0].Schema, err
return ss.schemas[0].Schema, nil
}

type ModuleLookupFunc func(string) (*Module, error)
Expand Down
58 changes: 58 additions & 0 deletions internal/state/provider_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,64 @@ func TestStateStore_AddPreloadedSchema_duplicate(t *testing.T) {
}
}

// Test a scenario where Terraform 0.13+ produced schema with non-legacy
// addresses but lookup is still done via legacy address
func TestStateStore_IncompleteSchema_legacyLookup(t *testing.T) {
s, err := NewStateStore()
if err != nil {
t.Fatal(err)
}

modPath := t.TempDir()
err = s.Modules.Add(modPath)
if err != nil {
t.Fatal(err)
}

addr := tfaddr.Provider{
Hostname: tfaddr.DefaultRegistryHost,
Namespace: "hashicorp",
Type: "aws",
}
pv := testVersion(t, "3.2.0")

pvs := map[tfaddr.Provider]*version.Version{
addr: pv,
}

// obtaining versions typically takes less time than schema itself
// so we test that "incomplete" state is handled correctly too

err = s.Modules.UpdateTerraformVersion(modPath, testVersion(t, "0.13.0"), pvs, nil)
if err != nil {
t.Fatal(err)
}

_, err = s.ProviderSchemas.ProviderSchema(modPath, tfaddr.NewLegacyProvider("aws"), testConstraint(t, ">= 1.0"))
if err == nil {
t.Fatal("expected error when requesting incomplete schema")
}
expectedErr := &NoSchemaError{}
if !errors.As(err, &expectedErr) {
t.Fatalf("unexpected error: %#v", err)
}

// next attempt (after schema is actually obtained) should not fail

err = s.ProviderSchemas.AddLocalSchema(modPath, addr, &tfschema.ProviderSchema{})
if err != nil {
t.Fatal(err)
}

ps, err := s.ProviderSchemas.ProviderSchema(modPath, tfaddr.NewLegacyProvider("aws"), testConstraint(t, ">= 1.0"))
if err != nil {
t.Fatal(err)
}
if ps == nil {
t.Fatal("expected provider schema not to be nil")
}
}

func TestStateStore_AddLocalSchema_duplicate(t *testing.T) {
s, err := NewStateStore()
if err != nil {
Expand Down
14 changes: 14 additions & 0 deletions internal/state/state.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package state

import (
"io/ioutil"
"log"

"github.com/hashicorp/go-memdb"
"github.com/hashicorp/go-version"
tfaddr "github.com/hashicorp/terraform-registry-address"
Expand Down Expand Up @@ -52,6 +55,7 @@ type StateStore struct {
type ModuleStore struct {
db *memdb.MemDB
tableName string
logger *log.Logger
}

type ModuleReader interface {
Expand All @@ -63,6 +67,7 @@ type ModuleReader interface {
type ProviderSchemaStore struct {
db *memdb.MemDB
tableName string
logger *log.Logger
}

type SchemaReader interface {
Expand All @@ -79,10 +84,19 @@ func NewStateStore() (*StateStore, error) {
Modules: &ModuleStore{
db: db,
tableName: moduleTableName,
logger: defaultLogger,
},
ProviderSchemas: &ProviderSchemaStore{
db: db,
tableName: providerSchemaTableName,
logger: defaultLogger,
},
}, nil
}

func (s *StateStore) SetLogger(logger *log.Logger) {
s.Modules.logger = logger
s.ProviderSchemas.logger = logger
}

var defaultLogger = log.New(ioutil.Discard, "", 0)

0 comments on commit 9209786

Please sign in to comment.