diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f7b174cef3..447004dfa4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,13 @@ repos: files: \.py$ args: [--profile=black] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.942 + hooks: + - id: mypy + exclude: examples/simple/setup.py + additional_dependencies: [types-requests] + - repo: https://github.com/pre-commit/mirrors-prettier rev: v2.6.2 hooks: diff --git a/MANIFEST.in b/MANIFEST.in index 5ec49b2ec5..d5cd493a2e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ include README.md include RELEASE.md include CHANGELOG.md include package.json +include jupyter_server/py.typed # include everything in package_data recursive-include jupyter_server * diff --git a/docs/source/conf.py b/docs/source/conf.py index c78443f012..5bf33d15de 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,7 @@ import shutil import sys -from pkg_resources import parse_version +from packaging.version import parse as parse_version HERE = osp.abspath(osp.dirname(__file__)) diff --git a/examples/authorization/jupyter_nbclassic_readonly_config.py b/examples/authorization/jupyter_nbclassic_readonly_config.py index 292644c284..7fa83b17be 100644 --- a/examples/authorization/jupyter_nbclassic_readonly_config.py +++ b/examples/authorization/jupyter_nbclassic_readonly_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = ReadOnly +c.ServerApp.authorizer_class = ReadOnly # type:ignore[name-defined] diff --git a/examples/authorization/jupyter_nbclassic_rw_config.py b/examples/authorization/jupyter_nbclassic_rw_config.py index 261efcf984..c56c6dcc8f 100644 --- a/examples/authorization/jupyter_nbclassic_rw_config.py +++ b/examples/authorization/jupyter_nbclassic_rw_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = ReadWriteOnly +c.ServerApp.authorizer_class = ReadWriteOnly # type:ignore[name-defined] diff --git a/examples/authorization/jupyter_temporary_config.py b/examples/authorization/jupyter_temporary_config.py index e1bd2fb507..d19b5f74df 100644 --- a/examples/authorization/jupyter_temporary_config.py +++ b/examples/authorization/jupyter_temporary_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = TemporaryServerPersonality +c.ServerApp.authorizer_class = TemporaryServerPersonality # type:ignore[name-defined] diff --git a/examples/simple/jupyter_server_config.py b/examples/simple/jupyter_server_config.py index 723d6cdadb..4e3a70049e 100644 --- a/examples/simple/jupyter_server_config.py +++ b/examples/simple/jupyter_server_config.py @@ -3,4 +3,6 @@ # Application(SingletonConfigurable) configuration # ------------------------------------------------------------------------------ # The date format used by logging formatters for %(asctime)s -c.Application.log_datefmt = "%Y-%m-%d %H:%M:%S Simple_Extensions_Example" +c.Application.log_datefmt = ( # type:ignore[name-defined] + "%Y-%m-%d %H:%M:%S Simple_Extensions_Example" +) diff --git a/examples/simple/jupyter_simple_ext11_config.py b/examples/simple/jupyter_simple_ext11_config.py index d2baa1360a..b1035b8746 100644 --- a/examples/simple/jupyter_simple_ext11_config.py +++ b/examples/simple/jupyter_simple_ext11_config.py @@ -1 +1 @@ -c.SimpleApp11.ignore_js = True +c.SimpleApp11.ignore_js = True # type:ignore[name-defined] diff --git a/examples/simple/jupyter_simple_ext1_config.py b/examples/simple/jupyter_simple_ext1_config.py index f40b66afaf..5e32346335 100644 --- a/examples/simple/jupyter_simple_ext1_config.py +++ b/examples/simple/jupyter_simple_ext1_config.py @@ -1,4 +1,4 @@ -c.SimpleApp1.configA = "ConfigA from file" -c.SimpleApp1.configB = "ConfigB from file" -c.SimpleApp1.configC = "ConfigC from file" -c.SimpleApp1.configD = "ConfigD from file" +c.SimpleApp1.configA = "ConfigA from file" # type:ignore[name-defined] +c.SimpleApp1.configB = "ConfigB from file" # type:ignore[name-defined] +c.SimpleApp1.configC = "ConfigC from file" # type:ignore[name-defined] +c.SimpleApp1.configD = "ConfigD from file" # type:ignore[name-defined] diff --git a/examples/simple/jupyter_simple_ext2_config.py b/examples/simple/jupyter_simple_ext2_config.py index f145cbb87a..d5faa9e942 100644 --- a/examples/simple/jupyter_simple_ext2_config.py +++ b/examples/simple/jupyter_simple_ext2_config.py @@ -1 +1 @@ -c.SimpleApp2.configD = "ConfigD from file" +c.SimpleApp2.configD = "ConfigD from file" # type:ignore[name-defined] diff --git a/jupyter_server/__init__.py b/jupyter_server/__init__.py index d5b97f0c90..199ff6b0e3 100644 --- a/jupyter_server/__init__.py +++ b/jupyter_server/__init__.py @@ -13,7 +13,9 @@ del os -from ._version import __version__, version_info # noqa +from ._version import __version__, version_info + +__all__ = ["__version__", "version_info"] def _cleanup(): diff --git a/jupyter_server/auth/__init__.py b/jupyter_server/auth/__init__.py index 54477ffd1b..a05b8b095f 100644 --- a/jupyter_server/auth/__init__.py +++ b/jupyter_server/auth/__init__.py @@ -1,3 +1,5 @@ from .authorizer import * # noqa from .decorator import authorized # noqa from .security import passwd # noqa + +__all__ = ["authorized", "passwd"] diff --git a/jupyter_server/auth/decorator.py b/jupyter_server/auth/decorator.py index 72a489dbe9..8a09f5b4ec 100644 --- a/jupyter_server/auth/decorator.py +++ b/jupyter_server/auth/decorator.py @@ -3,19 +3,21 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. from functools import wraps -from typing import Callable, Optional, Union +from typing import Any, Callable, Optional, TypeVar, Union, cast from tornado.log import app_log from tornado.web import HTTPError from .utils import HTTP_METHOD_TO_AUTH_ACTION, warn_disabled_authorization +T = TypeVar("T", bound=Callable[..., Any]) + def authorized( - action: Optional[Union[str, Callable]] = None, + action: Optional[Union[str, Callable[..., Any]]] = None, resource: Optional[str] = None, message: Optional[str] = None, -) -> Callable: +) -> Callable[..., Any]: """A decorator for tornado.web.RequestHandler methods that verifies whether the current user is authorized to make the following request. @@ -38,7 +40,7 @@ def authorized( a message for the unauthorized action. """ - def wrapper(method): + def wrapper(method: T) -> T: @wraps(method) def inner(self, *args, **kwargs): # default values for action, resource @@ -70,7 +72,7 @@ def inner(self, *args, **kwargs): # Raise an exception if the method wasn't returned (i.e. not authorized) raise HTTPError(status_code=403, log_message=message) - return inner + return cast(T, inner) if callable(action): method = action diff --git a/jupyter_server/auth/login.py b/jupyter_server/auth/login.py index 382077d9e0..c81a060355 100644 --- a/jupyter_server/auth/login.py +++ b/jupyter_server/auth/login.py @@ -83,7 +83,7 @@ def post(self): elif self.token and self.token == typed_password: self.set_login_cookie(self, uuid.uuid4().hex) if new_password and self.settings.get("allow_password_change"): - config_dir = self.settings.get("config_dir") + config_dir = self.settings.get("config_dir", "") config_file = os.path.join(config_dir, "jupyter_server_config.json") set_password(new_password, config_file=config_file) self.log.info("Wrote hashed password to %s" % config_file) diff --git a/jupyter_server/auth/security.py b/jupyter_server/auth/security.py index fa7dded7fb..219687e1ae 100644 --- a/jupyter_server/auth/security.py +++ b/jupyter_server/auth/security.py @@ -64,9 +64,9 @@ def passwd(passphrase=None, algorithm="argon2"): time_cost=10, parallelism=8, ) - h = ph.hash(passphrase) + h_ph = ph.hash(passphrase) - return ":".join((algorithm, h)) + return ":".join((algorithm, h_ph)) h = hashlib.new(algorithm) salt = ("%0" + str(salt_len) + "x") % random.getrandbits(4 * salt_len) diff --git a/jupyter_server/auth/utils.py b/jupyter_server/auth/utils.py index b939b87ae0..76cd9f18ba 100644 --- a/jupyter_server/auth/utils.py +++ b/jupyter_server/auth/utils.py @@ -44,9 +44,9 @@ def get_regex_to_resource_map(): from jupyter_server.serverapp import JUPYTER_SERVICE_HANDLERS modules = [] - for mod in JUPYTER_SERVICE_HANDLERS.values(): - if mod: - modules.extend(mod) + for mod_name in JUPYTER_SERVICE_HANDLERS.values(): + if mod_name: + modules.extend(mod_name) resource_map = {} for handler_module in modules: mod = importlib.import_module(handler_module) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index 42f7fb3d5e..1bcd4acb50 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -10,6 +10,7 @@ import re import traceback import types +import typing as t import warnings from http.client import responses from http.cookies import Morsel @@ -114,7 +115,7 @@ def force_clear_cookie(self, name, path="/", domain=None): name = escape.native_str(name) expires = datetime.datetime.utcnow() - datetime.timedelta(days=365) - morsel = Morsel() + morsel: Morsel[str] = Morsel() morsel.set(name, "", '""') morsel["expires"] = httputil.format_timestamp(expires) morsel["path"] = path @@ -241,8 +242,8 @@ def mathjax_config(self): return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe") @property - def base_url(self): - return self.settings.get("base_url", "/") + def base_url(self) -> str: + return self.settings.get("base_url", "/") # type:ignore[no-any-return] @property def default_url(self): @@ -476,7 +477,9 @@ def check_host(self): return True # Remove port (e.g. ':8888') from host - host = re.match(r"^(.*?)(:\d+)?$", self.request.host).group(1) + match = re.match(r"^(.*?)(:\d+)?$", self.request.host) + assert match is not None + host = match.group(1) # Browsers format IPv6 addresses like [::1]; we need to remove the [] if host.startswith("[") and host.endswith("]"): @@ -567,7 +570,7 @@ def write_error(self, status_code, **kwargs): exc_info = kwargs.get("exc_info") message = "" status_message = responses.get(status_code, "Unknown HTTP Error") - exception = "(unknown)" + if exc_info: exception = exc_info[1] # get the custom message, if defined @@ -580,6 +583,8 @@ def write_error(self, status_code, **kwargs): reason = getattr(exception, "reason", "") if reason: status_message = reason + else: + exception = "(unknown)" # build template namespace ns = dict( @@ -602,6 +607,8 @@ def write_error(self, status_code, **kwargs): class APIHandler(JupyterHandler): """Base class for API handlers""" + _user_cache: str + def prepare(self): if not self.check_origin(): raise web.HTTPError(404) @@ -611,7 +618,7 @@ def write_error(self, status_code, **kwargs): """APIHandler errors are JSON, not human pages""" self.set_header("Content-Type", "application/json") message = responses.get(status_code, "Unknown HTTP Error") - reply = { + reply: t.Dict[str, t.Any] = { "message": message, } exc_info = kwargs.get("exc_info") @@ -627,13 +634,13 @@ def write_error(self, status_code, **kwargs): self.log.warning(reply["message"]) self.finish(json.dumps(reply)) - def get_current_user(self): + def get_current_user(self) -> str: """Raise 403 on API handlers instead of redirecting to human login page""" # preserve _user_cache so we don't raise more than once if hasattr(self, "_user_cache"): return self._user_cache self._user_cache = user = super().get_current_user() - return user + return t.cast(str, user) def get_login_url(self): # if get_login_url is invoked in an API handler, @@ -733,13 +740,14 @@ def head(self, path): @web.authenticated def get(self, path): - if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", False): + if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", None): name = path.rsplit("/", 1)[-1] self.set_attachment_header(name) return web.StaticFileHandler.get(self, path) def get_content_type(self): + assert self.absolute_path is not None path = self.absolute_path.strip("/") if "/" in path: _, name = path.rsplit("/", 1) @@ -818,7 +826,8 @@ class FileFindHandler(JupyterHandler, web.StaticFileHandler): """subclass of StaticFileHandler for serving files from a search path""" # cache search results, don't search for files more than once - _static_paths = {} + _static_paths: t.Dict[str, str] = {} + root: t.Any def set_headers(self): super().set_headers() @@ -882,6 +891,7 @@ class TrailingSlashHandler(web.RequestHandler): """ def get(self): + assert self.request.uri is not None path, *rest = self.request.uri.partition("?") # trim trailing *and* leading / # to avoid misinterpreting repeated '//' diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index 28e296c722..d26688b6af 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -5,6 +5,7 @@ import re import struct import sys +import typing as t from urllib.parse import urlparse import tornado @@ -17,6 +18,7 @@ from jupyter_client.jsonutil import extract_dates from jupyter_client.session import Session from tornado import ioloop, web +from tornado.iostream import IOStream from tornado.websocket import WebSocketHandler from jupyter_server.auth.utils import warn_disabled_authorization @@ -93,7 +95,7 @@ def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): else: msg_list = msg_or_list channel = channel.encode("utf-8") - offsets = [] + offsets: t.List[t.Any] = [] offsets.append(8 * (1 + 1 + len(msg_list) + 1)) offsets.append(len(channel) + offsets[-1]) for msg in msg_list: @@ -122,9 +124,9 @@ class WebSocketMixin: """Mixin for common websocket options""" ping_callback = None - last_ping = 0 - last_pong = 0 - stream = None + last_ping = 0.0 + last_pong = 0.0 + stream = None # type: t.Optional[IOStream] @property def ping_interval(self): @@ -132,7 +134,7 @@ def ping_interval(self): Set ws_ping_interval = 0 to disable pings. """ - return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) + return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] @property def ping_timeout(self): @@ -140,9 +142,12 @@ def ping_timeout(self): close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. """ - return self.settings.get("ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)) + return self.settings.get( # type:ignore[attr-defined] + "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) + ) - def check_origin(self, origin=None): + @t.no_type_check + def check_origin(self, origin: t.Optional[str] = None) -> bool: """Check Origin == Host or Access-Control-Allow-Origin. Tornado >= 4 calls this method automatically, raising 403 if it returns False. @@ -188,6 +193,7 @@ def clear_cookie(self, *args, **kwargs): """meaningless for websockets""" pass + @t.no_type_check def open(self, *args, **kwargs): self.log.debug("Opening websocket %s", self.request.path) @@ -203,6 +209,7 @@ def open(self, *args, **kwargs): self.ping_callback.start() return super().open(*args, **kwargs) + @t.no_type_check def send_ping(self): """send a ping to keep the websocket alive""" if self.ws_connection is None and self.ping_callback is not None: @@ -327,7 +334,7 @@ def pre_get(self): elif not self.authorizer.is_authorized(self, user, "execute", "kernels"): raise web.HTTPError(403) - if self.get_argument("session_id", False): + if self.get_argument("session_id", None): self.session.session = self.get_argument("session_id") else: self.log.warning("No session ID specified") diff --git a/jupyter_server/config_manager.py b/jupyter_server/config_manager.py index 25c5efd28f..8bd8506a3e 100644 --- a/jupyter_server/config_manager.py +++ b/jupyter_server/config_manager.py @@ -6,6 +6,7 @@ import glob import json import os +import typing as t from traitlets.config import LoggingConfigurable from traitlets.traitlets import Bool, Unicode @@ -95,7 +96,7 @@ def get(self, section_name, include_root=True): section_name, "\n\t".join(paths), ) - data = {} + data: t.Dict[str, t.Any] = {} for path in paths: if os.path.isfile(path): with open(path, encoding="utf-8") as f: diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index 167f6dd94e..27f51337eb 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -1,6 +1,7 @@ import logging import re import sys +import typing as t from jinja2 import Environment, FileSystemLoader from jupyter_core.application import JupyterApp, NoStart @@ -137,7 +138,7 @@ class method. This method can be set as a entry_point in # A useful class property that subclasses can override to # configure the underlying Jupyter Server when this extension # is launched directly (using its `launch_instance` method). - serverapp_config = {} + serverapp_config: t.Dict[str, t.Any] = {} # Some subclasses will likely override this trait to flip # the default value to False if they don't offer a browser @@ -165,7 +166,7 @@ def config_file_paths(self): # file, jupyter_{name}_config. # This should also match the jupyter subcommand used to launch # this extension from the CLI, e.g. `jupyter {name}`. - name = None + name = "ExtensionApp" @classmethod def get_extension_package(cls): @@ -318,7 +319,7 @@ def _prepare_handlers(self): handler = handler_items[1] # Get handler kwargs, if given - kwargs = {} + kwargs: t.Dict[str, t.Any] = {} if issubclass(handler, ExtensionHandlerMixin): kwargs["name"] = self.name diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py index 164d74bb15..a68159eb2b 100644 --- a/jupyter_server/extension/handler.py +++ b/jupyter_server/extension/handler.py @@ -1,3 +1,5 @@ +from typing import no_type_check + from jinja2.exceptions import TemplateNotFound from jupyter_server.base.handlers import FileFindHandler @@ -8,6 +10,7 @@ class ExtensionHandlerJinjaMixin: template rendering. """ + @no_type_check def get_template(self, name): """Return the jinja template object for a given name""" try: @@ -33,17 +36,17 @@ def initialize(self, name): @property def extensionapp(self): - return self.settings[self.name] + return self.settings[self.name] # type:ignore[attr-defined] @property def serverapp(self): key = "serverapp" - return self.settings[key] + return self.settings[key] # type:ignore[attr-defined] @property def log(self): if not hasattr(self, "name"): - return super().log + return super().log # type:ignore[misc] # Attempt to pull the ExtensionApp's log, otherwise fall back to ServerApp. try: return self.extensionapp.log @@ -52,15 +55,15 @@ def log(self): @property def config(self): - return self.settings[f"{self.name}_config"] + return self.settings[f"{self.name}_config"] # type:ignore[attr-defined] @property def server_config(self): - return self.settings["config"] + return self.settings["config"] # type:ignore[attr-defined] @property - def base_url(self): - return self.settings.get("base_url", "/") + def base_url(self) -> str: + return self.settings.get("base_url", "/") # type:ignore @property def static_url_prefix(self): @@ -68,7 +71,7 @@ def static_url_prefix(self): @property def static_path(self): - return self.settings[f"{self.name}_static_paths"] + return self.settings[f"{self.name}_static_paths"] # type:ignore[attr-defined] def static_url(self, path, include_host=None, **kwargs): """Returns a static URL for the given relative static file path. @@ -89,9 +92,9 @@ def static_url(self, path, include_host=None, **kwargs): """ key = f"{self.name}_static_paths" try: - self.require_setting(key, "static_url") + self.require_setting(key, "static_url") # type:ignore[attr-defined] except Exception as e: - if key in self.settings: + if key in self.settings: # type:ignore[attr-defined] raise Exception( "This extension doesn't have any static paths listed. Check that the " "extension's `static_paths` trait is set." @@ -99,13 +102,15 @@ def static_url(self, path, include_host=None, **kwargs): else: raise e - get_url = self.settings.get("static_handler_class", FileFindHandler).make_static_url + get_url = self.settings.get( # type:ignore[attr-defined] + "static_handler_class", FileFindHandler + ).make_static_url if include_host is None: include_host = getattr(self, "include_host", False) if include_host: - base = self.request.protocol + "://" + self.request.host + base = self.request.protocol + "://" + self.request.host # type:ignore[attr-defined] else: base = "" diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py index 1efb2cadd0..0a96737d70 100644 --- a/jupyter_server/extension/manager.py +++ b/jupyter_server/extension/manager.py @@ -1,6 +1,7 @@ import importlib import sys import traceback +import typing as t from tornado.gen import multi from traitlets import Any, Bool, Dict, HasTraits, Instance, Unicode, default, observe @@ -167,7 +168,7 @@ def __init__(self, *args, **kwargs): self._linked_points = {} super().__init__(*args, **kwargs) - _linked_points = {} + _linked_points: t.Dict[str, t.Any] = {} @validate_trait("name") def _validate_name(self, proposed): diff --git a/jupyter_server/files/handlers.py b/jupyter_server/files/handlers.py index c76fdc28d3..de60117324 100644 --- a/jupyter_server/files/handlers.py +++ b/jupyter_server/files/handlers.py @@ -4,6 +4,7 @@ import json import mimetypes from base64 import decodebytes +from typing import List from tornado import web @@ -57,7 +58,7 @@ async def get(self, path, include_body=True): model = await ensure_async(cm.get(path, type="file", content=include_body)) - if self.get_argument("download", False): + if self.get_argument("download", None): self.set_attachment_header(name) # get mimetype from filename @@ -91,4 +92,4 @@ async def get(self, path, include_body=True): self.flush() -default_handlers = [] +default_handlers: List[JupyterHandler] = [] diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index a36f2d4faf..24da860215 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -5,6 +5,7 @@ import mimetypes import os import random +import typing as t from jupyter_client.session import Session from tornado import web @@ -17,7 +18,7 @@ from ..base.handlers import APIHandler, JupyterHandler from ..utils import url_path_join -from .managers import GatewayClient +from .gateway_client import GatewayClient # Keepalive ping interval (default: 30 seconds) GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", 30)) @@ -52,7 +53,8 @@ def authenticate(self): self.log.warning("Couldn't authenticate WebSocket connection") raise web.HTTPError(403) - if self.get_argument("session_id", False): + if self.get_argument("session_id", None): + assert self.session is not None self.session.session = self.get_argument("session_id") else: self.log.warning("No session ID specified") @@ -79,6 +81,7 @@ def open(self, kernel_id, *args, **kwargs): self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000) self.ping_callback.start() + assert self.gateway is not None self.gateway.on_open( kernel_id=kernel_id, message_callback=self.write_message, @@ -87,6 +90,7 @@ def open(self, kernel_id, *args, **kwargs): def on_message(self, message): """Forward message to gateway web socket handler.""" + assert self.gateway is not None self.gateway.on_message(message) def write_message(self, message, binary=False): @@ -105,6 +109,7 @@ def write_message(self, message, binary=False): def on_close(self): self.log.debug("Closing websocket connection %s", self.request.path) + assert self.gateway is not None self.gateway.on_close() super().on_close() @@ -137,7 +142,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.kernel_id = None self.ws = None - self.ws_future = Future() + self.ws_future: Future = Future() self.disconnected = False self.retry = 0 @@ -152,7 +157,7 @@ async def _connect(self, kernel_id, message_callback): "channels", ) self.log.info(f"Connecting to {ws_url}") - kwargs = {} + kwargs: t.Dict[str, t.Any] = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) @@ -269,7 +274,8 @@ async def get(self, kernel_name, path, include_body=True): " resource serving.".format(path, kernel_name) ) else: - self.set_header("Content-Type", mimetypes.guess_type(path)[0]) + mimetype = mimetypes.guess_type(path)[0] or "text/plain" + self.set_header("Content-Type", mimetype) self.finish(kernel_spec_res) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 73215eb90b..aec9307dce 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -3,10 +3,10 @@ import datetime import json import os +import typing as t from logging import Logger from queue import Queue from threading import Thread -from typing import Dict import websocket from jupyter_client.asynchronous.client import AsyncKernelClient @@ -29,7 +29,7 @@ class GatewayMappingKernelManager(AsyncMappingKernelManager): """Kernel manager that supports remote kernels hosted by Jupyter Kernel or Enterprise Gateway.""" # We'll maintain our own set of kernel ids - _kernels: Dict[str, "GatewayKernelManager"] = {} + _kernels: t.Dict[str, "GatewayKernelManager"] = {} @default("kernel_manager_class") def _default_kernel_manager_class(self): @@ -453,6 +453,7 @@ async def shutdown_kernel(self, now=False, restart=False): async def restart_kernel(self, **kw): """Restarts a kernel via HTTP.""" if self.has_kernel: + assert self.kernel_url is not None kernel_url = self.kernel_url + "/restart" self.log.debug("Request restart kernel at: %s", kernel_url) response = await gateway_request(kernel_url, method="POST", body=json_encode({})) @@ -461,6 +462,7 @@ async def restart_kernel(self, **kw): async def interrupt_kernel(self): """Interrupts the kernel via an HTTP request.""" if self.has_kernel: + assert self.kernel_url is not None kernel_url = self.kernel_url + "/interrupt" self.log.debug("Request interrupt kernel at: %s", kernel_url) response = await gateway_request(kernel_url, method="POST", body=json_encode({})) @@ -483,9 +485,9 @@ def cleanup_resources(self, restart=False): KernelManagerABC.register(GatewayKernelManager) -class ChannelQueue(Queue): +class ChannelQueue(Queue): # type:ignore[type-arg] - channel_name: str = None + channel_name: t.Optional[str] = None def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger): super().__init__() @@ -493,7 +495,7 @@ def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: self.channel_socket = channel_socket self.log = log - async def get_msg(self, *args, **kwargs) -> dict: + async def get_msg(self, *args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: timeout = kwargs.get("timeout", 1) msg = self.get(timeout=timeout) self.log.debug( @@ -502,9 +504,9 @@ async def get_msg(self, *args, **kwargs) -> dict: ) ) self.task_done() - return msg + return t.cast(t.Dict[str, t.Any], msg) - def send(self, msg: dict) -> None: + def send(self, msg: t.Dict[str, t.Any]) -> None: message = json.dumps(msg, default=ChannelQueue.serialize_datetime).replace(" None: @staticmethod def serialize_datetime(dt): - if isinstance(dt, (datetime.date, datetime.datetime)): + if isinstance(dt, (datetime.datetime)): return dt.timestamp() return None @@ -571,13 +573,18 @@ class GatewayKernelClient(AsyncKernelClient): # flag for whether execute requests should be allowed to call raw_input: allow_stdin = False _channels_stopped = False - _channel_queues = {} + _channel_queues: t.Optional[t.Dict[str, t.Any]] = {} + _control_channel: t.Optional[ChannelQueue] + _hb_channel: t.Optional[ChannelQueue] + _stdin_channel: t.Optional[ChannelQueue] + _iopub_channel: t.Optional[ChannelQueue] + _shell_channel: t.Optional[ChannelQueue] def __init__(self, **kwargs): super().__init__(**kwargs) self.kernel_id = kwargs["kernel_id"] - self.channel_socket = None - self.response_router = None + self.channel_socket: t.Optional[websocket.WebSocket] = None + self.response_router: t.Optional[Thread] = None # -------------------------------------------------------------------------- # Channel management methods @@ -626,7 +633,9 @@ def stop_channels(self): self._channels_stopped = True self.log.debug("Closing websocket connection") + assert self.channel_socket is not None self.channel_socket.close() + assert self.response_router is not None self.response_router.join() if self._channel_queues: @@ -640,7 +649,9 @@ def shell_channel(self): """Get the shell channel object for this kernel.""" if self._shell_channel is None: self.log.debug("creating shell channel queue") + assert self.channel_socket is not None self._shell_channel = ChannelQueue("shell", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["shell"] = self._shell_channel return self._shell_channel @@ -649,7 +660,9 @@ def iopub_channel(self): """Get the iopub channel object for this kernel.""" if self._iopub_channel is None: self.log.debug("creating iopub channel queue") + assert self.channel_socket is not None self._iopub_channel = ChannelQueue("iopub", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["iopub"] = self._iopub_channel return self._iopub_channel @@ -658,7 +671,9 @@ def stdin_channel(self): """Get the stdin channel object for this kernel.""" if self._stdin_channel is None: self.log.debug("creating stdin channel queue") + assert self.channel_socket is not None self._stdin_channel = ChannelQueue("stdin", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["stdin"] = self._stdin_channel return self._stdin_channel @@ -667,7 +682,9 @@ def hb_channel(self): """Get the hb channel object for this kernel.""" if self._hb_channel is None: self.log.debug("creating hb channel queue") + assert self.channel_socket is not None self._hb_channel = HBChannelQueue("hb", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["hb"] = self._hb_channel return self._hb_channel @@ -676,7 +693,9 @@ def control_channel(self): """Get the control channel object for this kernel.""" if self._control_channel is None: self.log.debug("creating control channel queue") + assert self.channel_socket is not None self._control_channel = ChannelQueue("control", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["control"] = self._control_channel return self._control_channel @@ -690,11 +709,13 @@ def _route_responses(self): """ try: while not self._channels_stopped: + assert self.channel_socket is not None raw_message = self.channel_socket.recv() if not raw_message: break response_message = json_decode(utf8(raw_message)) channel = response_message["channel"] + assert self._channel_queues is not None self._channel_queues[channel].put_nowait(response_message) except websocket.WebSocketConnectionClosedException: diff --git a/jupyter_server/i18n/__init__.py b/jupyter_server/i18n/__init__.py index e44aa11393..b49f0ac408 100644 --- a/jupyter_server/i18n/__init__.py +++ b/jupyter_server/i18n/__init__.py @@ -3,6 +3,7 @@ import errno import json import re +import typing as t from collections import defaultdict from os.path import dirname from os.path import join as pjoin @@ -15,7 +16,7 @@ # ... # } # }} -TRANSLATIONS_CACHE = {"nbjs": {}} +TRANSLATIONS_CACHE: t.Dict[str, t.Any] = {"nbjs": {}} _accept_lang_re = re.compile( @@ -87,7 +88,7 @@ def combine_translations(accept_language, domain="nbjs"): Returns data re-packaged in jed1.x format. """ lang_codes = parse_accept_lang_header(accept_language) - combined = {} + combined: t.Dict[str, t.Any] = {} for language in lang_codes: if language == "en": # en is default, all translations are in frontend. diff --git a/jupyter_server/prometheus/metrics.py b/jupyter_server/prometheus/metrics.py index ae98043c3e..947caa399d 100644 --- a/jupyter_server/prometheus/metrics.py +++ b/jupyter_server/prometheus/metrics.py @@ -4,6 +4,11 @@ Read https://prometheus.io/docs/practices/naming/ for naming conventions for metrics & labels. """ +__all__ = [ + "HTTP_REQUEST_DURATION_SECONDS", + "KERNEL_CURRENTLY_RUNNING_TOTAL", + "TERMINAL_CURRENTLY_RUNNING_TOTAL", +] try: # Jupyter Notebook also defines these metrics. Re-defining them results in a ValueError. diff --git a/jupyter_server/py.typed b/jupyter_server/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/pytest_plugin.py b/jupyter_server/pytest_plugin.py index 7b35795c63..49fcf31f7a 100644 --- a/jupyter_server/pytest_plugin.py +++ b/jupyter_server/pytest_plugin.py @@ -34,7 +34,9 @@ import asyncio if os.name == "nt" and sys.version_info >= (3, 7): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy() # type:ignore[attr-defined] + ) # ============ Move to Jupyter Core ============= diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index f2c337d404..eeba145407 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -23,22 +23,25 @@ import sys import threading import time +import typing as t import urllib import warnings import webbrowser from base64 import encodebytes +from urllib.parse import urljoin +from urllib.request import pathname2url try: import resource except ImportError: # Windows - resource = None + resource = None # type:ignore[assignment] from jinja2 import Environment, FileSystemLoader from jupyter_core.paths import secure_write from jupyter_server.transutils import _i18n, trans -from jupyter_server.utils import pathname2url, run_sync_in_loop, urljoin +from jupyter_server.utils import run_sync_in_loop # the minimum viable tornado version: needs to be kept in sync with setup.py MIN_TORNADO = (6, 1, 0) @@ -102,8 +105,8 @@ from jupyter_server.extension.config import ExtensionConfigManager from jupyter_server.extension.manager import ExtensionManager from jupyter_server.extension.serverextension import ServerExtensionApp +from jupyter_server.gateway.gateway_client import GatewayClient from jupyter_server.gateway.managers import ( - GatewayClient, GatewayKernelSpecManager, GatewayMappingKernelManager, GatewaySessionManager, @@ -275,7 +278,7 @@ def init_settings( _template_path = (_template_path,) template_path = [os.path.expanduser(path) for path in _template_path] - jenv_opt = {"autoescape": True} + jenv_opt: t.Dict[str, t.Any] = {"autoescape": True} jenv_opt.update(jinja_env_options if jinja_env_options else {}) env = Environment( @@ -1036,7 +1039,7 @@ def _token_default(self): return os.getenv("JUPYTER_TOKEN") if os.getenv("JUPYTER_TOKEN_FILE"): self._token_generated = False - with open(os.getenv("JUPYTER_TOKEN_FILE")) as token_file: + with open(os.getenv("JUPYTER_TOKEN_FILE", "")) as token_file: return token_file.read() if self.password: # no token if password is enabled @@ -1191,10 +1194,10 @@ def _default_allow_remote(self): except ValueError: # Address is a hostname for info in socket.getaddrinfo(self.ip, self.port, 0, socket.SOCK_STREAM): - addr = info[4][0] + addr = info[4][0] # type:ignore[assignment] try: - parsed = ipaddress.ip_address(addr.split("%")[0]) + parsed = ipaddress.ip_address(addr.split("%")[0]) # type:ignore[union-attr] except ValueError: self.log.warning("Unrecognised IP address: %r", addr) continue @@ -1202,7 +1205,10 @@ def _default_allow_remote(self): # Macs map localhost to 'fe80::1%lo0', a link local address # scoped to the loopback interface. For now, we'll assume that # any scoped link-local address is effectively local. - if not (parsed.is_loopback or (("%" in addr) and parsed.is_link_local)): + if not ( + parsed.is_loopback + or (("%" in addr) and parsed.is_link_local) # type:ignore[operator] + ): return True return False else: @@ -1989,12 +1995,7 @@ def _get_urlparts(self, path=None, include_token=False): query = urllib.parse.urlencode({"token": token}) # Build the URL Parts to dump. urlparts = urllib.parse.ParseResult( - scheme=scheme, - netloc=netloc, - path=path, - params=None, - query=query, - fragment=None, + scheme=scheme, netloc=netloc, path=path, query=query or "", params="", fragment="" ) return urlparts @@ -2644,6 +2645,7 @@ def launch_browser(self): assembled_url, _ = self._prepare_browser_open() def target(): + assert browser is not None browser.open(assembled_url, new=self.webbrowser_open_new) threading.Thread(target=target).start() diff --git a/jupyter_server/services/config/__init__.py b/jupyter_server/services/config/__init__.py index 9a2aee241d..7bd910ee92 100644 --- a/jupyter_server/services/config/__init__.py +++ b/jupyter_server/services/config/__init__.py @@ -1 +1,3 @@ from .manager import ConfigManager # noqa + +__all__ = ["ConfigManager"] diff --git a/jupyter_server/services/config/manager.py b/jupyter_server/services/config/manager.py index 5f04925fe7..720c8e7bd7 100644 --- a/jupyter_server/services/config/manager.py +++ b/jupyter_server/services/config/manager.py @@ -3,6 +3,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import os.path +import typing as t from jupyter_core.paths import jupyter_config_dir, jupyter_config_path from traitlets import Instance, List, Unicode, default, observe @@ -22,7 +23,7 @@ class ConfigManager(LoggingConfigurable): def get(self, section_name): """Get the config from all config sections.""" - config = {} + config: t.Dict[str, t.Any] = {} # step through back to front, to ensure front of the list is top priority for p in self.read_config_path[::-1]: cm = BaseJSONConfigManager(config_dir=p) diff --git a/jupyter_server/services/contents/fileio.py b/jupyter_server/services/contents/fileio.py index d01bfd16dc..e1d6ae66dc 100644 --- a/jupyter_server/services/contents/fileio.py +++ b/jupyter_server/services/contents/fileio.py @@ -204,11 +204,12 @@ def atomic_writing(self, os_path, *args, **kwargs): Depending on flag 'use_atomic_writing', the wrapper perform an actual atomic writing or simply writes the file (whatever an old exists or not)""" with self.perm_to_403(os_path): + kwargs["log"] = self.log if self.use_atomic_writing: - with atomic_writing(os_path, *args, log=self.log, **kwargs) as f: + with atomic_writing(os_path, *args, **kwargs) as f: yield f else: - with _simple_writing(os_path, *args, log=self.log, **kwargs) as f: + with _simple_writing(os_path, *args, **kwargs) as f: yield f @contextmanager diff --git a/jupyter_server/services/contents/filemanager.py b/jupyter_server/services/contents/filemanager.py index 88aa0e3620..1982468218 100644 --- a/jupyter_server/services/contents/filemanager.py +++ b/jupyter_server/services/contents/filemanager.py @@ -7,6 +7,7 @@ import shutil import stat import sys +import typing as t from datetime import datetime import nbformat @@ -331,7 +332,7 @@ def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error = {} + validation_error: t.Dict[str, Exception] = {} nb = self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -412,7 +413,7 @@ def save(self, model, path=""): os_path = self._get_os_path(path) self.log.debug("Saving %s", os_path) - validation_error = {} + validation_error: t.Dict[str, Exception] = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) @@ -657,7 +658,7 @@ async def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error = {} + validation_error: t.Dict[str, Exception] = {} nb = await self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -738,7 +739,7 @@ async def save(self, model, path=""): os_path = self._get_os_path(path) self.log.debug("Saving %s", os_path) - validation_error = {} + validation_error: t.Dict[str, Exception] = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) diff --git a/jupyter_server/services/contents/handlers.py b/jupyter_server/services/contents/handlers.py index 59c109ad84..6b98c5d6cf 100644 --- a/jupyter_server/services/contents/handlers.py +++ b/jupyter_server/services/contents/handlers.py @@ -54,7 +54,7 @@ def validate_model(model, expect_content): f"Keys unexpectedly None: {errors}", ) else: - errors = {key: model[key] for key in maybe_none_keys if model[key] is not None} + errors = {key: model[key] for key in maybe_none_keys if model[key] is not None} # type: ignore[assignment] if errors: raise web.HTTPError( 500, @@ -102,10 +102,10 @@ async def get(self, path=""): format = self.get_query_argument("format", default=None) if format not in {None, "text", "base64"}: raise web.HTTPError(400, "Format %r is invalid" % format) - content = self.get_query_argument("content", default="1") - if content not in {"0", "1"}: - raise web.HTTPError(400, "Content %r is invalid" % content) - content = int(content) + content_str = self.get_query_argument("content", default="1") + if content_str not in {"0", "1"}: + raise web.HTTPError(400, "Content %r is invalid" % content_str) + content = int(content_str or "") model = await ensure_async( self.contents_manager.get( diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index c5fd110fa9..aa005f0e49 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -5,6 +5,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import json +import typing as t from textwrap import dedent from traceback import format_tb @@ -114,7 +115,10 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): # class-level registry of open sessions # allows checking for conflict on session-id, # which is used as a zmq identity and must be unique. - _open_sessions = {} + _open_sessions: t.Dict[str, t.Any] = {} + + _kernel_info_future: Future + _close_future: Future @property def kernel_info_timeout(self): @@ -177,7 +181,7 @@ def nudge(self): # establishing its zmq subscriptions before processing the next request. if getattr(kernel, "execution_state", None) == "busy": self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) - f = Future() + f: Future = Future() f.set_result(None) return f # Use a transient shell channel to prevent leaking @@ -189,8 +193,8 @@ def nudge(self): # The IOPub used by the client, whose subscriptions we are verifying. iopub_channel = self.channels["iopub"] - info_future = Future() - iopub_future = Future() + info_future: Future = Future() + iopub_future: Future = Future() both_done = gen.multi([info_future, iopub_future]) def finish(_=None): @@ -203,7 +207,7 @@ def finish(_=None): def cleanup(_=None): """Common cleanup""" - loop.remove_timeout(nudge_handle) + loop.remove_timeout(nudge_handle) # type:ignore[has-type] iopub_channel.stop_on_recv() if not shell_channel.closed(): shell_channel.close() @@ -271,7 +275,7 @@ def nudge(count): log(f"Nudge: attempt {count} on kernel {self.kernel_id}") self.session.send(shell_channel, "kernel_info_request") self.session.send(control_channel, "kernel_info_request") - nonlocal nudge_handle + nonlocal nudge_handle # type:ignore[misc] nudge_handle = loop.call_later(0.5, nudge, count) nudge_handle = loop.call_later(0, nudge, count=0) @@ -293,8 +297,9 @@ def request_kernel_info(self): self.log.debug("Requesting kernel info from %s", self.kernel_id) # Create a kernel_info channel to query the kernel protocol version. # This channel will be closed after the kernel_info reply is received. - if self.kernel_info_channel is None: + if self.kernel_info_channel is None: # type:ignore[has-type] self.kernel_info_channel = km.connect_shell(self.kernel_id) + assert self.kernel_info_channel is not None self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) self.session.send(self.kernel_info_channel, "kernel_info_request") # store the future on the kernel, so only one request is sent @@ -512,6 +517,7 @@ def on_message(self, ws_msg): ignore_msg = False if am: msg["header"] = self.get_part("header", msg["header"], msg_list) + assert msg["header"] is not None if msg["header"]["msg_type"] not in am: self.log.warning( 'Received message of type "%s", which is not allowed. Ignoring.' diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index f9b9af23bd..9142702696 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -403,7 +403,7 @@ async def restart_kernel(self, kernel_id, now=False): kernel = self.get_kernel(kernel_id) # return a Future that will resolve when the kernel has successfully restarted channel = kernel.connect_shell() - future = Future() + future: Future = Future() def finish(): """Common cleanup when restart finishes/fails for any reason.""" diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 5ea14af5ac..b161966305 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -2,13 +2,14 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import pathlib +import typing as t import uuid try: import sqlite3 except ImportError: # fallback on pysqlite2 if Python was build without sqlite - from pysqlite2 import dbapi2 as sqlite3 + from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef] from dataclasses import dataclass, fields from typing import Union @@ -41,7 +42,7 @@ class KernelSessionRecord: session_id: Union[None, str] = None kernel_id: Union[None, str] = None - def __eq__(self, other: "KernelSessionRecord") -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, KernelSessionRecord): condition1 = self.kernel_id and self.kernel_id == other.kernel_id condition2 = all( @@ -96,14 +97,14 @@ class KernelSessionRecordList: """ def __init__(self, *records): - self._records = [] + self._records: t.List[KernelSessionRecord] = [] for record in records: self.update(record) def __str__(self): return str(self._records) - def __contains__(self, record: Union[KernelSessionRecord, str]): + def __contains__(self, record: Union[KernelSessionRecord, str]) -> bool: """Search for records by kernel_id and session_id""" if isinstance(record, KernelSessionRecord) and record in self._records: return True diff --git a/jupyter_server/terminal/__init__.py b/jupyter_server/terminal/__init__.py index c8d2856087..a4b2f2a672 100644 --- a/jupyter_server/terminal/__init__.py +++ b/jupyter_server/terminal/__init__.py @@ -15,12 +15,14 @@ from .handlers import TermSocket from .terminalmanager import TerminalManager +__all__ = ["TermSocket", "TerminalManager", "initialize"] + def initialize(webapp, root_dir, connection_url, settings): if os.name == "nt": default_shell = "powershell.exe" else: - default_shell = which("sh") + default_shell = which("sh") # type:ignore[assignment] shell_override = settings.get("shell_command") shell = [os.environ.get("SHELL") or default_shell] if shell_override is None else shell_override # When the notebook server is not running in a terminal (e.g. when diff --git a/jupyter_server/terminal/api_handlers.py b/jupyter_server/terminal/api_handlers.py index e521dd353a..c1b8170dee 100644 --- a/jupyter_server/terminal/api_handlers.py +++ b/jupyter_server/terminal/api_handlers.py @@ -35,7 +35,7 @@ def post(self): if not cwd.resolve().exists(): cwd = Path(self.settings["server_root_dir"]).expanduser() / cwd if not cwd.resolve().exists(): - cwd = None + cwd = None # type:ignore[assignment] if cwd is None: server_root_dir = self.settings["server_root_dir"] diff --git a/jupyter_server/traittypes.py b/jupyter_server/traittypes.py index cad8b4e204..1034f6935c 100644 --- a/jupyter_server/traittypes.py +++ b/jupyter_server/traittypes.py @@ -8,6 +8,8 @@ class TypeFromClasses(ClassBasedTraitType): """A trait whose value must be a subclass of a class in a specified list of classes.""" + default_value: Undefined + def __init__(self, default_value=Undefined, klasses=None, **kwargs): """Construct a Type trait A Type trait specifies that its values must be subclasses of @@ -181,6 +183,7 @@ def validate(self, obj, value): def info(self): result = "an instance of " + assert self.klasses is not None for klass in self.klasses: if isinstance(klass, str): result += klass @@ -199,6 +202,7 @@ def instance_init(self, obj): def _resolve_classes(self): # Resolve all string names to actual classes. self.importable_klasses = [] + assert self.klasses is not None for klass in self.klasses: if isinstance(klass, str): # Try importing the classes to compare. Silently, ignore if not importable. diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index c7eb9a71f5..bcd9ebe853 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -149,7 +149,7 @@ def _check_pid_win32(pid): # OpenProcess returns 0 if no such process (of ours) exists # positive int otherwise - return bool(ctypes.windll.kernel32.OpenProcess(1, 0, pid)) + return bool(ctypes.windll.kernel32.OpenProcess(1, 0, pid)) # type:ignore[attr-defined] def _check_pid_posix(pid): diff --git a/pyproject.toml b/pyproject.toml index 6c549dda5b..0d6893059e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,3 +53,22 @@ default = "" [[tool.tbump.field]] name = "release" default = "" + +[tool.mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_decorators = true +follow_imports = "normal" +ignore_missing_imports = true +no_implicit_optional = true +no_implicit_reexport = true +pretty = true +show_error_context = true +show_error_codes = true +strict_equality = true +strict_optional = true +warn_unused_configs = true +warn_redundant_casts = true +warn_return_any = true +warn_unused_ignores = true diff --git a/tests/auth/test_authorizer.py b/tests/auth/test_authorizer.py index 096437b47a..a7f801c722 100644 --- a/tests/auth/test_authorizer.py +++ b/tests/auth/test_authorizer.py @@ -1,5 +1,6 @@ """Tests for authorization""" import json +import typing as t import pytest from jupyter_client.kernelspec import NATIVE_KERNEL_NAME @@ -18,7 +19,7 @@ class AuthorizerforTesting(Authorizer): # Set these class attributes from within a test # to verify that they match the arguments passed # by the REST API. - permissions = {} + permissions: t.Dict[str, str] = {} def normalize_url(self, path): """Drop the base URL and make sure path leads with a /""" diff --git a/tests/auth/test_login.py b/tests/auth/test_login.py index 0b918d91d5..6a8fb9d198 100644 --- a/tests/auth/test_login.py +++ b/tests/auth/test_login.py @@ -48,6 +48,7 @@ async def _login(jp_serverapp, http_server_client, jp_base_url, next): except HTTPClientError as e: if e.code != 302: raise + assert e.response is not None return e.response.headers["Location"] else: assert resp.code == 302, "Should have returned a redirect!" diff --git a/tests/extension/test_app.py b/tests/extension/test_app.py index 88a423f252..60e8674d83 100644 --- a/tests/extension/test_app.py +++ b/tests/extension/test_app.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from traitlets.config import Config @@ -78,7 +80,7 @@ def test_extensionapp_no_parent(): assert app.serverapp is not None -OPEN_BROWSER_COMBINATIONS = ( +OPEN_BROWSER_COMBINATIONS: Any = ( (True, {}), (True, {"ServerApp": {"open_browser": True}}), (False, {"ServerApp": {"open_browser": False}}), diff --git a/tests/extension/test_manager.py b/tests/extension/test_manager.py index 2b52fea543..03e23597ee 100644 --- a/tests/extension/test_manager.py +++ b/tests/extension/test_manager.py @@ -6,11 +6,13 @@ from jupyter_server.extension.manager import ( ExtensionManager, - ExtensionMetadataError, - ExtensionModuleNotFound, ExtensionPackage, ExtensionPoint, ) +from jupyter_server.extension.utils import ( + ExtensionMetadataError, + ExtensionModuleNotFound, +) # Use ServerApps environment because it monkeypatches # jupyter_core.paths and provides a config directory diff --git a/tests/nbconvert/test_handlers.py b/tests/nbconvert/test_handlers.py index 809f0ba3ec..f14fde35a2 100644 --- a/tests/nbconvert/test_handlers.py +++ b/tests/nbconvert/test_handlers.py @@ -3,9 +3,9 @@ from shutil import which import pytest -import tornado from nbformat import writes from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook, new_output +from tornado.httpclient import HTTPClientError from ..utils import expected_http_error @@ -75,7 +75,7 @@ async def test_from_file(jp_fetch, notebook): async def test_from_file_404(jp_fetch, notebook): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch( "nbconvert", "html", diff --git a/tests/services/contents/test_api.py b/tests/services/contents/test_api.py index 988dcdb603..1398a6efe9 100644 --- a/tests/services/contents/test_api.py +++ b/tests/services/contents/test_api.py @@ -1,6 +1,7 @@ import json import pathlib import sys +import typing as t from base64 import decodebytes, encodebytes from unicodedata import normalize @@ -54,7 +55,7 @@ def contents_dir(tmp_path, jp_serverapp): @pytest.fixture def contents(contents_dir): # Create files in temporary directory - paths = { + paths: t.Dict[str, t.Any] = { "notebooks": [], "textfiles": [], "blobs": [], diff --git a/tests/services/contents/test_manager.py b/tests/services/contents/test_manager.py index 6765cbbe54..7086c09660 100644 --- a/tests/services/contents/test_manager.py +++ b/tests/services/contents/test_manager.py @@ -1,8 +1,8 @@ import os import sys import time +import typing as t from itertools import combinations -from typing import Dict, Optional, Tuple from unittest.mock import patch import pytest @@ -77,8 +77,8 @@ def add_invalid_cell(notebook): async def prepare_notebook( - jp_contents_manager, make_invalid: Optional[bool] = False -) -> Tuple[Dict, str]: + jp_contents_manager: FileContentsManager, make_invalid: t.Optional[bool] = False +) -> t.Tuple[t.Dict[str, t.Any], str]: cm = jp_contents_manager model = await ensure_async(cm.new_untitled(type="notebook")) name = model["name"] @@ -235,7 +235,7 @@ async def test_good_symlink(jp_file_contents_manager_class, tmp_path): symlink(cm, file_model["path"], path) symlink_model = await ensure_async(cm.get(path, content=False)) dir_model = await ensure_async(cm.get(parent)) - assert sorted(dir_model["content"], key=lambda x: x["name"]) == [ + assert sorted(dir_model["content"], key=lambda x: x["name"]) == [ # type:ignore[no-any-return] symlink_model, file_model, ] @@ -756,7 +756,7 @@ async def test_validate_notebook_model(jp_contents_manager): with patch("jupyter_server.services.contents.manager.validate_nb") as mock_validate_nb: # Valid notebook and a non-None dictionary, no validate call expected - validation_error = {} + validation_error: t.Dict[str, object] = {} cm.validate_notebook_model(model, validation_error) assert mock_validate_nb.call_count == 0 mock_validate_nb.reset_mock() diff --git a/tests/services/kernels/test_api.py b/tests/services/kernels/test_api.py index bb91a588e4..2faf2df34e 100644 --- a/tests/services/kernels/test_api.py +++ b/tests/services/kernels/test_api.py @@ -1,6 +1,7 @@ import json import os import time +import typing as t import jupyter_client import pytest @@ -27,7 +28,7 @@ async def _(kernel_id): return _ -configs = [ +configs: t.List[t.Any] = [ { "ServerApp": { "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager" diff --git a/tests/services/kernelspecs/test_api.py b/tests/services/kernelspecs/test_api.py index ee14d6afb0..461cc40e3e 100644 --- a/tests/services/kernelspecs/test_api.py +++ b/tests/services/kernelspecs/test_api.py @@ -1,8 +1,8 @@ import json import pytest -import tornado from jupyter_client.kernelspec import NATIVE_KERNEL_NAME +from tornado.httpclient import HTTPClientError from ...utils import expected_http_error, some_resource @@ -51,7 +51,7 @@ async def test_get_kernelspecs(jp_fetch, jp_kernelspecs): async def test_get_nonexistant_kernelspec(jp_fetch, jp_kernelspecs): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("api", "kernelspecs", "nonexistant", method="GET") assert expected_http_error(e, 404) @@ -63,10 +63,10 @@ async def test_get_kernel_resource_file(jp_fetch, jp_kernelspecs): async def test_get_nonexistant_resource(jp_fetch, jp_kernelspecs): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("kernelspecs", "nonexistant", "resource.txt", method="GET") assert expected_http_error(e, 404) - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("kernelspecs", "sample", "nonexistant.txt", method="GET") assert expected_http_error(e, 404) diff --git a/tests/services/sessions/test_api.py b/tests/services/sessions/test_api.py index 19c165cbb1..2eb4c490d7 100644 --- a/tests/services/sessions/test_api.py +++ b/tests/services/sessions/test_api.py @@ -2,6 +2,7 @@ import os import shutil import time +import typing as t import jupyter_client import pytest @@ -25,13 +26,13 @@ def j(r): class NewPortsKernelManager(AsyncIOLoopKernelManager): - @default("cache_ports") + @default("cache_ports") # type:ignore[misc] def _default_cache_ports(self) -> bool: return False - async def restart_kernel(self, now: bool = False, newports: bool = True, **kw) -> None: + async def restart_kernel(self, now: bool = False, newports: bool = True, **kw: t.Any) -> None: self.log.debug(f"DEBUG**** calling super().restart_kernel with newports={newports}") - return await super().restart_kernel(now=now, newports=newports, **kw) + return await super().restart_kernel(now=now, newports=newports, **kw) # type:ignore class NewPortsMappingKernelManager(AsyncMappingKernelManager): @@ -41,7 +42,7 @@ def _default_kernel_manager_class(self): return "tests.services.sessions.test_api.NewPortsKernelManager" -configs = [ +configs: t.List[t.Any] = [ { "ServerApp": { "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager" @@ -65,7 +66,7 @@ def _default_kernel_manager_class(self): # See https://github.com/jupyter-server/jupyter_server/issues/672 if os.name != "nt" and jupyter_client._version.version_info >= (7, 1): # Add a pending kernels condition - c = { + c: t.Dict[str, t.Any] = { "ServerApp": { "kernel_manager_class": "tests.services.sessions.test_api.NewPortsMappingKernelManager" }, diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index a67dd6398e..48d5761746 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -16,6 +16,9 @@ class DummyKernel: + execution_state: str + last_activity: str + def __init__(self, kernel_name="python"): self.kernel_name = kernel_name @@ -132,7 +135,7 @@ def test_kernel_record_list(): # Test .get() r_ = records.get(r) assert r == r_ - r_ = records.get(r.kernel_id) + r_ = records.get(r.kernel_id or "") assert r == r_ with pytest.raises(ValueError): diff --git a/tests/test_files.py b/tests/test_files.py index 7fac8419d4..06f1932591 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -3,9 +3,9 @@ from pathlib import Path import pytest -import tornado from nbformat import writes from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook, new_output +from tornado.httpclient import HTTPClientError from .utils import expected_http_error @@ -28,7 +28,7 @@ async def fetch_expect_200(jp_fetch, *path_parts): async def fetch_expect_404(jp_fetch, *path_parts): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("files", *path_parts, method="GET") assert expected_http_error(e, 404), [path_parts, e] diff --git a/tests/test_gateway.py b/tests/test_gateway.py index d040999558..047cf2caff 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -1,9 +1,10 @@ """Test GatewayClient""" import json import os +import typing as t import uuid from datetime import datetime -from io import StringIO +from io import BytesIO from unittest.mock import patch import pytest @@ -11,7 +12,7 @@ from tornado.httpclient import HTTPRequest, HTTPResponse from tornado.web import HTTPError -from jupyter_server.gateway.managers import GatewayClient +from jupyter_server.gateway.gateway_client import GatewayClient from jupyter_server.utils import ensure_async from .utils import expected_http_error @@ -34,7 +35,7 @@ def generate_kernelspec(name): # We'll mock up two kernelspecs - kspec_foo and kspec_bar -kernelspecs = { +kernelspecs: t.Dict[str, t.Any] = { "default": "kspec_foo", "kernelspecs": { "kspec_foo": generate_kernelspec("kspec_foo"), @@ -72,16 +73,17 @@ async def mock_gateway_request(url, **kwargs): # Fetch all kernelspecs if endpoint.endswith("/api/kernelspecs") and method == "GET": - response_buf = StringIO(json.dumps(kernelspecs)) + response_buf = BytesIO(json.dumps(kernelspecs).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response # Fetch named kernelspec if endpoint.rfind("/api/kernelspecs/") >= 0 and method == "GET": requested_kernelspec = endpoint.rpartition("/")[2] - kspecs = kernelspecs.get("kernelspecs") + kspecs: t.Dict[str, t.Any] = kernelspecs["kernelspecs"] if requested_kernelspec in kspecs: - response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec))) + response_str = json.dumps(kspecs.get(requested_kernelspec)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response else: @@ -96,7 +98,7 @@ async def mock_gateway_request(url, **kwargs): assert name == kspec_name # Ensure that KERNEL_ env values get propagated model = generate_model(name) running_kernels[model.get("id")] = model # Register model as a running kernel - response_buf = StringIO(json.dumps(model)) + response_buf = BytesIO(json.dumps(model).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 201, buffer=response_buf)) return response @@ -106,7 +108,7 @@ async def mock_gateway_request(url, **kwargs): for kernel_id in running_kernels.keys(): model = running_kernels.get(kernel_id) kernels.append(model) - response_buf = StringIO(json.dumps(kernels)) + response_buf = BytesIO(json.dumps(kernels).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response @@ -122,7 +124,8 @@ async def mock_gateway_request(url, **kwargs): raise HTTPError(404, message="Kernel does not exist: %s" % requested_kernel_id) elif action == "restart": if requested_kernel_id in running_kernels: - response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response_str = json.dumps(running_kernels.get(requested_kernel_id)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 204, buffer=response_buf)) return response else: @@ -143,7 +146,8 @@ async def mock_gateway_request(url, **kwargs): if endpoint.rfind("/api/kernels/") >= 0 and method == "GET": requested_kernel_id = endpoint.rpartition("/")[2] if requested_kernel_id in running_kernels: - response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response_str = json.dumps(running_kernels.get(requested_kernel_id)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response else: @@ -313,6 +317,7 @@ async def create_session(root_dir, jp_fetch, kernel_name): kernel_id = model.get("kernel").get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None assert kernel_id == running_kernel.get("id") assert model.get("kernel").get("name") == running_kernel.get("name") session_id = model.get("id") @@ -359,6 +364,7 @@ async def create_kernel(jp_fetch, kernel_name): kernel_id = model.get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None assert kernel_id == running_kernel.get("id") assert model.get("name") == kernel_name @@ -398,6 +404,7 @@ async def restart_kernel(jp_fetch, kernel_id): restarted_kernel_id = model.get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(restarted_kernel_id) + assert running_kernel is not None assert restarted_kernel_id == running_kernel.get("id") assert model.get("name") == running_kernel.get("name") diff --git a/tests/test_paths.py b/tests/test_paths.py index 0789be4ded..9a6a41b3ba 100644 --- a/tests/test_paths.py +++ b/tests/test_paths.py @@ -63,6 +63,7 @@ async def test_trailing_slash( ) # Capture the response from the raised exception value. response = err.value.response + assert response is not None assert response.code == 302 assert "Location" in response.headers assert response.headers["Location"] == url_path_join(jp_base_url, expected) diff --git a/tests/unix_sockets/test_serverapp_integration.py b/tests/unix_sockets/test_serverapp_integration.py index 5bb1038234..9661539d7e 100644 --- a/tests/unix_sockets/test_serverapp_integration.py +++ b/tests/unix_sockets/test_serverapp_integration.py @@ -35,6 +35,7 @@ def test_shutdown_sock_server_integration(jp_unix_socket_file): ) complete = False + assert p.stderr is not None for line in iter(p.stderr.readline, b""): if url in line: complete = True diff --git a/tests/utils.py b/tests/utils.py index 6e6649af42..4eabcdceaa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ import json -import tornado +from tornado.httpclient import HTTPClientError +from tornado.web import HTTPError some_resource = "The very model of a modern major general" @@ -20,7 +21,7 @@ def mkdir(tmp_path, *parts): def expected_http_error(error, expected_code, expected_message=None): """Check that the error matches the expected output error.""" e = error.value - if isinstance(e, tornado.web.HTTPError): + if isinstance(e, HTTPError): if expected_code != e.status_code: return False if expected_message is not None and expected_message != str(e): @@ -28,8 +29,8 @@ def expected_http_error(error, expected_code, expected_message=None): return True elif any( [ - isinstance(e, tornado.httpclient.HTTPClientError), - isinstance(e, tornado.httpclient.HTTPError), + isinstance(e, HTTPClientError), + isinstance(e, HTTPError), ] ): if expected_code != e.code: