from matplotlib import pyplot
from numpy import where
from sklearn.manifold import TSNE
# tsne is only used to decrease dimension for visualization
tsne = TSNE(n_components=2, init='pca', random_state=0)
x = tsne.fit_transform(X)
def plot_clusters(X, cluster_ids):
for class_value in range(3):
row_ix = where(cluster_ids == class_value)
pyplot.scatter(X[row_ix, 0], X[row_ix, 1])
pyplot.show()