Skip to content

Commit c36eb42

Browse files
Speedup For EDistance-like distances (#880)
* init * use a kernel * kwargs for pairwisedistance * fix n_pairs blunder * update subproject commit * Update pertpy/tools/_distances/_distances.py Co-authored-by: Lukas Heumos <[email protected]> * resolve comments * Tutorials Signed-off-by: Lukas Heumos <[email protected]> * Update pertpy/tools/_distances/_distances.py Co-authored-by: Lukas Heumos <[email protected]> * Fix mixscape Signed-off-by: Lukas Heumos <[email protected]> * fastmath=True --------- Signed-off-by: Lukas Heumos <[email protected]> Co-authored-by: Lukas Heumos <[email protected]>
1 parent 543b37c commit c36eb42

File tree

2 files changed

+280
-26
lines changed

2 files changed

+280
-26
lines changed

pertpy/tools/_distances/_distances.py

Lines changed: 273 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import annotations
22

3+
import warnings
34
from abc import ABC, abstractmethod
45
from typing import TYPE_CHECKING, Literal, NamedTuple
56

67
import jax
78
import numpy as np
89
import pandas as pd
9-
from numba import jit
10+
from numba import jit, prange
1011
from ott.geometry.geometry import Geometry
1112
from ott.geometry.pointcloud import PointCloud
1213
from ott.problems.linear.linear_problem import LinearProblem
@@ -29,6 +30,85 @@
2930
from anndata import AnnData
3031

3132

33+
@jit(nopython=True, cache=True)
34+
def _euclidean_distance(x: np.ndarray, y: np.ndarray) -> float:
35+
"""Compute euclidean distance between two vectors."""
36+
dist_sq = 0.0
37+
for k in range(x.shape[0]):
38+
diff = x[k] - y[k]
39+
dist_sq += diff * diff
40+
return np.sqrt(dist_sq)
41+
42+
43+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
44+
def _euclidean_pairwise_mean_within(X: np.ndarray) -> float:
45+
"""Compute mean pairwise euclidean distance within a group (X to X)."""
46+
n_samples = X.shape[0]
47+
if n_samples < 2:
48+
return 0.0
49+
50+
total_distance = 0.0
51+
n_pairs = n_samples * (n_samples - 1) / 2.0
52+
53+
for i in prange(n_samples):
54+
for j in range(i + 1, n_samples):
55+
total_distance += _euclidean_distance(X[i], X[j])
56+
57+
return total_distance / n_pairs
58+
59+
60+
@jit(nopython=True, parallel=True, cache=True, fastmath=True)
61+
def _euclidean_pairwise_mean_between(X: np.ndarray, Y: np.ndarray) -> float:
62+
"""Compute mean pairwise euclidean distance between two groups (X to Y)."""
63+
n_samples_X = X.shape[0]
64+
n_samples_Y = Y.shape[0]
65+
66+
if n_samples_X == 0 or n_samples_Y == 0:
67+
return 0.0
68+
69+
total_distance = 0.0
70+
n_pairs = n_samples_X * n_samples_Y
71+
72+
for i in prange(n_samples_X):
73+
for j in range(n_samples_Y):
74+
total_distance += _euclidean_distance(X[i], Y[j])
75+
76+
return total_distance / n_pairs
77+
78+
79+
def pairwise_distance_mean(X: np.ndarray, Y: np.ndarray | None = None, metric: str = "euclidean", **kwargs) -> float:
80+
"""Compute mean pairwise distance. Memory-efficient and fast for euclidean.
81+
82+
If Y is None, computes within-group distances (X to X).
83+
84+
Args:
85+
X: First array of shape (n_samples_X, n_features).
86+
Y: Second array of shape (n_samples_Y, n_features). If None, computes within-group distances.
87+
metric: Distance metric to use.
88+
kwargs: Additional keyword arguments passed to the metric function.
89+
90+
Returns:
91+
Mean pairwise distance.
92+
"""
93+
if metric == "euclidean":
94+
if len(kwargs) > 0:
95+
warnings.warn(
96+
"kwargs are not used for euclidean distance.",
97+
UserWarning,
98+
stacklevel=2,
99+
)
100+
if Y is None:
101+
# Within-group distance (X to X)
102+
return _euclidean_pairwise_mean_within(X)
103+
else:
104+
# Between-group distance (X to Y)
105+
return _euclidean_pairwise_mean_between(X, Y)
106+
elif Y is None:
107+
return pairwise_distances(X, X, metric=metric, **kwargs).mean()
108+
else:
109+
return pairwise_distances(X, Y, metric=metric, **kwargs).mean()
110+
111+
32112
class MeanVar(NamedTuple):
33113
mean: float
34114
variance: float
@@ -327,12 +407,49 @@ def pairwise(
327407
df_var = pd.DataFrame(index=groups, columns=groups, dtype=float)
328408
fct = track if show_progressbar else lambda iterable: iterable
329409

330-
# Some metrics are able to handle precomputed distances. This means that
331-
# the pairwise distances between all cells are computed once and then
332-
# passed to the metric function. This is much faster than computing the
333-
# pairwise distances for each group separately. Other metrics are not
334-
# able to handle precomputed distances such as the PseudobulkDistance.
335-
if self.metric_fct.accepts_precomputed:
410+
# Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix
411+
# This mode is incompatible with bootstrap since cached values would be invalid
412+
use_value_cache = self.metric_fct.supports_value_cache() and not bootstrap
413+
414+
if use_value_cache:
415+
# Value caching mode: precompute within distances per group and between distances per pair
416+
embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key]
417+
418+
# Precompute within distances for each group
419+
df_within = pd.Series(index=groups, dtype=float)
420+
for group in fct(groups):
421+
idx_group = grouping == group
422+
cells_group = embedding[np.asarray(idx_group)]
423+
df_within[group] = self.metric_fct.compute_within_distance(cells_group, **kwargs)
424+
425+
# Precompute between distances for each pair
426+
df_between = pd.DataFrame(index=groups, columns=groups, dtype=float)
427+
for index_x, group_x in enumerate(fct(groups)):
428+
idx_x = grouping == group_x
429+
cells_x = embedding[np.asarray(idx_x)]
430+
for group_y in groups[index_x:]: # type: ignore
431+
if group_x == group_y:
432+
df_between.loc[group_x, group_y] = 0.0
433+
else:
434+
idx_y = grouping == group_y
435+
cells_y = embedding[np.asarray(idx_y)]
436+
between = self.metric_fct.compute_between_distance(cells_x, cells_y, **kwargs)
437+
df_between.loc[group_x, group_y] = between
438+
df_between.loc[group_y, group_x] = between
439+
440+
# Compute distances from cached values
441+
for group_x in groups:
442+
for group_y in groups:
443+
if group_x == group_y:
444+
df.loc[group_x, group_y] = 0.0
445+
else:
446+
dist = self.metric_fct.from_cached_values(
447+
df_within[group_x], df_within[group_y], df_between.loc[group_x, group_y], **kwargs
448+
)
449+
df.loc[group_x, group_y] = dist
450+
451+
elif self.metric_fct.accepts_precomputed:
452+
# Precomputed pairwise distance matrix mode
336453
# Precompute the pairwise distances if needed
337454
if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp:
338455
self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
@@ -364,6 +481,7 @@ def pairwise(
364481
df.loc[group_x, group_y] = df.loc[group_y, group_x] = bootstrap_output.mean
365482
df_var.loc[group_x, group_y] = df_var.loc[group_y, group_x] = bootstrap_output.variance
366483
else:
484+
# Standard mode: compute distances directly
367485
embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy()
368486
for index_x, group_x in enumerate(fct(groups)):
369487
cells_x = embedding[np.asarray(grouping == group_x)].copy()
@@ -461,12 +579,39 @@ def onesided_distances(
461579
df_var = pd.Series(index=groups, dtype=float)
462580
fct = track if show_progressbar else lambda iterable: iterable
463581

464-
# Some metrics are able to handle precomputed distances. This means that
465-
# the pairwise distances between all cells are computed once and then
466-
# passed to the metric function. This is much faster than computing the
467-
# pairwise distances for each group separately. Other metrics are not
468-
# able to handle precomputed distances such as the PseudobulkDistance.
469-
if self.metric_fct.accepts_precomputed:
582+
# Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix
583+
# This mode is incompatible with bootstrap since cached values would be invalid
584+
use_value_cache = self.metric_fct.supports_value_cache() and not bootstrap
585+
586+
if use_value_cache:
587+
# Value caching mode: precompute within distances per group and between distances per pair
588+
embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key]
589+
590+
# Precompute within distance for selected_group (only need it once)
591+
idx_selected = grouping == selected_group
592+
cells_selected = embedding[np.asarray(idx_selected)]
593+
within_selected = self.metric_fct.compute_within_distance(cells_selected, **kwargs)
594+
595+
# Precompute within distances for each group and between distances to selected_group
596+
for group_x in fct(groups):
597+
if group_x == selected_group:
598+
df.loc[group_x] = 0.0 # by distance axiom
599+
else:
600+
idx_x = grouping == group_x
601+
cells_x = embedding[np.asarray(idx_x)]
602+
603+
# Compute within distance for this group
604+
within_x = self.metric_fct.compute_within_distance(cells_x, **kwargs)
605+
606+
# Compute between distance to selected_group
607+
between = self.metric_fct.compute_between_distance(cells_x, cells_selected, **kwargs)
608+
609+
# Compute distance from cached values
610+
dist = self.metric_fct.from_cached_values(within_x, within_selected, between, **kwargs)
611+
df.loc[group_x] = dist
612+
613+
elif self.metric_fct.accepts_precomputed:
614+
# Precomputed pairwise distance matrix mode
470615
# Precompute the pairwise distances if needed
471616
if f"{self.obsm_key}_{self.cell_wise_metric}_predistances" not in adata.obsp:
472617
self.precompute_distances(adata, n_jobs=n_jobs, **kwargs)
@@ -495,6 +640,7 @@ def onesided_distances(
495640
df.loc[group_x] = bootstrap_output.mean
496641
df_var.loc[group_x] = bootstrap_output.variance
497642
else:
643+
# Standard mode: compute distances directly
498644
embedding = adata.layers[self.layer_key] if self.layer_key else adata.obsm[self.obsm_key].copy()
499645
for group_x in fct(groups):
500646
cells_x = embedding[np.asarray(grouping == group_x)].copy()
@@ -655,6 +801,61 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
655801
"""
656802
raise NotImplementedError("Metric class is abstract.")
657803

804+
def supports_value_cache(self) -> bool:
805+
"""Whether this metric supports value-level caching (within/between distances).
806+
807+
Returns:
808+
bool: True if value caching is supported, False otherwise.
809+
"""
810+
return False
811+
812+
def compute_within_distance(self, X: np.ndarray, **kwargs) -> float:
813+
"""Compute within-group distance statistic for caching.
814+
815+
Only called if supports_value_cache() returns True.
816+
This represents the mean pairwise distance within a single group.
817+
818+
Args:
819+
X: Vector of shape (n_samples, n_features) for a single group.
820+
kwargs: Additional keyword arguments.
821+
822+
Returns:
823+
float: Cached within-group distance statistic.
824+
"""
825+
raise NotImplementedError("Metric does not support value caching.")
826+
827+
def compute_between_distance(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
828+
"""Compute between-group distance statistic for caching.
829+
830+
Only called if supports_value_cache() returns True.
831+
This represents the mean pairwise distance between two groups.
832+
833+
Args:
834+
X: First vector of shape (n_samples, n_features).
835+
Y: Second vector of shape (n_samples, n_features).
836+
kwargs: Additional keyword arguments.
837+
838+
Returns:
839+
float: Cached between-group distance statistic.
840+
"""
841+
raise NotImplementedError("Metric does not support value caching.")
842+
843+
def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float:
844+
"""Compute distance using precomputed cached values.
845+
846+
Only called if supports_value_cache() returns True and values have been cached.
847+
848+
Args:
849+
within_X: Precomputed within-group distance for group X.
850+
within_Y: Precomputed within-group distance for group Y.
851+
between: Precomputed between-group distance for pair (X, Y).
852+
kwargs: Additional keyword arguments.
853+
854+
Returns:
855+
float: Distance between X and Y.
856+
"""
857+
raise NotImplementedError("Metric does not support value caching.")
858+
658859

659860
class Edistance(AbstractDistance):
660861
"""Edistance metric."""
@@ -665,16 +866,32 @@ def __init__(self) -> None:
665866
self.cell_wise_metric = "euclidean"
666867

667868
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
668-
sigma_X = pairwise_distances(X, X, metric=self.cell_wise_metric, **kwargs).mean()
669-
sigma_Y = pairwise_distances(Y, Y, metric=self.cell_wise_metric, **kwargs).mean()
670-
delta = pairwise_distances(X, Y, metric=self.cell_wise_metric, **kwargs).mean()
671-
return 2 * delta - sigma_X - sigma_Y
869+
within_X = pairwise_distance_mean(X, metric=self.cell_wise_metric, **kwargs)
870+
within_Y = pairwise_distance_mean(Y, metric=self.cell_wise_metric, **kwargs)
871+
between = pairwise_distance_mean(X, Y, metric=self.cell_wise_metric, **kwargs)
872+
return 2 * between - within_X - within_Y
672873

673874
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
674-
sigma_X = P[idx, :][:, idx].mean()
675-
sigma_Y = P[~idx, :][:, ~idx].mean()
676-
delta = P[idx, :][:, ~idx].mean()
677-
return 2 * delta - sigma_X - sigma_Y
875+
within_X = P[idx, :][:, idx].mean()
876+
within_Y = P[~idx, :][:, ~idx].mean()
877+
between = P[idx, :][:, ~idx].mean()
878+
return 2 * between - within_X - within_Y
879+
880+
def supports_value_cache(self) -> bool:
881+
"""Edistance benefits from caching within and between distances."""
882+
return True
883+
884+
def compute_within_distance(self, X: np.ndarray, **kwargs) -> float:
885+
"""Compute within-group distance (mean pairwise distance within group)."""
886+
return pairwise_distance_mean(X, metric=self.cell_wise_metric, **kwargs)
887+
888+
def compute_between_distance(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
889+
"""Compute between-group distance (mean pairwise distance between groups)."""
890+
return pairwise_distance_mean(X, Y, metric=self.cell_wise_metric, **kwargs)
891+
892+
def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float:
893+
"""Compute edistance using cached within and between distances."""
894+
return 2 * between - within_X - within_Y
678895

679896

680897
class MMD(AbstractDistance):
@@ -706,6 +923,40 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0,
706923
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
707924
raise NotImplementedError("MMD cannot be called on a pairwise distance matrix.")
708925

926+
def supports_value_cache(self) -> bool:
927+
"""MMD benefits from caching within and between kernel means."""
928+
return True
929+
930+
def compute_within_distance(self, X: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs) -> float:
931+
"""Compute within-group kernel mean (mean of kernel matrix within group)."""
932+
if kernel == "linear":
933+
XX = np.dot(X, X.T)
934+
elif kernel == "rbf":
935+
XX = rbf_kernel(X, X, gamma=gamma)
936+
elif kernel == "poly":
937+
XX = polynomial_kernel(X, X, degree=degree, gamma=gamma, coef0=0)
938+
else:
939+
raise ValueError(f"Kernel {kernel} not recognized.")
940+
return XX.mean()
941+
942+
def compute_between_distance(
943+
self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0, degree=2, **kwargs
944+
) -> float:
945+
"""Compute between-group kernel mean (mean of kernel matrix between groups)."""
946+
if kernel == "linear":
947+
XY = np.dot(X, Y.T)
948+
elif kernel == "rbf":
949+
XY = rbf_kernel(X, Y, gamma=gamma)
950+
elif kernel == "poly":
951+
XY = polynomial_kernel(X, Y, degree=degree, gamma=gamma, coef0=0)
952+
else:
953+
raise ValueError(f"Kernel {kernel} not recognized.")
954+
return XY.mean()
955+
956+
def from_cached_values(self, within_X: float, within_Y: float, between: float, **kwargs) -> float:
957+
"""Compute MMD using cached within and between kernel means."""
958+
return within_X + within_Y - 2 * between
959+
709960

710961
class WassersteinDistance(AbstractDistance):
711962
def __init__(self) -> None:
@@ -810,7 +1061,7 @@ def __init__(self) -> None:
8101061
self.accepts_precomputed = True
8111062

8121063
def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
813-
return pairwise_distances(X, Y, **kwargs).mean()
1064+
return pairwise_distance_mean(X, Y, **kwargs)
8141065

8151066
def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
8161067
return P[idx, :][:, ~idx].mean()

pertpy/tools/_mixscape.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
import warnings
55
from collections import OrderedDict
6-
from typing import TYPE_CHECKING, Literal
6+
from typing import TYPE_CHECKING, Any, Literal
77

88
import matplotlib.pyplot as plt
99
import numpy as np
@@ -1220,9 +1220,12 @@ def __init__(
12201220
if self.fixed_cov_indices:
12211221
self.fixed_cov_values = np.array([fixed_covariances[i] for i in self.fixed_cov_indices])
12221222

1223-
def _m_step(self, X: np.ndarray, log_resp: np.ndarray):
1224-
"""Modified M-step to respect fixed means and covariances."""
1225-
super()._m_step(X, log_resp)
1223+
def _m_step(self, X: np.ndarray, log_resp: np.ndarray, xp: Any | None = None):
1224+
"""Modified M-step to respect fixed means and covariances.
1225+
1226+
xp is the array API namespace passed by sklearn 1.6+ for backend compatibility.
1227+
"""
1228+
super()._m_step(X, log_resp, xp=xp)
12261229

12271230
if self.fixed_mean_indices:
12281231
self.means_[self.fixed_mean_indices] = self.fixed_mean_values

0 commit comments

Comments
 (0)