Skip to content

Commit

Permalink
Support custom SASL mechanisms
Browse files Browse the repository at this point in the history
There is some interest in supporting various SASL mechanisms not
currently included in the library:

* dpkp#2110 (DMS)
* dpkp#2204 (SSPI)
* dpkp#2232 (AWS_MSK_IAM)

Adding these mechanisms in the core library may be undesirable due to:

* Increased maintenance burden.
* Unavailable testing environments.
* Vendor specificity.

This commit provides a quick prototype for a pluggable SASL system.

---

**Example**

To define a custom SASL mechanism a module must implement two methods:

```py

def validate_config(conn):
    # Check configuration values, available libraries, etc.
    assert conn.config['vendor_specific_setting'] is not None, (
        'vendor_specific_setting required when sasl_mechanism=MY_SASL'
    )

def try_authenticate(conn, future):
    # Do authentication routine and return resolved Future with failed
    # or succeeded state.
```

And then the custom mechanism should be registered before initializing
a KafkaAdminClient, KafkaConsumer, or KafkaProducer:

```py

import kafka.sasl
from kafka import KafkaProducer

import my_sasl

kafka.sasl.register_mechanism('MY_SASL', my_sasl)

producer = KafkaProducer(sasl_mechanism='MY_SASL')
```

---

**Notes**

**ABCs**

This prototype does not implement an ABC for custom SASL mechanisms.
Using an ABC would reduce a few of the explicit assertions involved with
registering a mechanism and is a viable option. Due to differing feature
sets between py2/py3 this option was not explored, but shouldn't be
difficult.

**Private Methods**

This prototype relies on some methods that are currently marked as
**private** in `BrokerConnection`.

* `._can_send_recv`
* `._lock`
* `._recv_bytes_blocking`
* `._send_bytes_blocking`

A pluggable system would require stable interfaces for these actions.

**Alternative Approach**

If the module-scoped dict modification in `register_mechanism` feels too
clunky maybe the addtional mechanisms can be specified via an argument
when initializing one of the `Kafka*` classes?
  • Loading branch information
mattoberle committed Aug 20, 2021
1 parent f0a57a6 commit 03c357f
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 255 deletions.
274 changes: 19 additions & 255 deletions kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import copy
import errno
import io
import logging
from random import shuffle, uniform

Expand All @@ -14,25 +13,26 @@
from kafka.vendor import selectors34 as selectors

import socket
import struct
import threading
import time

from kafka.vendor import six

from kafka import sasl
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest
from kafka.protocol.admin import (
DescribeAclsRequest_v2,
DescribeClientQuotasRequest,
SaslHandShakeRequest,
)
from kafka.protocol.commit import OffsetFetchRequest
from kafka.protocol.offset import OffsetRequest
from kafka.protocol.produce import ProduceRequest
from kafka.protocol.metadata import MetadataRequest
from kafka.protocol.fetch import FetchRequest
from kafka.protocol.parser import KafkaProtocol
from kafka.protocol.types import Int32, Int8
from kafka.scram import ScramClient
from kafka.version import __version__


Expand Down Expand Up @@ -83,6 +83,12 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
Expand Down Expand Up @@ -227,7 +233,6 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512")

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -260,22 +265,10 @@ def __init__(self, host, port, afi, **configs):
assert ssl_available, "Python wasn't built with SSL support"

if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'):
assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, (
'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS))
if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'):
assert self.config['sasl_plain_username'] is not None, (
'sasl_plain_username required for PLAIN or SCRAM sasl'
)
assert self.config['sasl_plain_password'] is not None, (
'sasl_plain_password required for PLAIN or SCRAM sasl'
)
if self.config['sasl_mechanism'] == 'GSSAPI':
assert gssapi is not None, 'GSSAPI lib not available'
assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl'
if self.config['sasl_mechanism'] == 'OAUTHBEARER':
token_provider = self.config['sasl_oauth_token_provider']
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'
assert self.config['sasl_mechanism'] in sasl.MECHANISMS, (
'sasl_mechanism must be one of {}'.format(', '.join(sasl.MECHANISMS.keys()))
)
sasl.MECHANISMS[self.config['sasl_mechanism']].validate_config(self)
# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand Down Expand Up @@ -553,19 +546,9 @@ def _handle_sasl_handshake_response(self, future, response):
Errors.UnsupportedSaslMechanismError(
'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s'
% (self.config['sasl_mechanism'], response.enabled_mechanisms)))
elif self.config['sasl_mechanism'] == 'PLAIN':
return self._try_authenticate_plain(future)
elif self.config['sasl_mechanism'] == 'GSSAPI':
return self._try_authenticate_gssapi(future)
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
return self._try_authenticate_oauth(future)
elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"):
return self._try_authenticate_scram(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
'kafka-python does not support SASL mechanism %s' %
self.config['sasl_mechanism']))

try_authenticate = sasl.MECHANISMS[self.config['sasl_mechanism']].try_authenticate
return try_authenticate(self, future)

