diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 799f7f2..b453036 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -20,38 +20,29 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [ '3.8', '3.9', '3.10.0', '3.11' ] + python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - name: Setup conda - uses: mamba-org/setup-micromamba@v1 + - name: Setup Python + uses: actions/setup-python@v5 with: - init-shell: bash powershell - cache-environment: true - post-cleanup: 'all' - environment-name: test-env - create-args: >- - python=${{ matrix.python-version }} - pip + python-version: ${{ matrix.python-version }} - name: Install akernel run: | - micromamba activate test-env - pip install .[test] + pip install . --group test - name: Check style and types run: | - micromamba activate test-env - black --check akernel - ruff check akernel - mypy akernel + ruff check --show-fixes + ruff format --check + mypy src - name: Run tests run: | - micromamba activate test-env akernel --help - test -f ${CONDA_PREFIX}/share/jupyter/kernels/akernel/kernel.json - pytest akernel/tests -v --reruns 5 + test -f ${pythonLocation}/../share/jupyter/kernels/akernel/kernel.json + pytest -v --reruns 5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e383f65 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 + hooks: + - id: ruff-check + args: [--fix, --show-fixes] + - id: ruff-format diff --git a/akernel/__init__.py b/akernel/__init__.py deleted file mode 100644 index 3ced358..0000000 --- a/akernel/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.2.1" diff --git a/akernel/akernel.py b/akernel/akernel.py deleted file mode 100644 index 9df1a17..0000000 --- a/akernel/akernel.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Optional - -import typer - -from .kernel import Kernel -from .kernelspec import write_kernelspec - - -cli = typer.Typer() - - -@cli.command() -def install( - mode: str = typer.Argument("", help="Mode of the kernel to install."), - cache_dir: Optional[str] = typer.Option( - None, "-c", help="Path to the cache directory, if mode is 'cache'." - ), -): - kernel_name = "akernel" - if mode: - modes = mode.split("-") - modes.sort() - mode = "-".join(modes) - kernel_name += f"-{mode}" - display_name = f"Python 3 ({kernel_name})" - write_kernelspec(kernel_name, mode, display_name, cache_dir) - - -@cli.command() -def launch( - mode: str = typer.Argument("", help="Mode of the kernel to launch."), - cache_dir: Optional[str] = typer.Option( - None, "-c", help="Path to the cache directory, if mode is 'cache'." - ), - connection_file: str = typer.Option(..., "-f", help="Path to the connection file."), -): - Kernel(mode, cache_dir, connection_file) - - -if __name__ == "__main__": - cli() diff --git a/plugins/akernel_task/LICENSE b/plugins/akernel_task/LICENSE new file mode 100644 index 0000000..513ffaf --- /dev/null +++ b/plugins/akernel_task/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2025 David Brochart + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/plugins/akernel_task/README.md b/plugins/akernel_task/README.md new file mode 100644 index 0000000..1a46c52 --- /dev/null +++ b/plugins/akernel_task/README.md @@ -0,0 +1,3 @@ +# fps-akernel-task + +An FPS plugin for the kernel task API. diff --git a/plugins/akernel_task/fps_akernel_task/__init__.py b/plugins/akernel_task/fps_akernel_task/__init__.py new file mode 100644 index 0000000..43a8fd7 --- /dev/null +++ b/plugins/akernel_task/fps_akernel_task/__init__.py @@ -0,0 +1,6 @@ +import importlib.metadata + +try: + __version__ = importlib.metadata.version("fps_akernel_task") +except importlib.metadata.PackageNotFoundError: + __version__ = "unknown" diff --git a/plugins/akernel_task/fps_akernel_task/akernel_task.py b/plugins/akernel_task/fps_akernel_task/akernel_task.py new file mode 100644 index 0000000..6d4262b --- /dev/null +++ b/plugins/akernel_task/fps_akernel_task/akernel_task.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from anyio import TASK_STATUS_IGNORED, create_task_group, sleep_forever +from anyio.abc import TaskStatus +from jupyverse_api.kernel import Kernel as _Kernel + +from akernel.kernel import Kernel + + +class AKernelTask(_Kernel): + def __init__(self, *args, **kwargs): + super().__init__() + + async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: + async with ( + create_task_group() as self.task_group, + self._to_shell_send_stream, + self._to_shell_receive_stream, + self._from_shell_send_stream, + self._from_shell_receive_stream, + self._to_control_send_stream, + self._to_control_receive_stream, + self._from_control_send_stream, + self._from_control_receive_stream, + self._to_stdin_send_stream, + self._to_stdin_receive_stream, + self._from_stdin_send_stream, + self._from_stdin_receive_stream, + self._from_iopub_send_stream, + self._from_iopub_receive_stream, + ): + self.kernel = Kernel( + self._to_shell_receive_stream, + self._from_shell_send_stream, + self._to_control_receive_stream, + self._from_control_send_stream, + self._to_stdin_receive_stream, + self._from_stdin_send_stream, + self._from_iopub_send_stream, + ) + self.task_group.start_soon(self.kernel.start) + task_status.started() + await sleep_forever() + + async def stop(self) -> None: + self.task_group.cancel_scope.cancel() + + async def interrupt(self) -> None: + pass diff --git a/plugins/akernel_task/fps_akernel_task/main.py b/plugins/akernel_task/fps_akernel_task/main.py new file mode 100644 index 0000000..7f99c85 --- /dev/null +++ b/plugins/akernel_task/fps_akernel_task/main.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from fps import Module + +from jupyverse_api.kernel import KernelFactory +from jupyverse_api.kernels import Kernels + +from .akernel_task import AKernelTask + + +class AKernelTaskModule(Module): + async def prepare(self) -> None: + kernels = await self.get(Kernels) + kernels.register_kernel_factory("akernel", KernelFactory(AKernelTask)) diff --git a/akernel/IPython/core/__init__.py b/plugins/akernel_task/fps_akernel_task/py.typed similarity index 100% rename from akernel/IPython/core/__init__.py rename to plugins/akernel_task/fps_akernel_task/py.typed diff --git a/plugins/akernel_task/pyproject.toml b/plugins/akernel_task/pyproject.toml new file mode 100644 index 0000000..b7dbce4 --- /dev/null +++ b/plugins/akernel_task/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "fps_akernel_task" +version = "0.1.0" +description = "An FPS plugin for the kernel task API" +keywords = ["jupyter", "server", "fastapi", "plugins"] +requires-python = ">=3.9" +dependencies = [ + "jupyverse-api >=0.10.0,<0.11.0", + "anyio", +] + +[[project.authors]] +name = "David Brochart" +email = "david.brochart@gmail.com" + +[project.readme] +file = "README.md" +content-type = "text/markdown" + +[project.license] +text = "MIT" + +[project.urls] +Homepage = "https://github.com/davidbrochart/akernel" + +[project.entry-points] +"fps.modules" = {akernel_task = "fps_akernel_task.main:AKernelTaskModule"} +"jupyverse.modules" = {akernel_task = "fps_akernel_task.main:AKernelTaskModule"} diff --git a/pyproject.toml b/pyproject.toml index 3e48ce4..9380776 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,38 +4,35 @@ build-backend = "hatchling.build" [project] name = "akernel" -dynamic = ["version"] +version = "0.2.1" description = "An asynchronous Python Jupyter kernel" readme = "README.md" license = {text = "MIT"} -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [{name = "David Brochart", email = "david.brochart@gmail.com"}] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] keywords = [ "jupyter" ] dependencies = [ - "pyzmq", - "typer >=0.4.0", - "click", "python-dateutil", "colorama", - "gast >=0.5.3", + "gast >=0.6.0, <0.7.0", "comm >=0.1.3,<1", ] -[project.optional-dependencies] +[dependency-groups] test = [ "mypy", "ruff", - "black", "pytest", "pytest-asyncio", "pytest-rerunfailures", @@ -45,6 +42,12 @@ test = [ "zict", ] +[project.optional-dependencies] +subprocess = [ + "zmq-anyio >=0.3.9,<0.4.0", + "typer >=0.4.0", +] + react = [ "ipyx >=0.1.7", ] @@ -56,14 +59,19 @@ cache = [ [project.scripts] akernel = "akernel.akernel:cli" +[tool.hatch.build.targets.wheel] +ignore-vcs = true +packages = ["src/akernel"] + [tool.hatch.build.targets.wheel.shared-data] "share/jupyter/kernels/akernel/kernel.json" = "share/jupyter/kernels/akernel/kernel.json" [project.urls] Homepage = "https://github.com/davidbrochart/akernel" -[tool.hatch.version] -path = "akernel/__init__.py" - [tool.ruff] line-length = 100 +exclude = ["examples"] + +[tool.uv.sources] +fps-akernel-task = { workspace = true } diff --git a/akernel/IPython/__init__.py b/src/akernel/IPython/__init__.py similarity index 100% rename from akernel/IPython/__init__.py rename to src/akernel/IPython/__init__.py diff --git a/akernel/display/__init__.py b/src/akernel/IPython/core/__init__.py similarity index 100% rename from akernel/display/__init__.py rename to src/akernel/IPython/core/__init__.py diff --git a/akernel/IPython/core/getipython.py b/src/akernel/IPython/core/getipython.py similarity index 100% rename from akernel/IPython/core/getipython.py rename to src/akernel/IPython/core/getipython.py diff --git a/akernel/IPython/core/interactiveshell.py b/src/akernel/IPython/core/interactiveshell.py similarity index 100% rename from akernel/IPython/core/interactiveshell.py rename to src/akernel/IPython/core/interactiveshell.py diff --git a/src/akernel/__init__.py b/src/akernel/__init__.py new file mode 100644 index 0000000..26498ca --- /dev/null +++ b/src/akernel/__init__.py @@ -0,0 +1,6 @@ +import importlib.metadata + +try: + __version__ = importlib.metadata.version("akernel") +except importlib.metadata.PackageNotFoundError: + __version__ = "unknown" diff --git a/src/akernel/akernel.py b/src/akernel/akernel.py new file mode 100644 index 0000000..1e6c1ee --- /dev/null +++ b/src/akernel/akernel.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import json +from typing import cast + +import typer +from anyio import create_memory_object_stream, create_task_group, run, sleep_forever + +from .connect import connect_channel +from .kernel import Kernel +from .kernelspec import write_kernelspec +from .message import receive_message, send_message + + +cli = typer.Typer() + + +@cli.command() +def install( + mode: str = typer.Argument("", help="Mode of the kernel to install."), + cache_dir: str | None = typer.Option( + None, "-c", help="Path to the cache directory, if mode is 'cache'." + ), +): + kernel_name = "akernel" + if mode: + modes = mode.split("-") + modes.sort() + mode = "-".join(modes) + kernel_name += f"-{mode}" + display_name = f"Python 3 ({kernel_name})" + write_kernelspec(kernel_name, mode, display_name, cache_dir) + + +@cli.command() +def launch( + mode: str = typer.Argument("", help="Mode of the kernel to launch."), + cache_dir: str | None = typer.Option( + None, "-c", help="Path to the cache directory, if mode is 'cache'." + ), + connection_file: str | None = typer.Option(None, "-f", help="Path to the connection file."), +): + akernel = AKernel(mode, cache_dir, connection_file) + run(akernel.start) + + +class AKernel: + def __init__(self, mode, cache_dir, connection_file): + self._to_shell_send_stream, self._to_shell_receive_stream = create_memory_object_stream[list[bytes]]() + self._from_shell_send_stream, self._from_shell_receive_stream = create_memory_object_stream[list[bytes]]() + self._to_control_send_stream, self._to_control_receive_stream = create_memory_object_stream[list[bytes]]() + self._from_control_send_stream, self._from_control_receive_stream = create_memory_object_stream[list[bytes]]() + self._to_stdin_send_stream, self._to_stdin_receive_stream = create_memory_object_stream[list[bytes]]() + self._from_stdin_send_stream, self._from_stdin_receive_stream = create_memory_object_stream[list[bytes]]() + self._from_iopub_send_stream, self._from_iopub_receive_stream = create_memory_object_stream[list[bytes]](max_buffer_size=float("inf")) + self.kernel = Kernel( + self._to_shell_receive_stream, + self._from_shell_send_stream, + self._to_control_receive_stream, + self._from_control_send_stream, + self._to_stdin_receive_stream, + self._from_stdin_send_stream, + self._from_iopub_send_stream, + mode, + cache_dir, + ) + with open(connection_file) as f: + connection_cfg = json.load(f) + self.kernel.key = cast(str, connection_cfg["key"]) + self.shell_channel = connect_channel("shell", connection_cfg) + self.iopub_channel = connect_channel("iopub", connection_cfg) + self.control_channel = connect_channel("control", connection_cfg) + self.stdin_channel = connect_channel("stdin", connection_cfg) + + async def start(self) -> None: + async with ( + create_task_group() as tg, + self._to_shell_send_stream, + self._to_shell_receive_stream, + self._from_shell_send_stream, + self._from_shell_receive_stream, + self._to_control_send_stream, + self._to_control_receive_stream, + self._from_control_send_stream, + self._from_control_receive_stream, + self._to_stdin_send_stream, + self._to_stdin_receive_stream, + self._from_stdin_send_stream, + self._from_stdin_receive_stream, + self._from_iopub_send_stream, + self._from_iopub_receive_stream, + self.shell_channel, + self.control_channel, + self.stdin_channel, + self.iopub_channel, + ): + tg.start_soon(self.kernel.start) + tg.start_soon(self.to_shell) + tg.start_soon(self.from_shell) + tg.start_soon(self.to_control) + tg.start_soon(self.from_control) + tg.start_soon(self.to_stdin) + tg.start_soon(self.from_stdin) + tg.start_soon(self.from_iopub) + await sleep_forever() + + async def to_shell(self) -> None: + while True: + msg = await receive_message(self.shell_channel) + await self._to_shell_send_stream.send(msg) + + async def from_shell(self) -> None: + async for msg in self._from_shell_receive_stream: + await send_message(msg, self.shell_channel) + + async def to_control(self) -> None: + while True: + msg = await receive_message(self.control_channel) + await self._to_control_send_stream.send(msg) + + async def from_control(self) -> None: + async for msg in self._from_control_receive_stream: + await send_message(msg, self.control_channel) + + async def to_stdin(self) -> None: + while True: + msg = await receive_message(self.stdin_channel) + await self._to_stdin_send_stream.send(msg) + + async def from_stdin(self) -> None: + async for msg in self._from_stdin_receive_stream: + await send_message(msg, self.stdin_channel) + + async def from_iopub(self) -> None: + async for msg in self._from_iopub_receive_stream: + await send_message(msg, self.iopub_channel) + + +if __name__ == "__main__": + cli() diff --git a/akernel/cache.py b/src/akernel/cache.py similarity index 74% rename from akernel/cache.py rename to src/akernel/cache.py index 3d5361b..7fdf1b6 100644 --- a/akernel/cache.py +++ b/src/akernel/cache.py @@ -10,9 +10,7 @@ def cache(cache_dir: str | None): if not cache_dir: - cache_dir = os.path.join( - sys.prefix, "share", "jupyter", "kernels", "akernel", "cache" - ) + cache_dir = os.path.join(sys.prefix, "share", "jupyter", "kernels", "akernel", "cache") l4 = File(cache_dir) l3 = Func(zlib.compress, zlib.decompress, l4) diff --git a/akernel/code.py b/src/akernel/code.py similarity index 83% rename from akernel/code.py rename to src/akernel/code.py index c613f9a..7831682 100644 --- a/akernel/code.py +++ b/src/akernel/code.py @@ -1,5 +1,6 @@ from __future__ import annotations +import ast import copy import gast # type: ignore @@ -48,9 +49,9 @@ """ ).strip() -body_declare = gast.parse(code_declare).body -body_assign = gast.parse(code_assign).body -body_return = gast.parse(code_return).body +body_declare = ast.parse(code_declare).body +body_assign = ast.parse(code_assign).body +body_return = ast.parse(code_return).body def get_declare_body(lhs: str): @@ -88,14 +89,12 @@ def get_return_body(val): return body -body_globals_update_locals = gast.parse("globals().update(locals())").body +body_globals_update_locals = ast.parse("globals().update(locals())").body class Transform: - def __init__( - self, code: str, task_i: int | None = None, react: bool = False - ) -> None: - self.gtree = gast.parse(code) + def __init__(self, code: str, task_i: int | None = None, react: bool = False) -> None: + self.gtree = ast.parse(code) self.task_i = task_i self.react = react c = GlobalUseCollector() @@ -107,26 +106,24 @@ def __init__( if react: self.make_react() - def get_async_ast(self) -> gast.Module: + def get_async_ast(self) -> ast.Module: new_body = [] if self.globals: - new_body += [gast.Global(names=list(self.globals))] - if isinstance(self.last_statement, gast.Expr): + new_body += [ast.Global(names=list(self.globals))] + if isinstance(self.last_statement, ast.Expr): self.gtree.body.remove(self.last_statement) if self.react: last_statement = get_return_body(self.last_statement.value) else: - last_statement = [gast.Return(value=self.last_statement.value)] + last_statement = [ast.Return(value=self.last_statement.value)] new_body += self.gtree.body + body_globals_update_locals + last_statement else: new_body += self.gtree.body + body_globals_update_locals - name = ( - "__async_cell__" if self.task_i is None else f"__async_cell{self.task_i}__" - ) + name = "__async_cell__" if self.task_i is None else f"__async_cell{self.task_i}__" body = [ - gast.AsyncFunctionDef( + ast.AsyncFunctionDef( name=name, - args=gast.arguments( + args=ast.arguments( args=[], posonlyargs=[], vararg=None, @@ -141,20 +138,20 @@ def get_async_ast(self) -> gast.Module: type_comment=None, ), ] - gtree = gast.Module(body=body, type_ignores=[]) - gast.fix_missing_locations(gtree) + gtree = ast.Module(body=body, type_ignores=[]) + ast.fix_missing_locations(gtree) return gtree def get_code(self) -> str: - return gast.unparse(self.gtree) + return ast.unparse(self.gtree) def get_async_code(self) -> str: gtree = self.get_async_ast() - return gast.unparse(gtree) + return ast.unparse(gtree) def get_async_bytecode(self) -> CodeType: - gtree = self.get_async_ast() - tree = gast.gast_to_ast(gtree) + tree = self.get_async_ast() + #tree = gast.gast_to_ast(gtree) bytecode = compile(tree, filename="", mode="exec") return bytecode @@ -169,16 +166,12 @@ def make_react(self): ): # RHS rhs_calls = [ - n - for n in gast.walk(statement.value) - if isinstance(n, gast.Call) + n for n in gast.walk(statement.value) if isinstance(n, gast.Call) ] for n in rhs_calls: ipyx_name = gast.Name(id="ipyx", ctx=gast.Load()) n.func = gast.Call( - func=gast.Attribute( - value=ipyx_name, attr="F", ctx=gast.Load() - ), + func=gast.Attribute(value=ipyx_name, attr="F", ctx=gast.Load()), args=[n.func], keywords=[], ) @@ -243,9 +236,7 @@ def visit_Name(self, node): def visit_Assign(self, node): ctx, g = self.context[-1] if ctx == "global": - self.outputs += [ - target.id for target in node.targets if isinstance(target, gast.Name) - ] + self.outputs += [target.id for target in node.targets if isinstance(target, gast.Name)] self.generic_visit(node) def visit_AugAssign(self, node): diff --git a/akernel/comm/__init__.py b/src/akernel/comm/__init__.py similarity index 100% rename from akernel/comm/__init__.py rename to src/akernel/comm/__init__.py diff --git a/akernel/comm/comm.py b/src/akernel/comm/comm.py similarity index 62% rename from akernel/comm/comm.py rename to src/akernel/comm/comm.py index 4842643..f1d770e 100644 --- a/akernel/comm/comm.py +++ b/src/akernel/comm/comm.py @@ -1,19 +1,19 @@ from __future__ import annotations -from typing import Dict, List, Any, Callable +from typing import Any, Callable -import comm # type: ignore +import comm -from ..message import send_message, create_message +from ..message import create_message, serialize class Comm(comm.base_comm.BaseComm): _msg_callback: Callable | None comm_id: str topic: bytes - parent_header: Dict[str, Any] + parent_header: dict[str, Any] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: from akernel.kernel import KERNEL, PARENT_VAR, Kernel self.kernel: Kernel = KERNEL @@ -23,26 +23,23 @@ def __init__(self, **kwargs): def publish_msg( self, msg_type: str, - data: Dict[str, Any], - metadata: Dict[str, Any], - buffers: List[bytes], - **keys, + data: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + buffers: list[bytes] | None = None, + **keys: Any, ) -> None: msg = create_message( msg_type, content=dict(data=data, comm_id=self.comm_id, **keys), metadata=metadata, parent_header=self.parent_header, - ) - send_message( - msg, - self.kernel.iopub_channel, - self.kernel.key, - address=self.topic, buffers=buffers, + address=self.topic, ) + to_send = serialize(msg, self.kernel.key) + self.kernel.from_iopub_send_stream.send_nowait(to_send) - def handle_msg(self, msg: Dict[str, Any]) -> None: + def handle_msg(self, msg: dict[str, Any]) -> None: if self._msg_callback: self.kernel.execution_state = "busy" msg2 = create_message( @@ -50,7 +47,8 @@ def handle_msg(self, msg: Dict[str, Any]) -> None: parent_header=msg["header"], content={"execution_state": self.kernel.execution_state}, ) - send_message(msg2, self.kernel.iopub_channel, self.kernel.key) + to_send = serialize(msg2, self.kernel.key) + self.kernel.from_iopub_send_stream.send_nowait(to_send) self._msg_callback(msg) self.kernel.execution_state = "idle" msg2 = create_message( @@ -58,7 +56,8 @@ def handle_msg(self, msg: Dict[str, Any]) -> None: parent_header=msg["header"], content={"execution_state": self.kernel.execution_state}, ) - send_message(msg2, self.kernel.iopub_channel, self.kernel.key) + to_send = serialize(msg2, self.kernel.key) + self.kernel.from_iopub_send_stream.send_nowait(to_send) comm.create_comm = Comm diff --git a/akernel/comm/manager.py b/src/akernel/comm/manager.py similarity index 59% rename from akernel/comm/manager.py rename to src/akernel/comm/manager.py index c722f6e..94fe237 100644 --- a/akernel/comm/manager.py +++ b/src/akernel/comm/manager.py @@ -1,21 +1,22 @@ -from typing import Dict, Callable +from typing import Dict, Callable, cast -import comm # type: ignore +import comm from .comm import Comm class CommManager(comm.CommManager): - comms: Dict[str, Comm] + comms: dict[str, comm.base_comm.BaseComm] targets: Dict[str, Callable] - def __init__(self): + def __init__(self) -> None: super().__init__() from akernel.kernel import KERNEL, Kernel self.kernel: Kernel = KERNEL - def register_comm(self, comm: Comm) -> str: + def register_comm(self, comm: comm.base_comm.BaseComm) -> str: + comm = cast(Comm, comm) comm_id = comm.comm_id comm.kernel = self.kernel self.comms[comm_id] = comm diff --git a/akernel/connect.py b/src/akernel/connect.py similarity index 69% rename from akernel/connect.py rename to src/akernel/connect.py index 1150f04..8c43f66 100644 --- a/akernel/connect.py +++ b/src/akernel/connect.py @@ -1,12 +1,13 @@ -import zmq -from typing import Dict, Union +from __future__ import annotations -from zmq.asyncio import Context, Socket +import zmq +from zmq import Context +from zmq_anyio import Socket context = Context() -cfg_t = Dict[str, Union[str, int]] +cfg_t = dict[str, str | int] channel_socket_types = { "shell": zmq.ROUTER, @@ -21,12 +22,11 @@ def create_socket(channel: str, cfg: cfg_t) -> Socket: port = cfg[f"{channel}_port"] url = f"tcp://{ip}:{port}" socket_type = channel_socket_types[channel] - sock = context.socket(socket_type) + sock = Socket(context.socket(socket_type)) sock.linger = 1000 sock.bind(url) return sock def connect_channel(channel_name: str, cfg: cfg_t) -> Socket: - sock = create_socket(channel_name, cfg) - return sock + return create_socket(channel_name, cfg) diff --git a/src/akernel/display/__init__.py b/src/akernel/display/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/akernel/display/display.py b/src/akernel/display/display.py similarity index 71% rename from akernel/display/display.py rename to src/akernel/display/display.py index 5ea58bd..6634ac1 100644 --- a/akernel/display/display.py +++ b/src/akernel/display/display.py @@ -1,4 +1,4 @@ -from ..message import send_message, create_message +from ..message import create_message, serialize def display(*args, raw: bool = False) -> None: @@ -11,7 +11,8 @@ def display(*args, raw: bool = False) -> None: content=dict(data=data, transient={}, metadata={}), parent_header=parent_header, ) - send_message(msg, KERNEL.iopub_channel, KERNEL.key) + to_send = serialize(msg, KERNEL.key) + KERNEL.from_iopub_send_stream.send_nowait(to_send) def clear_output() -> None: diff --git a/akernel/execution.py b/src/akernel/execution.py similarity index 99% rename from akernel/execution.py rename to src/akernel/execution.py index bea45eb..17c704f 100644 --- a/akernel/execution.py +++ b/src/akernel/execution.py @@ -43,8 +43,7 @@ def pre_execute( f"{Style.RESET_ALL}:", f"{Fore.RED}{exception.text.rstrip()}{Style.RESET_ALL}", (exception.offset - 1) * " " + "^", - f"{Fore.RED}{type(exception).__name__}{Style.RESET_ALL}: " - f"{exception.args[0]}", + f"{Fore.RED}{type(exception).__name__}{Style.RESET_ALL}: {exception.args[0]}", ] else: if cache is not None: diff --git a/akernel/kernel.py b/src/akernel/kernel.py similarity index 72% rename from akernel/kernel.py rename to src/akernel/kernel.py index 196650f..427e868 100644 --- a/akernel/kernel.py +++ b/src/akernel/kernel.py @@ -1,21 +1,21 @@ from __future__ import annotations +import asyncio import sys import platform -import asyncio import json from io import StringIO from contextvars import ContextVar from typing import Dict, Any, List, Union, Awaitable, cast -from zmq.asyncio import Socket +from anyio import Event, create_task_group, sleep import comm # type: ignore from akernel.comm.manager import CommManager from akernel.display import display import akernel.IPython from akernel.IPython import core from .connect import connect_channel -from .message import receive_message, send_message, create_message, check_message +from .message import create_message, feed_identities, deserialize, serialize from .execution import pre_execute, cache_execution from .traceback import get_traceback from . import __version__ @@ -33,17 +33,12 @@ class Kernel: - shell_channel: Socket - iopub_channel: Socket - control_channel: Socket - stdin_channel: Socket - connection_cfg: Dict[str, Union[str, int]] - stop: asyncio.Event + stop_event: Event restart: bool key: str comm_manager: CommManager kernel_mode: str - cell_done: Dict[int, asyncio.Event] + cell_done: Dict[int, Event] running_cells: Dict[int, asyncio.Task] task_i: int execution_count: int @@ -53,31 +48,41 @@ class Kernel: _multi_kernel: bool | None _cache_kernel: bool | None _react_kernel: bool | None - kernel_initialized: List[str] + kernel_initialized: set[str] cache: Dict[str, Any] | None def __init__( self, - kernel_mode: str, - cache_dir: str | None, - connection_file: str, + to_shell_receive_stream, + from_shell_send_stream, + to_control_receive_stream, + from_control_send_stream, + to_stdin_receive_stream, + from_stdin_send_stream, + from_iopub_send_stream, + kernel_mode: str = "", + cache_dir: str | None = None, ): global KERNEL KERNEL = self self.comm_manager = CommManager() + comm.get_comm_manager = lambda: self.comm_manager - def get_comm_manager(): - return self.comm_manager + self.to_shell_receive_stream = to_shell_receive_stream + self.from_shell_send_stream = from_shell_send_stream + self.to_control_receive_stream = to_control_receive_stream + self.from_control_send_stream = from_control_send_stream + self.to_stdin_receive_stream = to_stdin_receive_stream + self.from_stdin_send_stream = from_stdin_send_stream + self.from_iopub_send_stream = from_iopub_send_stream - comm.get_comm_manager = get_comm_manager - self.loop = asyncio.get_event_loop() self.kernel_mode = kernel_mode self.cache_dir = cache_dir self._concurrent_kernel = None self._multi_kernel = None self._cache_kernel = None self._react_kernel = None - self.kernel_initialized = [] + self.kernel_initialized = set() self.globals = {} self.locals = {} self._chain_execution = not self.concurrent_kernel @@ -86,9 +91,6 @@ def get_comm_manager(): self.task_i = 0 self.execution_count = 1 self.execution_state = "starting" - with open(connection_file) as f: - self.connection_cfg = json.load(f) - self.key = cast(str, self.connection_cfg["key"]) self.restart = False self.interrupted = False self.msg_cnt = 0 @@ -98,27 +100,8 @@ def get_comm_manager(): self.cache = cache(cache_dir) else: self.cache = None - self.shell_channel = connect_channel("shell", self.connection_cfg) - self.iopub_channel = connect_channel("iopub", self.connection_cfg) - self.control_channel = connect_channel("control", self.connection_cfg) - self.stdin_channel = connect_channel("stdin", self.connection_cfg) - msg = self.create_message( - "status", content={"execution_state": self.execution_state} - ) - send_message(msg, self.iopub_channel, self.key) - self.execution_state = "idle" - self.stop = asyncio.Event() - while True: - try: - self.loop.run_until_complete(self.main()) - except KeyboardInterrupt: - self.interrupt() - else: - if not self.restart: - break - finally: - self.shell_task.cancel() - self.control_task.cancel() + self.stop_event = Event() + self.key = "0" def chain_execution(self) -> None: self._chain_execution = True @@ -166,14 +149,13 @@ def init_kernel(self, namespace): self.locals[namespace] = {} if self.react_kernel: code = ( - "import ipyx, ipywidgets;" - "globals().update({'ipyx': ipyx, 'ipywidgets': ipywidgets})" + "import ipyx, ipywidgets;globals().update({'ipyx': ipyx, 'ipywidgets': ipywidgets})" ) exec(code, self.globals[namespace], self.locals[namespace]) - self.kernel_initialized.append(namespace) + self.kernel_initialized.add(namespace) - def get_namespace(self, parent_header): + def get_namespace(self, parent_header) -> str: if self.multi_kernel: return parent_header["session"] @@ -185,15 +167,32 @@ def interrupt(self): task.cancel() self.running_cells = {} - async def main(self) -> None: - self.shell_task = asyncio.create_task(self.listen_shell()) - self.control_task = asyncio.create_task(self.listen_control()) + async def start(self) -> None: + async with create_task_group() as self.task_group: + msg = self.create_message("status", content={"execution_state": self.execution_state}) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) + self.execution_state = "idle" + while True: + try: + await self._start() + except KeyboardInterrupt: + self.interrupt() + else: + if not self.restart: + break + finally: + self.task_group.cancel_scope.cancel() + + async def _start(self) -> None: + self.task_group.start_soon(self.listen_shell) + self.task_group.start_soon(self.listen_control) while True: # run until shutdown request - await self.stop.wait() + await self.stop_event.wait() if self.restart: # kernel restart - self.stop.clear() + self.stop_event = Event() else: # kernel shutdown break @@ -201,15 +200,15 @@ async def main(self) -> None: async def listen_shell(self) -> None: while True: # let a chance to execute a blocking cell - await asyncio.sleep(0) + await sleep(0) # if there was a blocking cell execution, and it was interrupted, # let's ignore all the following execution requests until the pipe # is empty - if self.interrupted and not await check_message(self.shell_channel): + if self.interrupted and self.to_shell_receive_stream.statistics().tasks_waiting_send == 0: self.interrupted = False - res = await receive_message(self.shell_channel) - assert res is not None - idents, msg = res + msg_list = await self.to_shell_receive_stream.receive() + idents, msg_list = feed_identities(msg_list) + msg = deserialize(msg_list) msg_type = msg["header"]["msg_type"] parent_header = msg["header"] parent = msg @@ -230,14 +229,17 @@ async def listen_shell(self) -> None: }, "banner": "Python " + sys.version, }, + address=idents[0], ) - send_message(msg, self.shell_channel, self.key, idents[0]) + to_send = serialize(msg, self.key) + await self.from_shell_send_stream.send(to_send) msg = self.create_message( "status", parent_header=parent_header, content={"execution_state": self.execution_state}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) elif msg_type == "execute_request": self.execution_state = "busy" code = msg["content"]["code"] @@ -246,16 +248,18 @@ async def listen_shell(self) -> None: parent_header=parent_header, content={"execution_state": self.execution_state}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) if self.interrupted: - self.finish_execution(idents, parent_header, None, no_exec=True) + await self.finish_execution(idents, parent_header, None, no_exec=True) continue msg = self.create_message( "execute_input", parent_header=parent_header, content={"code": code, "execution_count": self.execution_count}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) namespace = self.get_namespace(parent_header) self.init_kernel(namespace) traceback, exception, cache_info = pre_execute( @@ -268,7 +272,7 @@ async def listen_shell(self) -> None: cache=self.cache, ) if cache_info["cached"]: - self.finish_execution( + await self.finish_execution( idents, parent_header, self.execution_count, @@ -276,7 +280,7 @@ async def listen_shell(self) -> None: ) self.execution_count += 1 elif traceback: - self.finish_execution( + await self.finish_execution( idents, parent_header, self.execution_count, @@ -305,7 +309,8 @@ async def listen_shell(self) -> None: parent_header=parent_header, content={"execution_state": self.execution_state}, ) - send_message(msg2, self.iopub_channel, self.key) + to_send = serialize(msg2, self.key) + await self.from_iopub_send_stream.send(to_send) if "target_name" in msg["content"]: target_name = msg["content"]["target_name"] comms: List[str] = [] @@ -314,28 +319,28 @@ async def listen_shell(self) -> None: parent_header=parent_header, content={ "status": "ok", - "comms": { - comm_id: {"target_name": target_name} - for comm_id in comms - }, + "comms": {comm_id: {"target_name": target_name} for comm_id in comms}, }, + address=idents[0], ) - send_message(msg2, self.shell_channel, self.key, idents[0]) + to_send = serialize(msg2, self.key) + await self.from_shell_send_stream.send(to_send) self.execution_state = "idle" msg2 = self.create_message( "status", parent_header=parent_header, content={"execution_state": self.execution_state}, ) - send_message(msg2, self.iopub_channel, self.key) + to_send = serialize(msg2, self.key) + await self.from_iopub_send_stream.send(to_send) elif msg_type == "comm_msg": - self.comm_manager.comm_msg(None, None, msg) + self.comm_manager.comm_msg(None, None, msg) # type: ignore[arg-type] async def listen_control(self) -> None: while True: - res = await receive_message(self.control_channel) - assert res is not None - idents, msg = res + msg_list = await self.to_control_receive_stream.receive() + idents, msg_list = feed_identities(msg_list) + msg = deserialize(msg_list) msg_type = msg["header"]["msg_type"] parent_header = msg["header"] if msg_type == "shutdown_request": @@ -344,11 +349,13 @@ async def listen_control(self) -> None: "shutdown_reply", parent_header=parent_header, content={"restart": self.restart}, + address=idents[0], ) - send_message(msg, self.control_channel, self.key, idents[0]) + to_send = serialize(msg, self.key) + await self.from_control_send_stream.send(to_send) if self.restart: self.execution_count = 1 - self.stop.set() + self.stop_event.set() async def execute_and_finish( self, @@ -376,12 +383,12 @@ async def execute_and_finish( exception = e traceback = get_traceback(code, e, execution_count) else: - self.show_result(result, self.globals[namespace], parent_header) + await self.show_result(result, self.globals[namespace], parent_header) cache_execution(self.cache, cache_info, self.globals[namespace], result) finally: self.cell_done[task_i].set() del self.locals[namespace][f"__async_cell{task_i}__"] - self.finish_execution( + await self.finish_execution( idents, parent_header, execution_count, @@ -391,7 +398,7 @@ async def execute_and_finish( if task_i in self.running_cells: del self.running_cells[task_i] - def finish_execution( + async def finish_execution( self, idents: List[bytes], parent_header: Dict[str, Any], @@ -403,7 +410,7 @@ def finish_execution( ) -> None: if result: namespace = self.get_namespace(parent_header) - self.show_result(result, self.globals[namespace], parent_header) + await self.show_result(result, self.globals[namespace], parent_header) if no_exec: status = "aborted" else: @@ -419,22 +426,26 @@ def finish_execution( "traceback": traceback, }, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) else: status = "ok" msg = self.create_message( "execute_reply", parent_header=parent_header, content={"status": status, "execution_count": execution_count}, + address=idents[0], ) - send_message(msg, self.shell_channel, self.key, idents[0]) + to_send = serialize(msg, self.key) + await self.from_shell_send_stream.send(to_send) self.execution_state = "idle" msg = self.create_message( "status", parent_header=parent_header, content={"execution_state": self.execution_state}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) def task(self, cell_i: int = -1) -> Awaitable: if cell_i < 0: @@ -453,10 +464,13 @@ async def ainput(self, prompt: str = "") -> Any: "input_request", parent_header=parent["header"], content={"prompt": prompt, "password": False}, + address=idents[0], ) - send_message(msg, self.stdin_channel, self.key, idents[0]) - res = await receive_message(self.stdin_channel) - assert res is not None + to_send = serialize(msg, self.key) + await self.from_stdin_send_stream.send(to_send) + msg_list = await self.to_stdin_receive_stream.receive() + idents, msg_list = feed_identities(msg_list) + msg = deserialize(msg_list) idents, msg = res if msg["content"]["status"] == "ok": return msg["content"]["value"] @@ -485,21 +499,27 @@ def print( parent_header=PARENT_VAR.get()["header"], content={"name": name, "text": text}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + self.from_iopub_send_stream.send_nowait(to_send) def create_message( self, msg_type: str, content: Dict = {}, parent_header: Dict[str, Any] = {}, + address: bytes | None = None, ) -> Dict[str, Any]: msg = create_message( - msg_type, content=content, parent_header=parent_header, msg_cnt=self.msg_cnt + msg_type, + content=content, + parent_header=parent_header, + msg_cnt=self.msg_cnt, + address=address, ) self.msg_cnt += 1 return msg - def show_result(self, result, globals_, parent_header): + async def show_result(self, result, globals_, parent_header): if result is not None: globals_["_"] = result send_stream = True @@ -522,4 +542,5 @@ def show_result(self, result, globals_, parent_header): parent_header=parent_header, content={"name": "stdout", "text": f"{repr(result)}\n"}, ) - send_message(msg, self.iopub_channel, self.key) + to_send = serialize(msg, self.key) + await self.from_iopub_send_stream.send(to_send) diff --git a/akernel/kernelspec.py b/src/akernel/kernelspec.py similarity index 85% rename from akernel/kernelspec.py rename to src/akernel/kernelspec.py index c0a8644..77971f4 100644 --- a/akernel/kernelspec.py +++ b/src/akernel/kernelspec.py @@ -5,9 +5,7 @@ import json -def write_kernelspec( - dir_name: str, mode: str, display_name: str, cache_dir: str | None -) -> None: +def write_kernelspec(dir_name: str, mode: str, display_name: str, cache_dir: str | None) -> None: argv = ["akernel", "launch"] if mode: argv.append(mode) diff --git a/akernel/message.py b/src/akernel/message.py similarity index 59% rename from akernel/message.py rename to src/akernel/message.py index f07628a..03d5832 100644 --- a/akernel/message.py +++ b/src/akernel/message.py @@ -4,10 +4,10 @@ import hmac import hashlib from datetime import datetime, timezone -from typing import List, Dict, Tuple, Any, cast +from typing import Any, cast from zmq.utils import jsonapi -from zmq.asyncio import Socket +from zmq_anyio import Socket from dateutil.parser import parse as dateutil_parse # type: ignore @@ -17,8 +17,8 @@ DELIM = b"" -def date_to_str(obj: Dict[str, Any]): - if obj is not None and "date" in obj and type(obj["date"]) != str: +def date_to_str(obj: dict[str, Any]): + if obj is not None and "date" in obj and not isinstance(obj["date"], str): obj["date"] = obj["date"].isoformat().replace("+00:00", "Z") return obj @@ -27,14 +27,13 @@ def utcnow() -> datetime: return datetime.utcnow().replace(tzinfo=timezone.utc) -def feed_identities(msg_list: List[bytes]) -> Tuple[List[bytes], List[bytes]]: +def feed_identities(msg_list: list[bytes]) -> tuple[list[bytes], list[bytes]]: idx = msg_list.index(DELIM) - return msg_list[:idx], msg_list[idx + 1 :] # noqa + idents = msg_list[:idx] or ["foo"] + return idents , msg_list[idx + 1 :] # noqa -def create_message_header( - msg_type: str, session_id: str, msg_cnt: int -) -> Dict[str, Any]: +def create_message_header(msg_type: str, session_id: str, msg_cnt: int) -> dict[str, Any]: if not session_id: session_id = msg_id = uuid.uuid4().hex else: @@ -52,12 +51,20 @@ def create_message_header( def create_message( msg_type: str, - content: Dict = {}, - metadata: Dict = {}, - parent_header: Dict[str, Any] = {}, + content: dict = {}, + metadata: dict[str, Any] | None = None, + parent_header: dict[str, Any] = {}, session_id: str = "", msg_cnt: int = 0, -) -> Dict[str, Any]: + buffers: list = [], + address: bytes | None = None, +) -> dict[str, Any]: + for buf in buffers: + if isinstance(buf, memoryview): + view = buf + else: + view = memoryview(buf) + assert view.contiguous if parent_header: session_id = parent_header["session"] header = create_message_header(msg_type, session_id, msg_cnt) @@ -68,11 +75,14 @@ def create_message( "parent_header": parent_header, "content": content, "metadata": metadata, + "buffers": buffers, } + if address is not None: + msg["address"] = address return msg -def serialize(msg: Dict[str, Any], key: str, address: bytes = b"") -> List[bytes]: +def serialize(msg: dict[str, Any], key: str) -> list[bytes]: message = [ pack(date_to_str(msg["header"])), pack(date_to_str(msg["parent_header"])), @@ -80,52 +90,34 @@ def serialize(msg: Dict[str, Any], key: str, address: bytes = b"") -> List[bytes pack(date_to_str(msg.get("content", {}))), ] to_send = [] - if address: + address = msg.get("address") + if address is not None: to_send.append(address) - to_send += [DELIM, sign(message, key)] + message + to_send += [DELIM, sign(message, key)] + message + msg.get("buffers", []) return to_send -async def receive_message( - sock: Socket, timeout: float = float("inf") -) -> Tuple[List[bytes], Dict[str, Any]] | None: - timeout *= 1000 # in ms - ready = await sock.poll(timeout) - if ready: - msg_list = await sock.recv_multipart() - idents, msg_list = feed_identities(msg_list) - return idents, deserialize(msg_list) +async def receive_message(sock: Socket) -> tuple[list[bytes], dict[str, Any]] | None: + return await sock.arecv_multipart().wait() return None -def send_message( - msg: Dict[str, Any], +async def send_message( + msg: dict[str, Any], sock: Socket, - key: str, - address: bytes = b"", - buffers: List | None = None, ) -> None: - to_send = serialize(msg, key, address) - buffers = buffers or [] - for buf in buffers: - if isinstance(buf, memoryview): - view = buf - else: - view = memoryview(buf) - assert view.contiguous - to_send += buffers - sock.send_multipart(to_send, copy=True) + await sock.asend_multipart(msg, copy=True).wait() -def pack(obj: Dict[str, Any]) -> bytes: +def pack(obj: dict[str, Any]) -> bytes: return jsonapi.dumps(obj) -def unpack(s: bytes) -> Dict[str, Any]: - return cast(Dict[str, Any], jsonapi.loads(s)) +def unpack(s: bytes) -> dict[str, Any]: + return cast(dict[str, Any], jsonapi.loads(s)) -def sign(msg_list: List[bytes], key: str) -> bytes: +def sign(msg_list: list[bytes], key: str) -> bytes: auth = hmac.new(key.encode("ascii"), digestmod=hashlib.sha256) h = auth.copy() for m in msg_list: @@ -133,14 +125,14 @@ def sign(msg_list: List[bytes], key: str) -> bytes: return h.hexdigest().encode() -def str_to_date(obj: Dict[str, Any]) -> Dict[str, Any]: +def str_to_date(obj: dict[str, Any]) -> dict[str, Any]: if "date" in obj: obj["date"] = dateutil_parse(obj["date"]) return obj -def deserialize(msg_list: List[bytes]) -> Dict[str, Any]: - message: Dict[str, Any] = {} +def deserialize(msg_list: list[bytes]) -> dict[str, Any]: + message: dict[str, Any] = {} header = unpack(msg_list[1]) message["header"] = str_to_date(header) message["msg_id"] = header["msg_id"] @@ -150,7 +142,3 @@ def deserialize(msg_list: List[bytes]) -> Dict[str, Any]: message["content"] = unpack(msg_list[4]) message["buffers"] = [memoryview(b) for b in msg_list[5:]] return message - - -async def check_message(sock: Socket) -> int: - return await sock.poll(0) diff --git a/akernel/traceback.py b/src/akernel/traceback.py similarity index 92% rename from akernel/traceback.py rename to src/akernel/traceback.py index baa1b89..74d43de 100644 --- a/akernel/traceback.py +++ b/src/akernel/traceback.py @@ -43,8 +43,5 @@ def get_traceback(code: str, exception, execution_count: int = 0): ] trace.append(code.splitlines()[frame.f_lineno - 1]) traceback += trace - traceback += [ - f"{Fore.RED}{type(exception).__name__}{Style.RESET_ALL}: " - f"{exception.args[0]}" - ] + traceback += [f"{Fore.RED}{type(exception).__name__}{Style.RESET_ALL}: {exception.args[0]}"] return traceback diff --git a/akernel/tests/conftest.py b/tests/conftest.py similarity index 82% rename from akernel/tests/conftest.py rename to tests/conftest.py index cc0447b..1cc455e 100644 --- a/akernel/tests/conftest.py +++ b/tests/conftest.py @@ -17,9 +17,7 @@ def all_modes(request): mode = "-".join(modes) kernel_name += f"-{mode}" if mode == "cache": - cache_dir = os.path.join( - sys.prefix, "share", "jupyter", "kernels", "akernel", "cache" - ) + cache_dir = os.path.join(sys.prefix, "share", "jupyter", "kernels", "akernel", "cache") shutil.rmtree(cache_dir, ignore_errors=True) display_name = f"Python 3 ({kernel_name})" write_kernelspec(kernel_name, mode, display_name, None) diff --git a/akernel/tests/test_async_code.py b/tests/test_async_code.py similarity index 100% rename from akernel/tests/test_async_code.py rename to tests/test_async_code.py diff --git a/akernel/tests/test_execution.py b/tests/test_execution.py similarity index 100% rename from akernel/tests/test_execution.py rename to tests/test_execution.py diff --git a/akernel/tests/test_kernel.py b/tests/test_kernel.py similarity index 92% rename from akernel/tests/test_kernel.py rename to tests/test_kernel.py index d94e0e3..08929c8 100644 --- a/akernel/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -3,6 +3,7 @@ import asyncio import signal import re +from pathlib import Path from textwrap import dedent import pytest @@ -10,9 +11,7 @@ TIMEOUT = 5 -KERNELSPEC_PATH = ( - os.environ["CONDA_PREFIX"] + "/share/jupyter/kernels/akernel/kernel.json" -) +KERNELSPEC_PATH = str(Path(sys.prefix) / "share" / "jupyter" / "kernels" / "akernel" / "kernel.json") ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") @@ -102,12 +101,8 @@ async def test_concurrent_cells(capfd, all_modes): kd = KernelDriver(kernelspec_path=KERNELSPEC_PATH, log=False) await kd.start(startup_timeout=TIMEOUT) asyncio.create_task(kd.execute("__unchain_execution__()", timeout=TIMEOUT)) - asyncio.create_task( - kd.execute("await asyncio.sleep(0.2)\nprint('done1')", timeout=TIMEOUT) - ) - asyncio.create_task( - kd.execute("await asyncio.sleep(0.1)\nprint('done2')", timeout=TIMEOUT) - ) + asyncio.create_task(kd.execute("await asyncio.sleep(0.2)\nprint('done1')", timeout=TIMEOUT)) + asyncio.create_task(kd.execute("await asyncio.sleep(0.1)\nprint('done2')", timeout=TIMEOUT)) await asyncio.sleep(0.5) await kd.stop() @@ -119,9 +114,7 @@ async def test_concurrent_cells(capfd, all_modes): async def test_chained_cells(capfd, all_modes): kd = KernelDriver(kernelspec_path=KERNELSPEC_PATH, log=False) await kd.start(startup_timeout=TIMEOUT) - asyncio.create_task( - kd.execute("await asyncio.sleep(0.2)\nprint('done1')", timeout=TIMEOUT) - ) + asyncio.create_task(kd.execute("await asyncio.sleep(0.2)\nprint('done1')", timeout=TIMEOUT)) asyncio.create_task( kd.execute( "await __task__()\nawait asyncio.sleep(0.1)\nprint('done2')", @@ -197,9 +190,7 @@ async def test_interrupt_blocking(capfd, all_modes): ) ) asyncio.create_task( - kd.execute( - "print('before 1')\ntime.sleep(1)\nprint('after 1')", timeout=TIMEOUT - ) + kd.execute("print('before 1')\ntime.sleep(1)\nprint('after 1')", timeout=TIMEOUT) ) await asyncio.sleep(0.1) interrupt_kernel(kd.kernel_process) diff --git a/akernel/tests/test_react_code.py b/tests/test_react_code.py similarity index 92% rename from akernel/tests/test_react_code.py rename to tests/test_react_code.py index 4bce480..b19b005 100644 --- a/akernel/tests/test_react_code.py +++ b/tests/test_react_code.py @@ -9,9 +9,7 @@ def test_assign_constant(): a = 1 """ ).strip() - expected = ( - code_assign.replace("lhs", "a").replace("rhs.v", "1 .v").replace("rhs", "1") - ) + expected = code_assign.replace("lhs", "a").replace("rhs.v", "1 .v").replace("rhs", "1") assert Transform(code, react=True).get_code() == expected