Skip to content

Commit

Permalink
Merge pull request #3310 from tamird/pgwire-logictest
Browse files Browse the repository at this point in the history
sql: use embedded certs
  • Loading branch information
tamird committed Dec 4, 2015
2 parents 258e4f9 + 971dac6 commit 80e77b1
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 68 deletions.
19 changes: 5 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ GO ?= go
GOFLAGS :=
# Set to 1 to use static linking for all builds (including tests).
STATIC :=
# The cockroach image to be used for starting Docker containers
# during acceptance tests. Usually cockroachdb/cockroach{,-dev}
# depending on the context.
COCKROACH_IMAGE :=

RUN := run

# Variables to be overridden on the command line, e.g.
# make test PKG=./storage TESTFLAGS=--vmodule=multiraft=1
Expand All @@ -43,20 +37,17 @@ DUPLFLAGS := -t 100

ifeq ($(STATIC),1)
# The netgo build tag instructs the net package to try to build a
# Go-only resolver.
# Go-only resolver. As of Go 1.5, netgo is the default...but apparently
# not when using cgo (???).
TAGS += netgo
# The installsuffix makes sure we actually get the netgo build, see
# https://github.com/golang/go/issues/9369#issuecomment-69864440
GOFLAGS += -installsuffix netgo
LDFLAGS += -extldflags "-static"
LDFLAGS += -extldflags '-static'
endif

.PHONY: all
all: build test check

# On a release build, rebuild everything (except stdlib)
# to make sure that the 'release' build tag is taken
# into account.
# On a release build, rebuild everything to make sure that the
# 'release' build tag is taken into account.
.PHONY: release
release: TAGS += release
release: GOFLAGS += -a
Expand Down
27 changes: 12 additions & 15 deletions sql/logic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,25 @@ func (t *logicTest) setUser(tempDir, user string) {
if err != nil {
t.Fatal(err)
}
dir, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
certDir := filepath.Join(filepath.Dir(dir), "resource", security.EmbeddedCertsDir)

certPath := security.ClientCertPath(certDir, user)
keyPath := security.ClientKeyPath(certDir, user)
caPath := filepath.Join(security.EmbeddedCertsDir, "ca.crt")
certPath := security.ClientCertPath(security.EmbeddedCertsDir, user)
keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, user)

// `github.com/lib/pq` requires that private key file permissions are
// "u=rw (0600) or less".
tmpKeyPath := tempRestrictedCopy(t.T, keyPath, tempDir)
// Copy these assets to disk from embedded strings, so this test can
// run from a standalone binary.
tempCAPath, _ := tempRestrictedCopy(t.T, tempDir, caPath, "TestLogic_ca")
tempCertPath, _ := tempRestrictedCopy(t.T, tempDir, certPath, "TestLogic_cert")
tempKeyPath, _ := tempRestrictedCopy(t.T, tempDir, keyPath, "TestLogic_key")

pgUrl := url.URL{
Scheme: "postgres",
User: url.User(user),
Host: net.JoinHostPort(host, port),
RawQuery: fmt.Sprintf("sslmode=verify-full&sslrootcert=%s&sslcert=%s&sslkey=%s",
url.QueryEscape(filepath.Join(certDir, "ca.crt")),
url.QueryEscape(certPath),
url.QueryEscape(tmpKeyPath),
url.QueryEscape(tempCAPath),
url.QueryEscape(tempCertPath),
url.QueryEscape(tempKeyPath),
),
}

