Exercise 9.3.6

The $k$th quantiles of an $n$-element set are $k - 1$ order statistics that divide the sorted set into $k$ equal-sized sets (to within 1). Give an $\O(n\lg{k})$-time algorithm to list the $k$th quantiles of a set.

  1. If $k = 1$ we return an empty list.
  2. If $k$ is even, we find the median, partition around it, solve two similar subproblems of size $\lfloor n / 2 \rfloor$ and return their solutions plus the median.
  3. If $k$ is odd, we find the $\lfloor k/2 \rfloor$ and $\lceil k/2 \rceil$ boundaries and the we reduce to two subproblems, each with size less than $n/2$. The worst case recurrence is:

$$ T(n, k) = 2T(\lfloor n/2 \rfloor, k / 2) + O(n) $$

Which is the desired bound ­ $\O(n\lg{k})$.

This works easily when the number of elements is $ak + k - 1$ for a positive integer $a$. When they are a different number, some care with rounding needs to be taken in order to avoid creating two segments that differ by more than 1.


Python code

import math

def k_quantiles(items, k):
    index = median_index(len(items))

    if k == 1:
        return []
    elif k % 2:
        n = len(items)
        left_index  = math.ceil((k // 2) * (n / k)) - 1
        right_index = n - left_index - 1

        left  = select(items, left_index)
        right = select(items, right_index)

        partition(items, left)
        lower = k_quantiles(items[:left], k // 2)
        partition(items, right)
        upper = k_quantiles(items[right + 1:], k // 2)

        return lower + [left, right] + upper
    else:
        index = median_index(len(items))
        median = select(items, index)
        partition(items, median)

        return k_quantiles(items[:index], k // 2) + \
                    [median] + \
                    k_quantiles(items[index + 1:], k // 2)

def median_index(n):
    if n % 2:
        return n // 2
    else:
        return n // 2 - 1

def partition(items, element):
    i = 0

    for j in range(len(items) - 1):
        if items[j] == element:
            items[j], items[-1] = items[-1], items[j]

        if items[j] < element:
            items[i], items[j] = items[j], items[i]
            i += 1

    items[i], items[-1] = items[-1], items[i]

    return i

def select(items, n):
    if len(items) <= 1:
        return items[0]

    medians = []

    for i in range(0, len(items), 5):
        group = sorted(items[i:i + 5])
        items[i:i + 5] = group
        median = group[median_index(len(group))]
        medians.append(median)

    pivot = select(medians, median_index(len(medians)))
    index = partition(items, pivot)

    if n == index:
        return items[index]
    elif n < index:
        return select(items[:index], n)
    else:
        return select(items[index + 1:], n - index - 1)