Skip to content

Commit b094c4c

Browse files
authored
Merge pull request #753 from Avid29/ColorAnalysis/Clustering_Refinement
Added DBScan clustering to ColorPaletteSampler
2 parents 4399602 + 7ed9c4b commit b094c4c

File tree

3 files changed

+167
-3
lines changed

3 files changed

+167
-3
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System.Numerics;
6+
7+
namespace CommunityToolkit.WinUI.Helpers;
8+
9+
public partial class ColorPaletteSampler
10+
{
11+
private ref struct DBScan
12+
{
13+
private const int Unclassified = -1;
14+
15+
public static Vector3[] Cluster(Span<Vector3> points, float epsilon, int minPoints, ref float[] weights)
16+
{
17+
var centroids = new List<Vector3>();
18+
var newWeights = new List<float>();
19+
20+
// Create context
21+
var context = new DBScan(points, weights, epsilon, minPoints);
22+
23+
// Attempt to create a cluster around each point,
24+
// skipping that point if already classified
25+
for (int i = 0; i < points.Length; i++)
26+
{
27+
// Already classified, skip
28+
if (context.PointClusterIds[i] is not Unclassified)
29+
continue;
30+
31+
// Attempt to create cluster
32+
if(context.CreateCluster(i, out var centroid, out var weight))
33+
{
34+
centroids.Add(centroid);
35+
newWeights.Add(weight);
36+
}
37+
}
38+
39+
weights = newWeights.ToArray();
40+
return centroids.ToArray();
41+
}
42+
43+
private bool CreateCluster(int originIndex, out Vector3 centroid, out float weight)
44+
{
45+
weight = 0;
46+
centroid = Vector3.Zero;
47+
var seeds = GetSeeds(originIndex, out bool isCore);
48+
49+
// Not enough seeds to be a core point.
50+
// Cannot create a cluster around it
51+
if (!isCore)
52+
{
53+
return false;
54+
}
55+
56+
ExpandCluster(seeds, out centroid, out weight);
57+
ClusterId++;
58+
59+
return true;
60+
}
61+
62+
private void ExpandCluster(Queue<int> seeds, out Vector3 centroid, out float weight)
63+
{
64+
weight = 0;
65+
centroid = Vector3.Zero;
66+
while(seeds.Count > 0)
67+
{
68+
var seedIndex = seeds.Dequeue();
69+
70+
// Skip duplicate seed entries
71+
if (PointClusterIds[seedIndex] is not Unclassified)
72+
continue;
73+
74+
// Assign this seed's id to the cluster
75+
PointClusterIds[seedIndex] = ClusterId;
76+
var w = Weights[seedIndex];
77+
centroid += Points[seedIndex] * w;
78+
weight += w;
79+
80+
// Check if this seed is a core point
81+
var grandSeeds = GetSeeds(seedIndex, out var seedIsCore);
82+
if (!seedIsCore)
83+
continue;
84+
85+
// This seed is a core point. Enqueue all its seeds
86+
foreach(var grandSeedIndex in grandSeeds)
87+
if (PointClusterIds[grandSeedIndex] is Unclassified)
88+
seeds.Enqueue(grandSeedIndex);
89+
}
90+
91+
centroid /= weight;
92+
}
93+
94+
private Queue<int> GetSeeds(int originIndex, out bool isCore)
95+
{
96+
var origin = Points[originIndex];
97+
98+
// NOTE: Seeding could be done using a spatial data structure to improve traversal
99+
// speeds. However currently DBSCAN is run after KMeans with a maximum of 8 points.
100+
// There is no need.
101+
102+
var seeds = new Queue<int>();
103+
for (int i = 0; i < Points.Length; i++)
104+
{
105+
if (Vector3.DistanceSquared(origin, Points[i]) <= Epsilon2)
106+
seeds.Enqueue(i);
107+
}
108+
109+
// Count includes self, so compare without checking equals
110+
isCore = seeds.Count > MinPoints;
111+
return seeds;
112+
}
113+
114+
private DBScan(Span<Vector3> points, Span<float> weights, float epsilon, int minPoints)
115+
{
116+
Points = points;
117+
Weights = weights;
118+
Epsilon2 = epsilon * epsilon;
119+
MinPoints = minPoints;
120+
121+
ClusterId = 0;
122+
PointClusterIds = new int[points.Length];
123+
for(int i = 0; i < points.Length; i++)
124+
PointClusterIds[i] = Unclassified;
125+
}
126+
127+
/// <summary>
128+
/// Gets the points being clustered.
129+
/// </summary>
130+
public Span<Vector3> Points { get; }
131+
132+
/// <summary>
133+
/// Gets the weights of the points.
134+
/// </summary>
135+
public Span<float> Weights { get; }
136+
137+
/// <summary>
138+
/// Gets or sets the id of the currently evaluating cluster.
139+
/// </summary>
140+
public int ClusterId { get; set; }
141+
142+
/// <summary>
143+
/// Gets an array containing the id of the cluster each point belongs to.
144+
/// </summary>
145+
public int[] PointClusterIds { get; }
146+
147+
/// <summary>
148+
/// Gets epsilon squared. Where epsilon is the max distance to consider two points connected.
149+
/// </summary>
150+
/// <remarks>
151+
/// This is cached as epsilon squared to skip a sqrt operation when comparing distances to epsilon.
152+
/// </remarks>
153+
public double Epsilon2 { get; }
154+
155+
/// <summary>
156+
/// Gets the minimum number of points required to make a core point.
157+
/// </summary>
158+
public int MinPoints { get; }
159+
}
160+
}

components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.Clustering.cs renamed to components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.KMeans.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ private static void Split(int k, int[] clusterIds)
6969
/// <summary>
7070
/// Calculates the centroid of each cluster, and prunes empty clusters.
7171
/// </summary>
72-
private static void CalculateCentroidsAndPrune(ref Span<Vector3> centroids, ref int[] counts, Span<Vector3> points, int[] clusterIds)
72+
internal static void CalculateCentroidsAndPrune(ref Span<Vector3> centroids, ref int[] counts, Span<Vector3> points, int[] clusterIds)
7373
{
7474
// Clear centroids and counts before recalculation
7575
for (int i = 0; i < centroids.Length; i++)

components/ColorAnalyzer/src/ColorPaletteSampler/ColorPaletteSampler.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ public async Task UpdatePaletteAsync()
5252

5353
const int sampleCount = 4096;
5454
const int k = 8;
55+
const float mergeDistance = 0.12f;
5556

5657
// Retreive pixel samples from source
5758
var samples = await SampleSourcePixelColorsAsync(sampleCount);
@@ -62,8 +63,11 @@ public async Task UpdatePaletteAsync()
6263

6364
// Cluster samples in RGB floating-point color space
6465
// With Euclidean Squared distance function, then construct palette data.
65-
var clusters = KMeansCluster(samples, k, out var sizes);
66-
var colorData = clusters.Select((vectorColor, i) => new PaletteColor(vectorColor.ToColor(), (float)sizes[i] / samples.Length));
66+
// Merge KMeans results that are too similar, using DBScan
67+
var kClusters = KMeansCluster(samples, k, out var counts);
68+
var weights = counts.Select(x => (float)x / samples.Length).ToArray();
69+
var dbCluster = DBScan.Cluster(kClusters, mergeDistance, 0, ref weights);
70+
var colorData = dbCluster.Select((vectorColor, i) => new PaletteColor(vectorColor.ToColor(), weights[i]));
6771

6872
// Update palettes on the UI thread
6973
foreach (var palette in PaletteSelectors)

0 commit comments

Comments
 (0)