Skip to content

Commit

Permalink
add the missing pieces to execute client-based plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
francoismichel committed Apr 22, 2024
1 parent d4b153f commit cde645f
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 96 deletions.
19 changes: 17 additions & 2 deletions auth/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package auth
import (
"net/http"

"github.com/francoismichel/ssh3"
client_config "github.com/francoismichel/ssh3/client/config"
"github.com/quic-go/quic-go/http3"
"golang.org/x/crypto/ssh/agent"
)

/*
Expand All @@ -31,7 +33,20 @@ type ServerAuthPlugin func(username string, identityStr string) (RequestIdentity

// Updates `request` with the correct authentication material so that an SSH3 conversation
// can be established by performing the request
type ClientAuthPluginFunc func(request *http.Request, clientOpts *client_config.Config, roundTripper *http3.RoundTripper) error
type GetAuthMethodsFunc func(request *http.Request, clientConfig *client_config.Config, roundTripper *http3.RoundTripper) ([]ClientAuthMethod, error)

type ClientAuthMethod interface {
// PrepareRequestForAuth updated the provided request with the needed headers
// for authentication.
// The method must not alter the request method (must always be CONNECT) nor the
// Host/:origin, User-Agent or :path headers.
// The agent is the connected SSH agent if it exists, nil otherwise
// The provided roundTripper can be used to perform requests with the server to prepare
// the authentication process.
// username is the username to authenticate
// conversation is the Conversation we want to establish
PrepareRequestForAuth(request *http.Request, sshAgent agent.ExtendedAgent, roundTripper *http3.RoundTripper, username string, conversation *ssh3.Conversation) error
}

type ClientAuthPlugin struct {
// A plugin can define one or more new SSH3 config options.
Expand All @@ -41,5 +56,5 @@ type ClientAuthPlugin struct {
// (good practice: "<your_repo_name>[-<option_name>]")
PluginOptions map[client_config.OptionName]client_config.OptionParser

PluginFunc ClientAuthPluginFunc
PluginFunc GetAuthMethodsFunc
}
181 changes: 102 additions & 79 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/francoismichel/ssh3/auth/oidc"
client_config "github.com/francoismichel/ssh3/client/config"
"github.com/francoismichel/ssh3/client/winsize"
"github.com/francoismichel/ssh3/internal"
ssh3Messages "github.com/francoismichel/ssh3/message"
"github.com/francoismichel/ssh3/util"
)
Expand Down Expand Up @@ -213,16 +214,16 @@ type Client struct {
*ssh3.Conversation
}

func Dial(ctx context.Context, options *client_config.Config, qconn quic.EarlyConnection,
func Dial(ctx context.Context, config *client_config.Config, qconn quic.EarlyConnection,
roundTripper *http3.RoundTripper,
sshAgent agent.ExtendedAgent) (*Client, error) {

hostUrl := url.URL{}
hostUrl.Scheme = "https"
hostUrl.Host = options.URLHostnamePort()
hostUrl.Path = options.UrlPath()
hostUrl.Host = config.URLHostnamePort()
hostUrl.Path = config.UrlPath()
urlQuery := hostUrl.Query()
urlQuery.Set("user", options.Username())
urlQuery.Set("user", config.Username())
hostUrl.RawQuery = urlQuery.Encode()
requestUrl := hostUrl.String()

Expand Down Expand Up @@ -270,97 +271,119 @@ func Dial(ctx context.Context, options *client_config.Config, qconn quic.EarlyCo
}
req.Proto = "ssh3"

var identity ssh3.Identity
for _, method := range options.AuthMethods() {
switch m := method.(type) {
case *ssh3.PasswordAuthMethod:
log.Debug().Msgf("try password-based auth")
fmt.Printf("password for %s:", hostUrl.String())
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
// TODO: replace this by a loop actually performing the requests for qeach auth method of each plugin
foundSuitableAuthPlugin := false
plugins := internal.GetClientAuthPlugins()
for _, plugin := range plugins {
authMethods, err := plugin.PluginFunc(req, config, roundTripper)
if err != nil {
return nil, err
}
for _, authMethod := range authMethods {
err = authMethod.PrepareRequestForAuth(req, sshAgent, roundTripper, config.Username(), conv)
if err != nil {
log.Error().Msgf("could not get password: %s", err)
log.Error().Msgf("error when preparing request for auth plugin %T: %s", plugin, err)
return nil, err
}
identity = m.IntoIdentity(string(password))
case *ssh3.PrivkeyFileAuthMethod:
log.Debug().Msgf("try file-based pubkey auth using file %s", m.Filename())
identity, err = m.IntoIdentityWithoutPassphrase()
// could not identify without passphrase, try agent authentication by using the key's public key
if passphraseErr, ok := err.(*ssh.PassphraseMissingError); ok {
// the pubkey may be contained in the privkey file
pubkey := passphraseErr.PublicKey
if pubkey == nil {
// if it is not the case, try to find a .pub equivalent, like OpenSSH does
pubkeyBytes, err := os.ReadFile(fmt.Sprintf("%s.pub", m.Filename()))
if err == nil {
filePubkey, _, _, _, err := ssh.ParseAuthorizedKey(pubkeyBytes)
foundSuitableAuthPlugin = true
log.Debug().Msgf("found suitable auth plugin")
}
}

if !foundSuitableAuthPlugin {

var identity ssh3.Identity
for _, method := range config.AuthMethods() {
switch m := method.(type) {
case *ssh3.PasswordAuthMethod:
log.Debug().Msgf("try password-based auth")
fmt.Printf("password for %s:", hostUrl.String())
password, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
log.Error().Msgf("could not get password: %s", err)
return nil, err
}
identity = m.IntoIdentity(string(password))
case *ssh3.PrivkeyFileAuthMethod:
log.Debug().Msgf("try file-based pubkey auth using file %s", m.Filename())
identity, err = m.IntoIdentityWithoutPassphrase()
// could not identify without passphrase, try agent authentication by using the key's public key
if passphraseErr, ok := err.(*ssh.PassphraseMissingError); ok {
// the pubkey may be contained in the privkey file
pubkey := passphraseErr.PublicKey
if pubkey == nil {
// if it is not the case, try to find a .pub equivalent, like OpenSSH does
pubkeyBytes, err := os.ReadFile(fmt.Sprintf("%s.pub", m.Filename()))
if err == nil {
pubkey = filePubkey
filePubkey, _, _, _, err := ssh.ParseAuthorizedKey(pubkeyBytes)
if err == nil {
pubkey = filePubkey
}
}
}
}

// now, try to see of the agent manages this key
foundAgentKey := false
if pubkey != nil {
for _, agentKey := range agentKeys {
if bytes.Equal(agentKey.Marshal(), pubkey.Marshal()) {
log.Debug().Msgf("found key in agent: %s", agentKey)
identity = ssh3.NewAgentAuthMethod(pubkey).IntoIdentity(sshAgent)
foundAgentKey = true
break
// now, try to see of the agent manages this key
foundAgentKey := false
if pubkey != nil {
for _, agentKey := range agentKeys {
if bytes.Equal(agentKey.Marshal(), pubkey.Marshal()) {
log.Debug().Msgf("found key in agent: %s", agentKey)
identity = ssh3.NewAgentAuthMethod(pubkey).IntoIdentity(sshAgent)
foundAgentKey = true
break
}
}
}
}

// key not handled by agent, let's try to decrypt it ourselves
if !foundAgentKey {
fmt.Printf("passphrase for private key stored in %s:", m.Filename())
var passphraseBytes []byte
passphraseBytes, err = term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
log.Error().Msgf("could not get passphrase: %s", err)
return nil, err
}
passphrase := string(passphraseBytes)
identity, err = m.IntoIdentityPassphrase(passphrase)
if err != nil {
log.Error().Msgf("could not load private key: %s", err)
return nil, err
// key not handled by agent, let's try to decrypt it ourselves
if !foundAgentKey {
fmt.Printf("passphrase for private key stored in %s:", m.Filename())
var passphraseBytes []byte
passphraseBytes, err = term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
log.Error().Msgf("could not get passphrase: %s", err)
return nil, err
}
passphrase := string(passphraseBytes)
identity, err = m.IntoIdentityPassphrase(passphrase)
if err != nil {
log.Error().Msgf("could not load private key: %s", err)
return nil, err
}
}
} else if err != nil {
log.Warn().Msgf("Could not load private key: %s", err)
}
} else if err != nil {
log.Warn().Msgf("Could not load private key: %s", err)
}
case *ssh3.AgentAuthMethod:
log.Debug().Msgf("try ssh-agent-based auth")
identity = m.IntoIdentity(sshAgent)
case *ssh3.OidcAuthMethod:
log.Debug().Msgf("try OIDC auth to issuer %s", m.OIDCConfig().IssuerUrl)
token, err := oidc.Connect(context.Background(), m.OIDCConfig(), m.OIDCConfig().IssuerUrl, m.DoPKCE())
if err != nil {
log.Error().Msgf("could not get token: %s", err)
return nil, err
case *ssh3.AgentAuthMethod:
log.Debug().Msgf("try ssh-agent-based auth")
identity = m.IntoIdentity(sshAgent)
case *ssh3.OidcAuthMethod:
log.Debug().Msgf("try OIDC auth to issuer %s", m.OIDCConfig().IssuerUrl)
token, err := oidc.Connect(context.Background(), m.OIDCConfig(), m.OIDCConfig().IssuerUrl, m.DoPKCE())
if err != nil {
log.Error().Msgf("could not get token: %s", err)
return nil, err
}
identity = m.IntoIdentity(token)
}
identity = m.IntoIdentity(token)
// currently only tries a single identity (the first one), but the goal is to
// try several identities, similarly to OpenSSH
log.Debug().Msgf("we only try the first specified auth method for now")
break
}
// currently only tries a single identity (the first one), but the goal is to
// try several identities, similarly to OpenSSH
log.Debug().Msgf("we only try the first specified auth method for now")
break
}

if identity == nil {
return nil, NoSuitableIdentity{}
}
if identity == nil {
return nil, NoSuitableIdentity{}
}

log.Debug().Msgf("try the following Identity: %s", identity)
err = identity.SetAuthorizationHeader(req, options.Username(), conv)
if err != nil {
log.Error().Msgf("could not set authorization header in HTTP request: %s", err)
return nil, err
log.Debug().Msgf("try the following Identity: %s", identity)
err = identity.SetAuthorizationHeader(req, config.Username(), conv)
if err != nil {
log.Error().Msgf("could not set authorization header in HTTP request: %s", err)
return nil, err
}
}

log.Debug().Msgf("establish conversation with the server")
Expand Down
10 changes: 6 additions & 4 deletions client/config/config.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package options
package config

import (
"fmt"
Expand Down Expand Up @@ -94,9 +94,11 @@ type OptionParser interface {
// This keyword is used when parsing the SSH config.
OptionConfigName() string

// returns the Option[T] represented by this CLI argument.
// Option() will always be called *after* having parsed the CLI args using flag.Parse()
Parse(string) (Option, error)
// returns the Option represented by this list of config values
// the values are all retrieved from the config.
// values contain several entries if the keyword (see `OptionConfigName`) is
// present several times in the config.
Parse(values []string) (Option, error)
}

// CLIOptionParser defines a parser that can be hooked in the CLI flags
Expand Down
18 changes: 9 additions & 9 deletions client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ type privkeyFileIdentity struct {
}

func (i *privkeyFileIdentity) SetAuthorizationHeader(req *http.Request, username string, conversation *Conversation) error {
bearerToken, err := buildJWTBearerToken(i.signingMethod, i.privkey, username, conversation)
bearerToken, err := BuildJWTBearerToken(i.signingMethod, i.privkey, username, conversation)
if err != nil {
return err
}
Expand Down Expand Up @@ -204,7 +204,7 @@ func (i *agentBasedIdentity) SetAuthorizationHeader(req *http.Request, username
Key: i.pubkey,
}

bearerToken, err := buildJWTBearerToken(signingMethod, i.pubkey, username, conversation)
bearerToken, err := BuildJWTBearerToken(signingMethod, i.pubkey, username, conversation)
if err != nil {
return err
}
Expand Down Expand Up @@ -301,19 +301,19 @@ func GetConfigForHost(host string, config *ssh_config.Config, pluginsOptionsPars
}

pluginOptions = make(map[client_config.OptionName]client_config.Option)
log.Debug().Msgf("parsing options using option parsers")
log.Debug().Msgf("parsing options using option parsers: %+v", pluginsOptionsParsers)
for optionName, optionParser := range pluginsOptionsParsers {
log.Debug().Msgf("search for option %s (%s) in config", optionName, optionParser.OptionConfigName())
var optionValue string
optionValue, err = config.Get(host, optionParser.OptionConfigName())
var optionValues []string
optionValues, err = config.GetAll(host, optionParser.OptionConfigName())
if err != nil {
log.Error().Msgf("config.Get returned an error: %s", err)
return
}
if optionValue != "" {
if optionValues != nil {
var option client_config.Option
log.Debug().Msgf("found value for %s: %s", optionParser.OptionConfigName(), optionValue)
option, err = optionParser.Parse(optionValue)
log.Debug().Msgf("found value(s) for %s: %s", optionParser.OptionConfigName(), optionValues)
option, err = optionParser.Parse(optionValues)
if err != nil {
log.Error().Msgf("config option parser returned an error: %s", err)
return
Expand All @@ -325,7 +325,7 @@ func GetConfigForHost(host string, config *ssh_config.Config, pluginsOptionsPars
return hostname, port, user, urlPath, authMethodsToTry, pluginOptions, nil
}

func buildJWTBearerToken(signingMethod jwt.SigningMethod, key interface{}, username string, conversation *Conversation) (string, error) {
func BuildJWTBearerToken(signingMethod jwt.SigningMethod, key interface{}, username string, conversation *Conversation) (string, error) {
convID := conversation.ConversationID()
b64ConvID := base64.StdEncoding.EncodeToString(convID[:])
token := jwt.NewWithClaims(signingMethod, jwt.MapClaims{
Expand Down
4 changes: 2 additions & 2 deletions cmd/ssh3.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ func getConnectionMaterialFromURL(hostUrl *url.URL, sshConfig *ssh_config.Config
for k, v := range cliOptions {
if _, ok := pluginOptionsFromConfig[k]; ok {
log.Debug().Msgf("override config option %s by the value provided by the CLI", k)
pluginOptionsFromConfig[k] = v
}
pluginOptionsFromConfig[k] = v
}

options, err := client_config.NewConfig(configOptions.Username(), configOptions.Hostname(), configOptions.Port(), configOptions.UrlPath(), authMethods, configOptions.Options())
Expand Down Expand Up @@ -355,7 +355,7 @@ func (v *FlagValue) Set(s string) (err error) {
}
}
v.val = s
v.parsedOption, err = v.CLIOptionParser.Parse(s)
v.parsedOption, err = v.CLIOptionParser.Parse([]string{s})
if err != nil {
return err
}
Expand Down
7 changes: 7 additions & 0 deletions internal/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,10 @@ func GetPluginsClientOptionsParsers() (parsers map[client_config.OptionName]clie
}
return parsers, nil
}

func GetClientAuthPlugins() (plugins []auth.ClientAuthPlugin) {
for _, plugin := range clientRegistry.plugins {
plugins = append(plugins, plugin)
}
return plugins
}

0 comments on commit cde645f

Please sign in to comment.