Skip to content
This repository has been archived by the owner on Mar 13, 2022. It is now read-only.

Implement port forwarding. #210

Merged
merged 7 commits into from
Sep 8, 2020
151 changes: 78 additions & 73 deletions stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,33 +238,51 @@ def __init__(self, websocket, ports):

self.websocket = websocket
self.local_ports = {}
for ix, local_remote in enumerate(ports):
self.local_ports[local_remote[0]] = self._Port(ix, local_remote[1])
for ix, port_number in enumerate(ports):
self.local_ports[port_number] = self._Port(ix, port_number)
# There is a thread run per PortForward instance which performs the translation between the
# raw socket data sent by the python application and the websocket protocol. This thread
# terminates after either side has closed all ports, and after flushing all pending data.
threading.Thread(
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
name="Kubernetes port forward proxy: %s" % ', '.join([str(port) for port in ports]),
target=self._proxy,
daemon=True
).start()
yliaog marked this conversation as resolved.
Show resolved Hide resolved

def socket(self, local_number):
if local_number not in self.local_ports:
def socket(self, port_number):
if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[local_number].socket
return self.local_ports[port_number].socket

def error(self, local_number):
if local_number not in self.local_ports:
def error(self, port_number):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please change it to error_channel or something more descriptive. 'error' looks like error/exception, which is confusing.

Copy link
Contributor Author

@iciclespider iciclespider Sep 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed the change to error_channel.

However, a little clarification. This is not really a "channel" that can send and receive data. From the client perspective, it is read only and cannot be written to. And the only time it is written to by kubernetes is when there is an error/exception case on the server and that port is unusable. It is written as the parting message.

See: https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/cri/streaming/portforward/websocket.go#L184-L197

func (h *websocketStreamHandler) portForward(p *websocketStreamPair) {
	defer p.dataStream.Close()
	defer p.errorStream.Close()

	klog.V(5).Infof("(conn=%p) invoking forwarder.PortForward for port %d", h.conn, p.port)
	err := h.forwarder.PortForward(h.pod, h.uid, p.port, p.dataStream)
	klog.V(5).Infof("(conn=%p) done invoking forwarder.PortForward for port %d", h.conn, p.port)

	if err != nil {
		msg := fmt.Errorf("error forwarding port %d to pod %s, uid %v: %v", p.port, h.pod, h.uid, err)
		runtime.HandleError(msg)
		fmt.Fprint(p.errorStream, msg.Error())
	}
}

if port_number not in self.local_ports:
raise ValueError("Invalid port number")
return self.local_ports[local_number].error
return self.local_ports[port_number].error

def close(self):
for port in self.local_ports.values():
port.socket.close()

class _Port:
def __init__(self, ix, remote_number):
self.remote_number = remote_number
def __init__(self, ix, port_number):
# The remote port number
self.port_number = port_number
# The websocket channel byte number for this port
self.channel = bytes([ix * 2])
# A socket pair is created to provide a means of translating the data flow
# between the python application and the kubernetes websocket. The self.python
# half of the socket pair is used by the _proxy method to receive and send data
# to the running python application.
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
# The self.socket half of the pair is used by the python application to send
# and receive data to the eventual pod port. It is wrapped in the _Socket class
# because a socket pair is an AF_UNIX socket, not a AF_NET socket. This allows
# intercepting setting AF_INET socket options that would error against an AD_UNIX
yliaog marked this conversation as resolved.
Show resolved Hide resolved
# socket.
self.socket = self._Socket(s)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
# Data accumulated from the websocket to be sent to the python application.
self.data = b''
# All data sent from kubernetes on the port error channel.
self.error = None

