From 1dd275626c04c87699cac1d26528fbadbf06881d Mon Sep 17 00:00:00 2001 From: "Patrick J. McNerthney" Date: Sun, 23 Aug 2020 13:34:41 -1000 Subject: [PATCH] Implement port forwarding. --- stream/__init__.py | 2 +- stream/stream.py | 27 +++-- stream/ws_client.py | 289 +++++++++++++++++++++++++++++++++++--------- 3 files changed, 252 insertions(+), 66 deletions(-) diff --git a/stream/__init__.py b/stream/__init__.py index e72d0583..cd346528 100644 --- a/stream/__init__.py +++ b/stream/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .stream import stream +from .stream import stream, portforward diff --git a/stream/stream.py b/stream/stream.py index 6d5f05f8..968aa7e8 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import types + from . import ws_client @@ -19,19 +21,22 @@ def stream(func, *args, **kwargs): """Stream given API call using websocket. Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()""" - def _intercept_request_call(*args, **kwargs): - # old generated code's api client has config. new ones has - # configuration - try: - config = func.__self__.api_client.configuration - except AttributeError: - config = func.__self__.api_client.config + api_client = func.__self__.api_client + prev_request = api_client.request + try: + api_client.request = types.MethodType(ws_client.websocket_call, api_client) + return func(*args, **kwargs) + finally: + api_client.request = prev_request - return ws_client.websocket_call(config, *args, **kwargs) - prev_request = func.__self__.api_client.request +def portforward(func, *args, **kwargs): + kwargs['_preload_content'] = False + api_client = func.__self__.api_client + prev_request = api_client.request try: - func.__self__.api_client.request = _intercept_request_call + api_client.request = types.MethodType(ws_client.portforward_call, api_client) return func(*args, **kwargs) finally: - func.__self__.api_client.request = prev_request + api_client.request = prev_request + diff --git a/stream/ws_client.py b/stream/ws_client.py index 2b599381..2b6e842a 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from kubernetes.client.rest import ApiException +from kubernetes.client.rest import ApiException, ApiValueError import certifi import collections import select +import socket import ssl +import threading import time import six import yaml -from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse +from six.moves.urllib.parse import urlencode, urlparse, urlunparse from six import StringIO from websocket import WebSocket, ABNF, enableTrace @@ -51,47 +53,13 @@ def __init__(self, configuration, url, headers, capture_all): like port forwarding can forward different pods' streams to different channels. """ - enableTrace(False) - header = [] self._connected = False self._channels = {} if capture_all: self._all = StringIO() else: self._all = _IgnoredIO() - - # We just need to pass the Authorization, ignore all the other - # http headers we get from the generated code - if headers and 'authorization' in headers: - header.append("authorization: %s" % headers['authorization']) - - if headers and 'sec-websocket-protocol' in headers: - header.append("sec-websocket-protocol: %s" % - headers['sec-websocket-protocol']) - else: - header.append("sec-websocket-protocol: v4.channel.k8s.io") - - if url.startswith('wss://') and configuration.verify_ssl: - ssl_opts = { - 'cert_reqs': ssl.CERT_REQUIRED, - 'ca_certs': configuration.ssl_ca_cert or certifi.where(), - } - if configuration.assert_hostname is not None: - ssl_opts['check_hostname'] = configuration.assert_hostname - else: - ssl_opts = {'cert_reqs': ssl.CERT_NONE} - - if configuration.cert_file: - ssl_opts['certfile'] = configuration.cert_file - if configuration.key_file: - ssl_opts['keyfile'] = configuration.key_file - - self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False) - if configuration.proxy: - proxy_url = urlparse(configuration.proxy) - self.sock.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port) - else: - self.sock.connect(url, header=header) + self.sock = create_websocket(configuration, url, headers=headers) self._connected = True def peek_channel(self, channel, timeout=0): @@ -259,44 +227,257 @@ def close(self, **kwargs): WSResponse = collections.namedtuple('WSResponse', ['data']) -def get_websocket_url(url): +class PortForwardClient: + def __init__(self, websocket, ports): + """A websocket client with support for port forwarding. + + Port Forward command sends on 2 channels per port, a read/write + data channel and a read only error channel. Both channels are sent an + initial frame contaning the port number that channel is associated with. + """ + + self.websocket = websocket + self.ports = {} + for ix, port_number in enumerate(ports): + self.ports[port_number] = self.Port(ix, port_number) + threading.Thread( + name="Kubernetes port forward proxy", target=self._proxy, daemon=True + ).start() + + def socket(self, port_number): + if port_number not in self.ports: + raise ValueError("Invalid port number") + return self.ports[port_number].socket + + def error(self, port_number): + if port_number not in self.ports: + raise ValueError("Invalid port number") + return self.ports[port_number].error + + def close(self): + for port in self.ports.values(): + port.socket.close() + + class Port: + def __init__(self, ix, number): + self.number = number + self.channel = bytes([ix * 2]) + s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket = self.Socket(s) + self.data = b'' + self.error = None + + class Socket: + def __init__(self, socket): + self._socket = socket + + def __getattr__(self, name): + return getattr(self._socket, name) + + def setsockopt(self, level, optname, value): + # The following socket option is not valid with a socket created from socketpair, + # and is set when creating an SSLSocket from this socket. + if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY: + return + self._socket.setsockopt(level, optname, value) + + # Proxy all socket data between the python code and the kubernetes websocket. + def _proxy(self): + channel_ports = [] + channel_initialized = [] + python_ports = {} + rlist = [] + for port in self.ports.values(): + channel_ports.append(port) + channel_initialized.append(False) + channel_ports.append(port) + channel_initialized.append(False) + python_ports[port.python] = port + rlist.append(port.python) + rlist.append(self.websocket.sock) + kubernetes_data = b'' + while True: + wlist = [] + for port in self.ports.values(): + if port.data: + wlist.append(port.python) + if kubernetes_data: + wlist.append(self.websocket.sock) + r, w, _ = select.select(rlist, wlist, []) + for s in w: + if s == self.websocket.sock: + sent = self.websocket.sock.send(kubernetes_data) + kubernetes_data = kubernetes_data[sent:] + else: + port = python_ports[s] + sent = port.python.send(port.data) + port.data = port.data[sent:] + for s in r: + if s == self.websocket.sock: + opcode, frame = self.websocket.recv_data_frame(True) + if opcode == ABNF.OPCODE_CLOSE: + for port in self.ports.values(): + port.python.close() + return + if opcode == ABNF.OPCODE_BINARY: + if not frame.data: + raise RuntimeError("Unexpected frame data size") + channel = frame.data[0] + if channel >= len(channel_ports): + raise RuntimeError("Unexpected channel number: " + str(channel)) + port = channel_ports[channel] + if channel_initialized[channel]: + if channel % 2: + port.error = frame.data[1:].decode() + if port.python in rlist: + port.python.close() + rlist.remove(port.python) + port.data = b'' + else: + port.data += frame.data[1:] + else: + if len(frame.data) != 3: + raise RuntimeError( + "Unexpected initial channel frame data size" + ) + port_number = frame.data[1] + (frame.data[2] * 256) + if port_number != port.number: + raise RuntimeError( + "Unexpected port number in initial channel frame: " + str(port_number) + ) + channel_initialized[channel] = True + elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG): + raise RuntimeError("Unexpected websocket opcode: " + str(opcode)) + else: + port = python_ports[s] + data = port.python.recv(1024 * 1024) + if data: + kubernetes_data += ABNF.create_frame( + port.channel + data, + ABNF.OPCODE_BINARY, + ).format() + else: + port.python.close() + rlist.remove(s) + if len(rlist) == 1: + self.websocket.close() + return + + +def get_websocket_url(url, query_params=None): parsed_url = urlparse(url) parts = list(parsed_url) if parsed_url.scheme == 'http': parts[0] = 'ws' elif parsed_url.scheme == 'https': parts[0] = 'wss' + if query_params: + query = [] + for key, value in query_params: + if key == 'command' and isinstance(value, list): + for command in value: + query.append((key, command)) + else: + query.append((key, value)) + if query: + parts[4] = urlencode(query) return urlunparse(parts) -def websocket_call(configuration, *args, **kwargs): +def create_websocket(configuration, url, headers=None): + enableTrace(False) + + # We just need to pass the Authorization, ignore all the other + # http headers we get from the generated code + header = [] + if headers and 'authorization' in headers: + header.append("authorization: %s" % headers['authorization']) + if headers and 'sec-websocket-protocol' in headers: + header.append("sec-websocket-protocol: %s" % + headers['sec-websocket-protocol']) + else: + header.append("sec-websocket-protocol: v4.channel.k8s.io") + + if url.startswith('wss://') and configuration.verify_ssl: + ssl_opts = { + 'cert_reqs': ssl.CERT_REQUIRED, + 'ca_certs': configuration.ssl_ca_cert or certifi.where(), + } + if configuration.assert_hostname is not None: + ssl_opts['check_hostname'] = configuration.assert_hostname + else: + ssl_opts = {'cert_reqs': ssl.CERT_NONE} + + if configuration.cert_file: + ssl_opts['certfile'] = configuration.cert_file + if configuration.key_file: + ssl_opts['keyfile'] = configuration.key_file + + websocket = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False) + if configuration.proxy: + proxy_url = urlparse(configuration.proxy) + websocket.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port) + else: + websocket.connect(url, header=header) + return websocket + + +def _configuration(api_client): + # old generated code's api client has config. new ones has + # configuration + try: + return api_client.configuration + except AttributeError: + return api_client.config + + +def websocket_call(api_client, _method, url, **kwargs): """An internal function to be called in api-client when a websocket connection is required. args and kwargs are the parameters of apiClient.request method.""" - url = args[1] + url = get_websocket_url(url, kwargs.get("query_params")) + headers = kwargs.get("headers") _request_timeout = kwargs.get("_request_timeout", 60) _preload_content = kwargs.get("_preload_content", True) capture_all = kwargs.get("capture_all", True) - headers = kwargs.get("headers") - - # Expand command parameter list to indivitual command params - query_params = [] - for key, value in kwargs.get("query_params", {}): - if key == 'command' and isinstance(value, list): - for command in value: - query_params.append((key, command)) - else: - query_params.append((key, value)) - - if query_params: - url += '?' + urlencode(query_params) try: - client = WSClient(configuration, get_websocket_url(url), headers, capture_all) + client = WSClient(_configuration(api_client), url, headers, capture_all) if not _preload_content: return client client.run_forever(timeout=_request_timeout) return WSResponse('%s' % ''.join(client.read_all())) except (Exception, KeyboardInterrupt, SystemExit) as e: raise ApiException(status=0, reason=str(e)) + + +def portforward_call(api_client, _method, url, **kwargs): + """An internal function to be called in api-client when a websocket + connection is required for port forwarding. args and kwargs are the + parameters of apiClient.request method.""" + + query_params = kwargs.get("query_params") + + ports = [] + for key, value in query_params: + if key == 'ports': + for port in value.split(','): + try: + port = int(port) + if not (0 < port < 65536): + raise ValueError + ports.append(port) + except ValueError: + raise ApiValueError("Invalid port number `" + str(port) + "`") + if not ports: + raise ApiValueError("Missing required parameter `ports`") + + url = get_websocket_url(url, query_params) + headers = kwargs.get("headers") + + try: + websocket = create_websocket(_configuration(api_client), url, headers) + return PortForwardClient(websocket, ports) + except (Exception, KeyboardInterrupt, SystemExit) as e: + raise ApiException(status=0, reason=str(e))