diff --git a/stream/stream.py b/stream/stream.py index a9d0b402..6d5f05f8 100644 --- a/stream/stream.py +++ b/stream/stream.py @@ -16,7 +16,8 @@ def stream(func, *args, **kwargs): - """Stream given API call using websocket""" + """Stream given API call using websocket. + Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()""" def _intercept_request_call(*args, **kwargs): # old generated code's api client has config. new ones has diff --git a/stream/ws_client.py b/stream/ws_client.py index 0823abbd..9d017512 100644 --- a/stream/ws_client.py +++ b/stream/ws_client.py @@ -34,9 +34,16 @@ ERROR_CHANNEL = 3 RESIZE_CHANNEL = 4 +class _IgnoredIO: + def write(self, _x): + pass + + def getvalue(self): + raise TypeError("Tried to read_all() from a WSClient configured to not capture. Did you mean `capture_all=True`?") + class WSClient: - def __init__(self, configuration, url, headers): + def __init__(self, configuration, url, headers, capture_all): """A websocket client with support for channels. Exec command uses different channels for different streams. for @@ -48,7 +55,10 @@ def __init__(self, configuration, url, headers): header = [] self._connected = False self._channels = {} - self._all = StringIO() + if capture_all: + self._all = StringIO() + else: + self._all = _IgnoredIO() # We just need to pass the Authorization, ignore all the other # http headers we get from the generated code @@ -159,7 +169,7 @@ def read_all(self): channels mapped for each input. """ out = self._all.getvalue() - self._all = StringIO() + self._all = type(self._all)() self._channels = {} return out @@ -258,6 +268,7 @@ def websocket_call(configuration, *args, **kwargs): url = args[1] _request_timeout = kwargs.get("_request_timeout", 60) _preload_content = kwargs.get("_preload_content", True) + capture_all = kwargs.get("capture_all", True) headers = kwargs.get("headers") # Expand command parameter list to indivitual command params @@ -273,7 +284,7 @@ def websocket_call(configuration, *args, **kwargs): url += '?' + urlencode(query_params) try: - client = WSClient(configuration, get_websocket_url(url), headers) + client = WSClient(configuration, get_websocket_url(url), headers, capture_all) if not _preload_content: return client client.run_forever(timeout=_request_timeout)