Last active
June 3, 2022 20:07
-
-
Save neilmehra/c88f0376919da35fe0e2f175360b8ad0 to your computer and use it in GitHub Desktop.
python implementation of the kmeans clustering algorithm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| import PIL | |
| from PIL import Image, ImageTk | |
| import urllib.request | |
| import io, sys, os, random, time | |
| import tkinter as tk | |
| def choose_random_means(k, img, pix): | |
| means = [] | |
| w, h = img.size | |
| for i in range(k): | |
| means.append(pix[random.randint(0, w), random.randint(0, h)]) | |
| return means | |
| def check_move_count(mc): | |
| return [0] * len(mc) == mc | |
| def color_dist(col1, col2): | |
| return math.sqrt(sum([(col1[x] - col2[x])**2 for x in range(3)])) | |
| def dist(col, means): | |
| min_index = -1 | |
| min_dist = 10000000 | |
| for i, mean in enumerate(means): | |
| d = color_dist(col, mean) | |
| if d < min_dist: | |
| min_dist = d | |
| min_index = i | |
| return min_index | |
| def clustering(img, pix, cb, mc, means, count): | |
| temp_pb, temp_mc, temp_m = [[] for x in means], [], [] | |
| temp_cb = [0 for x in means] | |
| w,h = img.size | |
| for x in range(w): | |
| for y in range(h): | |
| p = pix[x,y] | |
| d_index = dist(p, means) | |
| temp_cb[d_index]+=1 | |
| temp_pb[d_index].append(p) | |
| for i in range(len(means)): | |
| if temp_cb[i] == 0: | |
| temp_m.append(means[i]) | |
| else: | |
| sums = [sum([p[x] for p in temp_pb[i]]) for x in range(3)] | |
| temp_m.append([sums[x] / temp_cb[i] for x in range(3)]) | |
| temp_mc = [ (a-b) for a, b in zip(temp_cb, cb)] | |
| print ('diff', count, ':', temp_mc) | |
| return temp_cb, temp_mc, temp_m | |
| def update_picture(img, pix, means): | |
| w,h = img.size | |
| means = [[int(x) for x in y] for y in means] | |
| for x in range(w): | |
| for y in range(h): | |
| d = dist(pix[x,y], means) | |
| p = means[d] | |
| pix[x,y] = tuple(p) | |
| return pix | |
| def distinct_pix_count(img, pix): | |
| cols = {} | |
| w, h = img.size | |
| for x in range(w): | |
| for y in range(h): | |
| pixel = pix[x,y] | |
| if pixel in cols: | |
| cols[pixel]+=1 | |
| else: | |
| cols[pixel]=1 | |
| max_col, max_count = pix[0, 0], 0 | |
| for color, num in cols.items(): | |
| if num > max_count: | |
| max_count = num | |
| max_col = color | |
| return len(cols.keys()), max_col, max_count | |
| def valid(x, y, w, h): | |
| if x < 0 or y < 0: return False | |
| if x >= w or y >= h: return False | |
| return True | |
| def bfs(x, y, vis, pix, w, h): | |
| q = [] | |
| q.append([x,y]) | |
| vis[x][y] = True | |
| while len(q) > 0: | |
| c = q.pop() | |
| x = c[0] | |
| y = c[1] | |
| col = pix[x,y] | |
| def l(nx,ny): | |
| if valid(nx,ny,w,h) and not vis[nx][ny] and pix[nx,ny] == col: | |
| q.append([nx,ny]) | |
| vis[nx][ny] = True | |
| for i in range(-1,2): | |
| for j in range(-1,2): | |
| if i != x and j != y: | |
| l(x+i, y+j) | |
| return vis | |
| def region_counts(img, pix, means): | |
| region_count = [0 for _ in means] | |
| w,h = img.size | |
| visited = [[False for _ in range(h)] for _ in range(w)] | |
| for x in range(w): | |
| for y in range(h): | |
| if not visited[x][y]: | |
| p = pix[x,y] | |
| d = dist(p, means) | |
| region_count[d]+=1 | |
| visited = bfs(x, y, visited, pix, w, h) | |
| return region_count | |
| def main(): | |
| k = int(sys.argv[1]) | |
| file = 'wallpaper.png' | |
| if not os.path.isfile(file): | |
| file = io.BytesIO(urllib.request.urlopen(file).read()) | |
| window = tk.Tk() | |
| img = Image.open(file) | |
| cp_img = Image.open(file) | |
| cp_img_tk = ImageTk.PhotoImage(cp_img) | |
| old_lbl = tk.Label(window, image = cp_img_tk).pack() | |
| pix = img.load() # pix[0, 0] : (r, g, b) | |
| print ('Size:', img.size[0], 'x', img.size[1]) | |
| print ('Pixels:', img.size[0]*img.size[1]) | |
| d_count, m_col, m_count = distinct_pix_count(img, pix) | |
| print ('Distinct pixel count:', d_count) | |
| print ('Most common pixel:', m_col, '=>', m_count) | |
| count_buckets = [0 for x in range(k)] | |
| move_count = [10 for x in range(k)] | |
| means = choose_random_means(k, img, pix) | |
| print ('random means:', means) | |
| count = 1 | |
| while not check_move_count(move_count): | |
| count += 1 | |
| count_buckets, move_count, means = clustering(img, pix, count_buckets, move_count, means, count) | |
| if count == 2: | |
| print ('first means:', means) | |
| print ('starting sizes:', count_buckets) | |
| pix = update_picture(img, pix, means) | |
| print ('Final sizes:', count_buckets) | |
| print ('Final means:') | |
| for i in range(len(means)): | |
| print (i+1, ':', means[i], '=>', count_buckets[i]) | |
| region_list = region_counts(img, pix, means) | |
| print('region count:', region_list) | |
| img_tk = ImageTk.PhotoImage(img) | |
| lbl = tk.Label(window, image = img_tk).pack() # display the image at window | |
| im_name = str(int(time.time())) + '.png' | |
| img.save(im_name, 'PNG') | |
| #window.mainloop() | |
| if __name__ == '__main__': | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment