Going to re:Invent this December? Come party with us and The Chainsmokers ->->

Blog / Engineering

How to Scale K-Means Clustering with just ClickHouse SQL

author avatar
Dale McDiarmid
Apr 11, 2024

Introduction

Recently, when helping a user who wanted to compute centroids from vectors held in ClickHouse, we realized that the same solution could be used to implement K-Means clustering. They wanted to solve this at scale across potentially billions of data points while ensuring memory could be tightly managed. In this post, we give implementing K-means clustering using just SQL a try and show that it scales to billions of rows.

In the writing of this blog, we became aware of the work performed by Boris Tyshkevich. While we use a different approach in this blog, we would like to recognize Boris for his work and for having this idea well before we did!

As part of implementing K-Means with ClickHouse SQL, we cluster 170M NYC taxi rides in under 3 minutes. The equivalent scikit-learn operation with the same resources takes over 100 minutes and requires 90GB of RAM. With no memory limitations and ClickHouse automatically distributing the computation, we show that ClickHouse can accelerate machine learning workloads and reduce iteration time.

kmeans_cluster_1.png

All of the code for this blog post can be found in a notebook here.

Why K-Means in ClickHouse SQL?

The key motivation for using ClickHouse SQL to do K-Means is that training is not memory-bound, making it possible to cluster PB datasets thanks to the incremental computation of centroids (with settings to limit memory overhead). In contrast, distributing this workload across servers using Python-based approaches would require an additional framework and complexity.

Additionally, we can easily increase the level of parallelism in our clustering to use the full resources of a Clickhouse instance. Should we need to handle larger datasets, we simply scale the database service - a simple operation in ClickHouse Cloud.

Transforming the data for K-Means is a simple SQL query that can process billions of rows per second. With centroids and points held in ClickHouse, we can compute statistics such as model errors with just SQL and potentially use our clusters for other operations e.g. product quantization for vector search.

K-Means recap

