diff --git a/conn.go b/conn.go index 964551cd..1fb4474c 100644 --- a/conn.go +++ b/conn.go @@ -378,6 +378,14 @@ func (c *Conn) ConnectionState() (rv ConnectionState) { return } +func (c *Conn) AlpnSelected() string { + var buf *C.uchar + var bufLen C.uint + C.SSL_get0_alpn_selected(c.ssl, &buf, &bufLen) + protoBytes := C.GoBytes(unsafe.Pointer(buf), C.int(bufLen)) + return string(protoBytes) +} + func (c *Conn) shutdown() func() error { c.mtx.Lock() defer c.mtx.Unlock() diff --git a/ctx.go b/ctx.go index 33befc40..b75a1c00 100644 --- a/ctx.go +++ b/ctx.go @@ -18,9 +18,11 @@ package openssl import "C" import ( + "bytes" "errors" "fmt" "io/ioutil" + "math" "os" "runtime" "sync" @@ -37,12 +39,13 @@ var ( ) type Ctx struct { - ctx *C.SSL_CTX - cert *Certificate - chain []*Certificate - key PrivateKey - verify_cb VerifyCallback - sni_cb TLSExtServernameCallback + ctx *C.SSL_CTX + cert *Certificate + chain []*Certificate + key PrivateKey + verify_cb VerifyCallback + sni_cb TLSExtServernameCallback + alpnProtos []string ticket_store_mu sync.Mutex ticket_store *TicketStore @@ -65,6 +68,7 @@ func newCtx(method *C.SSL_METHOD) (*Ctx, error) { runtime.SetFinalizer(c, func(c *Ctx) { C.SSL_CTX_free(c.ctx) }) + C.X_SSL_CTX_set_ecdh_auto(ctx, 1) return c, nil } @@ -494,6 +498,33 @@ func (c *Ctx) SetTLSExtServernameCallback(sni_cb TLSExtServernameCallback) { C.X_SSL_CTX_set_tlsext_servername_callback(c.ctx, (*[0]byte)(C.sni_cb)) } +//export go_alpn_cb +func go_alpn_cb(p unsafe.Pointer, ssl *C.SSL, out **C.uchar, outLen *C.uchar, in *C.uchar, inLen C.int, arg unsafe.Pointer) SSLTLSExtErr { + ctx := (*Ctx)(p) + //see details https://www.openssl.org/docs/man1.0.2/ssl/SSL_CTX_set_alpn_select_cb.html + inBytes := C.GoBytes(unsafe.Pointer(in), inLen) + for _, proto := range ctx.alpnProtos { + if len(proto) > math.MaxUint8 { + continue //couln't match because prefix len + } + protoLen := byte(len(proto)) + inProto := append([]byte{protoLen}, []byte(proto)...) + ind := bytes.Index(inBytes, inProto) + if ind < 0 { + continue + } + *out = (*C.uchar)(unsafe.Pointer(uintptr(unsafe.Pointer(in)) + uintptr(ind+1))) + *outLen = C.uchar(protoLen) + return SSLTLSExtErrOK + } + return SSLTLSEXTErrNoAck +} + +func (c *Ctx) SetAlpn(protos []string) { + c.alpnProtos = protos + C.SSL_CTX_set_alpn_select_cb(c.ctx, (*[0]byte)(C.alpn_cb), nil) +} + func (c *Ctx) SetSessionId(session_id []byte) error { runtime.LockOSThread() defer runtime.UnlockOSThread() diff --git a/shim.c b/shim.c index 6e680841..fffc1af9 100644 --- a/shim.c +++ b/shim.c @@ -104,6 +104,10 @@ int X_EVP_DigestVerify(EVP_MD_CTX *ctx, const unsigned char *sigret, */ #if OPENSSL_VERSION_NUMBER >= 0x1010000fL +int X_SSL_CTX_set_ecdh_auto(SSL_CTX *ctx, int onoff) { + return 1; +} + void X_BIO_set_data(BIO* bio, void* data) { BIO_set_data(bio, data); } @@ -229,6 +233,10 @@ int X_PEM_write_bio_PrivateKey_traditional(BIO *bio, EVP_PKEY *key, const EVP_CI */ #if OPENSSL_VERSION_NUMBER < 0x1010000fL +int X_SSL_CTX_set_ecdh_auto(SSL_CTX *ctx, int onoff) { + return SSL_CTX_set_ecdh_auto(ctx, onoff); +} + static int x_bio_create(BIO *b) { b->shutdown = 1; b->init = 1; @@ -439,6 +447,12 @@ int X_SSL_verify_cb(int ok, X509_STORE_CTX* store) { return go_ssl_verify_cb_thunk(p, ok, store); } +int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg) { + SSL_CTX* ssl_ctx = SSL_get_SSL_CTX(ssl); + void* p = SSL_CTX_get_ex_data(ssl_ctx, get_ssl_ctx_idx()); + return go_alpn_cb(p, ssl, (unsigned char **)out, (unsigned char *)outlen, (unsigned char *)in, inlen, arg); +} + const SSL_METHOD *X_SSLv23_method() { return SSLv23_method(); } diff --git a/shim.h b/shim.h index b792822b..0572a9e6 100644 --- a/shim.h +++ b/shim.h @@ -63,6 +63,7 @@ extern const SSL_METHOD *X_TLSv1_2_method(); extern int sni_cb(SSL *ssl_conn, int *ad, void *arg); #endif extern int X_SSL_verify_cb(int ok, X509_STORE_CTX* store); +extern int alpn_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg); /* SSL_CTX methods */ extern int X_SSL_CTX_new_index(); @@ -89,6 +90,7 @@ extern int X_SSL_CTX_set_tlsext_ticket_key_cb(SSL_CTX *sslctx, extern int X_SSL_CTX_ticket_key_cb(SSL *s, unsigned char key_name[16], unsigned char iv[EVP_MAX_IV_LENGTH], EVP_CIPHER_CTX *cctx, HMAC_CTX *hctx, int enc); +extern int X_SSL_CTX_set_ecdh_auto(SSL_CTX *ctx, int onoff); /* BIO methods */ extern int X_BIO_get_flags(BIO *b);