diff --git a/array_api_strict/_array_object.py b/array_api_strict/_array_object.py index d9d485f..629af98 100644 --- a/array_api_strict/_array_object.py +++ b/array_api_strict/_array_object.py @@ -20,7 +20,7 @@ from collections.abc import Iterator from enum import IntEnum from types import EllipsisType, ModuleType -from typing import Any, Final, Literal, SupportsIndex +from typing import Any, Final, Literal, SupportsIndex, Callable import numpy as np import numpy.typing as npt @@ -125,6 +125,9 @@ def __new__(cls, *args: object, **kwargs: object) -> Array: raise TypeError( "The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead." ) + + def __reduce__(self) -> tuple[Callable, tuple[npt.NDArray[Any], Device]]: + return (self._new, (self._array, self._device)) # These functions are not required by the spec, but are implemented for # the sake of usability. diff --git a/array_api_strict/tests/test_array_object.py b/array_api_strict/tests/test_array_object.py index 0269d13..f580585 100644 --- a/array_api_strict/tests/test_array_object.py +++ b/array_api_strict/tests/test_array_object.py @@ -1,6 +1,7 @@ import sys import warnings import operator +import pickle from builtins import all as all_ from numpy.testing import assert_raises @@ -747,3 +748,15 @@ def test_dlpack_2023_12(api_version): a.__dlpack__(copy=False) a.__dlpack__(copy=True) a.__dlpack__(copy=None) + +def test_pickle(): + """Check that arrays are pickleable (despite raising on `__new__`)""" + a = ones(2) + min_supported_protocol = 2 + for protocol in range(min_supported_protocol, pickle.HIGHEST_PROTOCOL + 1): + bytes = pickle.dumps(a, protocol=protocol) + a_from_pickle = pickle.loads(bytes) + assert a_from_pickle.device == a.device + assert a_from_pickle.dtype == a.dtype + assert a_from_pickle.shape == a.shape + assert all(a_from_pickle == a) diff --git a/pyproject.toml b/pyproject.toml index 00d09da..22fb964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,9 @@ classifiers = [ "Operating System :: OS Independent", ] +[project.optional-dependencies] +test = ["pytest", "hypothesis"] + [project.urls] Homepage = "https://data-apis.org/array-api-strict/" Repository = "https://github.com/data-apis/array-api-strict"