Skip to content
This repository has been archived by the owner on Mar 13, 2022. It is now read-only.

Commit

Permalink
Implement port forwarding.
Browse files Browse the repository at this point in the history
  • Loading branch information
iciclespider committed Aug 24, 2020
1 parent 54d188f commit 1dd2756
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 66 deletions.
2 changes: 1 addition & 1 deletion stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .stream import stream
from .stream import stream, portforward
27 changes: 16 additions & 11 deletions stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import types

from . import ws_client


def stream(func, *args, **kwargs):
"""Stream given API call using websocket.
Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()"""

def _intercept_request_call(*args, **kwargs):
# old generated code's api client has config. new ones has
# configuration
try:
config = func.__self__.api_client.configuration
except AttributeError:
config = func.__self__.api_client.config
api_client = func.__self__.api_client
prev_request = api_client.request
try:
api_client.request = types.MethodType(ws_client.websocket_call, api_client)
return func(*args, **kwargs)
finally:
api_client.request = prev_request

return ws_client.websocket_call(config, *args, **kwargs)

prev_request = func.__self__.api_client.request
def portforward(func, *args, **kwargs):
kwargs['_preload_content'] = False
api_client = func.__self__.api_client
prev_request = api_client.request
try:
func.__self__.api_client.request = _intercept_request_call
api_client.request = types.MethodType(ws_client.portforward_call, api_client)
return func(*args, **kwargs)
finally:
func.__self__.api_client.request = prev_request
api_client.request = prev_request

289 changes: 235 additions & 54 deletions stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from kubernetes.client.rest import ApiException
from kubernetes.client.rest import ApiException, ApiValueError

import certifi
import collections
import select
import socket
import ssl
import threading
import time

import six
import yaml

from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO

