from itertools import cycle
from pprint import pprint as pprint
import argparse
import matplotlib.pyplot as plt
import numpy as np


def initialize_centroids(data, k):
	'''
	data: numpy array of shape (N, d)
	k: int: the number of cluster centroids to return

	Returns a numpy array of shape (k, d), representing the cluster centroids
	'''

	centroids = np.zeros((k, data.shape[1]), dtype=float)

	# TODO
	# Forgy Initialization
	# Initialize the cluster centroids by sampling k unique datapoints from data

	# END TODO
	assert len(centroids) == k
	return centroids


def distance_euclidean(p1, p2):
	'''
	p1: numpy array of shape (d, 1): 1st point
	p2: numpy array of shape (d, 1): 2nd point

	Returns the Euclidean distance b/w the two points.
	'''

	distance = None

	# TODO
	# Your function must work for all sized tuples.

	#END TODO
	return distance


def kmeans_iteration_one(data, centroids):
	'''
	data: numpy array of shape (N, d): the list of data points
	centroids: numpy array of shape (k, d): the current centroids of the clusters

	Returns :
	new_centroids : numpy array of shape (k, d) representing the new centroids of the clusters
	clusters : numpy array of shape (N, 1) representing the cluster of each of the N points after one iteration of k-means clustering algorithm.
	'''

	n = data.shape[0]
	k = centroids.shape[0]
	d = data.shape[1]
	new_centroids = np.zeros((k, d), dtype=float)
	clusters = np.zeros((n, 1), dtype=int)

	# TODO
	# You must find the new cluster centroids.
	# Perform just 1 iteration (assignment+updation) of k-means algorithm.
	# Use distance_euclidean to find the distance between a point and centroid


	# END TODO

	assert len(new_centroids) == len(centroids)
	return new_centroids, clusters


def hasconverged(old_centroids, new_centroids, epsilon=1e-4):
	'''
	old_centroids: numpy array of shape (k, d): The cluster centroids found by the previous iteration
	new_centroids: numpy array of shape (k, d): The cluster centroids found by the current iteration

	Returns true iff no cluster centroid moved more than epsilon distance.
	'''

	converged = False

	# TODO
	# Use Euclidean distance to measure centroid displacements.


	#END TODO
	return converged


def performance_SSE(data, centroids, clusters):
	'''
	data: numpy array of shape (N, d): the list of data points
	centroids: numpy array of shape (k, d): representing the cluster centroids
	clusters : numpy array of shape (N, 1) representing the cluster assigned to the point

	Returns: The Sum Squared Error of the clustering represented by centroids, on the data.
	'''

	sse = None

	# TODO 
	# Calculate the Sum Squared Error of the clustering represented by centroids, on the data.
	# Make sure to use the distance metric provided.


	# END TODO
	return sse

########################################################################
#                       DO NOT EDIT THE FOLLWOING                      #
########################################################################


def argmin(values):
	return min(enumerate(values), key=lambda x: x[1])[0]


def parse():
	parser = argparse.ArgumentParser()
	parser.add_argument(dest='input', type=str, help='Dataset filename')
	parser.add_argument('-m', '--iter', '--maxiter', dest='maxiter', type=int, default=1000, help='Maximum number of iterations of the algorithm to perform (may stop earlier if convergence is achieved). Default: 1000')
	parser.add_argument('-e', '--eps', '--epsilon', dest='epsilon', type=float, default=1e-4, help='Minimum distance the cluster centroids move b/w two consecutive iterations for the algorithm to continue. Default: 1e-3')
	parser.add_argument('-k', '--k', dest='k', type=int, default=8, help='The number of clusters to use. Default: 8')
	parser.add_argument('-s', '--seed', dest='seed', type=int, default=0, help='The Random Generator seed. Default: 0')
	_a = parser.parse_args()

	args = {}
	for a in vars(_a):
		args[a] = getattr(_a, a)

	print('-'*40 + '\n')
	print('Arguments:')
	pprint(args)
	print('-'*40 + '\n')
	return args


def readfile(filename):
	'''
	File format: Each line contains a comma separated list of real numbers, representing a single point.
	Returns a numpy array of shape (N, d). N points, where each point is d-dimensional.
	'''
	return np.loadtxt(filename, delimiter=",", dtype=float)


def kmeans(data, centroids, maxiter, epsilon=1e-3):
	'''
	maxiter: int: Number of iterations to perform

	Performs maxiter iterations of the clustering algorithm, and saves the cluster centroids of all iterations.
	Stops if convergence is reached earlier.

	Returns:
	final_centroids : numpy array of shape (k, d) representing the final centroids of the clusters
	clusters : numpy array of shape (N, 1) representing the cluster of each of the N points
	'''
	
	all_centroids = []
	for i in range(maxiter) :
		new_centroids, clusters = kmeans_iteration_one(data, centroids)
		all_centroids.append(new_centroids)
		
		if (hasconverged(centroids, new_centroids, epsilon)) :
			break
		centroids = new_centroids

	return all_centroids, clusters



def visualize_data(data, clusters, k):
	print('Visualizing...')
	raw_colors = cycle(["r", "g", "b", "#654899", "k", "c", "m", "y"])
	k_colors = [next(raw_colors) for i in range(k)]
	cluster_colors = [k_colors[i[0]] for i in clusters]
	plt.scatter(data[:,0],data[:,1],c=cluster_colors)

	plt.show()


def visualize_performance(data, all_centroids, clusters):
	errors = [performance_SSE(data, centroids, clusters) for centroids in all_centroids]
	ylabel = 'Sum Squared Error'
	
	plt.plot(range(len(all_centroids)), errors)
	plt.title('Performance plot')
	plt.xlabel('Iteration')
	plt.ylabel(ylabel)
	plt.show()


if __name__ == '__main__':

	args = parse()

	# Read data
	data = readfile(args['input'])
	print('Number of points in input data: ', len(data))

	# Set the random state for numpy
	np.random.seed(args["seed"])
	
	# Initialize centroids
	centroids = initialize_centroids(data, args['k'])

	# Run clustering algorithm
	all_centroids, clusters = kmeans(data, centroids, args['maxiter'], args['epsilon'])

	# If the data is 2-d and small, visualize it.
	if len(data) < 5000 and len(data[0]) == 2:
		visualize_data(data, clusters, args['k'])

	visualize_performance(data, all_centroids, clusters)
