Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections.abc import Iterator
from enum import IntEnum
from types import EllipsisType, ModuleType
from typing import Any, Final, Literal, SupportsIndex
from typing import Any, Final, Literal, SupportsIndex, Callable

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -125,6 +125,9 @@ def __new__(cls, *args: object, **kwargs: object) -> Array:
raise TypeError(
"The array_api_strict Array object should not be instantiated directly. Use an array creation function, such as asarray(), instead."
)

def __reduce__(self) -> tuple[Callable, tuple[npt.NDArray[Any], Device]]:
return (self._new, (self._array, self._device))

# These functions are not required by the spec, but are implemented for
# the sake of usability.
Expand Down
13 changes: 13 additions & 0 deletions array_api_strict/tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import warnings
import operator
import pickle
from builtins import all as all_

from numpy.testing import assert_raises
Expand Down Expand Up @@ -747,3 +748,15 @@ def test_dlpack_2023_12(api_version):
a.__dlpack__(copy=False)
a.__dlpack__(copy=True)
a.__dlpack__(copy=None)

def test_pickle():
"""Check that arrays are pickleable (despite raising on `__new__`)"""
a = ones(2)
min_supported_protocol = 2
for protocol in range(min_supported_protocol, pickle.HIGHEST_PROTOCOL + 1):
bytes = pickle.dumps(a, protocol=protocol)
a_from_pickle = pickle.loads(bytes)
assert a_from_pickle.device == a.device
assert a_from_pickle.dtype == a.dtype
assert a_from_pickle.shape == a.shape
assert all(a_from_pickle == a)
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ classifiers = [
"Operating System :: OS Independent",
]

[project.optional-dependencies]
test = ["pytest", "hypothesis"]

[project.urls]
Homepage = "https://data-apis.org/array-api-strict/"
Repository = "https://github.com/data-apis/array-api-strict"
Expand Down
Loading