Skip to content
This repository has been archived by the owner on Mar 13, 2022. It is now read-only.

Commit

Permalink
Merge pull request #210 from iciclespider/port-forward
Browse files Browse the repository at this point in the history
Implement port forwarding.
  • Loading branch information
k8s-ci-robot committed Sep 8, 2020
2 parents 471a678 + 5d39d0d commit 3dc7fe0
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 4 deletions.
2 changes: 1 addition & 1 deletion stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .stream import stream
from .stream import stream, portforward
8 changes: 6 additions & 2 deletions stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from . import ws_client


def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs):
"""Override the ApiClient.request method with an alternative websocket based
method and call the supplied Kubernetes API method with that in place."""
if force_kwargs:
for kwarg, value in force_kwargs.items():
kwargs[kwarg] = value
api_client = api_method.__self__.api_client
# old generated code's api client has config. new ones has configuration
try:
Expand All @@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
api_client.request = prev_request


stream = functools.partial(_websocket_reqeust, ws_client.websocket_call)
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None)
portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False})
205 changes: 204 additions & 1 deletion stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from kubernetes.client.rest import ApiException
from kubernetes.client.rest import ApiException, ApiValueError

import certifi
import collections
import select
import socket
import ssl
import threading
import time

import six
Expand Down Expand Up @@ -225,6 +227,174 @@ def close(self, **kwargs):
WSResponse = collections.namedtuple('WSResponse', ['data'])


class PortForward:
def __init__(self, websocket, ports):
"""A websocket client with support for port forwarding.
Port Forward command sends on 2 channels per port, a read/write
data channel and a read only error channel. Both channels are sent an
initial frame contaning the port number that channel is associated with.
"""

self.websocket = websocket
self.local_ports = {}
for ix, port_number in enumerate(ports):
self.local_ports[port_number] = self._Port(ix, port_number)
# There is a thread run per PortForward instance which performs the translation between the
# raw socket data sent by the python application and the websocket protocol. This thread
# terminates after either side has closed all ports, and after flushing all pending data.
proxy = threading.Thread(
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
target=self._proxy
)
proxy.daemon = True
proxy.start()

@property
def connected(self):
return self.websocket.connected

def socket(self, port_number):
if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[port_number].socket

def error(self, port_number):
if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[port_number].error

def close(self):
for port in self.local_ports.values():
port.socket.close()

class _Port:
def __init__(self, ix, port_number):
# The remote port number
self.port_number = port_number
# The websocket channel byte number for this port
self.channel = six.int2byte(ix * 2)
# A socket pair is created to provide a means of translating the data flow
# between the python application and the kubernetes websocket. The self.python
# half of the socket pair is used by the _proxy method to receive and send data
# to the running python application.
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
# The self.socket half of the pair is used by the python application to send
# and receive data to the eventual pod port. It is wrapped in the _Socket class
# because a socket pair is an AF_UNIX socket, not a AF_INET socket. This allows
# intercepting setting AF_INET socket options that would error against an AF_UNIX
# socket.
self.socket = self._Socket(s)
# Data accumulated from the websocket to be sent to the python application.
self.data = b''
# All data sent from kubernetes on the port error channel.
self.error = None

class _Socket:
def __init__(self, socket):
self._socket = socket

def __getattr__(self, name):
return getattr(self._socket, name)

def setsockopt(self, level, optname, value):
# The following socket option is not valid with a socket created from socketpair,
# and is set by the http.client.HTTPConnection.connect method.
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
return
self._socket.setsockopt(level, optname, value)

# Proxy all socket data between the python code and the kubernetes websocket.
def _proxy(self):
channel_ports = []
channel_initialized = []
local_ports = {}
for port in self.local_ports.values():
# Setup the data channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
# Setup the error channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
port.python.setblocking(True)
local_ports[port.python] = port
# The data to send on the websocket socket
kubernetes_data = b''
while True:
rlist = [] # List of sockets to read from
wlist = [] # List of sockets to write to
if self.websocket.connected:
rlist.append(self.websocket)
if kubernetes_data:
wlist.append(self.websocket)
local_all_closed = True
for port in self.local_ports.values():
if port.python.fileno() != -1:
if port.error or not self.websocket.connected:
if port.data:
wlist.append(port.python)
local_all_closed = False
else:
port.python.close()
else:
rlist.append(port.python)
if port.data:
wlist.append(port.python)
local_all_closed = False
if local_all_closed and not (self.websocket.connected and kubernetes_data):
self.websocket.close()
return
r, w, _ = select.select(rlist, wlist, [])
for sock in r:
if sock == self.websocket:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = six.byte2int(frame.data)
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: %s" % channel)
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
if port.error is None:
port.error = ''
port.error += frame.data[1:].decode()
else:
port.data += frame.data[1:]
else:
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = six.byte2int(frame.data[1:2]) + (six.byte2int(frame.data[2:3]) * 256)
if port_number != port.port_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: %s" % port_number
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
raise RuntimeError("Unexpected websocket opcode: %s" % opcode)
else:
port = local_ports[sock]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
for sock in w:
if sock == self.websocket:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = local_ports[sock]
sent = port.python.send(port.data)
port.data = port.data[sent:]


def get_websocket_url(url, query_params=None):
parsed_url = urlparse(url)
parts = list(parsed_url)
Expand Down Expand Up @@ -302,3 +472,36 @@ def websocket_call(configuration, _method, url, **kwargs):
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))


def portforward_call(configuration, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required for port forwarding. args and kwargs are the
parameters of apiClient.request method."""

query_params = kwargs.get("query_params")

ports = []
for param, value in query_params:
if param == 'ports':
for port in value.split(','):
try:
port_number = int(port)
except ValueError:
raise ApiValueError("Invalid port number: %s" % port)
if not (0 < port_number < 65536):
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
if port_number in ports:
raise ApiValueError("Duplicate port numbers: %s" % port)
ports.append(port_number)
if not ports:
raise ApiValueError("Missing required parameter `ports`")

url = get_websocket_url(url, query_params)
headers = kwargs.get("headers")

try:
websocket = create_websocket(configuration, url, headers)
return PortForward(websocket, ports)
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

0 comments on commit 3dc7fe0

Please sign in to comment.