from websocket import WebSocket, ABNF, enableTrace
Expand Down Expand Up @@ -51,47 +53,13 @@ def __init__(self, configuration, url, headers, capture_all):
like port forwarding can forward different pods' streams to different
channels.
"""
enableTrace(False)
header = []
self._connected = False
self._channels = {}
if capture_all:
self._all = StringIO()
else:
self._all = _IgnoredIO()

# We just need to pass the Authorization, ignore all the other
# http headers we get from the generated code
if headers and 'authorization' in headers:
header.append("authorization: %s" % headers['authorization'])

if headers and 'sec-websocket-protocol' in headers:
header.append("sec-websocket-protocol: %s" %
headers['sec-websocket-protocol'])
else:
header.append("sec-websocket-protocol: v4.channel.k8s.io")

if url.startswith('wss://') and configuration.verify_ssl:
ssl_opts = {
'cert_reqs': ssl.CERT_REQUIRED,
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
}
if configuration.assert_hostname is not None:
ssl_opts['check_hostname'] = configuration.assert_hostname
else:
ssl_opts = {'cert_reqs': ssl.CERT_NONE}

if configuration.cert_file:
ssl_opts['certfile'] = configuration.cert_file
if configuration.key_file:
ssl_opts['keyfile'] = configuration.key_file

self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
if configuration.proxy:
proxy_url = urlparse(configuration.proxy)
self.sock.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
else:
self.sock.connect(url, header=header)
self.sock = create_websocket(configuration, url, headers=headers)
self._connected = True

def peek_channel(self, channel, timeout=0):
Expand Down Expand Up @@ -259,44 +227,257 @@ def close(self, **kwargs):
WSResponse = collections.namedtuple('WSResponse', ['data'])


def get_websocket_url(url):
class PortForwardClient:
def __init__(self, websocket, ports):
"""A websocket client with support for port forwarding.
Port Forward command sends on 2 channels per port, a read/write
data channel and a read only error channel. Both channels are sent an
initial frame contaning the port number that channel is associated with.
"""

self.websocket = websocket
self.ports = {}
for ix, port_number in enumerate(ports):
self.ports[port_number] = self.Port(ix, port_number)
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
).start()

def socket(self, port_number):
if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].socket

def error(self, port_number):
if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].error

def close(self):
for port in self.ports.values():
port.socket.close()

class Port:
def __init__(self, ix, number):
self.number = number
self.channel = bytes([ix * 2])
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket = self.Socket(s)
self.data = b''
self.error = None

class Socket:
def __init__(self, socket):
self._socket = socket

def __getattr__(self, name):
return getattr(self._socket, name)

def setsockopt(self, level, optname, value):
# The following socket option is not valid with a socket created from socketpair,
# and is set when creating an SSLSocket from this socket.
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
return
self._socket.setsockopt(level, optname, value)

# Proxy all socket data between the python code and the kubernetes websocket.
def _proxy(self):
channel_ports = []
channel_initialized = []
python_ports = {}
rlist = []
for port in self.ports.values():
channel_ports.append(port)
channel_initialized.append(False)
channel_ports.append(port)
channel_initialized.append(False)
python_ports[port.python] = port
rlist.append(port.python)
rlist.append(self.websocket.sock)
kubernetes_data = b''
while True:
wlist = []
for port in self.ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
wlist.append(self.websocket.sock)
r, w, _ = select.select(rlist, wlist, [])
for s in w:
if s == self.websocket.sock:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = python_ports[s]
sent = port.python.send(port.data)
port.data = port.data[sent:]
for s in r:
if s == self.websocket.sock:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.ports.values():
port.python.close()
return
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = frame.data[0]
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: " + str(channel))
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
port.error = frame.data[1:].decode()
if port.python in rlist:
port.python.close()
rlist.remove(port.python)
port.data = b''
else:
port.data += frame.data[1:]
else:
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
else:
port = python_ports[s]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
rlist.remove(s)
if len(rlist) == 1:
self.websocket.close()
return


def get_websocket_url(url, query_params=None):
parsed_url = urlparse(url)
parts = list(parsed_url)
if parsed_url.scheme == 'http':
parts[0] = 'ws'
elif parsed_url.scheme == 'https':
parts[0] = 'wss'
if query_params:
query = []
for key, value in query_params:
if key == 'command' and isinstance(value, list):
for command in value:
query.append((key, command))
else:
query.append((key, value))
if query:
parts[4] = urlencode(query)
return urlunparse(parts)


def websocket_call(configuration, *args, **kwargs):
def create_websocket(configuration, url, headers=None):
enableTrace(False)

# We just need to pass the Authorization, ignore all the other
# http headers we get from the generated code
header = []
if headers and 'authorization' in headers:
header.append("authorization: %s" % headers['authorization'])
if headers and 'sec-websocket-protocol' in headers:
header.append("sec-websocket-protocol: %s" %
headers['sec-websocket-protocol'])
else:
header.append("sec-websocket-protocol: v4.channel.k8s.io")

if url.startswith('wss://') and configuration.verify_ssl:
ssl_opts = {
'cert_reqs': ssl.CERT_REQUIRED,
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
}
if configuration.assert_hostname is not None:
ssl_opts['check_hostname'] = configuration.assert_hostname
else:
ssl_opts = {'cert_reqs': ssl.CERT_NONE}

if configuration.cert_file:
ssl_opts['certfile'] = configuration.cert_file
if configuration.key_file:
ssl_opts['keyfile'] = configuration.key_file

websocket = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
if configuration.proxy:
proxy_url = urlparse(configuration.proxy)
websocket.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
else:
websocket.connect(url, header=header)
return websocket


def _configuration(api_client):
# old generated code's api client has config. new ones has
# configuration
try:
return api_client.configuration
except AttributeError:
return api_client.config


def websocket_call(api_client, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required. args and kwargs are the parameters of
apiClient.request method."""

url = args[1]
url = get_websocket_url(url, kwargs.get("query_params"))
headers = kwargs.get("headers")
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)
headers = kwargs.get("headers")

# Expand command parameter list to indivitual command params
query_params = []
for key, value in kwargs.get("query_params", {}):
if key == 'command' and isinstance(value, list):
for command in value:
query_params.append((key, command))
else:
query_params.append((key, value))

if query_params:
url += '?' + urlencode(query_params)

try:
client = WSClient(configuration, get_websocket_url(url), headers, capture_all)
client = WSClient(_configuration(api_client), url, headers, capture_all)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))


def portforward_call(api_client, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required for port forwarding. args and kwargs are the
parameters of apiClient.request method."""

query_params = kwargs.get("query_params")

ports = []
for key, value in query_params:
if key == 'ports':
for port in value.split(','):
try:
port = int(port)
if not (0 < port < 65536):
raise ValueError
ports.append(port)
except ValueError:
raise ApiValueError("Invalid port number `" + str(port) + "`")
if not ports:
raise ApiValueError("Missing required parameter `ports`")

url = get_websocket_url(url, query_params)
headers = kwargs.get("headers")

try:
websocket = create_websocket(_configuration(api_client), url, headers)
return PortForwardClient(websocket, ports)
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

0 comments on commit 1dd2756

Please sign in to comment.