From 78b1462938706aedf88058aef1a64d5d5012907a Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Wed, 7 Feb 2018 10:50:37 -0800 Subject: [PATCH] r/membership: Fix import crash on incorrect ID An incorrect two-part ID supplied to github_membership during import results in a crash - this is because the import functionality is passing this directly to the Read function of the resource, for which IDs are programatically controlled and a panic there would be more telling of a much more serious error worthy of panicking over. This adds an internal validation function for two-part IDs, and wraps the resource's import functionality to use it to validate incoming IDs before passing them to Read. --- github/resource_github_membership.go | 13 ++++++++- github/util.go | 14 ++++++++++ github/util_test.go | 41 ++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) diff --git a/github/resource_github_membership.go b/github/resource_github_membership.go index 50bc2f164c..e28d833553 100644 --- a/github/resource_github_membership.go +++ b/github/resource_github_membership.go @@ -15,7 +15,7 @@ func resourceGithubMembership() *schema.Resource { Update: resourceGithubMembershipUpdate, Delete: resourceGithubMembershipDelete, Importer: &schema.ResourceImporter{ - State: schema.ImportStatePassthrough, + State: resourceGithubMembershipImport, }, Schema: map[string]*schema.Schema{ @@ -89,3 +89,14 @@ func resourceGithubMembershipDelete(d *schema.ResourceData, meta interface{}) er return err } + +func resourceGithubMembershipImport(d *schema.ResourceData, meta interface{}) ([]*schema.ResourceData, error) { + // All we do here is validate that the import string is in a correct enough + // format to be parsed. parseTwoPartID will panic if it's missing elements, + // and is used otherwise in places where that should never happen, so we want + // to keep it that way. + if err := validateTwoPartID(d.Id()); err != nil { + return nil, err + } + return []*schema.ResourceData{d}, nil +} diff --git a/github/util.go b/github/util.go index d8f07df5a2..a606cc90fa 100644 --- a/github/util.go +++ b/github/util.go @@ -1,6 +1,7 @@ package github import ( + "errors" "fmt" "strconv" "strings" @@ -46,6 +47,19 @@ func parseTwoPartID(id string) (string, string) { return parts[0], parts[1] } +// validateTwoPartID performs a quick validation of a two-part ID, designed for +// use when validation has not been previously possible, such as importing. +func validateTwoPartID(id string) error { + if id == "" { + return errors.New("no ID supplied. Please supply an ID format matching organization:username") + } + parts := strings.Split(id, ":") + if len(parts) != 2 { + return fmt.Errorf("incorrectly formatted ID %q. Please supply an ID format matching organization:username", id) + } + return nil +} + // format the strings into an id `a:b` func buildTwoPartID(a, b *string) string { return fmt.Sprintf("%s:%s", *a, *b) diff --git a/github/util_test.go b/github/util_test.go index 5d58407ca8..fd4f4a42b2 100644 --- a/github/util_test.go +++ b/github/util_test.go @@ -53,3 +53,44 @@ func TestAccGithubUtilTwoPartID(t *testing.T) { t.Fatalf("Expected parsed part two bar, actual: %s", parsedPartTwo) } } + +func TestAccValidateTwoPartID(t *testing.T) { + cases := []struct { + name string + id string + expectedErr string + }{ + { + name: "valid", + id: "foo:bar", + }, + { + name: "blank ID", + id: "", + expectedErr: "no ID supplied. Please supply an ID format matching organization:username", + }, + { + name: "not enough parts", + id: "foo", + expectedErr: "incorrectly formatted ID \"foo\". Please supply an ID format matching organization:username", + }, + { + name: "too many parts", + id: "foo:bar:baz", + expectedErr: "incorrectly formatted ID \"foo:bar:baz\". Please supply an ID format matching organization:username", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := validateTwoPartID(tc.id) + switch { + case err != nil && tc.expectedErr == "": + t.Fatalf("expected no error, got %q", err) + case err != nil && tc.expectedErr != "": + if err.Error() != tc.expectedErr { + t.Fatalf("expected error to be %q, got %q", tc.expectedErr, err.Error()) + } + } + }) + } +}