class _Socket:
Expand All @@ -285,42 +303,44 @@ def setsockopt(self, level, optname, value):
def _proxy(self):
channel_ports = []
channel_initialized = []
python_ports = {}
rlist = []
local_ports = {}
for port in self.local_ports.values():
# Setup the data channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
# Setup the error channel for this port number
channel_ports.append(port)
channel_initialized.append(False)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
python_ports[port.python] = port
rlist.append(port.python)
rlist.append(self.websocket.sock)
port.python.setblocking(True)
local_ports[port.python] = port
# The data to send on the websocket socket
kubernetes_data = b''
while True:
wlist = []
rlist = [] # List of sockets to read from
wlist = [] # List of sockets to write to
if self.websocket.connected:
rlist.append(self.websocket)
if kubernetes_data:
wlist.append(self.websocket)
all_closed = True
for port in self.local_ports.values():
if port.data:
wlist.append(port.python)
if kubernetes_data:
wlist.append(self.websocket.sock)
if port.python.fileno() != -1:
if port.data:
wlist.append(port.python)
all_closed = False
else:
if self.websocket.connected:
rlist.append(port.python)
yliaog marked this conversation as resolved.
Show resolved Hide resolved
all_closed = False
else:
port.python.close()
yliaog marked this conversation as resolved.
Show resolved Hide resolved
if all_closed and (not self.websocket.connected or not kubernetes_data):
yliaog marked this conversation as resolved.
Show resolved Hide resolved
self.websocket.close()
return
r, w, _ = select.select(rlist, wlist, [])
for s in w:
if s == self.websocket.sock:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = python_ports[s]
sent = port.python.send(port.data)
port.data = port.data[sent:]
for s in r:
if s == self.websocket.sock:
for sock in r:
if sock == self.websocket:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_CLOSE:
for port in self.local_ports.values():
port.python.close()
return
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
Expand All @@ -341,27 +361,32 @@ def _proxy(self):
"Unexpected initial channel frame data size"
)
port_number = frame.data[1] + (frame.data[2] * 256)
if port_number != port.remote_number:
if port_number != port.port_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: " + str(port_number)
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
else:
port = python_ports[s]
port = local_ports[sock]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
rlist.remove(s)
if len(rlist) == 1:
self.websocket.close()
return
if not port.data:
port.python.close()
for sock in w:
if sock == self.websocket:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = local_ports[sock]
sent = port.python.send(port.data)
port.data = port.data[sent:]


def get_websocket_url(url, query_params=None):
Expand Down Expand Up @@ -451,38 +476,18 @@ def portforward_call(configuration, _method, url, **kwargs):
query_params = kwargs.get("query_params")

ports = []
yliaog marked this conversation as resolved.
Show resolved Hide resolved
for ix in range(len(query_params)):
if query_params[ix][0] == 'ports':
remote_ports = []
for port in query_params[ix][1].split(','):
for param, value in query_params:
if param == 'ports':
for port in value.split(','):
yliaog marked this conversation as resolved.
Show resolved Hide resolved
try:
local_remote = port.split(':')
if len(local_remote) > 2:
raise ValueError
if len(local_remote) == 1:
local_remote[0] = int(local_remote[0])
if not (0 < local_remote[0] < 65536):
raise ValueError
local_remote.append(local_remote[0])
elif len(local_remote) == 2:
if local_remote[0]:
local_remote[0] = int(local_remote[0])
if not (0 <= local_remote[0] < 65536):
raise ValueError
else:
local_remote[0] = 0
local_remote[1] = int(local_remote[1])
if not (0 < local_remote[1] < 65536):
raise ValueError
if not local_remote[0]:
local_remote[0] = len(ports) + 1
else:
raise ValueError
ports.append(local_remote)
remote_ports.append(str(local_remote[1]))
port_number = int(port)
except ValueError:
raise ApiValueError("Invalid port number `" + port + "`")
query_params[ix] = ('ports', ','.join(remote_ports))
raise ApiValueError("Invalid port number: %s" % port)
if not (0 < port_number < 65536):
raise ApiValueError("Port number must be between 0 and 65536: %s" % port)
if port_number in ports:
raise ApiValueError("Duplicate port numbers: %s" % port)
ports.append(port_number)
if not ports:
raise ApiValueError("Missing required parameter `ports`")

Expand Down