Skip to content
This repository has been archived by the owner on Mar 13, 2022. It is now read-only.

stream+ws_client: Optional capture-all #178

Merged
merged 2 commits into from
Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


def stream(func, *args, **kwargs):
"""Stream given API call using websocket"""
"""Stream given API call using websocket.
Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()"""

def _intercept_request_call(*args, **kwargs):
# old generated code's api client has config. new ones has
Expand Down
24 changes: 18 additions & 6 deletions stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import yaml

from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
from six import StringIO

from websocket import WebSocket, ABNF, enableTrace

Expand All @@ -33,9 +34,16 @@
ERROR_CHANNEL = 3
RESIZE_CHANNEL = 4

class _IgnoredIO:
def write(self, _x):
pass

def getvalue(self):
raise TypeError("Tried to read_all() from a WSClient configured to not capture. Did you mean `capture_all=True`?")


class WSClient:
def __init__(self, configuration, url, headers):
def __init__(self, configuration, url, headers, capture_all):
"""A websocket client with support for channels.

Exec command uses different channels for different streams. for
Expand All @@ -47,7 +55,10 @@ def __init__(self, configuration, url, headers):
header = []
self._connected = False
self._channels = {}
self._all = ""
if capture_all:
self._all = StringIO()
else:
self._all = _IgnoredIO()

# We just need to pass the Authorization, ignore all the other
# http headers we get from the generated code
Expand Down Expand Up @@ -157,8 +168,8 @@ def read_all(self):
TODO: Maybe we can process this and return a more meaningful map with
channels mapped for each input.
"""
out = self._all
self._all = ""
out = self._all.getvalue()
self._all = self._all.__class__()
self._channels = {}
return out

Expand Down Expand Up @@ -195,7 +206,7 @@ def update(self, timeout=0):
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
# keeping all messages in the order they received
# for non-blocking call.
self._all += data
self._all.write(data)
if channel not in self._channels:
self._channels[channel] = data
else:
Expand Down Expand Up @@ -257,6 +268,7 @@ def websocket_call(configuration, *args, **kwargs):
url = args[1]
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)
headers = kwargs.get("headers")

# Expand command parameter list to indivitual command params
Expand All @@ -272,7 +284,7 @@ def websocket_call(configuration, *args, **kwargs):
url += '?' + urlencode(query_params)

try:
client = WSClient(configuration, get_websocket_url(url), headers)
client = WSClient(configuration, get_websocket_url(url), headers, capture_all)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
Expand Down