Beispiel zu Clustering mit K-Means

Thomas Darimont

Erfahrenes Mitglied
Hallo,

hier mal ein kleines Beispiel zum Thema Clustering mit dem K-Means Algorithmus:

Unser DataPoint
Java:
package de.tutorials.clustering;

/**
 * 
 * @author Thomas.Darimont
 *
 */
public class DataPoint {
    int dimensions;
    double[] values;
    Cluster cluster;

    /**
     * @param x
     * @param y
     */
    public DataPoint(double[] values) {
        this.values = values;
        this.dimensions = values.length;
    }

    /**
     * @param x
     * @param y
     */
    public DataPoint(Cluster cluster, double[] values) {
        this(values);
        this.cluster = cluster;
    }

    /**
     * @return the dimensions
     */
    public int getDimensions() {
        return dimensions;
    }

    /**
     * @return the values
     */
    public double[] getValues() {
        return values;
    }

    /**
     * @param values
     *            the values to set
     */
    public void setValues(double[] values) {
        this.values = values;
    }

    public double calculateEuclideanDistanceTo(DataPoint dataPoint) {
        double distance = 0.0D;
        for (int i = 0; i < dimensions; i++) {
            distance += Math.pow(values[i] - dataPoint.values[i], 2);
        }
        return Math.sqrt(distance);
    }

    public Cluster getCluster() {
        return cluster;
    }

    public void setCluster(Cluster cluster) {
        this.cluster = cluster;
    }

    @Override
    public String toString() {
        StringBuilder stringBuilder = new StringBuilder("DataPoint: ");
        for (int i = 0; i < dimensions; i++) {
            stringBuilder.append(i);
            stringBuilder.append(":");
            stringBuilder.append(values[i]);
            stringBuilder.append(" ");
        }
        return stringBuilder.toString();
    }
}

Unser Cluster:
Java:
package de.tutorials.clustering;

import java.util.ArrayList;
import java.util.List;

/**
 * 
 * @author Thomas.Darimont
 *
 */
public class Cluster extends DataPoint {
    List<DataPoint> dataPoints;
    String name;
    int id;

    public Cluster(double[] values) {
        super(values);
        dataPoints = new ArrayList<DataPoint>();
    }

    public void add(DataPoint dataPoint) {
        getDataPoints().add(dataPoint);
    }

    public static void swap(Cluster from, Cluster to, DataPoint dataPoint) {
        if (null != from) {
            from.remove(dataPoint);
        }
        dataPoint.setCluster(to);
        to.add(dataPoint);
    }

    void remove(DataPoint dataPoint) {
        getDataPoints().remove(dataPoint);
    }

    /**
     * @return the dataPoints
     */
    public List<DataPoint> getDataPoints() {
        return dataPoints;
    }

    /**
     * @param dataPoints
     *            the dataPoints to set
     */
    public void setDataPoints(List<DataPoint> dataPoints) {
        this.dataPoints = dataPoints;
    }

    /**
     * @return the name
     */
    public String getName() {
        return name;
    }

    /**
     * @param name
     *            the name to set
     */
    public void setName(String name) {
        this.name = name;
    }

    public double calculateDistanceTo(Cluster cluster) {
        return getCenter().calculateEuclideanDistanceTo(cluster.getCenter());
    }

    public DataPoint getCenter() {
        double[] values = new double[dimensions];

        for (int i = 0, size = getDataPoints().size(); i < size; i++) {
            DataPoint dataPoint = getDataPoints().get(i);
            for (int j = 0; j < dimensions; j++) {
                values[j] += dataPoint.values[j];
            }
        }

        for (int j = 0; j < dimensions; j++) {
            values[j] += this.values[j];
        }

        for (int j = 0, size = getDataPoints().size() + 1; j < dimensions; j++) {
            values[j] *= 1.0 / size;
        }

        DataPoint centerDataPoint = new DataPoint(this, values);

        return centerDataPoint;
    }
    
    

    public int getId() {
        return id;
    }

