diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 213f887266..1f18830b28 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,8 +60,13 @@ jobs: - name: Install dependencies run: | - uv tool install --with='click!=8.3.0' hatch + echo "::group::Install hatch" + uv tool install hatch + echo "::endgroup::" + echo "::group::Create environment" hatch -v env create ${{ matrix.env.name }} + echo "::endgroup::" + hatch run ${{ matrix.env.name }}:session-info scanpy anndata - name: Run tests if: matrix.env.test-type == null diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..4a5695190e 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -71,19 +71,7 @@ // followed by the pip installed packages). // "matrix": { - "numpy": [""], - "scipy": [""], - "h5py": [""], - "natsort": [""], - "pandas": [""], - "memory_profiler": [""], - "zarr": [""], - "pytest": [""], - "pip+igraph": [""], // https://github.com/airspeed-velocity/asv/issues/1554 - // "psutil": [""] "pooch": [""], - "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/docs/release-notes/3929.fix.md b/docs/release-notes/3929.fix.md new file mode 100644 index 0000000000..e67cc5388e --- /dev/null +++ b/docs/release-notes/3929.fix.md @@ -0,0 +1 @@ +Fix compatibility with pandas 3.0 {smaller}`P Angerer` diff --git a/hatch.toml b/hatch.toml index bb7c357e71..02e8df4ff3 100644 --- a/hatch.toml +++ b/hatch.toml @@ -35,6 +35,7 @@ overrides.matrix.deps.python = [ ] overrides.matrix.deps.extra-dependencies = [ { if = [ "pre" ], value = "anndata @ git+https://github.com/scverse/anndata.git" }, + { if = [ "pre" ], value = "pandas>=3rc0" }, ] overrides.matrix.deps.dependency-groups = [ { if = [ "stable", "pre", "low-vers" ], value = "test" }, diff --git a/pyproject.toml b/pyproject.toml index 7b80ef798e..5565870587 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "numpy>=2", "fast-array-utils[accel,sparse]>=1.2.1", "matplotlib>=3.9", - "pandas >=2.2.2, <3.0.0rc0", + "pandas >=2.2.2", "scipy>=1.13", "seaborn>=0.13.2", "h5py>=3.11", diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index 956656d581..598b498bc8 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -29,6 +29,7 @@ import h5py import numpy as np +import pandas as pd from anndata._core.sparse_dataset import BaseCompressedSparseDataset from packaging.version import Version @@ -44,6 +45,7 @@ from anndata import AnnData from igraph import Graph from numpy.typing import ArrayLike, NDArray + from pandas._typing import Dtype as PdDtype from .._compat import CSRBase from ..neighbors import NeighborsParams, RPForestDict @@ -79,6 +81,7 @@ "sanitize_anndata", "select_groups", "update_params", + "with_cat_dtype", ] @@ -494,6 +497,23 @@ def moving_average(a: np.ndarray, n: int): return ret[n - 1 :] / n +@singledispatch +def with_cat_dtype[X: pd.Series | pd.CategoricalIndex | pd.Categorical]( + x: X, dtype: PdDtype +) -> X: + raise NotImplementedError + + +@with_cat_dtype.register(pd.Series) +def _(x: pd.Series, dtype: PdDtype) -> pd.Series: + return x.cat.set_categories(x.cat.categories.astype(dtype)) + + +@with_cat_dtype.register(pd.Categorical | pd.CategoricalIndex) +def _[X: pd.Categorical | pd.CategoricalIndex](x: X, dtype: PdDtype) -> X: + return x.set_categories(x.categories.astype(dtype)) + + # -------------------------------------------------------------------------------- # Deal with tool parameters # -------------------------------------------------------------------------------- diff --git a/src/scanpy/external/exporting.py b/src/scanpy/external/exporting.py index a8b1cf567b..60605a3330 100644 --- a/src/scanpy/external/exporting.py +++ b/src/scanpy/external/exporting.py @@ -219,7 +219,7 @@ def spring_project( # noqa: PLR0912, PLR0915 np.save(subplot_dir / "cell_filter.npy", np.arange(x.shape[0])) # Write 2-D coordinates, after adjusting to roughly match SPRING's default d3js force layout parameters - coords = coords - coords.min(0)[None, :] + coords = coords - coords.min(axis=0)[None, :] coords = ( coords * (np.array([1000, 1000]) / coords.ptp(0))[None, :] + np.array([200, -200])[None, :] @@ -342,8 +342,8 @@ def _get_color_stats_genes(color_stats, x, gene_list): means, variances = mean_var(x, axis=0, correction=1) stdevs = np.zeros(variances.shape, dtype=float) stdevs[variances > 0] = np.sqrt(variances[variances > 0]) - mins = x.min(0).todense().A1 - maxes = x.max(0).todense().A1 + mins = x.min(axis=0).todense().A1 + maxes = x.max(axis=0).todense().A1 pctl = 99.6 pctl_n = (100 - pctl) / 100.0 * x.shape[0] diff --git a/src/scanpy/get/get.py b/src/scanpy/get/get.py index ab1c62f0e5..c106db8a31 100644 --- a/src/scanpy/get/get.py +++ b/src/scanpy/get/get.py @@ -259,8 +259,8 @@ def obs_df( >>> plotdf = sc.get.obs_df( ... pbmc, keys=["CD8B", "n_genes"], obsm_keys=[("X_umap", 0), ("X_umap", 1)] ... ) - >>> plotdf.columns - Index(['CD8B', 'n_genes', 'X_umap-0', 'X_umap-1'], dtype='object') + >>> plotdf.columns.astype("string") + Index(['CD8B', 'n_genes', 'X_umap-0', 'X_umap-1'], dtype='string') >>> plotdf.plot.scatter("X_umap-0", "X_umap-1", c="CD8B") # doctest: +SKIP diff --git a/src/scanpy/plotting/_anndata.py b/src/scanpy/plotting/_anndata.py index 65d63cc1d6..30156bf482 100755 --- a/src/scanpy/plotting/_anndata.py +++ b/src/scanpy/plotting/_anndata.py @@ -749,7 +749,7 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915 layer: str | None = None, density_norm: DensityNorm = "width", order: Sequence[str] | None = None, - multi_panel: bool | None = None, + multi_panel: bool = False, xlabel: str = "", ylabel: str | Sequence[str] | None = None, rotation: float | None = None, @@ -1202,11 +1202,11 @@ def heatmap( # noqa: PLR0912, PLR0913, PLR0915 ).issubset(categories) if standard_scale == "obs": - obs_tidy = obs_tidy.sub(obs_tidy.min(1), axis=0) - obs_tidy = obs_tidy.div(obs_tidy.max(1), axis=0).fillna(0) + obs_tidy = obs_tidy.sub(obs_tidy.min(axis=1), axis=0) + obs_tidy = obs_tidy.div(obs_tidy.max(axis=1), axis=0).fillna(0) elif standard_scale == "var": - obs_tidy -= obs_tidy.min(0) - obs_tidy = (obs_tidy / obs_tidy.max(0)).fillna(0) + obs_tidy -= obs_tidy.min(axis=0) + obs_tidy = (obs_tidy / obs_tidy.max(axis=0)).fillna(0) elif standard_scale is None: pass else: diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index ca2855b62f..c8bb7cb492 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -213,11 +213,13 @@ def __init__( # noqa: PLR0913 dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() if standard_scale == "group": - dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0) - dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0) + dot_color_df = dot_color_df.sub(dot_color_df.min(axis=1), axis=0) + dot_color_df = dot_color_df.div( + dot_color_df.max(axis=1), axis=0 + ).fillna(0) elif standard_scale == "var": - dot_color_df -= dot_color_df.min(0) - dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0) + dot_color_df -= dot_color_df.min(axis=0) + dot_color_df = (dot_color_df / dot_color_df.max(axis=0)).fillna(0) elif standard_scale is None: pass else: @@ -696,10 +698,10 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 group_axis = 1 if standard_scale is not None: dot_color = dot_color.sub( - dot_color.min((group_axis + 1) % 2), axis=group_axis + dot_color.min(axis=1 - group_axis), axis=group_axis ) dot_color = dot_color.div( - dot_color.max((group_axis + 1) % 2), axis=group_axis + dot_color.max(axis=1 - group_axis), axis=group_axis ).fillna(0) # make scatter plot in which # x = var_names diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index 674ccf723c..e803bdc1eb 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -180,11 +180,11 @@ def __init__( # noqa: PLR0913 ) if standard_scale == "group": - values_df = values_df.sub(values_df.min(1), axis=0) - values_df = values_df.div(values_df.max(1), axis=0).fillna(0) + values_df = values_df.sub(values_df.min(axis=1), axis=0) + values_df = values_df.div(values_df.max(axis=1), axis=0).fillna(0) elif standard_scale == "var": - values_df -= values_df.min(0) - values_df = (values_df / values_df.max(0)).fillna(0) + values_df -= values_df.min(axis=0) + values_df = (values_df / values_df.max(axis=0)).fillna(0) elif standard_scale is None: pass else: diff --git a/src/scanpy/plotting/_scrublet.py b/src/scanpy/plotting/_scrublet.py index 5a3d7b158a..2705419236 100644 --- a/src/scanpy/plotting/_scrublet.py +++ b/src/scanpy/plotting/_scrublet.py @@ -78,7 +78,7 @@ def scrublet_score_distribution( if "batched_by" in adata.uns["scrublet"]: batched_by = adata.uns["scrublet"]["batched_by"] - batches = adata.obs[batched_by].astype("category", copy=False) + batches = adata.obs[batched_by].astype("category") n_batches = len(batches.cat.categories) figsize = (figsize[0], figsize[1] * n_batches) else: diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 284bef860f..19d6aa7a97 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -229,11 +229,13 @@ def __init__( # noqa: PLR0913 msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" warn(msg, FutureWarning) if standard_scale == "group": - self.obs_tidy = self.obs_tidy.sub(self.obs_tidy.min(1), axis=0) - self.obs_tidy = self.obs_tidy.div(self.obs_tidy.max(1), axis=0).fillna(0) + self.obs_tidy = self.obs_tidy.sub(self.obs_tidy.min(axis=1), axis=0) + self.obs_tidy = self.obs_tidy.div(self.obs_tidy.max(axis=1), axis=0).fillna( + 0 + ) elif standard_scale == "var": - self.obs_tidy -= self.obs_tidy.min(0) - self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(0)).fillna(0) + self.obs_tidy -= self.obs_tidy.min(axis=0) + self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(axis=0)).fillna(0) elif standard_scale is None: pass else: @@ -556,7 +558,7 @@ def _make_rows_of_violinplots( x=x, y="values", data=_df, - orient="vertical", + orient="v", ax=row_ax, # use a single `color`` if row_colors[idx] is defined # else use the palette diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 13b74fc636..eab7908c87 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -16,7 +16,7 @@ from ... import logging as logg from ..._compat import old_positionals from ..._settings import settings -from ..._utils import _doc_params, _empty, sanitize_anndata +from ..._utils import _doc_params, _empty, sanitize_anndata, with_cat_dtype from ...get import rank_genes_groups_df from .._anndata import ranking from .._docs import ( @@ -1295,12 +1295,13 @@ def rank_genes_groups_violin( # noqa: PLR0913 _gene_names = _gene_names.tolist() df = obs_df(adata, _gene_names, use_raw=use_raw, gene_symbols=gene_symbols) new_gene_names = df.columns - df["hue"] = adata.obs[groups_key].astype(str).values + df["hue"] = adata.obs[groups_key].astype(str).array if reference == "rest": df.loc[df["hue"] != group_name, "hue"] = "rest" else: df.loc[~df["hue"].isin([group_name, reference]), "hue"] = np.nan - df["hue"] = df["hue"].astype("category") + # Convert categories to object because of https://github.com/mwaskom/seaborn/issues/3893 + df["hue"] = with_cat_dtype(df["hue"].astype("category"), object) df_tidy = pd.melt(df, id_vars="hue", value_vars=new_gene_names) x = "variable" y = "value" @@ -1316,7 +1317,7 @@ def rank_genes_groups_violin( # noqa: PLR0913 hue="hue", split=split, density_norm=density_norm, - orient="vertical", + orient="v", ax=ax, ) if strip: diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index c22a03917f..26d5a7763e 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -814,9 +814,7 @@ def highly_variable_genes( # noqa: PLR0913 adata.var["highly_variable"] = df["highly_variable"] adata.var["means"] = df["means"] adata.var["dispersions"] = df["dispersions"] - adata.var["dispersions_norm"] = df["dispersions_norm"].astype( - np.float32, copy=False - ) + adata.var["dispersions_norm"] = df["dispersions_norm"].astype(np.float32) if batch_key is not None: adata.var["highly_variable_nbatches"] = df["highly_variable_nbatches"] diff --git a/tests/conftest.py b/tests/conftest.py index 68dedd681d..a0111acde2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import shutil import sys +from contextlib import ExitStack from pathlib import Path from textwrap import dedent from typing import TYPE_CHECKING, TypedDict, cast @@ -148,3 +149,9 @@ def plt(): from matplotlib import pyplot as plt return plt + + +@pytest.fixture +def exit_stack() -> Generator[ExitStack]: + with ExitStack() as stack: + yield stack diff --git a/tests/external/test_hashsolo.py b/tests/external/test_hashsolo.py index 9468c4f6ce..d2d376728a 100644 --- a/tests/external/test_hashsolo.py +++ b/tests/external/test_hashsolo.py @@ -39,6 +39,7 @@ def test_cell_demultiplexing(): expected = pd.array(doublets + classes + negatives, dtype="string") classification = test_data.obs["Classification"].array.astype("string") # This is a bit flaky, so allow some mismatches: - if (expected != classification).sum() > 3: + # (Series() because of https://github.com/pandas-dev/pandas/issues/63458) + if pd.Series(expected != classification).sum() > 3: # Compare lists for better error message assert classification.tolist() == expected.tolist() diff --git a/tests/notebooks/test_pbmc3k.py b/tests/notebooks/test_pbmc3k.py index ea6e063f4e..c131222801 100644 --- a/tests/notebooks/test_pbmc3k.py +++ b/tests/notebooks/test_pbmc3k.py @@ -17,12 +17,10 @@ import numpy as np import pytest -from matplotlib.testing import setup from sklearn.exceptions import ConvergenceWarning -setup() - import scanpy as sc +from scanpy._compat import pkg_version from testing.scanpy._pytest.marks import needs HERE: Path = Path(__file__).parent @@ -32,7 +30,7 @@ @needs.leidenalg # https://github.com/pandas-dev/pandas/issues/61928 @pytest.mark.filterwarnings("ignore:invalid value encountered in cast:RuntimeWarning") -def test_pbmc3k(image_comparer): # noqa: PLR0915 +def test_pbmc3k(subtests: pytest.Subtests, image_comparer) -> None: # noqa: PLR0915 # ensure violin plots and other non-determinstic plots have deterministic behavior np.random.seed(0) save_and_compare_images = partial(image_comparer, ROOT, tol=20) @@ -55,19 +53,22 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 # add the total counts per cell as observations-annotation to adata adata.obs["n_counts"] = adata.X.sum(axis=1).A1 - sc.pl.violin( - adata, - ["n_genes", "n_counts", "percent_mito"], - jitter=False, - multi_panel=True, - show=False, - ) - save_and_compare_images("violin") + with subtests.test("violin"): + sc.pl.violin( + adata, + ["n_genes", "n_counts", "percent_mito"], + jitter=False, + multi_panel=True, + show=False, + ) + save_and_compare_images("violin") - sc.pl.scatter(adata, x="n_counts", y="percent_mito", show=False) - save_and_compare_images("scatter_1") - sc.pl.scatter(adata, x="n_counts", y="n_genes", show=False) - save_and_compare_images("scatter_2") + with subtests.test("scatter_1"): + sc.pl.scatter(adata, x="n_counts", y="percent_mito", show=False) + save_and_compare_images("scatter_1") + with subtests.test("scatter_2"): + sc.pl.scatter(adata, x="n_counts", y="n_genes", show=False) + save_and_compare_images("scatter_2") adata = adata[adata.obs["n_genes"] < 2500, :] adata = adata[adata.obs["percent_mito"] < 0.05, :] @@ -84,9 +85,11 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 max_mean=3, min_disp=0.5, ) - with pytest.warns(FutureWarning, match=r"sc\.pl\.highly_variable_genes"): - sc.pl.filter_genes_dispersion(filter_result, show=False) - save_and_compare_images("filter_genes_dispersion") + + with subtests.test("filter_genes_dispersion"): + with pytest.warns(FutureWarning, match=r"sc\.pl\.highly_variable_genes"): + sc.pl.filter_genes_dispersion(filter_result, show=False) + save_and_compare_images("filter_genes_dispersion") adata = adata[:, filter_result.gene_subset].copy() sc.pp.log1p(adata) @@ -96,19 +99,17 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 # PCA sc.pp.pca(adata, svd_solver="arpack") - sc.pl.pca(adata, color="CST3", show=False) - save_and_compare_images("pca") + with subtests.test("pca"): + sc.pl.pca(adata, color="CST3", show=False) + save_and_compare_images("pca") - sc.pl.pca_variance_ratio(adata, log=True, show=False) - save_and_compare_images("pca_variance_ratio") + with subtests.test("pca_variance_ratio"): + sc.pl.pca_variance_ratio(adata, log=True, show=False) + save_and_compare_images("pca_variance_ratio") - # UMAP + # Neighbors sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40) - # sc.tl.umap(adata) # umaps lead to slight variations - - # sc.pl.umap(adata, color=['CST3', 'NKG7', 'PPBP'], use_raw=False, show=False) - # save_and_compare_images('umap_1') # Clustering the graph @@ -121,10 +122,9 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 flavor="igraph", ) - # sc.pl.umap(adata, color=["leiden", "CST3", "NKG7"], show=False) - # save_and_compare_images("umap_2") - sc.pl.scatter(adata, "CST3", "NKG7", color="leiden", show=False) - save_and_compare_images("scatter_3") + with subtests.test("scatter_3"): + sc.pl.scatter(adata, "CST3", "NKG7", color="leiden", show=False) + save_and_compare_images("scatter_3") # Finding marker genes # Due to incosistency with our test runner vs local, these clusters need to @@ -136,9 +136,10 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 data_df = adata[:, marker_genes].to_df() data_df["leiden"] = adata.obs["leiden"] max_idxs = data_df.groupby("leiden", observed=True).mean().idxmax() - assert not max_idxs[marker_genes][ - max_idxs[marker_genes].duplicated(keep=False) - ].tolist(), "Not all marker genes are unique per cluster" + with subtests.test("marker_genes_unique"): + assert not max_idxs[marker_genes][ + max_idxs[marker_genes].duplicated(keep=False) + ].tolist(), "Not all marker genes are unique per cluster" leiden_relabel = { max_idxs[marker_gene]: str(i) for i, marker_gene in enumerate(marker_genes) } @@ -152,23 +153,29 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 ) sc.tl.rank_genes_groups(adata, "leiden") - sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False, show=False) - save_and_compare_images("rank_genes_groups_1") + with subtests.test("rank_genes_groups_1"): + sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False, show=False) + save_and_compare_images("rank_genes_groups_1") with warnings.catch_warnings(): # This seems to only happen with older versions of scipy for some reason warnings.filterwarnings("always", category=ConvergenceWarning) sc.tl.rank_genes_groups(adata, "leiden", method="logreg") - sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False, show=False) - save_and_compare_images("rank_genes_groups_2") + with subtests.test("rank_genes_groups_2"): + sc.pl.rank_genes_groups(adata, n_genes=20, sharey=False, show=False) + save_and_compare_images("rank_genes_groups_2") sc.tl.rank_genes_groups(adata, "leiden", groups=["0"], reference="1") - sc.pl.rank_genes_groups(adata, groups="0", n_genes=20, show=False) - save_and_compare_images("rank_genes_groups_3") + with subtests.test("rank_genes_groups_3"): + sc.pl.rank_genes_groups(adata, groups="0", n_genes=20, show=False) + save_and_compare_images("rank_genes_groups_3") - # gives a strange error, probably due to jitter or something - # sc.pl.rank_genes_groups_violin(adata, groups='0', n_genes=8) - # save_and_compare_images('rank_genes_groups_4') + with subtests.test("rank_genes_groups_4"): + sc.pl.rank_genes_groups_violin(adata, groups="0", n_genes=8, show=False) + try: + save_and_compare_images("rank_genes_groups_4") + except AssertionError: + pytest.xfail("rank_genes_groups_violin not reproducible (jitter?)") new_cluster_names = [ *["CD4 T cells", "CD8 T cells", "B cells", "NK cells"], @@ -176,9 +183,14 @@ def test_pbmc3k(image_comparer): # noqa: PLR0915 ] adata.rename_categories("leiden", new_cluster_names) - # sc.pl.umap(adata, color='leiden', legend_loc='on data', title='', frameon=False, show=False) - # save_and_compare_images('umap_3') - sc.pl.violin( - adata, ["CST3", "NKG7", "PPBP"], groupby="leiden", rotation=90, show=False - ) - save_and_compare_images("violin_2") + with subtests.test("violin_2"): + sc.pl.violin( + adata, ["CST3", "NKG7", "PPBP"], groupby="leiden", rotation=90, show=False + ) + try: + save_and_compare_images("violin_2") + except AssertionError: + if pkg_version("pandas").major >= 3: + # See https://github.com/scverse/scanpy/pull/3929#issuecomment-3685784980 + pytest.xfail("seaborn is incompatible with pandas 3") + raise diff --git a/tests/test_plotting.py b/tests/test_plotting.py index adedf1279e..10e23b9844 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -28,6 +28,7 @@ if TYPE_CHECKING: from collections.abc import Callable + from contextlib import ExitStack from typing import Any, Literal from matplotlib.axes import Axes @@ -174,7 +175,14 @@ def test_clustermap(image_comparer, obs_keys, name): params_dotplot_matrixplot_stacked_violin = [ - pytest.param(id, fn, id=id) + pytest.param( + *(id, fn), + id=id, + # See https://github.com/scverse/scanpy/pull/3929#issuecomment-3685784980 + marks=[pytest.mark.xfail(reason="seaborn is incompatible with pandas 3")] + if pkg_version("pandas").major >= 3 and "stacked_violin" in id + else [], + ) for id, fn in [ ( "dotplot", @@ -458,7 +466,14 @@ def test_stacked_violin_obj(image_comparer, plt): # checking for https://github.com/scverse/scanpy/issues/3152 -def test_stacked_violin_swap_axes_match(image_comparer): +def test_stacked_violin_swap_axes_match( + request: pytest.FixtureRequest, image_comparer +) -> None: + if pkg_version("pandas").major >= 3: + # See https://github.com/scverse/scanpy/pull/3929#issuecomment-3685784980 + reason = "seaborn is incompatible with pandas 3" + request.applymarker(pytest.mark.xfail(reason=reason)) + save_and_compare_images = partial(image_comparer, ROOT, tol=10) pbmc = pbmc68k_reduced() sc.tl.rank_genes_groups( @@ -534,14 +549,18 @@ def test_multiple_plots(image_comparer): save_and_compare_images("multiple_plots") -def test_violin(image_comparer): +def test_violin( + subtests: pytest.Subtests, exit_stack: ExitStack, image_comparer +) -> None: save_and_compare_images = partial(image_comparer, ROOT, tol=40) + exit_stack.enter_context(plt.rc_context()) + sc.pl.set_rcParams_defaults() + sc.set_figure_params(dpi=50, color_map="viridis") - with plt.rc_context(): - sc.pl.set_rcParams_defaults() - sc.set_figure_params(dpi=50, color_map="viridis") + pbmc = pbmc68k_reduced() + pbmc.layers["negative"] = pbmc.X * -1 - pbmc = pbmc68k_reduced() + with subtests.test("default"): sc.pl.violin( pbmc, ["n_genes", "percent_mito", "n_counts"], @@ -552,6 +571,7 @@ def test_violin(image_comparer): ) save_and_compare_images("violin_multi_panel") + with subtests.test(groupby="bulk_labels"): sc.pl.violin( pbmc, ["n_genes", "percent_mito", "n_counts"], @@ -563,10 +583,15 @@ def test_violin(image_comparer): show=False, rotation=90, ) - save_and_compare_images("violin_multi_panel_with_groupby") - - # test use of layer - pbmc.layers["negative"] = pbmc.X * -1 + try: + save_and_compare_images("violin_multi_panel_with_groupby") + except AssertionError: + if pkg_version("pandas").major >= 3: + # See https://github.com/scverse/scanpy/pull/3929#issuecomment-3685784980 + pytest.skip("seaborn is incompatible with pandas 3") + raise + + with subtests.test(layer="negative"): sc.pl.violin( pbmc, "CST3", @@ -831,7 +856,19 @@ def test_correlation(image_comparer): @pytest.mark.parametrize( ("name", "fn"), - [pytest.param(name, fn, id=name) for name, fn in _RANK_GENES_GROUPS_PARAMS], + [ + pytest.param( + name, + fn, + id=name, + # See https://github.com/scverse/scanpy/pull/3929#issuecomment-3685784980 + # and https://github.com/mwaskom/seaborn/issues/3893 + marks=[pytest.mark.xfail(reason="seaborn is incompatible with pandas 3")] + if pkg_version("pandas").major >= 3 and "violin" in name + else [], + ) + for name, fn in _RANK_GENES_GROUPS_PARAMS + ], ) def test_rank_genes_groups(image_comparer, name, fn): save_and_compare_images = partial(image_comparer, ROOT, tol=15) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index a061b2c64d..57c5420750 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -112,8 +112,5 @@ def test_write_strings_to_cats(fmt: Literal["h5ad", "zarr"], *, s2c: bool) -> No adata_read = sc.read(p) assert_equal(adata_read, adata) - assert ( - adata_read.obs["a"].dtype - == adata.obs["a"].dtype - == ("category" if s2c else "object") - ) + assert adata_read.obs["a"].dtype == adata.obs["a"].dtype + assert adata_read.obs["a"].dtype in (("category",) if s2c else ("object", "string"))