diff --git a/acceptance/acceptance_test.go b/acceptance/acceptance_test.go index b4b27f201c..91ad09e9e1 100644 --- a/acceptance/acceptance_test.go +++ b/acceptance/acceptance_test.go @@ -20,6 +20,7 @@ import ( "github.com/databricks/cli/internal/testutil" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/testdiff" + "github.com/databricks/cli/libs/testserver" "github.com/databricks/databricks-sdk-go" "github.com/stretchr/testify/require" ) @@ -107,7 +108,7 @@ func testAccept(t *testing.T, InprocessMode bool, singleTest string) int { cloudEnv := os.Getenv("CLOUD_ENV") if cloudEnv == "" { - server := StartServer(t) + server := testserver.New(t) AddHandlers(server) // Redirect API access to local server: t.Setenv("DATABRICKS_HOST", server.URL) diff --git a/acceptance/cmd_server_test.go b/acceptance/cmd_server_test.go index 28feec1bd5..3f5a6356eb 100644 --- a/acceptance/cmd_server_test.go +++ b/acceptance/cmd_server_test.go @@ -8,10 +8,11 @@ import ( "testing" "github.com/databricks/cli/internal/testcli" + "github.com/databricks/cli/libs/testserver" "github.com/stretchr/testify/require" ) -func StartCmdServer(t *testing.T) *TestServer { +func StartCmdServer(t *testing.T) *testserver.Server { server := StartServer(t) server.Handle("/", func(r *http.Request) (any, error) { q := r.URL.Query() diff --git a/acceptance/server_test.go b/acceptance/server_test.go index dbc55c03fd..66de5dcbfa 100644 --- a/acceptance/server_test.go +++ b/acceptance/server_test.go @@ -1,73 +1,25 @@ package acceptance_test import ( - "encoding/json" "net/http" - "net/http/httptest" "testing" + "github.com/databricks/cli/libs/testserver" "github.com/databricks/databricks-sdk-go/service/catalog" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/iam" "github.com/databricks/databricks-sdk-go/service/workspace" ) -type TestServer struct { - *httptest.Server - Mux *http.ServeMux -} - -type HandlerFunc func(r *http.Request) (any, error) - -func NewTestServer() *TestServer { - mux := http.NewServeMux() - server := httptest.NewServer(mux) - - return &TestServer{ - Server: server, - Mux: mux, - } -} - -func (s *TestServer) Handle(pattern string, handler HandlerFunc) { - s.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { - resp, err := handler(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - - var respBytes []byte - - respString, ok := resp.(string) - if ok { - respBytes = []byte(respString) - } else { - respBytes, err = json.MarshalIndent(resp, "", " ") - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - } - - if _, err := w.Write(respBytes); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - }) -} - -func StartServer(t *testing.T) *TestServer { - server := NewTestServer() +func StartServer(t *testing.T) *testserver.Server { + server := testserver.New(t) t.Cleanup(func() { server.Close() }) return server } -func AddHandlers(server *TestServer) { +func AddHandlers(server *testserver.Server) { server.Handle("GET /api/2.0/policies/clusters/list", func(r *http.Request) (any, error) { return compute.ListPoliciesResponse{ Policies: []compute.Policy{ diff --git a/libs/testserver/server.go b/libs/testserver/server.go new file mode 100644 index 0000000000..10269af8fe --- /dev/null +++ b/libs/testserver/server.go @@ -0,0 +1,63 @@ +package testserver + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + + "github.com/databricks/cli/internal/testutil" +) + +type Server struct { + *httptest.Server + Mux *http.ServeMux + + t testutil.TestingT +} + +func New(t testutil.TestingT) *Server { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + + return &Server{ + Server: server, + Mux: mux, + t: t, + } +} + +type HandlerFunc func(req *http.Request) (resp any, err error) + +func (s *Server) Close() { + s.Server.Close() +} + +func (s *Server) Handle(pattern string, handler HandlerFunc) { + s.Mux.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { + resp, err := handler(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + + var respBytes []byte + + respString, ok := resp.(string) + if ok { + respBytes = []byte(respString) + } else { + respBytes, err = json.MarshalIndent(resp, "", " ") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + if _, err := w.Write(respBytes); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) +}