K-Means is an unsupervised machine learning algorithm for partitioning a dataset into K distinct, non-overlapping subgroups (clusters) where each data point belongs to the cluster with the nearest mean (the cluster's centroid). The process begins by initializing K centroids randomly or based on some heuristic. These centroids serve as the initial representatives of the clusters. The algorithm then iterates through two main steps until convergence: assignment and update.

In the assignment step, each data point is assigned to the nearest cluster based on the Euclidean distance (or another distance metric) between it and the centroids. In the update step, the centroids are recalculated as the mean of all points assigned to their respective clusters, potentially shifting their positions.

This process is guaranteed to converge, with the assignments of points to clusters eventually stabilizing and not changing between iterations. The number of clusters, K, needs to be specified beforehand and heavily influences the algorithm's effectiveness with the optimal value depending on the dataset and the goal of the clustering. For more details, we recommend this excellent overview.

Points and centroids

The key problem our user posed was the ability to efficiently compute centroids. Suppose we have a simple data schema for a transactions table, where each row represents a bank transaction for a specific customer. Vectors in ClickHouse are represented as an Array type.

CREATE TABLE transactions
(
  id UInt32,
  vector Array(Float32), 
  -- e.g.[0.6860357,-1.0086979,0.83166444,-1.0089169,0.22888935]
  customer UInt32,
  ...other columns omitted for brevity
)
ENGINE = MergeTree ORDER BY id

Our user wanted to find the centroid for each customer, effectively the positional average of all the transaction vectors associated with each customer. To find the set of average vectors, we can use the avgForEach[1][2] function. For instance, consider the example of computing the average of 3 vectors, each with 4 elements:

WITH vectors AS
   (
       SELECT c1 AS vector
       FROM VALUES([1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12])
   )
SELECT avgForEach(vector) AS centroid
FROM vectors

┌─centroid──┐
│ [5,6,7,8] │
└───────────┘

In our original transactions table, computing the average per customer thus becomes:

SELECT customer, avgForEach(vector) AS centroid FROM transactions GROUP BY customer

While simple, this approach has a few limitations. Firstly, for very large datasets, when the vector contains many Float32 points and the customer column has many unique elements (high cardinality), this query can be very memory intensive. Secondly, and maybe more relevant to K-Means, this approach requires us to rerun the query if new rows are inserted, which is inefficient. We can address these problems through Materialized Views and the AggregatingMergeTree engine.

Incrementally computing centroids with Materialized Views

Materialized Views allow us to shift the cost of computing our centroids to insert time. Unlike in other databases, a ClickHouse Materialized View is just a trigger that runs a query on blocks of data as they are inserted into a table. The results of this query are inserted into a second "target" table. In our case, the Materialized View query will compute our centroids, inserting the results to a table centroids.

Incremental computing centroids with MV.png

There are some important details here:

  • Our query, which computes our centroids, must produce the result set in a format that can be merged with subsequent result sets - since every block inserted will produce a result set. Rather than just sending averages to our centroids table (the average of an average would be incorrect), we send the “average state”. The average state representation contains the sum of each vector position, along with a count. This is achieved using the avgForEachState function - notice how we’ve just appended State to our function name! The AggregatingMergeTree table engine is required to store these aggregation states. We explore this more below.
  • The entire process is incremental with the centroids table containing the final state i.e. a row per centroid. Readers will notice that the table which receives inserts has a Null table engine. This causes the inserted rows to be thrown away, saving the IO associated with writing the full dataset on each iteration.
  • The query of our Materialized View is only executed on the blocks as they are inserted. The number of rows in each block can vary depending on the method of insertion. We recommend at least 1000 rows per block if formulating blocks on the client side, e.g., using the Go client. If the server is left to form blocks (e.g. when inserting by HTTP), the size can also be specified.
  • If using an INSERT INTO SELECT where ClickHouse reads rows from another table or external source, e.g. S3, the block size can be controlled by several key parameters discussed in detail in previous blogs. These settings (along with the number of insert threads) can have a dramatic effect on both the memory used (larger blocks = more memory) and the speed of ingestion (larger blocks = faster). These settings mean the amount of memory used can be finely controlled in exchange for performance.

AggregatingMergeTree

Our target table centroids uses the engine AggregatingMergeTree:

CREATE TABLE centroids
(
   customer UInt32,
   vector AggregateFunction(avgForEach, Array(Float32))
)
ENGINE = AggregatingMergeTree  ORDER BY customer

Our vector column here contains the aggregate states produced by the avgForEachState function above. These are intermediate centroids that must be merged to produce a final answer. This column needs to be of the appropriate type AggregateFunction(avgForEach, Array(Float32)).

Like all ClickHouse MergeTree tables, the AggregatingMergeTree stores data as parts that must be merged transparently to allow more efficient querying. When merging parts containing our aggregate states, this must be done so that only states pertaining to the same customer are merged. This is effectively achieved by ordering the table by the customer column with the ORDER BY clause. At query time, we must also ensure intermediate states are grouped and merged. This can be achieved by ensuring we GROUP BY by the column customer and use the Merge equivalent of the avgForEach function: avgForEachMerge.

SELECT customer, avgForEachMerge(vector) AS centroid
FROM centroids GROUP BY customer

All aggregation functions have an equivalent state function, obtained by appending State to their name, which produces an intermediate representation that can be stored and then retrieved and merged with a Merge equivalent. For more details, we recommend this blog and the video from our very own Mark.

This query will be very fast compared to our earlier GROUP BY. Most of the work for computing averages has been moved to insert time, with a small number of rows left for query time merging. Consider the performance of the following two approaches using 100m random transactions on a 48GiB, 12 vCPU Cloud service. Steps to load the data here.

Contrast the performance of computing our centroids from the transactions table:

SELECT customer, avgForEach(vector) AS centroid
FROM transactions GROUP BY customer
ORDER BY customer ASC
LIMIT 1 FORMAT Vertical

10 rows in set. Elapsed: 147.526 sec. Processed 100.00 million rows, 41.20 GB (677.85 thousand rows/s., 279.27 MB/s.)

Row 1:
──────
customer: 1
centroid: [0.49645231463677153,0.5042792240640065,...,0.5017436349466129]

1 row in set. Elapsed: 36.017 sec. Processed 100.00 million rows, 41.20 GB (2.78 million rows/s., 1.14 GB/s.)
Peak memory usage: 437.54 MiB.

vs the centroids table with is over 1700x faster:

SELECT customer, avgForEachMerge(vector) AS centroid
FROM centroids GROUP BY customer
ORDER BY customer ASC
LIMIT 1
FORMAT Vertical

Row 1:
──────
customer: 1
centroid: [0.49645231463677153,0.5042792240640065,...,0.5017436349466129]

1 row in set. Elapsed: 0.085 sec. Processed 10.00 thousand rows, 16.28 MB (117.15 thousand rows/s., 190.73 MB/s.)

Putting it all together

With our ability to compute centroids incrementally, let's focus on K-Means clustering. Let's assume we're trying to cluster a table points where each row has a vector representation. Here, we will cluster on similarity rather than just basing our centroids on the customer as we did with transactions.

A single iteration

We need to be able to store the current centroids after each iteration of the algorithm. For now, let's assume we have identified an optimal value of K. Our target table for our centroids might look like this:

CREATE TABLE centroids
(
  k UInt32,
  iteration UInt32,
  centroid UInt32,
  vector AggregateFunction(avgForEach, Array(Float32))
)
ENGINE = AggregatingMergeTree 
ORDER BY (k, iteration, centroid)

The value of the k column is set to our chosen value of K. Our centroid column here denotes the centroid number itself, with a value between 0 and K-1. Rather than use a separate table for each iteration of the algorithm, we simply include an iteration column and ensure our ordering key is (k, iteration, centroid). ClickHouse will ensure the intermediate state is only merged for each unique K, centroid, and iteration. This means our final row count will be small, ensuring fast querying of these centroids.

Our Materialized View for computing our centroids should be familiar with only a small adjustment to also GROUP BY k, centroid, and iteration:

CREATE TABLE temp
(
   k UInt32,
   iteration UInt32,
   centroid UInt32,
   vector Array(Float32)
)
ENGINE = Null

CREATE MATERIALIZED VIEW centroids_mv TO centroids
AS SELECT k, iteration, centroid, avgForEachState(vector) AS vector
FROM temp GROUP BY k, centroid, iteration

Notice that our query executes over blocks inserted into a temp table, not our data source table transactions, which does not have an iteration or centroid column. This temp table will receive our inserts and uses the Null table engine again to avoid writing data. With these building blocks in place, we can visualize a single iteration of the algorithm assuming K = 5:

kmeans_clickhouse.png

The above shows how we insert into our temp table and thus compute our centroids by performing an INSERT INTO SELECT with a points table as our source data. This insertion effectively represents an iteration of the algorithm. The SELECT query here is critical as it needs to specify the transaction vector and its current centroid and iteration (and fixed value of K). How might we compute the latter of these two? The full INSERT INTO SELECT is shown below:

INSERT INTO temp 
WITH
  5 as k_val,
  -- (1) obtain the max value of iteration - will be the previous iteration
  (
      SELECT max(iteration)
      FROM centroids 
      -- As later we will reuse this table for all values of K
      WHERE k = k_val
  ) AS c_iteration,
  (
      -- (3) convert centroids into a array of tuples 
      -- i.e. [(0, [vector]), (1, [vector]), ... , (k-1, [vector])]
      SELECT groupArray((centroid, position))
      FROM
      (
         -- (2) compute the centroids from the previous iteration
          SELECT
              centroid,
              avgForEachMerge(vector) AS position
          FROM centroids
          WHERE iteration = c_iteration AND k = k_val
          GROUP BY centroid
      )
  ) AS c_centroids
SELECT
  k_val AS k,
  -- (4) increment the iteration
  c_iteration + 1 AS iteration,
  -- (5) find the closest centroid for this vector using Euclidean distance
  (arraySort(c -> (c.2), arrayMap(x -> (x.1, L2Distance(x.2, vector)), c_centroids))[1]).1 AS centroid,
  vector AS v
FROM points

Firstly, at (1), this query identifies the number of the previous iteration. This is then used within the CTE at (2) to determine the centroids produced for this iteration (and chosen K), using the same avgForEachMerge query shown earlier. These centroids are collapsed into a single row containing an array of Tuples via the groupArray query to facilitate easy matching against the points. In the SELECT, we increment the iteration number (4) and compute the new closest centroid (with the Euclidean distance L2Distance function) using an arrayMap and arraySort functions for each point.

By inserting the rows into temp here, with a centroid based on the previous iteration, we can allow the Materialized View to compute the new centroids (with the iteration value +1).

Initializing the centroids

The above assumes we have some initial centroids for iteration 1, which are used to compute membership. This requires us to initialize the system. We can do this by simply selecting and inserting K random points with the following query (k=5):

INSERT INTO temp WITH 
  5 as k_val,
  vectors AS
  (
      SELECT vector
      FROM points
      -- select random points, use k to make pseudo-random
      ORDER BY cityHash64(concat(toString(id), toString(k_val))) ASC
      LIMIT k_val -- k
  )
SELECT
  k_val as k,
  1 AS iteration,
  rowNumberInAllBlocks() AS centroid,
  vector
FROM vectors

Successful clustering is very sensitive to the initial placement of centroids; poor assignment leads to slow convergence or suboptimal clustering. We will discuss this a little later.

Centroid assignment and when to stop iterating

All of the above represents a single iteration (and initialization step). After each iteration, we need to make a decision as to whether to stop based on an empirical measurement of whether the clustering has converged. The simplest way to do this is to simply stop when points no longer change centroids (and thus clusters) between iterations.

To identify which points belong to which centroids, we can use the above SELECT from our earlier INSERT INTO SELECT at any time.

To compute the number of points that moved clusters in the last iteration, we first compute the centroids for the previous two iterations (1) and (2). Using these, we identify the centroids for each point for each iteration (3) and (4). If these are the same (5), we return 0 and 1 otherwise. A total of these (6) values provides us with the number of points that moved clusters.

WITH 5 as k_val,
(
      SELECT max(iteration)
      FROM centroids
) AS c_iteration,
(
  -- (1) current centroids
  SELECT groupArray((centroid, position))
  FROM
  (
      SELECT
          centroid,
          avgForEachMerge(vector) AS position
      FROM centroids
      WHERE iteration = c_iteration AND k = k_val
      GROUP BY centroid
  )
) AS c_centroids,
(
  -- (2) previous centroids
  SELECT groupArray((centroid, position))
  FROM
  (
      SELECT
          centroid,
          avgForEachMerge(vector) AS position
      FROM centroids
      WHERE iteration = (c_iteration-1) AND k = k_val
      GROUP BY centroid
  )
) AS c_p_centroids
-- (6) sum differences
SELECT sum(changed) FROM (
  SELECT id,
  -- (3) current centroid for point
  (arraySort(c -> (c.2), arrayMap(x -> (x.1, L2Distance(x.2, vector)), c_centroids))[1]).1 AS cluster,
  -- (4) previous centroid for point
  (arraySort(c -> (c.2), arrayMap(x -> (x.1, L2Distance(x.2, vector)), c_p_centroids))[1]).1 AS cluster_p,
  -- (5) difference in allocation
  if(cluster = cluster_p, 0, 1) as changed
  FROM points
)

A test dataset

The above has been mostly theoretical. Let's see if the above actually works on a real dataset! For this, we'll use a 3m row subset of the popular NYC taxis dataset as the clusters are hopefully relatable. To create and insert the data from S3:

CREATE TABLE trips (
  trip_id         	UInt32,
  pickup_datetime 	DateTime,
  dropoff_datetime	DateTime,
  pickup_longitude	Nullable(Float64),
  pickup_latitude 	Nullable(Float64),
  dropoff_longitude   Nullable(Float64),
  dropoff_latitude	Nullable(Float64),
  passenger_count 	UInt8,
  trip_distance   	Float32,
  fare_amount     	Float32,
  extra           	Float32,
  tip_amount      	Float32,
  tolls_amount    	Float32,
  total_amount    	Float32,
  payment_type    	Enum('CSH' = 1, 'CRE' = 2, 'NOC' = 3, 'DIS' = 4, 'UNK' = 5),
  pickup_ntaname  	LowCardinality(String),
  dropoff_ntaname 	LowCardinality(String)
)
ENGINE = MergeTree
ORDER BY (pickup_datetime, dropoff_datetime);

INSERT INTO trips SELECT trip_id, pickup_datetime, dropoff_datetime, pickup_longitude, pickup_latitude, dropoff_longitude, dropoff_latitude, passenger_count, trip_distance, fare_amount, extra, tip_amount, tolls_amount, total_amount, payment_type, pickup_ntaname, dropoff_ntaname
FROM gcs('https://storage.googleapis.com/clickhouse-public-datasets/nyc-taxi/trips_{0..2}.gz', 'TabSeparatedWithNames');

Feature selection

Feature selection is crucial for good clustering as it directly impacts the quality of the clusters formed. We won’t go into detail here on how we selected our features. For those interested, we include the notes in the notebook. We end up with the following points table:

CREATE TABLE points
(
   `id` UInt32,
   `vector` Array(Float32),
   `pickup_hour` UInt8,
   `pickup_day_of_week` UInt8,
   `pickup_day_of_month` UInt8,
   `dropoff_hour` UInt8,
   `pickup_longitude` Float64,
   `pickup_latitude` Float64,
   `dropoff_longitude` Float64,
   `dropoff_latitude` Float64,
   `passenger_count` UInt8,
   `trip_distance` Float32,
   `fare_amount` Float32,
   `total_amount` Float32
) ENGINE = MergeTree ORDER BY id

To populate this table, we use an INSERT INTO SELECT SQL query, which creates the features, scales them, and filters any outliers. Note our final columns are also encoded in a vector column.

The linked query is our first attempt at producing features. We expect more work to be possible here, which might produce better results than those shown. Suggestions are welcome!

A little bit of Python

We have described how an iteration in the algorithm effectively reduces to an INSERT INTO SELECT, with the Materialized View handling the maintenance of the centroids. This means we need to invoke this statement N times until convergence has occurred.

Rather than waiting to reach a state where no points move between centroids, we use a threshold of 1000 i.e. if fewer than 1000 points move clusters, we stop. This check is made every 5 iterations.

The pseudo code for performing K-Means for a specific value of K becomes very simple given most of the work is performed by ClickHouse.

def kmeans(k, report_every = 5, min_cluster_move = 1000):
   startTime = time.time()
   # INITIALIZATION QUERY
   run_init_query(k)
   i = 0
   while True:
       # ITERATION QUERY
       run_iteration_query(k)
       # report every N iterations
       if (i + 1) % report_every == 0 or i == 0:
           num_moved = calculate_points_moved(k)
           if num_moved <= min_cluster_move:
               break
       i += 1
   execution_time = (time.time() - startTime))
   # COMPUTE d^2 ERROR
   d_2_error = compute_d2_error(k)
   # return the d^2, execution time and num of required iterations
   return d_2_error, execution_time, i+1