def _send_bytes(self, data):
"""Send some data via non-blocking IO
Expand Down Expand Up @@ -619,225 +602,6 @@ def _recv_bytes_blocking(self, n):
finally:
self._sock.settimeout(0.0)

def _try_authenticate_plain(self, future):
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.warning('%s: Sending username and password in the clear', self)

data = b''
# Send PLAIN credentials per RFC-4616
msg = bytes('\0'.join([self.config['sasl_plain_username'],
self.config['sasl_plain_username'],
self.config['sasl_plain_password']]).encode('utf-8'))
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)

except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
return future.failure(error)

log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_scram(self, future):
if self.config['security_protocol'] == 'SASL_PLAINTEXT':
log.warning('%s: Exchanging credentials in the clear', self)

scram_client = ScramClient(
self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism']
)

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
client_first = scram_client.first_message().encode('utf-8')
size = Int32.encode(len(client_first))
self._send_bytes_blocking(size + client_first)

(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
server_first = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_first_message(server_first)

client_final = scram_client.final_message().encode('utf-8')
size = Int32.encode(len(client_final))
self._send_bytes_blocking(size + client_final)

(data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4))
server_final = self._recv_bytes_blocking(data_len).decode('utf-8')
scram_client.process_server_final_message(server_final)

except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info(
'%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism']
)
return future.success(True)

def _try_authenticate_gssapi(self, future):
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name
gssapi_name = gssapi.Name(
auth_id,
name_type=gssapi.NameType.hostbased_service
).canonicalize(gssapi.MechType.kerberos)
log.debug('%s: GSSAPI name: %s', self, gssapi_name)

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
# Establish security context and negotiate protection level
# For reference RFC 2222, section 7.2.1
try:
# Exchange tokens until authentication either succeeds or fails
client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate')
received_token = None
while not client_ctx.complete:
# calculate an output token from kafka token (or None if first iteration)
output_token = client_ctx.step(received_token)

# pass output token to kafka, or send empty response if the security
# context is complete (output token is None in that case)
if output_token is None:
self._send_bytes_blocking(Int32.encode(0))
else:
msg = output_token
size = Int32.encode(len(msg))
self._send_bytes_blocking(size + msg)

# The server will send a token back. Processing of this token either
# establishes a security context, or it needs further token exchange.
# The gssapi will be able to identify the needed next step.
# The connection is closed on failure.
header = self._recv_bytes_blocking(4)
(token_size,) = struct.unpack('>i', header)
received_token = self._recv_bytes_blocking(token_size)

# Process the security layer negotiation token, sent by the server
# once the security context is established.

# unwraps message containing supported protection levels and msg size
msg = client_ctx.unwrap(received_token).message
# Kafka currently doesn't support integrity or confidentiality security layers, so we
# simply set QoP to 'auth' only (first octet). We reuse the max message size proposed
# by the server
msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:]
# add authorization identity to the response, GSS-wrap and send it
msg = client_ctx.wrap(msg + auth_id.encode(), False).message
size = Int32.encode(len(msg))
self._send_bytes_blocking(size + msg)

except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True
except Exception as e:
err = e
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name)
return future.success(True)

def _try_authenticate_oauth(self, future):
data = b''

msg = bytes(self._build_oauth_client_request().encode("utf-8"))
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
# Send SASL OAuthBearer request with OAuth token
self._send_bytes_blocking(size + msg)

# The server will send a zero sized message (that is Int32(0)) on success.
# The connection is closed on failure
data = self._recv_bytes_blocking(4)

except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

if data != b'\x00\x00\x00\x00':
error = Errors.AuthenticationFailedError('Unrecognized response during authentication')
return future.failure(error)

log.info('%s: Authenticated via OAuth', self)
return future.success(True)

def _build_oauth_client_request(self):
token_provider = self.config['sasl_oauth_token_provider']
return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions())

def _token_extensions(self):
"""
Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER
initial request.
"""
token_provider = self.config['sasl_oauth_token_provider']

# Only run if the #extensions() method is implemented by the clients Token Provider class
# Builds up a string separated by \x01 via a dict of key value pairs
if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0:
msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()])
return "\x01" + msg
else:
return ""

def blacked_out(self):
"""
Return true if we are disconnected from the given node and can't
Expand Down
53 changes: 53 additions & 0 deletions kafka/sasl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging

from kafka.sasl import gssapi, oauthbearer, plain, scram

log = logging.getLogger(__name__)

MECHANISMS = {
'GSSAPI': gssapi,
'OAUTHBEARER': oauthbearer,
'PLAIN': plain,
'SCRAM-SHA-256': scram,
'SCRAM-SHA-512': scram,
}


def register_mechanism(key, module):
"""
Registers a custom SASL mechanism that can be used via sasl_mechanism={key}.
Example:
import kakfa.sasl
from kafka import KafkaProducer
from mymodule import custom_sasl
kafka.sasl.register_mechanism('CUSTOM_SASL', custom_sasl)
producer = KafkaProducer(sasl_mechanism='CUSTOM_SASL')
Arguments:
key (str): The name of the mechanism returned by the broker and used
in the sasl_mechanism config value.
module (module): A module that implements the following methods...
def validate_config(conn: BrokerConnection): -> None:
# Raises an AssertionError for missing or invalid conifg values.
def try_authenticate(conn: BrokerConncetion, future: -> Future):
# Executes authentication routine and returns a resolved Future.
Raises:
AssertionError: The registered module does not define a required method.
"""
assert callable(getattr(module, 'validate_config', None)), (
'Custom SASL mechanism {} must implement method #validate_config()'
.format(key)
)
assert callable(getattr(module, 'try_authenticate', None)), (
'Custom SASL mechanism {} must implement method #try_authenticate()'
.format(key)
)
if key in MECHANISMS:
log.warning('Overriding existing SASL mechanism {}'.format(key))

MECHANISMS[key] = module
Loading

0 comments on commit 03c357f

Please sign in to comment.