Skip to content

Commit

Permalink
Enable binary support for WSClient
Browse files Browse the repository at this point in the history
Currently, under python 3, the WSClient decodes all data via UTF-8. This
will break, e.g. capturing the stdout of tar or gzip.
This adds a new 'binary' kwarg to the WSClient class and websocket_call
function. If this is set to true, then the decoding will not happen, and
all channels will be interpreted as binary.
This does raise a slight complication, as the OpenAPI-generated client
will convert the output to a string, no matter what, which it ends up
doing by (effectively) calling repr(). This requires a bit of magic to
recover the orignial bytes, and is inefficient. However, this is only
the case when using the default _preload_content=True, setting this to
False and manually calling read_all or read_channel, this issue does not
arise.
  • Loading branch information
meln5674 committed Feb 28, 2024
1 parent 7712421 commit 488518d
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
13 changes: 11 additions & 2 deletions kubernetes/base/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa
except AttributeError:
configuration = api_client.config
prev_request = api_client.request
binary = kwargs.pop('binary', False)
try:
api_client.request = functools.partial(websocket_request, configuration)
return api_method(*args, **kwargs)
api_client.request = functools.partial(websocket_request, configuration, binary=binary)
out = api_method(*args, **kwargs)
# The api_client insists on converting this to a string using its representation, so we have
# to do this dance to strip it of the b' prefix and ' suffix, encode it byte-per-byte (latin1),
# escape all of the unicode \x*'s, then encode it back byte-by-byte
# However, if _preload_content=False is passed, then the entire WSClient is returned instead
# of a response, and we want to leave it alone
if binary and kwargs.get('_preload_content', True):
out = out[2:-1].encode('latin1').decode('unicode_escape').encode('latin1')
return out
finally:
api_client.request = prev_request

Expand Down
29 changes: 19 additions & 10 deletions kubernetes/base/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
import six
import yaml


from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO
from six import StringIO, BytesIO

from websocket import WebSocket, ABNF, enableTrace
from base64 import urlsafe_b64decode
Expand All @@ -48,7 +49,7 @@ def getvalue(self):


class WSClient:
def __init__(self, configuration, url, headers, capture_all):
def __init__(self, configuration, url, headers, capture_all, binary=False):
"""A websocket client with support for channels.
Exec command uses different channels for different streams. for
Expand All @@ -58,8 +59,10 @@ def __init__(self, configuration, url, headers, capture_all):
"""
self._connected = False
self._channels = {}
self.binary = binary
self.newline = '\n' if not self.binary else b'\n'
if capture_all:
self._all = StringIO()
self._all = StringIO() if not self.binary else BytesIO()
else:
self._all = _IgnoredIO()
self.sock = create_websocket(configuration, url, headers)
Expand Down Expand Up @@ -92,8 +95,8 @@ def readline_channel(self, channel, timeout=None):
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
if self.newline in data:
index = data.find(self.newline)
ret = data[:index]
data = data[index+1:]
if data:
Expand Down Expand Up @@ -197,10 +200,12 @@ def update(self, timeout=0):
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
if six.PY3 and not self.binary:
data = data.decode("utf-8", "replace")
if len(data) > 1:
channel = ord(data[0])
channel = data[0]
if six.PY3 and not self.binary:
channel = ord(channel)
data = data[1:]
if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
Expand Down Expand Up @@ -518,13 +523,17 @@ def websocket_call(configuration, _method, url, **kwargs):
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)

binary = kwargs.get('binary', False)
try:
client = WSClient(configuration, url, headers, capture_all)
client = WSClient(configuration, url, headers, capture_all, binary=binary)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
all = client.read_all()
if binary:
return WSResponse(all)
else:
return WSResponse('%s' % ''.join(all))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

Expand Down
45 changes: 43 additions & 2 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import unittest
import uuid
import six
import io
import gzip

from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
Expand Down Expand Up @@ -118,15 +120,28 @@ def test_pod_apis(self):
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, str)
self.assertEqual(3, len(resp.splitlines()))

exec_command = ['/bin/sh',
'-c',
'echo -n "This is a test string" | gzip']
resp = stream(api.connect_get_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False,
binary=True)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, bytes)
self.assertEqual("This is a test string", gzip.decompress(resp).decode('utf-8'))

exec_command = 'uptime'
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s' % repr(resp))
self.assertEqual(1, len(resp.splitlines()))

resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
Expand Down Expand Up @@ -154,6 +169,32 @@ def test_pod_apis(self):
resp.update(timeout=5)
self.assertFalse(resp.is_open())

resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
binary=True,
_preload_content=False)
resp.write_stdin(b"echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual(b"test string 1", line)
resp.write_stdin(b"echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual(b"test string 2", line)
resp.write_stdin(b"exit\n")
resp.update(timeout=5)
while True:
line = resp.read_channel(ERROR_CHANNEL)
if len(line) != 0:
break
time.sleep(1)
status = json.loads(line)
self.assertEqual(status['status'], 'Success')
resp.update(timeout=5)
self.assertFalse(resp.is_open())

number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)

Expand Down

0 comments on commit 488518d

Please sign in to comment.