The full code for this loop, including the queries, can be found in the notebook.

Choosing K

So far, we’ve assumed K has been identified. There are several techniques for determining the optimal value of K, the simplest of which is to compute the aggregate squared distance (SSE) between each point and its respective cluster for each value of K. This gives us a cost metric that we aim to minimize. The method compute_d2_error computes this using the following SQL query (assuming a value of 5 for K):

WITH 5 as k_val,
(
       SELECT max(iteration)
       FROM centroids WHERE k={k}
) AS c_iteration,
(
   SELECT groupArray((centroid, position))
   FROM
   (
       SELECT
           centroid,
           avgForEachMerge(vector) AS position
       FROM centroids
       WHERE iteration = c_iteration AND k=k_val
       GROUP BY centroid
   )
) AS c_centroids
SELECT
   sum(pow((arraySort(c -> (c.2), arrayMap(x -> (x.1, L2Distance(x.2, vector)), c_centroids))[1]).2, 2)) AS distance
FROM points

This value is guaranteed to decrease as we increase K e.g. if we set K to the number of points, each cluster will have 1 point thus giving us an error of 0. Unfortunately, this won’t generalize the data very well!

As K increases, SSE typically decreases because the data points are closer to their cluster centroids. The goal is to find the "elbow point" where the rate of decrease in SSE sharply changes. This point indicates a diminishing return on the benefit of increasing K. Choosing K at the elbow point provides a model that captures the inherent grouping in the data without overfitting. A simple way to identify this elbow point is to plot K vs SEE and identify the value visually. For our NYC taxis data, we measure and plot SSE for the K values 2 to 20:

