diff --git a/stream/__init__.py b/stream/__init__.py index e72d0583..cd346528 100644 --- a/stream/__init__.py +++ b/stream/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .stream import stream +from .stream import stream, portforward diff --git a/stream/stream.py b/stream/stream.py index 9bb59017..57bac758 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -17,9 +17,12 @@ from . import ws_client -def _websocket_reqeust(websocket_request, api_method, *args, **kwargs): +def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs): """Override the ApiClient.request method with an alternative websocket based method and call the supplied Kubernetes API method with that in place.""" + if force_kwargs: + for kwarg, value in force_kwargs.items(): + kwargs[kwarg] = value api_client = api_method.__self__.api_client # old generated code's api client has config. new ones has configuration try: @@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs): api_client.request = prev_request -stream = functools.partial(_websocket_reqeust, ws_client.websocket_call) +stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None) +portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False}) diff --git a/stream/ws_client.py b/stream/ws_client.py index fa7f393e..69274d55 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from kubernetes.client.rest import ApiException +from kubernetes.client.rest import ApiException, ApiValueError import certifi import collections import select +import socket import ssl +import threading import time import six @@ -225,6 +227,143 @@ def close(self, **kwargs): WSResponse = collections.namedtuple('WSResponse', ['data']) +class PortForward: + def __init__(self, websocket, ports): + """A websocket client with support for port forwarding. + + Port Forward command sends on 2 channels per port, a read/write + data channel and a read only error channel. Both channels are sent an + initial frame contaning the port number that channel is associated with. + """ + + self.websocket = websocket + self.ports = {} + for ix, port_number in enumerate(ports): + self.ports[port_number] = self._Port(ix, port_number) + threading.Thread( + name="Kubernetes port forward proxy", target=self._proxy, daemon=True + ).start() + + def socket(self, port_number): + if port_number not in self.ports: + raise ValueError("Invalid port number") + return self.ports[port_number].socket + + def error(self, port_number): + if port_number not in self.ports: + raise ValueError("Invalid port number") + return self.ports[port_number].error + + def close(self): + for port in self.ports.values(): + port.socket.close() + + class _Port: + def __init__(self, ix, number): + self.number = number + self.channel = bytes([ix * 2]) + s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) + self.socket = self._Socket(s) + self.data = b'' + self.error = None + + class _Socket: + def __init__(self, socket): + self._socket = socket + + def __getattr__(self, name): + return getattr(self._socket, name) + + def setsockopt(self, level, optname, value): + # The following socket option is not valid with a socket created from socketpair, + # and is set when creating an SSLSocket from this socket. + if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY: + return + self._socket.setsockopt(level, optname, value) + + # Proxy all socket data between the python code and the kubernetes websocket. + def _proxy(self): + channel_ports = [] + channel_initialized = [] + python_ports = {} + rlist = [] + for port in self.ports.values(): + channel_ports.append(port) + channel_initialized.append(False) + channel_ports.append(port) + channel_initialized.append(False) + python_ports[port.python] = port + rlist.append(port.python) + rlist.append(self.websocket.sock) + kubernetes_data = b'' + while True: + wlist = [] + for port in self.ports.values(): + if port.data: + wlist.append(port.python) + if kubernetes_data: + wlist.append(self.websocket.sock) + r, w, _ = select.select(rlist, wlist, []) + for s in w: + if s == self.websocket.sock: + sent = self.websocket.sock.send(kubernetes_data) + kubernetes_data = kubernetes_data[sent:] + else: + port = python_ports[s] + sent = port.python.send(port.data) + port.data = port.data[sent:] + for s in r: + if s == self.websocket.sock: + opcode, frame = self.websocket.recv_data_frame(True) + if opcode == ABNF.OPCODE_CLOSE: + for port in self.ports.values(): + port.python.close() + return + if opcode == ABNF.OPCODE_BINARY: + if not frame.data: + raise RuntimeError("Unexpected frame data size") + channel = frame.data[0] + if channel >= len(channel_ports): + raise RuntimeError("Unexpected channel number: " + str(channel)) + port = channel_ports[channel] + if channel_initialized[channel]: + if channel % 2: + port.error = frame.data[1:].decode() + if port.python in rlist: + port.python.close() + rlist.remove(port.python) + port.data = b'' + else: + port.data += frame.data[1:] + else: + if len(frame.data) != 3: + raise RuntimeError( + "Unexpected initial channel frame data size" + ) + port_number = frame.data[1] + (frame.data[2] * 256) + if port_number != port.number: + raise RuntimeError( + "Unexpected port number in initial channel frame: " + str(port_number) + ) + channel_initialized[channel] = True + elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG): + raise RuntimeError("Unexpected websocket opcode: " + str(opcode)) + else: + port = python_ports[s] + data = port.python.recv(1024 * 1024) + if data: + kubernetes_data += ABNF.create_frame( + port.channel + data, + ABNF.OPCODE_BINARY, + ).format() + else: + port.python.close() + rlist.remove(s) + if len(rlist) == 1: + self.websocket.close() + return + + def get_websocket_url(url, query_params=None): parsed_url = urlparse(url) parts = list(parsed_url) @@ -302,3 +441,34 @@ def websocket_call(configuration, _method, url, **kwargs): return WSResponse('%s' % ''.join(client.read_all())) except (Exception, KeyboardInterrupt, SystemExit) as e: raise ApiException(status=0, reason=str(e)) + + +def portforward_call(configuration, _method, url, **kwargs): + """An internal function to be called in api-client when a websocket + connection is required for port forwarding. args and kwargs are the + parameters of apiClient.request method.""" + + query_params = kwargs.get("query_params") + + ports = [] + for key, value in query_params: + if key == 'ports': + for port in value.split(','): + try: + port = int(port) + if not (0 < port < 65536): + raise ValueError + ports.append(port) + except ValueError: + raise ApiValueError("Invalid port number `" + str(port) + "`") + if not ports: + raise ApiValueError("Missing required parameter `ports`") + + url = get_websocket_url(url, query_params) + headers = kwargs.get("headers") + + try: + websocket = create_websocket(configuration, url, headers) + return PortForward(websocket, ports) + except (Exception, KeyboardInterrupt, SystemExit) as e: + raise ApiException(status=0, reason=str(e))