Very Elegant Image Color Segmentation

Let’s start with this suggestive image of a tree with pink flowers on it.

What I want to do is simplifying it to see its main colors. How can I do it? Well, here some unsupervised learning techniques may come in handy.

Clustering the pixels of the image, which are no other than points in a 3D space, will let us find the main colors of the photo.

To have a realistic representation, you have to choose the right number of clusters. Although there are precise techniques, like considering the inertia or the silhouette score, since it’s easy for us to recognize colors we can judge how many clusters are needed. I opted for 6.

from sklearn.cluster import KMeans
from matplotlib.image import imread
import matplotlib.pyplot as plt

initial = imread('Peach.jpg')/255

km = KMeans(n_clusters=6)
edited = initial.reshape(-1, 3)
km.fit(edited)
segmented = km.cluster_centers_[km.labels_]
segmented = segmented.reshape(initial.shape)

plt.imsave(arr=segmented, fname='peach.JPG')

As you can see, I’m reading and reshaping the image to a 2D array with each point. Then I initialize the standard KMeans model, which will allow us to find the centroids of the clusters (the main colors).

Once the model fits the dataset, I replace each instance with the centroid of its cluster and reshape it to the original image.

Here’s the output.

Here’s the link to the github code.

https://github.com/mattiagiuri/machine-learning-projects/blob/main/Image%20Modifier/BasicImageSegmentation.py

Scroll to top