    public void setId(int id) {
        this.id = id;
    }

    @Override
    public String toString() {
        return "Cluster: " + getCenter();
    }
}

Unser K-Means Clusterer:
Java:
/**
 * 
 */
package de.tutorials.clustering;

import java.util.Random;

/**
 * 
 * @author Thomas.Darimont
 *
 */
public class KMeans {
  int dimensions;
  Cluster[] clusters;
  DataPoint[] dataPoints;
  static Random randomizer = new Random();

/**
 * 
 * @param clusterCount
 * @param start
 * @param end
 */
  public void generateRandomClusters(int clusterCount, int start, int end) {
    clusters = new Cluster[clusterCount];
    for (int i = 0; i < clusterCount; i++) {
      double[] values = new double[dimensions];
      for (int j = 0; j < values.length; j++) {
        values[j] = start + randomizer.nextInt(end - start);
      }
      Cluster cluster = new Cluster(values);
      cluster.setId(i);
      clusters[i] = cluster;
    }
  }

/**
 * 
 * @param countOfDataPointsToBeGenerated
 * @param start
 * @param end
 */
  public void generateRandomDataPoints(int countOfDataPointsToBeGenerated, int start, int end) {
    DataPoint[] dataPoints = new DataPoint[countOfDataPointsToBeGenerated];
    for (int i = 0; i < dataPoints.length; i++) {
      double[] values = new double[dimensions];
      for (int j = 0; j < dimensions; j++) {
        values[j] = start + randomizer.nextInt(end - start);
      }
      dataPoints[i] = new DataPoint(values);
    }
    this.dataPoints = dataPoints;
  }

/**
 * 
 * @param dataPoints
 */
  public void setData(DataPoint[] dataPoints) {
    this.dataPoints = dataPoints;
  }

/**
 * 
 * @param presenter
 */
  public void process(DataPointPresenter presenter) {

    int countOfSwapsOld = -1;
    int countOfCurrentSwaps = 0;
    int run = 0;

    while (countOfSwapsOld != countOfCurrentSwaps) {

      countOfSwapsOld = countOfCurrentSwaps;
      countOfCurrentSwaps = 0;

      System.out.println("Run: " + (run++));

      for (int j = 0; j < dataPoints.length; j++) {
        DataPoint currentDataPoint = dataPoints[j];

        Cluster nearestCluster = null;
        double currentMinimumDistance = Double.MAX_VALUE;
        for (int i = 0; i < clusters.length; i++) {
          Cluster currentCluster = clusters[i];

          if (null != presenter) {
            presenter.render(currentCluster);
          }

          double distanceToCluster = currentCluster.getCenter().calculateEuclideanDistanceTo(currentDataPoint);

          if (distanceToCluster < currentMinimumDistance) {
            currentMinimumDistance = distanceToCluster;
            nearestCluster = currentCluster;
          }

        }

        if (nearestCluster != currentDataPoint.getCluster()) {
          Cluster.swap(currentDataPoint.getCluster(), nearestCluster, currentDataPoint);
          countOfCurrentSwaps++;
        }

        if (null != presenter) {
          presenter.render(currentDataPoint);
        }

        // System.out.println("#################");
      }

      // try {
      // TimeUnit.SECONDS.sleep(1);
      // } catch (InterruptedException e) {
      // e.printStackTrace();
      // }
      System.out.println("Swaps:" + countOfCurrentSwaps);

    }
  }


  public int getDimensions() {
    return dimensions;
  }


  public void setDimensions(int dimensions) {
    this.dimensions = dimensions;
  }
}

Unser DataPointPresenter:
Java:
package de.tutorials.clustering;

import java.awt.Color;
import java.awt.Graphics;

/**
 * 
 * @author Thomas.Darimont
 *
 */
public class DataPointPresenter {

    Color[] clusterColors;

    Graphics graphics;

    public DataPointPresenter(Graphics graphics) {
        this.graphics = graphics;
    }