k_vs_d_2.png

The elbow point here isn’t as clear as we’d like, but a value of 5 seems a reasonable candidate.

The above results are based on a single end-to-end run for each value of K. K-Means can converge to a local minimum with no guarantee that nearby points will end up in the same cluster. It would be advisable to run multiple values for each value of K, each time with different initial centroids, to find the best candidate.

Results

If we select 5 as our value for K, the algorithm takes around 30 iterations and 20 seconds to converge on a 12 vCPU ClickHouse Cloud node. This approach considers all 3 million rows for each iteration.

k=5
initializing...OK
Iteration 0
Number changed cluster in first iteration: 421206
Iteration 1, 2, 3, 4
Number changed cluster in iteration 5: 87939
Iteration 5, 6, 7, 8, 9
Number changed cluster in iteration 10: 3610
Iteration 10, 11, 12, 13, 14
Number changed cluster in iteration 15: 1335
Iteration 15, 16, 17, 18, 19
Number changed cluster in iteration 20: 1104
Iteration 20, 21, 22, 23, 24
Number changed cluster in iteration 25: 390
stopping as moved less than 1000 clusters in last iteration
Execution time in seconds: 20.79200577735901
D^2 error for 5: 33000373.34968858

To visualize these clusters, we need to reduce the dimensionality. For this, we use Principal Component Analysis (PCA). We defer the implementation of PCA in SQL to another blog and just use Python with a sample of 10,000 random points. We can evaluate the effectiveness of PCA in capturing the essential properties of data by checking how much variance the principal components account for. 82% is less than the typically used threshold of 90%, but sufficient for an understanding of the effectiveness of our clustering:

Explained variances of the 3 principal components: 0.824

Using our 3 principal components, we can plot the same random 10,000 points and associate a color with each according to its cluster.

kmeans_1.png

The PCA visualization of the clusters shows a dense plane across PC1 and PC3, neatly divided into four distinct clusters, suggesting constrained variance within these dimensions. Along the 2nd principal component (PC2), the visualization becomes sparser, with a cluster (number 3) that diverges from the main group and could be particularly interesting.

To understand our clusters, we need labels. Ideally, we would produce these by exploring the distribution of every column in each cluster, looking for unique characteristics and temporal/spatial patterns. We’ll try to do this succinctly with a SQL query to understand the distribution of each column in each cluster. For the columns to focus on, we can inspect the values of the PCA components and identify the dimensions that dominate. Code for doing this can be found in the notebook and identifies the following:

PCA1:: ['pickup_day_of_month: 0.9999497049810415', 'dropoff_latitude: -0.006371842399701939', 'pickup_hour: 0.004444108327647353', 'dropoff_hour: 0.003868258226185553', …]

PCA 2:: ['total_amount: 0.5489526881298809', 'fare_amount: 0.5463895585884886', 'pickup_longitude: 0.43181504878694826', 'pickup_latitude: -0.3074228612885196', 'dropoff_longitude: 0.2756342866763702', 'dropoff_latitude: -0.19809343490462433', …]

