Exercise 9.3.6

The kkth quantiles of an nn-element set are k1k - 1 order statistics that divide the sorted set into kk equal-sized sets (to within 1). Give an O(nlgk)\O(n\lg{k})-time algorithm to list the kkth quantiles of a set.

  1. If k=1k = 1 we return an empty list.
  2. If kk is even, we find the median, partition around it, solve two similar subproblems of size n/2\lfloor n / 2 \rfloor and return their solutions plus the median.
  3. If kk is odd, we find the k/2\lfloor k/2 \rfloor and k/2\lceil k/2 \rceil boundaries and the we reduce to two subproblems, each with size less than n/2n/2. The worst case recurrence is:

T(n,k)=2T(n/2,k/2)+O(n) T(n, k) = 2T(\lfloor n/2 \rfloor, k / 2) + O(n)

Which is the desired bound ­ O(nlgk)\O(n\lg{k}).

This works easily when the number of elements is ak+k1ak + k - 1 for a positive integer aa. When they are a different number, some care with rounding needs to be taken in order to avoid creating two segments that differ by more than 1.


Python code

import math

def k_quantiles(items, k):
    index = median_index(len(items))

    if k == 1:
        return []
    elif k % 2:
        n = len(items)
        left_index  = math.ceil((k // 2) * (n / k)) - 1
        right_index = n - left_index - 1

        left  = select(items, left_index)
        right = select(items, right_index)

        partition(items, left)
        lower = k_quantiles(items[:left], k // 2)
        partition(items, right)
        upper = k_quantiles(items[right + 1:], k // 2)

        return lower + [left, right] + upper
    else:
        index = median_index(len(items))
        median = select(items, index)
        partition(items, median)

        return k_quantiles(items[:index], k // 2) + \
                    [median] + \
                    k_quantiles(items[index + 1:], k // 2)

def median_index(n):
    if n % 2:
        return n // 2
    else:
        return n // 2 - 1

def partition(items, element):
    i = 0

    for j in range(len(items) - 1):
        if items[j] == element:
            items[j], items[-1] = items[-1], items[j]

        if items[j] < element:
            items[i], items[j] = items[j], items[i]
            i += 1

    items[i], items[-1] = items[-1], items[i]

    return i

def select(items, n):
    if len(items) <= 1:
        return items[0]

    medians = []

    for i in range(0, len(items), 5):
        group = sorted(items[i:i + 5])
        items[i:i + 5] = group
        median = group[median_index(len(group))]
        medians.append(median)

    pivot = select(medians, median_index(len(medians)))
    index = partition(items, pivot)

    if n == index:
        return items[index]
    elif n < index:
        return select(items[:index], n)
    else:
        return select(items[index + 1:], n - index - 1)