Expand Down Expand Up @@ -254,8 +252,7 @@ func (t *logicTest) run(path string) {
// MySQL or Postgres instance.
t.srv = setupTestServer(t.T)

// `github.com/lib/pq` requires that private key file permissions are
// "u=rw (0600) or less".
// Make a temporary directory to hold all our on-disk crypto assets.
tempDir, err := ioutil.TempDir(os.TempDir(), "TestLogic")
if err != nil {
t.Fatal(err)
Expand Down
32 changes: 24 additions & 8 deletions sql/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"database/sql"
"fmt"
"io/ioutil"
"path/filepath"
"os"
"testing"

"github.com/cockroachdb/cockroach/client"
Expand Down Expand Up @@ -123,16 +123,32 @@ func cleanup(s *server.TestServer, db *sql.DB) {
cleanupTestServer(s)
}

// `github.com/lib/pq` requires that private key file permissions are
// "u=rw (0600) or less".
func tempRestrictedCopy(t *testing.T, keyPath, tempDir string) string {
key, err := ioutil.ReadFile(keyPath)
func tempRestrictedCopy(t *testing.T, tempdir, path, prefix string) (string, func()) {
contents, err := securitytest.Asset(path)
if err != nil {
t.Fatal(err)
}
tmpKeyPath := filepath.Join(tempDir, "tempRestrictedCopy")
if err := ioutil.WriteFile(tmpKeyPath, key, 0600); err != nil {

tempFile, err := ioutil.TempFile(tempdir, prefix)
if err != nil {
t.Fatal(err)
}
if err := tempFile.Close(); err != nil {
t.Fatal(err)
}
tempPath := tempFile.Name()
if err := os.Remove(tempPath); err != nil {
t.Fatal(err)
}
// `github.com/lib/pq` requires that private key file permissions are
// "u=rw (0600) or less".
if err := ioutil.WriteFile(tempPath, contents, 0600); err != nil {
t.Fatal(err)
}
return tmpKeyPath
return tempPath, func() {
if err := os.Remove(tempPath); err != nil {
// Not Fatal() because we might already be panicking.
t.Error(err)
}
}
}
47 changes: 16 additions & 31 deletions sql/pgwire_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ package sql_test
import (
"database/sql"
"fmt"
"io/ioutil"
"net"
"net/url"
"os"
"path/filepath"
"testing"

"github.com/lib/pq"
Expand All @@ -36,8 +34,8 @@ import (
"github.com/cockroachdb/cockroach/util/leaktest"
)

func trivialQuery(datasource string) error {
db, err := sql.Open("postgres", datasource)
func trivialQuery(pgUrl url.URL) error {
db, err := sql.Open("postgres", pgUrl.String())
if err != nil {
return err
}
Expand All @@ -51,29 +49,16 @@ func trivialQuery(datasource string) error {
func TestPGWire(t *testing.T) {
defer leaktest.AfterTest(t)

dir, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
certDir := filepath.Join(filepath.Dir(dir), "resource", security.EmbeddedCertsDir)

certUser := server.TestUser
certPath := security.ClientCertPath(certDir, certUser)
keyPath := security.ClientKeyPath(certDir, certUser)
certPath := security.ClientCertPath(security.EmbeddedCertsDir, certUser)
keyPath := security.ClientKeyPath(security.EmbeddedCertsDir, certUser)

// `github.com/lib/pq` requires that private key file permissions are
// "u=rw (0600) or less".
tempDir, err := ioutil.TempDir(os.TempDir(), "TestPGWire")
if err != nil {
t.Fatal(err)
}
defer func() {
if err := os.RemoveAll(tempDir); err != nil {
// Not Fatal() because we might already be panicking.
t.Error(err)
}
}()
tmpKeyPath := tempRestrictedCopy(t, keyPath, tempDir)
// Copy these assets to disk from embedded strings, so this test can
// run from a standalone binary.
tempCertPath, tempCertCleanup := tempRestrictedCopy(t, os.TempDir(), certPath, "TestPGWire_cert")
defer tempCertCleanup()
tempKeyPath, tempKeyCleanup := tempRestrictedCopy(t, os.TempDir(), keyPath, "TestPGWire_key")
defer tempKeyCleanup()

for _, insecure := range [...]bool{true, false} {
ctx := server.NewTestContext()
Expand All @@ -89,7 +74,7 @@ func TestPGWire(t *testing.T) {
Scheme: "postgres",
Host: net.JoinHostPort(host, port),
}
if err := trivialQuery(basePgUrl.String()); err != nil {
if err := trivialQuery(basePgUrl); err != nil {
if insecure {
if err != pq.ErrSSLNotSupported {
t.Fatal(err)
Expand All @@ -104,7 +89,7 @@ func TestPGWire(t *testing.T) {
{
disablePgUrl := basePgUrl
disablePgUrl.RawQuery = "sslmode=disable"
err := trivialQuery(disablePgUrl.String())
err := trivialQuery(disablePgUrl)
if insecure {
if err != nil {
t.Fatal(err)
Expand All @@ -119,7 +104,7 @@ func TestPGWire(t *testing.T) {
{
requirePgUrlNoCert := basePgUrl
requirePgUrlNoCert.RawQuery = "sslmode=require"
err := trivialQuery(requirePgUrlNoCert.String())
err := trivialQuery(requirePgUrlNoCert)
if insecure {
if err != pq.ErrSSLNotSupported {
t.Fatal(err)
Expand All @@ -136,10 +121,10 @@ func TestPGWire(t *testing.T) {
requirePgUrlWithCert := basePgUrl
requirePgUrlWithCert.User = url.User(optUser)
requirePgUrlWithCert.RawQuery = fmt.Sprintf("sslmode=require&sslcert=%s&sslkey=%s",
url.QueryEscape(certPath),
url.QueryEscape(tmpKeyPath),
url.QueryEscape(tempCertPath),
url.QueryEscape(tempKeyPath),
)
err := trivialQuery(requirePgUrlWithCert.String())
err := trivialQuery(requirePgUrlWithCert)
if insecure {
if err != pq.ErrSSLNotSupported {
t.Fatal(err)
Expand Down

0 comments on commit 80e77b1

Please sign in to comment.