11from __future__ import annotations
22
3+ import warnings
34from abc import ABC , abstractmethod
45from typing import TYPE_CHECKING , Literal , NamedTuple
56
67import jax
78import numpy as np
89import pandas as pd
9- from numba import jit
10+ from numba import jit , prange
1011from ott .geometry .geometry import Geometry
1112from ott .geometry .pointcloud import PointCloud
1213from ott .problems .linear .linear_problem import LinearProblem
2930 from anndata import AnnData
3031
3132
33+ @jit (nopython = True , cache = True )
34+ def _euclidean_distance (x : np .ndarray , y : np .ndarray ) -> float :
35+ """Compute euclidean distance between two vectors."""
36+ dist_sq = 0.0
37+ for k in range (x .shape [0 ]):
38+ diff = x [k ] - y [k ]
39+ dist_sq += diff * diff
40+ return np .sqrt (dist_sq )
41+
42+
43+ @jit (nopython = True , parallel = True , cache = True , fastmath = True )
44+ def _euclidean_pairwise_mean_within (X : np .ndarray ) -> float :
45+ """Compute mean pairwise euclidean distance within a group (X to X)."""
46+ n_samples = X .shape [0 ]
47+ if n_samples < 2 :
48+ return 0.0
49+
50+ total_distance = 0.0
51+ n_pairs = n_samples * (n_samples - 1 ) / 2.0
52+
53+ for i in prange (n_samples ):
54+ for j in range (i + 1 , n_samples ):
55+ total_distance += _euclidean_distance (X [i ], X [j ])
56+
57+ return total_distance / n_pairs
58+
59+
60+ @jit (nopython = True , parallel = True , cache = True , fastmath = True )
61+ def _euclidean_pairwise_mean_between (X : np .ndarray , Y : np .ndarray ) -> float :
62+ """Compute mean pairwise euclidean distance between two groups (X to Y)."""
63+ n_samples_X = X .shape [0 ]
64+ n_samples_Y = Y .shape [0 ]
65+
66+ if n_samples_X == 0 or n_samples_Y == 0 :
67+ return 0.0
68+
69+ total_distance = 0.0
70+ n_pairs = n_samples_X * n_samples_Y
71+
72+ for i in prange (n_samples_X ):
73+ for j in range (n_samples_Y ):
74+ total_distance += _euclidean_distance (X [i ], Y [j ])
75+
76+ return total_distance / n_pairs
77+
78+
79+ def pairwise_distance_mean (X : np .ndarray , Y : np .ndarray | None = None , metric : str = "euclidean" , ** kwargs ) -> float :
80+ """Compute mean pairwise distance. Memory-efficient and fast for euclidean.
81+
82+ If Y is None, computes within-group distances (X to X).
83+
84+ Args:
85+ X: First array of shape (n_samples_X, n_features).
86+ Y: Second array of shape (n_samples_Y, n_features). If None, computes within-group distances.
87+ metric: Distance metric to use.
88+ kwargs: Additional keyword arguments passed to the metric function.
89+
90+ Returns:
91+ Mean pairwise distance.
92+ """
93+ if metric == "euclidean" :
94+ if len (kwargs ) > 0 :
95+ warnings .warn (
96+ "kwargs are not used for euclidean distance." ,
97+ UserWarning ,
98+ stacklevel = 2 ,
99+ )
100+ if Y is None :
101+ # Within-group distance (X to X)
102+ return _euclidean_pairwise_mean_within (X )
103+ else :
104+ # Between-group distance (X to Y)
105+ return _euclidean_pairwise_mean_between (X , Y )
106+ elif Y is None :
107+ return pairwise_distances (X , X , metric = metric , ** kwargs ).mean ()
108+ else :
109+ return pairwise_distances (X , Y , metric = metric , ** kwargs ).mean ()
110+
111+
32112class MeanVar (NamedTuple ):
33113 mean : float
34114 variance : float
@@ -327,12 +407,49 @@ def pairwise(
327407 df_var = pd .DataFrame (index = groups , columns = groups , dtype = float )
328408 fct = track if show_progressbar else lambda iterable : iterable
329409
330- # Some metrics are able to handle precomputed distances. This means that
331- # the pairwise distances between all cells are computed once and then
332- # passed to the metric function. This is much faster than computing the
333- # pairwise distances for each group separately. Other metrics are not
334- # able to handle precomputed distances such as the PseudobulkDistance.
335- if self .metric_fct .accepts_precomputed :
410+ # Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix
411+ # This mode is incompatible with bootstrap since cached values would be invalid
412+ use_value_cache = self .metric_fct .supports_value_cache () and not bootstrap
413+
414+ if use_value_cache :
415+ # Value caching mode: precompute within distances per group and between distances per pair
416+ embedding = adata .layers [self .layer_key ] if self .layer_key else adata .obsm [self .obsm_key ]
417+
418+ # Precompute within distances for each group
419+ df_within = pd .Series (index = groups , dtype = float )
420+ for group in fct (groups ):
421+ idx_group = grouping == group
422+ cells_group = embedding [np .asarray (idx_group )]
423+ df_within [group ] = self .metric_fct .compute_within_distance (cells_group , ** kwargs )
424+
425+ # Precompute between distances for each pair
426+ df_between = pd .DataFrame (index = groups , columns = groups , dtype = float )
427+ for index_x , group_x in enumerate (fct (groups )):
428+ idx_x = grouping == group_x
429+ cells_x = embedding [np .asarray (idx_x )]
430+ for group_y in groups [index_x :]: # type: ignore
431+ if group_x == group_y :
432+ df_between .loc [group_x , group_y ] = 0.0
433+ else :
434+ idx_y = grouping == group_y
435+ cells_y = embedding [np .asarray (idx_y )]
436+ between = self .metric_fct .compute_between_distance (cells_x , cells_y , ** kwargs )
437+ df_between .loc [group_x , group_y ] = between
438+ df_between .loc [group_y , group_x ] = between
439+
440+ # Compute distances from cached values
441+ for group_x in groups :
442+ for group_y in groups :
443+ if group_x == group_y :
444+ df .loc [group_x , group_y ] = 0.0
445+ else :
446+ dist = self .metric_fct .from_cached_values (
447+ df_within [group_x ], df_within [group_y ], df_between .loc [group_x , group_y ], ** kwargs
448+ )
449+ df .loc [group_x , group_y ] = dist
450+
451+ elif self .metric_fct .accepts_precomputed :
452+ # Precomputed pairwise distance matrix mode
336453 # Precompute the pairwise distances if needed
337454 if f"{ self .obsm_key } _{ self .cell_wise_metric } _predistances" not in adata .obsp :
338455 self .precompute_distances (adata , n_jobs = n_jobs , ** kwargs )
@@ -364,6 +481,7 @@ def pairwise(
364481 df .loc [group_x , group_y ] = df .loc [group_y , group_x ] = bootstrap_output .mean
365482 df_var .loc [group_x , group_y ] = df_var .loc [group_y , group_x ] = bootstrap_output .variance
366483 else :
484+ # Standard mode: compute distances directly
367485 embedding = adata .layers [self .layer_key ] if self .layer_key else adata .obsm [self .obsm_key ].copy ()
368486 for index_x , group_x in enumerate (fct (groups )):
369487 cells_x = embedding [np .asarray (grouping == group_x )].copy ()
@@ -461,12 +579,39 @@ def onesided_distances(
461579 df_var = pd .Series (index = groups , dtype = float )
462580 fct = track if show_progressbar else lambda iterable : iterable
463581
464- # Some metrics are able to handle precomputed distances. This means that
465- # the pairwise distances between all cells are computed once and then
466- # passed to the metric function. This is much faster than computing the
467- # pairwise distances for each group separately. Other metrics are not
468- # able to handle precomputed distances such as the PseudobulkDistance.
469- if self .metric_fct .accepts_precomputed :
582+ # Check if metric supports value caching (within/between distances) - more efficient than precomputed matrix
583+ # This mode is incompatible with bootstrap since cached values would be invalid
584+ use_value_cache = self .metric_fct .supports_value_cache () and not bootstrap
585+
586+ if use_value_cache :
587+ # Value caching mode: precompute within distances per group and between distances per pair
588+ embedding = adata .layers [self .layer_key ] if self .layer_key else adata .obsm [self .obsm_key ]
589+
590+ # Precompute within distance for selected_group (only need it once)
591+ idx_selected = grouping == selected_group
592+ cells_selected = embedding [np .asarray (idx_selected )]
593+ within_selected = self .metric_fct .compute_within_distance (cells_selected , ** kwargs )
594+
595+ # Precompute within distances for each group and between distances to selected_group
596+ for group_x in fct (groups ):
597+ if group_x == selected_group :
598+ df .loc [group_x ] = 0.0 # by distance axiom
599+ else :
600+ idx_x = grouping == group_x
601+ cells_x = embedding [np .asarray (idx_x )]
602+
603+ # Compute within distance for this group
604+ within_x = self .metric_fct .compute_within_distance (cells_x , ** kwargs )
605+
606+ # Compute between distance to selected_group
607+ between = self .metric_fct .compute_between_distance (cells_x , cells_selected , ** kwargs )
608+
609+ # Compute distance from cached values
610+ dist = self .metric_fct .from_cached_values (within_x , within_selected , between , ** kwargs )
611+ df .loc [group_x ] = dist
612+
613+ elif self .metric_fct .accepts_precomputed :
614+ # Precomputed pairwise distance matrix mode
470615 # Precompute the pairwise distances if needed
471616 if f"{ self .obsm_key } _{ self .cell_wise_metric } _predistances" not in adata .obsp :
472617 self .precompute_distances (adata , n_jobs = n_jobs , ** kwargs )
@@ -495,6 +640,7 @@ def onesided_distances(
495640 df .loc [group_x ] = bootstrap_output .mean
496641 df_var .loc [group_x ] = bootstrap_output .variance
497642 else :
643+ # Standard mode: compute distances directly
498644 embedding = adata .layers [self .layer_key ] if self .layer_key else adata .obsm [self .obsm_key ].copy ()
499645 for group_x in fct (groups ):
500646 cells_x = embedding [np .asarray (grouping == group_x )].copy ()
@@ -655,6 +801,61 @@ def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
655801 """
656802 raise NotImplementedError ("Metric class is abstract." )
657803
804+ def supports_value_cache (self ) -> bool :
805+ """Whether this metric supports value-level caching (within/between distances).
806+
807+ Returns:
808+ bool: True if value caching is supported, False otherwise.
809+ """
810+ return False
811+
812+ def compute_within_distance (self , X : np .ndarray , ** kwargs ) -> float :
813+ """Compute within-group distance statistic for caching.
814+
815+ Only called if supports_value_cache() returns True.
816+ This represents the mean pairwise distance within a single group.
817+
818+ Args:
819+ X: Vector of shape (n_samples, n_features) for a single group.
820+ kwargs: Additional keyword arguments.
821+
822+ Returns:
823+ float: Cached within-group distance statistic.
824+ """
825+ raise NotImplementedError ("Metric does not support value caching." )
826+
827+ def compute_between_distance (self , X : np .ndarray , Y : np .ndarray , ** kwargs ) -> float :
828+ """Compute between-group distance statistic for caching.
829+
830+ Only called if supports_value_cache() returns True.
831+ This represents the mean pairwise distance between two groups.
832+
833+ Args:
834+ X: First vector of shape (n_samples, n_features).
835+ Y: Second vector of shape (n_samples, n_features).
836+ kwargs: Additional keyword arguments.
837+
838+ Returns:
839+ float: Cached between-group distance statistic.
840+ """
841+ raise NotImplementedError ("Metric does not support value caching." )
842+
843+ def from_cached_values (self , within_X : float , within_Y : float , between : float , ** kwargs ) -> float :
844+ """Compute distance using precomputed cached values.
845+
846+ Only called if supports_value_cache() returns True and values have been cached.
847+
848+ Args:
849+ within_X: Precomputed within-group distance for group X.
850+ within_Y: Precomputed within-group distance for group Y.
851+ between: Precomputed between-group distance for pair (X, Y).
852+ kwargs: Additional keyword arguments.
853+
854+ Returns:
855+ float: Distance between X and Y.
856+ """
857+ raise NotImplementedError ("Metric does not support value caching." )
858+
658859
659860class Edistance (AbstractDistance ):
660861 """Edistance metric."""
@@ -665,16 +866,32 @@ def __init__(self) -> None:
665866 self .cell_wise_metric = "euclidean"
666867
667868 def __call__ (self , X : np .ndarray , Y : np .ndarray , ** kwargs ) -> float :
668- sigma_X = pairwise_distances (X , X , metric = self .cell_wise_metric , ** kwargs ). mean ( )
669- sigma_Y = pairwise_distances (Y , Y , metric = self .cell_wise_metric , ** kwargs ). mean ( )
670- delta = pairwise_distances (X , Y , metric = self .cell_wise_metric , ** kwargs ). mean ( )
671- return 2 * delta - sigma_X - sigma_Y
869+ within_X = pairwise_distance_mean (X , metric = self .cell_wise_metric , ** kwargs )
870+ within_Y = pairwise_distance_mean (Y , metric = self .cell_wise_metric , ** kwargs )
871+ between = pairwise_distance_mean (X , Y , metric = self .cell_wise_metric , ** kwargs )
872+ return 2 * between - within_X - within_Y
672873
673874 def from_precomputed (self , P : np .ndarray , idx : np .ndarray , ** kwargs ) -> float :
674- sigma_X = P [idx , :][:, idx ].mean ()
675- sigma_Y = P [~ idx , :][:, ~ idx ].mean ()
676- delta = P [idx , :][:, ~ idx ].mean ()
677- return 2 * delta - sigma_X - sigma_Y
875+ within_X = P [idx , :][:, idx ].mean ()
876+ within_Y = P [~ idx , :][:, ~ idx ].mean ()
877+ between = P [idx , :][:, ~ idx ].mean ()
878+ return 2 * between - within_X - within_Y
879+
880+ def supports_value_cache (self ) -> bool :
881+ """Edistance benefits from caching within and between distances."""
882+ return True
883+
884+ def compute_within_distance (self , X : np .ndarray , ** kwargs ) -> float :
885+ """Compute within-group distance (mean pairwise distance within group)."""
886+ return pairwise_distance_mean (X , metric = self .cell_wise_metric , ** kwargs )
887+
888+ def compute_between_distance (self , X : np .ndarray , Y : np .ndarray , ** kwargs ) -> float :
889+ """Compute between-group distance (mean pairwise distance between groups)."""
890+ return pairwise_distance_mean (X , Y , metric = self .cell_wise_metric , ** kwargs )
891+
892+ def from_cached_values (self , within_X : float , within_Y : float , between : float , ** kwargs ) -> float :
893+ """Compute edistance using cached within and between distances."""
894+ return 2 * between - within_X - within_Y
678895
679896
680897class MMD (AbstractDistance ):
@@ -706,6 +923,40 @@ def __call__(self, X: np.ndarray, Y: np.ndarray, *, kernel="linear", gamma=1.0,
706923 def from_precomputed (self , P : np .ndarray , idx : np .ndarray , ** kwargs ) -> float :
707924 raise NotImplementedError ("MMD cannot be called on a pairwise distance matrix." )
708925
926+ def supports_value_cache (self ) -> bool :
927+ """MMD benefits from caching within and between kernel means."""
928+ return True
929+
930+ def compute_within_distance (self , X : np .ndarray , * , kernel = "linear" , gamma = 1.0 , degree = 2 , ** kwargs ) -> float :
931+ """Compute within-group kernel mean (mean of kernel matrix within group)."""
932+ if kernel == "linear" :
933+ XX = np .dot (X , X .T )
934+ elif kernel == "rbf" :
935+ XX = rbf_kernel (X , X , gamma = gamma )
936+ elif kernel == "poly" :
937+ XX = polynomial_kernel (X , X , degree = degree , gamma = gamma , coef0 = 0 )
938+ else :
939+ raise ValueError (f"Kernel { kernel } not recognized." )
940+ return XX .mean ()
941+
942+ def compute_between_distance (
943+ self , X : np .ndarray , Y : np .ndarray , * , kernel = "linear" , gamma = 1.0 , degree = 2 , ** kwargs
944+ ) -> float :
945+ """Compute between-group kernel mean (mean of kernel matrix between groups)."""
946+ if kernel == "linear" :
947+ XY = np .dot (X , Y .T )
948+ elif kernel == "rbf" :
949+ XY = rbf_kernel (X , Y , gamma = gamma )
950+ elif kernel == "poly" :
951+ XY = polynomial_kernel (X , Y , degree = degree , gamma = gamma , coef0 = 0 )
952+ else :
953+ raise ValueError (f"Kernel { kernel } not recognized." )
954+ return XY .mean ()
955+
956+ def from_cached_values (self , within_X : float , within_Y : float , between : float , ** kwargs ) -> float :
957+ """Compute MMD using cached within and between kernel means."""
958+ return within_X + within_Y - 2 * between
959+
709960
710961class WassersteinDistance (AbstractDistance ):
711962 def __init__ (self ) -> None :
@@ -810,7 +1061,7 @@ def __init__(self) -> None:
8101061 self .accepts_precomputed = True
8111062
8121063 def __call__ (self , X : np .ndarray , Y : np .ndarray , ** kwargs ) -> float :
813- return pairwise_distances (X , Y , ** kwargs ). mean ( )
1064+ return pairwise_distance_mean (X , Y , ** kwargs )
8141065
8151066 def from_precomputed (self , P : np .ndarray , idx : np .ndarray , ** kwargs ) -> float :
8161067 return P [idx , :][:, ~ idx ].mean ()
0 commit comments