Mean-shift clustering#

Mean-Shift is a density-based models: These models define clusters by their density in the data space. Areas with a high density of data points will become clusters, which are typically separated from one another by low-density areas.

The algorithm works by modeling the data points as a distribution function, where high-density areas (high concentration of data points) represent high peaks. Then, the general idea is to shift each data point until it reaches its nearest peak, which becomes a cluster.

Importing libraries and packages#

 1# Mathematical operations and data manipulation
 2import pandas as pd
 3
 4# Model
 5from sklearn.cluster import MeanShift
 6from sklearn.metrics import silhouette_score
 7from sklearn.metrics import calinski_harabasz_score
 8
 9# Plotting
10import matplotlib.pyplot as plt
11
12# Warnings
13import warnings
14
15warnings.filterwarnings("ignore")
16
17%matplotlib inline

Set paths#

1# Path to datasets directory
2data_path = "./datasets"
3# Path to assets directory (for saving results to)
4assets_path = "./assets"

Loading dataset#

1dataset = pd.read_csv(f"{data_path}/circles.csv")

Exploring dataset#

1# Shape of the dataset
2print("Shape of the dataset: ", dataset.shape)
3# Head
4dataset
Shape of the dataset:  (1500, 2)
0 1
0 0.393992 -0.416376
1 0.528243 -0.828242
2 -0.740158 0.607730
3 -0.971016 0.316821
4 0.018693 -0.605288
... ... ...
1495 0.085197 -0.463028
1496 0.890820 0.003483
1497 0.223768 -0.419122
1498 0.221189 -0.510314
1499 0.544376 0.049358

1500 rows × 2 columns

1plt.scatter(dataset.iloc[:, 0], dataset.iloc[:, 1])
2plt.show()
../../_images/b024d5226d92c1585af15defb25b7921bc94971c78795eecfe9b8ec8299db7ed.png
1# Using slicing to select the feature
2plt.hist(dataset.iloc[:, 0])
3plt.show()
../../_images/74eed4a837e64aaba260bc85c102589ea39a32ce0b0e014aa47e3d2a1e234c62.png

<a id=”mean-shift>

Mean-Shift#

1# The mean-shift algorithm does not the number of clusters specified.
2# Bandwith can be fine-tuned to get the best performance. Train the
3# model with a bandwidth of 0.5 - chosen after trying out 0.1 and 0.9.
4est_meanshift = MeanShift(bandwidth=0.5)
5# Fit to the data
6est_meanshift.fit(dataset)
7# Using the model to assign a cluster to each data point
8pred_meanshift = est_meanshift.predict(dataset)
1# Plot the data points into clusters
2plt.scatter(dataset.iloc[:, 0], dataset.iloc[:, 1], c=pred_meanshift)
3plt.savefig(f"{assets_path}/circles-mean-shift.png", bbox_inches="tight")
4plt.show()
../../_images/4e26fb8d564f9ed3542f5b42c67cb7a0fe7979ec47ce45d80bb12e506abff7ef.png

Metrics#

1# Silhouette
2meanshift_silhouette_score = silhouette_score(
3    dataset, pred_meanshift, metric="euclidean"
4)
5print(meanshift_silhouette_score)
0.3705351315933256
1# Calinski_harabasz
2meanshift_calinski_harabasz_score = calinski_harabasz_score(
3    dataset, pred_meanshift
4)
5print(meanshift_calinski_harabasz_score)
1305.1425653034062