PCA 3:: ['dropoff_hour: -0.6998176337701472', 'pickup_hour: -0.6995098287872831', 'pickup_day_of_week: 0.1134719682173672', 'pickup_longitude: -0.05495391127067617', …]

For PCA1, pickup_day_of_month is important, suggesting a focus on the time of the month. For PC2, dimensions, the location of pickup and drop off, and the cost of the ride appear to contribute heavily. This component probably focuses on a specific trip type. Finally, for PC3, the hour in which the trip occurred seems the most relevant. To understand how these columns differ per cluster with respect to time, date, and price, we again can just use an SQL query:

WITH
   5 AS k_val,
   (
       SELECT max(iteration)
       FROM centroids
       WHERE k = k_val
   ) AS c_iteration,
   (
       SELECT groupArray((centroid, position))
       FROM
       (
           SELECT
               centroid,
               avgForEachMerge(vector) AS position
           FROM centroids
           WHERE (iteration = c_iteration) AND (k = k_val)
           GROUP BY centroid
       )
   ) AS c_centroids
SELECT
   (arraySort(c -> (c.2), arrayMap(x -> (x.1, L2Distance(x.2, vector)), c_centroids))[1]).1 AS cluster,
   floor(avg(pickup_day_of_month)) AS pickup_day_of_month,
   round(avg(pickup_hour)) AS avg_pickup_hour,
   round(avg(fare_amount)) AS avg_fare_amount,
   round(avg(total_amount)) AS avg_total_amount
