Skip to content
This repository has been archived by the owner on Mar 13, 2022. It is now read-only.

Implement port forwarding. #210

Merged
merged 7 commits into from
Sep 8, 2020
2 changes: 1 addition & 1 deletion stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .stream import stream
from .stream import stream, portforward
8 changes: 6 additions & 2 deletions stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
from . import ws_client


def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
def _websocket_reqeust(websocket_request, force_kwargs, api_method, *args, **kwargs):
"""Override the ApiClient.request method with an alternative websocket based
method and call the supplied Kubernetes API method with that in place."""
if force_kwargs:
for kwarg, value in force_kwargs.items():
kwargs[kwarg] = value
yliaog marked this conversation as resolved.
Show resolved Hide resolved
api_client = api_method.__self__.api_client
# old generated code's api client has config. new ones has configuration
try:
Expand All @@ -34,4 +37,5 @@ def _websocket_reqeust(websocket_request, api_method, *args, **kwargs):
api_client.request = prev_request


stream = functools.partial(_websocket_reqeust, ws_client.websocket_call)
stream = functools.partial(_websocket_reqeust, ws_client.websocket_call, None)
portforward = functools.partial(_websocket_reqeust, ws_client.portforward_call, {'_preload_content':False})
172 changes: 171 additions & 1 deletion stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from kubernetes.client.rest import ApiException
from kubernetes.client.rest import ApiException, ApiValueError

import certifi
import collections
import select
import socket
import ssl
import threading
import time

import six
Expand Down Expand Up @@ -225,6 +227,143 @@ def close(self, **kwargs):
WSResponse = collections.namedtuple('WSResponse', ['data'])


class PortForward:
def __init__(self, websocket, ports):
"""A websocket client with support for port forwarding.

Port Forward command sends on 2 channels per port, a read/write
data channel and a read only error channel. Both channels are sent an
initial frame contaning the port number that channel is associated with.
"""

self.websocket = websocket
self.ports = {}
for ix, port_number in enumerate(ports):
self.ports[port_number] = self._Port(ix, port_number)
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
).start()
yliaog marked this conversation as resolved.
Show resolved Hide resolved

def socket(self, port_number):
if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].socket

def error(self, port_number):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change it to error_channel or something more descriptive. 'error' looks like error/exception, which is confusing.

Copy link
Contributor Author

@iciclespider iciclespider Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the change to error_channel.

However, a little clarification. This is not really a "channel" that can send and receive data. From the client perspective, it is read only and cannot be written to. And the only time it is written to by kubernetes is when there is an error/exception case on the server and that port is unusable. It is written as the parting message.

See: https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/cri/streaming/portforward/websocket.go#L184-L197

func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
	defer p.dataStream.Close()
	defer p.errorStream.Close()

	klog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
	err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
	klog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)

	if err != nil {
		msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
		runtime.HandleError(msg)
		fmt.Fprint(p.errorStream, msg.Error())
	}
}

if port_number not in self.ports:
raise ValueError("Invalid port number")
return self.ports[port_number].error

def close(self):
for port in self.ports.values():
port.socket.close()

class _Port:
def __init__(self, ix, number):
self.number = number
self.channel = bytes([ix * 2])
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
self.socket = self._Socket(s)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
self.data = b''
self.error = None

class _Socket:
def __init__(self, socket):
self._socket = socket

def __getattr__(self, name):
return getattr(self._socket, name)

def setsockopt(self, level, optname, value):
yliaog marked this conversation as resolved.
Show resolved Hide resolved
# The following socket option is not valid with a socket created from socketpair,
# and is set when creating an SSLSocket from this socket.
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
return
self._socket.setsockopt(level, optname, value)

# Proxy all socket data between the python code and the kubernetes websocket.
def _proxy(self):
channel_ports = []
channel_initialized = []
python_ports = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about call it local_port?

rlist = []
for port in self.ports.values():
channel_ports.append(port)
channel_initialized.append(False)
channel_ports.append(port)
channel_initialized.append(False)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
python_ports[port.python] = port
rlist.append(port.python)
rlist.append(self.websocket.sock)
kubernetes_data = b''
while True:
wlist = []
for port in self.ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
wlist.append(self.websocket.sock)
r, w, _ = select.select(rlist, wlist, [])
for s in w:
if s == self.websocket.sock:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = python_ports[s]
sent = port.python.send(port.data)
port.data = port.data[sent:]
for s in r:
if s == self.websocket.sock:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.ports.values():
port.python.close()
return
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = frame.data[0]
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: " + str(channel))
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
port.error = frame.data[1:].decode()
if port.python in rlist:
port.python.close()
rlist.remove(port.python)
port.data = b''
else:
port.data += frame.data[1:]
else:
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
else:
port = python_ports[s]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
yliaog marked this conversation as resolved.
Show resolved Hide resolved
rlist.remove(s)
if len(rlist) == 1:
yliaog marked this conversation as resolved.
Show resolved Hide resolved
self.websocket.close()
return


def get_websocket_url(url, query_params=None):
parsed_url = urlparse(url)
parts = list(parsed_url)
Expand Down Expand Up @@ -302,3 +441,34 @@ def websocket_call(configuration, _method, url, **kwargs):
return WSResponse('%s' % ''.join(client.read_all()))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))


def portforward_call(configuration, _method, url, **kwargs):
"""An internal function to be called in api-client when a websocket
connection is required for port forwarding. args and kwargs are the
parameters of apiClient.request method."""

query_params = kwargs.get("query_params")

ports = []
yliaog marked this conversation as resolved.
Show resolved Hide resolved
for key, value in query_params:
if key == 'ports':
for port in value.split(','):
yliaog marked this conversation as resolved.
Show resolved Hide resolved
try:
port = int(port)
if not (0 < port < 65536):
raise ValueError
ports.append(port)
except ValueError:
raise ApiValueError("Invalid port number `" + str(port) + "`")
if not ports:
raise ApiValueError("Missing required parameter `ports`")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic can be abstracted out to another function with a simpler implementation like below.

def _valid_ports(ports):
    _ports = []

    for p in ports:
        try:
            p = int(p)
        except ValueError:
            raise ApiValueError("Invalid port number `{}`. Port number must be an integer".format(p))

        if (p < 0) or (p > 65536):
            raise ApiValueError("Invalid port number `{}`. Port number must be between 0 and 65536".format(p))

        _ports.append(p)

    return _ports

ports = query_params.get("ports", "")
if not ports:
    raise ApiValueError("Missing required parameter `ports`")

ports = ports.split(",")
ports = _valid_ports(ports)

# There is no check required here, since any unwanted port value would be checked in _valid_ports itself.
# We can be assured that we get a non-empty port list here.

Copy link
Contributor Author

@iciclespider iciclespider Aug 31, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will not work:

ports = query_params.get("ports", "")

query_params is not a dictionary, it a list of two element tuples.


url = get_websocket_url(url, query_params)
headers = kwargs.get("headers")

try:
websocket = create_websocket(configuration, url, headers)
return PortForward(websocket, ports)
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))