Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement per-session Kerberos auth settings. Update Kerberos example. #632

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 90 additions & 32 deletions examples/kerberos/main.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,39 @@
package main

import (
"C" // Import cgo to avoid project-wide go fmt failures.
"database/sql"
"errors"
"flag"
"fmt"
"log"
"os"

"C" // Import cgo to avoid project-wide go fmt failures.

"github.com/jcmturner/gokrb5/v8/client"
"github.com/jcmturner/gokrb5/v8/config"
"github.com/jcmturner/gokrb5/v8/credentials"
"github.com/jcmturner/gokrb5/v8/gssapi"
"github.com/jcmturner/gokrb5/v8/spnego"

go_ora "github.com/sijms/go-ora/v2"
"github.com/sijms/go-ora/v2/advanced_nego"
)

type KerberosAuth struct{}
type KerberosAuth struct {
ccache string
}

func (kerb KerberosAuth) Authenticate(server, service string) ([]byte, error) {
conf, err := config.Load("/etc/krb5.conf")
func (kerb *KerberosAuth) Authenticate(server, service string) ([]byte, error) {
krb5conf := os.Getenv("KRB5_CONFIG")
if krb5conf == "" {
krb5conf = "/etc/krb5.conf"
}
conf, err := config.Load(krb5conf)
if err != nil {
return nil, err
}
ccache, err := credentials.LoadCCache("/tmp/krb5cc_1000")
ccache, err := credentials.LoadCCache(kerb.ccache)
if err != nil {
return nil, err
}
Expand All @@ -46,48 +56,96 @@ func (kerb KerberosAuth) Authenticate(server, service string) ([]byte, error) {
func usage() {
fmt.Println()
fmt.Println("kerberos")
fmt.Println(" a program to test kerberos5 authentication.")
fmt.Println(" a program to test kerberos authentication.")
fmt.Println()
fmt.Println("Usage:")
fmt.Println(` kerberos -server server_url`)
fmt.Println("Flags:")
flag.PrintDefaults()
fmt.Println()
fmt.Println("Example:")
fmt.Println(` kerberos -server "oracle://user:pass@server/service_name"`)
fmt.Println(` kerberos -host mydb.example.com`)
fmt.Println()
}

func main() {
var server string
type options struct {
host string
port int
serviceName string
kerberosCacheFile string
useGlobalAuth bool
}

func (o *options) validate() error {
if o.host == "" {
return errors.New("-host option is missing")
}
if o.port <= 0 || o.port > 65535 {
return errors.New("-port option is missing")
}
if o.serviceName == "" {
return errors.New("-service option is missing")
}
return nil
}

func parseOptions() *options {
var opts options

flag.StringVar(&opts.host, "host", "", "Oracle server host. REQUIRED.")
flag.IntVar(&opts.port, "port", 1521, "Oracle server port.")
flag.StringVar(&opts.serviceName, "service", "", "Database service name. REQUIRED.")
flag.StringVar(&opts.kerberosCacheFile, "ccache", "/tmp/krb5cc_1000", "Kerberos ticket cache file.")
flag.BoolVar(&opts.useGlobalAuth, "global_auth", false, "Configure Kerberos authentication via global variable.")

flag.StringVar(&server, "server", "", "Server's URL, oracle://user:pass@server/service_name")
flag.Parse()

connStr := os.ExpandEnv(server)
if connStr == "" {
fmt.Println("Missing -server option")
return &opts
}

func main() {
opts := parseOptions()
err := opts.validate()
if err != nil {
fmt.Printf("Error: %v\n", err)
usage()
os.Exit(1)
}
fmt.Println("Connection string: ", connStr)
advanced_nego.SetKerberosAuth(&KerberosAuth{})
//options := map[string]string{
// "TRACE FILE": "trace.log",
// "AUTH TYPE": "KERBEROS",
//}
conn, err := sql.Open("oracle", connStr)
if err != nil {
log.Fatalln("cannot connect: ", err)
}
defer func() {
err = conn.Close()

urlOpts := map[string]string{"AUTH TYPE": "KERBEROS"}
connString := go_ora.BuildUrl(opts.host, opts.port, opts.serviceName, "", "", urlOpts)

fmt.Printf("Connection string: %s\n", connString)

auth := &KerberosAuth{ccache: opts.kerberosCacheFile}

// open connection in preferred way
var sqlConn *sql.DB
if opts.useGlobalAuth {
advanced_nego.SetKerberosAuth(auth)
sqlConn, err = sql.Open("oracle", connString)
if err != nil {
fmt.Println("Can't close connection: ", err)
log.Fatalf("Cannot connect: %v", err)
}
}()
err = conn.Ping()
} else {
connector := go_ora.NewConnector(connString).(*go_ora.OracleConnector)
connector.WithKerberosAuth(auth)
sqlConn = sql.OpenDB(connector)
}
defer sqlConn.Close()

// verify the connection
err = sqlConn.Ping()
if err != nil {
log.Fatalf("Can't ping connection: %v", err)
} else {
fmt.Println("PING ok.")
}
// report username as seen by the database.
row := sqlConn.QueryRow("SELECT USER FROM DUAL")
var user string
err = row.Scan(&user)
if err != nil {
fmt.Println("Can't ping connection: ", err)
return
fmt.Println("Query read error:", err)
} else {
fmt.Println("Reported user:", user)
}
}
23 changes: 15 additions & 8 deletions v2/advanced_nego/advanced_nego.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ import (

var version = 0xB200200

type KerberosAuthInterface interface {
Authenticate(server, service string) ([]byte, error)
}
// KerberosAuthInterface is an alias for configurations.KerberosAuthInterface, maintained for backwards compatibility.
type KerberosAuthInterface = configurations.KerberosAuthInterface

var kerberosAuth KerberosAuthInterface = nil

Expand Down Expand Up @@ -171,9 +170,17 @@ func (nego *AdvNego) Read() error {
}
}
if authKerberos {
if kerberosAuth == nil {
return errors.New("advanced negotiation error: you need to call SetKerberosAuth with valid interface before use kerberos5 authentication")
// Validate configuration
if kerberosAuth == nil && nego.negoInfo.Kerberos == nil {
return fmt.Errorf("advanced negotiation error: Kerberos authenticator not set; call SetKerberosAuth to set it globally or WithKerberosAuth to set it per session")
}

// Prefer session-specific Kerberos auth object
auth := nego.negoInfo.Kerberos
if auth == nil {
auth = kerberosAuth
}

if authServ, ok := nego.serviceList[1].(*authService); ok {
authServ.writeHeader(4)
nego.comm.writeVersion(authServ.getVersion())
Expand All @@ -184,7 +191,7 @@ func (nego *AdvNego) Read() error {
if err != nil {
return err
}
return nego.kerberosHandshake(authServ)
return nego.kerberosHandshake(auth, authServ)
}
}
if authNTS {
Expand Down Expand Up @@ -273,7 +280,7 @@ func (nego *AdvNego) StartServices() error {
return nil
}

func (nego *AdvNego) kerberosHandshake(authServ *authService) error {
func (nego *AdvNego) kerberosHandshake(kerberos KerberosAuthInterface, authServ *authService) error {
header, err := nego.readHeader()
if err != nil {
return err
Expand Down Expand Up @@ -301,7 +308,7 @@ func (nego *AdvNego) kerberosHandshake(authServ *authService) error {
if len(serverHostName) == 0 {
return errors.New("kerberos negotiation error: Server hostname not received")
}
ticketData, err := kerberosAuth.Authenticate(serverHostName, serviceName)
ticketData, err := kerberos.Authenticate(serverHostName, serviceName)
if err != nil {
return err
}
Expand Down
6 changes: 6 additions & 0 deletions v2/configurations/connect_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@ const (
STREAM LobFetch = 1
)

type KerberosAuthInterface interface {
Authenticate(server, service string) ([]byte, error)
}

type AdvNegoServiceInfo struct {
AuthService []string
EncServiceLevel int
IntServiceLevel int
// Kerberos is an optional session-specific auth, which will be preferred over the global interface if present.
Kerberos KerberosAuthInterface
}
type ConnectionConfig struct {
ClientInfo
Expand Down
9 changes: 9 additions & 0 deletions v2/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ type OracleConnector struct {
connectString string
dialer configurations.DialerContext
tlsConfig *tls.Config
kerberos configurations.KerberosAuthInterface
}

func NewConnector(connString string) driver.Connector {
Expand Down Expand Up @@ -149,6 +150,9 @@ func (connector *OracleConnector) Connect(ctx context.Context) (driver.Conn, err
if conn.connOption.TLSConfig == nil {
conn.connOption.TLSConfig = connector.tlsConfig
}
if conn.connOption.Kerberos == nil {
conn.connOption.Kerberos = connector.kerberos
}
err = conn.OpenWithContext(ctx)
if err != nil {
return nil, err
Expand All @@ -172,6 +176,11 @@ func (connector *OracleConnector) WithTLSConfig(config *tls.Config) {
connector.tlsConfig = config
}

// WithKerberosAuth sets the Kerberos authenticator to be used by this connector. It does not enable the Kerberos; set AUTH TYPE to KERBEROS to do so.
func (connector *OracleConnector) WithKerberosAuth(auth configurations.KerberosAuthInterface) {
connector.kerberos = auth
}

// Open return a new open connection
func (driver *OracleDriver) Open(name string) (driver.Conn, error) {
conn, err := NewConnection(name, driver.connOption)
Expand Down
Loading