FROM points
GROUP BY cluster
ORDER BY cluster ASC

┌─cluster─┬─pickup_day_of_month─┬─avg_pickup_hour─┬─avg_fare_amount─┬─avg_total_amount─┐
│   	011141113 │
│   	13141214 │
│   	218131113 │
│   	316144958 │
│   	426141214 │
└─────────┴─────────────────────┴─────────────────┴─────────────────┴──────────────────┘

9 rows in set. Elapsed: 0.625 sec. Processed 2.95 million rows, 195.09 MB (4.72 million rows/s., 312.17 MB/s.)
Peak memory usage: 720.16 MiB.

Cluster 3 is clearly associated with more expensive trips. Given that the cost of the trip was associated with a principal component, which also identified pickup and drop-off locations as key, these are probably associated with a specific trip type. Other clusters need a deeper analysis but seem to be focused on monthly patterns. We can plot the pickup and drop-off locations for just cluster 3 on a map visualization. Blue and red points represent the pickup and drop-off locations, respectively, in the following plot:

clusters_nyc_map.png

On close inspection of the plot, this cluster is associated with airport trips to and from JFK.

Scaling

Our previous example uses only a 3m row subset of the NYC taxi rides. Testing on a larger dataset for all of taxi rides for 2009 (170m rows), we can complete clustering for k=5 in around 3 mins with a ClickHouse service using 60 cores.

k=5
initializing...OK
…
Iteration 15, 16, 17, 18, 19
Number changed cluster in iteration 20: 288
stopping as moved less than 1000 clusters in last iteration
Execution time in seconds: 178.61135005950928
D^2 error for 5: 1839404623.265372
Completed in 178.61135005950928s and 20 iterations with error 1839404623.265372

This produces similar clusters to our previous smaller subset. Running the same clustering on a 64 core m5d.16xlarge using scikit-learn takes 6132s, over 34x slower! Steps to reproduce this benchmark can be found at the end of the notebook and using these steps for scikit-learn.

Potential improvements & future work

Clustering is very sensitive to the initial points selected. K-Means++ is an improvement over standard K-Means clustering that addresses this by introducing a smarter initialization process that aims to spread out the initial centroids, reducing the likelihood of poor initial centroid placement and leading to faster convergence as well as potentially better clustering. We leave this as an exercise for the reader to improve.

K-Means also struggles to handle categorical variables. This can be partially handled with one-hot encoding (also possible in SQL) as well as dedicated algorithms such as KModes clustering designed for this class of data. Custom distance functions for specific domains instead of just Euclidean distance are also common and should be implementable using User Defined Functions (UDFs).

Finally, it might also be interesting to explore other soft clustering algorithms, such as Gaussian Mixture Models for normally distributed features, or Hierarchical Clustering algorithms, such as Agglomerative clustering. These latter approaches also overcome one of the main limitations of K-Means - the need to specify K. We would love to see attempts to implement these in ClickHouse SQL!

Get started with ClickHouse Cloud today and receive $300 in credits. At the end of your 30-day trial, continue with a pay-as-you-go plan, or contact us to learn more about our volume-based discounts. Visit our pricing page for details.

Share this post

Subscribe to our newsletter

Stay informed on feature releases, product roadmap, support, and cloud offerings!
Loading form...
Follow us
Twitter imageSlack imageGitHub image
Telegram imageMeetup imageRss image