Skip to content

Instantly share code, notes, and snippets.

@neilmehra
Last active June 3, 2022 20:07
Show Gist options
  • Select an option

  • Save neilmehra/c88f0376919da35fe0e2f175360b8ad0 to your computer and use it in GitHub Desktop.

Select an option

Save neilmehra/c88f0376919da35fe0e2f175360b8ad0 to your computer and use it in GitHub Desktop.
python implementation of the kmeans clustering algorithm
import math
import PIL
from PIL import Image, ImageTk
import urllib.request
import io, sys, os, random, time
import tkinter as tk
def choose_random_means(k, img, pix):
means = []
w, h = img.size
for i in range(k):
means.append(pix[random.randint(0, w), random.randint(0, h)])
return means
def check_move_count(mc):
return [0] * len(mc) == mc
def color_dist(col1, col2):
return math.sqrt(sum([(col1[x] - col2[x])**2 for x in range(3)]))
def dist(col, means):
min_index = -1
min_dist = 10000000
for i, mean in enumerate(means):
d = color_dist(col, mean)
if d < min_dist:
min_dist = d
min_index = i
return min_index
def clustering(img, pix, cb, mc, means, count):
temp_pb, temp_mc, temp_m = [[] for x in means], [], []
temp_cb = [0 for x in means]
w,h = img.size
for x in range(w):
for y in range(h):
p = pix[x,y]
d_index = dist(p, means)
temp_cb[d_index]+=1
temp_pb[d_index].append(p)
for i in range(len(means)):
if temp_cb[i] == 0:
temp_m.append(means[i])
else:
sums = [sum([p[x] for p in temp_pb[i]]) for x in range(3)]
temp_m.append([sums[x] / temp_cb[i] for x in range(3)])
temp_mc = [ (a-b) for a, b in zip(temp_cb, cb)]
print ('diff', count, ':', temp_mc)
return temp_cb, temp_mc, temp_m
def update_picture(img, pix, means):
w,h = img.size
means = [[int(x) for x in y] for y in means]
for x in range(w):
for y in range(h):
d = dist(pix[x,y], means)
p = means[d]
pix[x,y] = tuple(p)
return pix
def distinct_pix_count(img, pix):
cols = {}
w, h = img.size
for x in range(w):
for y in range(h):
pixel = pix[x,y]
if pixel in cols:
cols[pixel]+=1
else:
cols[pixel]=1
max_col, max_count = pix[0, 0], 0
for color, num in cols.items():
if num > max_count:
max_count = num
max_col = color
return len(cols.keys()), max_col, max_count
def valid(x, y, w, h):
if x < 0 or y < 0: return False
if x >= w or y >= h: return False
return True
def bfs(x, y, vis, pix, w, h):
q = []
q.append([x,y])
vis[x][y] = True
while len(q) > 0:
c = q.pop()
x = c[0]
y = c[1]
col = pix[x,y]
def l(nx,ny):
if valid(nx,ny,w,h) and not vis[nx][ny] and pix[nx,ny] == col:
q.append([nx,ny])
vis[nx][ny] = True
for i in range(-1,2):
for j in range(-1,2):
if i != x and j != y:
l(x+i, y+j)
return vis
def region_counts(img, pix, means):
region_count = [0 for _ in means]
w,h = img.size
visited = [[False for _ in range(h)] for _ in range(w)]
for x in range(w):
for y in range(h):
if not visited[x][y]:
p = pix[x,y]
d = dist(p, means)
region_count[d]+=1
visited = bfs(x, y, visited, pix, w, h)
return region_count
def main():
k = int(sys.argv[1])
file = 'wallpaper.png'
if not os.path.isfile(file):
file = io.BytesIO(urllib.request.urlopen(file).read())
window = tk.Tk()
img = Image.open(file)
cp_img = Image.open(file)
cp_img_tk = ImageTk.PhotoImage(cp_img)
old_lbl = tk.Label(window, image = cp_img_tk).pack()
pix = img.load() # pix[0, 0] : (r, g, b)
print ('Size:', img.size[0], 'x', img.size[1])
print ('Pixels:', img.size[0]*img.size[1])
d_count, m_col, m_count = distinct_pix_count(img, pix)
print ('Distinct pixel count:', d_count)
print ('Most common pixel:', m_col, '=>', m_count)
count_buckets = [0 for x in range(k)]
move_count = [10 for x in range(k)]
means = choose_random_means(k, img, pix)
print ('random means:', means)
count = 1
while not check_move_count(move_count):
count += 1
count_buckets, move_count, means = clustering(img, pix, count_buckets, move_count, means, count)
if count == 2:
print ('first means:', means)
print ('starting sizes:', count_buckets)
pix = update_picture(img, pix, means)
print ('Final sizes:', count_buckets)
print ('Final means:')
for i in range(len(means)):
print (i+1, ':', means[i], '=>', count_buckets[i])
region_list = region_counts(img, pix, means)
print('region count:', region_list)
img_tk = ImageTk.PhotoImage(img)
lbl = tk.Label(window, image = img_tk).pack() # display the image at window
im_name = str(int(time.time())) + '.png'
img.save(im_name, 'PNG')
#window.mainloop()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment