Skip to content

Commit

Permalink
Add type annotations, refactor sync/async (#623)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart authored Mar 26, 2021
1 parent 50dff2e commit 8c81bd4
Show file tree
Hide file tree
Showing 22 changed files with 1,495 additions and 1,695 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
pip install --upgrade setuptools pip
pip install --upgrade --upgrade-strategy eager -e .[test] pytest-cov codecov 'coverage<5'
pip freeze
- name: Check types
run: mypy jupyter_client/manager.py jupyter_client/multikernelmanager.py jupyter_client/client.py jupyter_client/blocking/client.py jupyter_client/asynchronous/client.py jupyter_client/channels.py jupyter_client/session.py jupyter_client/adapter.py jupyter_client/connect.py jupyter_client/consoleapp.py jupyter_client/jsonutil.py jupyter_client/kernelapp.py jupyter_client/launcher.py
- name: Run the tests
run: py.test --cov jupyter_client -v jupyter_client
- name: Code coverage
Expand Down
163 changes: 128 additions & 35 deletions jupyter_client/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

import re
import json
from typing import List, Tuple, Dict, Any

from jupyter_client import protocol_version_info

def code_to_line(code, cursor_pos):
def code_to_line(
code: str,
cursor_pos: int
) -> Tuple[str, int]:
"""Turn a multiline code block and cursor position into a single line
and new cursor position.
Expand All @@ -29,14 +33,17 @@ def code_to_line(code, cursor_pos):
_end_bracket = re.compile(r'\([^\(]*$', re.UNICODE)
_identifier = re.compile(r'[a-z_][0-9a-z._]*', re.I|re.UNICODE)

def extract_oname_v4(code, cursor_pos):
def extract_oname_v4(
code: str,
cursor_pos: int
) -> str:
"""Reimplement token-finding logic from IPython 2.x javascript
for adapting object_info_request from v5 to v4
"""

line, _ = code_to_line(code, cursor_pos)

oldline = line
line = _match_bracket.sub('', line)
while oldline != line:
Expand All @@ -58,29 +65,44 @@ class Adapter(object):
Override message_type(msg) methods to create adapters.
"""

msg_type_map = {}
msg_type_map: Dict[str, str] = {}

def update_header(self, msg):
def update_header(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
return msg

def update_metadata(self, msg):
def update_metadata(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
return msg

def update_msg_type(self, msg):
def update_msg_type(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
header = msg['header']
msg_type = header['msg_type']
if msg_type in self.msg_type_map:
msg['msg_type'] = header['msg_type'] = self.msg_type_map[msg_type]
return msg

def handle_reply_status_error(self, msg):
def handle_reply_status_error(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
"""This will be called *instead of* the regular handler
on any reply with status != ok
"""
return msg

def __call__(self, msg):
def __call__(
self,
msg: Dict[str, Any]
):
msg = self.update_header(msg)
msg = self.update_metadata(msg)
msg = self.update_msg_type(msg)
Expand All @@ -95,7 +117,9 @@ def __call__(self, msg):
return self.handle_reply_status_error(msg)
return handler(msg)

def _version_str_to_list(version):
def _version_str_to_list(
version: str
) -> List[int]:
"""convert a version string to a list of ints
non-int segments are excluded
Expand All @@ -121,14 +145,20 @@ class V5toV4(Adapter):
'inspect_reply' : 'object_info_reply',
}

def update_header(self, msg):
def update_header(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
msg['header'].pop('version', None)
msg['parent_header'].pop('version', None)
return msg

# shell channel

def kernel_info_reply(self, msg):
def kernel_info_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
v4c = {}
content = msg['content']
for key in ('language_version', 'protocol_version'):
Expand All @@ -145,18 +175,27 @@ def kernel_info_reply(self, msg):
msg['content'] = v4c
return msg

def execute_request(self, msg):
def execute_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content.setdefault('user_variables', [])
return msg

def execute_reply(self, msg):
def execute_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content.setdefault('user_variables', {})
# TODO: handle payloads
return msg

def complete_request(self, msg):
def complete_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
code = content['code']
cursor_pos = content['cursor_pos']
Expand All @@ -169,7 +208,10 @@ def complete_request(self, msg):
new_content['cursor_pos'] = cursor_pos
return msg

def complete_reply(self, msg):
def complete_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
cursor_start = content.pop('cursor_start')
cursor_end = content.pop('cursor_end')
Expand All @@ -178,7 +220,10 @@ def complete_reply(self, msg):
content.pop('metadata', None)
return msg

def object_info_request(self, msg):
def object_info_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
code = content['code']
cursor_pos = content['cursor_pos']
Expand All @@ -189,19 +234,28 @@ def object_info_request(self, msg):
new_content['detail_level'] = content['detail_level']
return msg

def object_info_reply(self, msg):
def object_info_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
msg['content'] = {'found' : False, 'oname' : 'unknown'}
return msg

# iopub channel

def stream(self, msg):
def stream(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content['data'] = content.pop('text')
return msg

def display_data(self, msg):
def display_data(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content.setdefault("source", "display")
data = content['data']
Expand All @@ -215,7 +269,10 @@ def display_data(self, msg):

# stdin channel

def input_request(self, msg):
def input_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
msg['content'].pop('password', None)
return msg

Expand All @@ -227,15 +284,21 @@ class V4toV5(Adapter):
# invert message renames above
msg_type_map = {v:k for k,v in V5toV4.msg_type_map.items()}

def update_header(self, msg):
def update_header(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
msg['header']['version'] = self.version
if msg['parent_header']:
msg['parent_header']['version'] = self.version
return msg

# shell channel

def kernel_info_reply(self, msg):
def kernel_info_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
for key in ('protocol_version', 'ipython_version'):
if key in content:
Expand All @@ -257,15 +320,21 @@ def kernel_info_reply(self, msg):
content['banner'] = ''
return msg

def execute_request(self, msg):
def execute_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
user_variables = content.pop('user_variables', [])
user_expressions = content.setdefault('user_expressions', {})
for v in user_variables:
user_expressions[v] = v
return msg

def execute_reply(self, msg):
def execute_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
user_expressions = content.setdefault('user_expressions', {})
user_variables = content.pop('user_variables', {})
Expand All @@ -281,15 +350,21 @@ def execute_reply(self, msg):

return msg

def complete_request(self, msg):
def complete_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
old_content = msg['content']

new_content = msg['content'] = {}
new_content['code'] = old_content['line']
new_content['cursor_pos'] = old_content['cursor_pos']
return msg

def complete_reply(self, msg):
def complete_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
# complete_reply needs more context than we have to get cursor_start and end.
# use special end=null to indicate current cursor position and negative offset
# for start relative to the cursor.
Expand All @@ -306,7 +381,10 @@ def complete_reply(self, msg):
new_content['metadata'] = {}
return msg

def inspect_request(self, msg):
def inspect_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
name = content['oname']

Expand All @@ -316,7 +394,10 @@ def inspect_request(self, msg):
new_content['detail_level'] = content['detail_level']
return msg

def inspect_reply(self, msg):
def inspect_reply(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
"""inspect_reply can't be easily backward compatible"""
content = msg['content']
new_content = msg['content'] = {'status' : 'ok'}
Expand All @@ -340,12 +421,18 @@ def inspect_reply(self, msg):

# iopub channel

def stream(self, msg):
def stream(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content['text'] = content.pop('data')
return msg

def display_data(self, msg):
def display_data(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
content = msg['content']
content.pop("source", None)
data = content['data']
Expand All @@ -359,13 +446,19 @@ def display_data(self, msg):

# stdin channel

def input_request(self, msg):
def input_request(
self,
msg: Dict[str, Any]
) -> Dict[str, Any]:
msg['content'].setdefault('password', False)
return msg



def adapt(msg, to_version=protocol_version_info[0]):
def adapt(
msg: Dict[str, Any],
to_version: int =protocol_version_info[0]
) -> Dict[str, Any]:
"""Adapt a single message to a target version
Parameters
Expand Down
Loading

0 comments on commit 8c81bd4

Please sign in to comment.