    public void render(DataPoint dataPoint) {
        Color oldColor = graphics.getColor();
        if (null != dataPoint.getCluster()) {
            graphics.setColor(clusterColors[dataPoint.getCluster().getId()]);
        }
        graphics.fillOval((int) (dataPoint.values[0] - 2.5),
                (int) (dataPoint.values[1] - 2.5), 5, 5);
        graphics.setColor(oldColor);
    }

    public void render(Cluster cluster) {
        Color oldColor = graphics.getColor();
        graphics.setColor(clusterColors[cluster.getId()]);
        DataPoint dataPoint = cluster.getCenter();
        graphics.fillOval((int) (dataPoint.values[0] - 5),
                (int) (dataPoint.values[1] - 5), 10, 10);
        graphics.setColor(oldColor);
    }

    public Color[] getClusterColors() {
        return clusterColors;
    }

    public void setClusterColors(Color[] clusterColors) {
        this.clusterColors = clusterColors;
    }
    
    public void clear(){
        graphics.clearRect(0,0,640, 480);
    }

}

KMeansClustererPresentation:
Java:
package de.tutorials.clustering;

import java.awt.BorderLayout;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.util.Random;
import java.util.concurrent.Executors;

import javax.swing.JButton;
import javax.swing.JFrame;
import javax.swing.JPanel;

/**
 * 
 * @author Thomas.Darimont
 *
 */
public class KMeansClusteringPresentation extends JFrame {
  static Random randomizer = new Random();
  JButton btnStart;
  JPanel panel;

  Runnable worker = new Runnable() {
    public void run() {
      KMeans kmeans = new KMeans();

      int clusterCount = 10;

      DataPointPresenter dataPointPresenter = new DataPointPresenter(panel.getGraphics());
      dataPointPresenter.setClusterColors(generateColorsForCluster(clusterCount));
      dataPointPresenter.clear();
      kmeans.setDimensions(2);
      kmeans.generateRandomDataPoints(6000, 0, 400);
      kmeans.generateRandomClusters(clusterCount, 0, 400);
      kmeans.process(dataPointPresenter);
    }
  };


  /**
   * @param clusterCount
   * @return
   */
  private Color[] generateColorsForCluster(int clusterCount) {
    Color[] colors = new Color[clusterCount];
    for (int i = 0; i < colors.length; i++) {
      colors[i] = new Color(randomizer.nextInt(255), randomizer.nextInt(255), randomizer.nextInt(255));
    }
    return colors;
  }


  /**
   * 
   */
  public KMeansClusteringPresentation() {
    super("KMeansClusteringPresentation");
    setDefaultCloseOperation(EXIT_ON_CLOSE);
    setSize(410, 480);
    setResizable(false);
    setLocationRelativeTo(null);
    panel = new JPanel();

    panel.setPreferredSize(new Dimension(410, 480));

    btnStart = new JButton("start");
    btnStart.addActionListener(new ActionListener() {
      public void actionPerformed(ActionEvent e) {
        Executors.newSingleThreadExecutor().execute(worker);
      }
    });

    add(panel, BorderLayout.CENTER);
    add(btnStart, BorderLayout.SOUTH);

    setVisible(true);

  }


  /**
   * @param args
   */
  public static void main(String[] args) {
    new KMeansClusteringPresentation();
  }
}

... und so schauts aus nach ein paar Läufen:
Man sieht schön, wie sich die Clusterzentren (fette Punkte) durch die ständigen Wechsel von Punkten eines
anderen Clusters zu nächst näher gelegenen Cluster langsam verschieben bis nach ein paar Läufen keine Wechsel mehr stattfinden und die optimale Aufteilung erreicht ist.

Die dabei entstehende Aufteilung nennt man auch Voronoi Tesselation

Gruß Tom
 

Anhänge

  • kmeans-clustering.jpg
    kmeans-clustering.jpg
    87,9 KB · Aufrufe: 345
  • kmeans-clustering1.jpg
    kmeans-clustering1.jpg
    78,1 KB · Aufrufe: 322
Zurück