Skip to content

Merge Sort — Optimize

12 optimizations with before / after in Go, Java, Python.


1. Allocate Aux Buffer Once (Not Per Recursive Call)

Before — Per-call allocation

def merge_sort(arr):
    if len(arr) <= 1: return arr
    mid = len(arr) // 2
    return merge(merge_sort(arr[:mid]), merge_sort(arr[mid:]))

After — One shared buffer

def merge_sort(arr):
    aux = [0] * len(arr)
    _sort(arr, aux, 0, len(arr) - 1)

def _sort(arr, aux, lo, hi):
    if lo >= hi: return
    mid = (lo + hi) // 2
    _sort(arr, aux, lo, mid)
    _sort(arr, aux, mid + 1, hi)
    _merge(arr, aux, lo, mid, hi)
n=100k Before After
Time 70 ms 25 ms (3× faster)
Allocations O(n log n) 1

2. Insertion Sort Cutoff for Small Subarrays

Before — pure merge

def _sort(arr, aux, lo, hi):
    if lo >= hi: return
    mid = (lo + hi) // 2
    _sort(arr, aux, lo, mid)
    _sort(arr, aux, mid + 1, hi)
    _merge(arr, aux, lo, mid, hi)

After — switch to Insertion at n ≤ 16

CUTOFF = 16

def _sort(arr, aux, lo, hi):
    if hi - lo < CUTOFF:
        _insertion(arr, lo, hi)
        return
    mid = (lo + hi) // 2
    _sort(arr, aux, lo, mid)
    _sort(arr, aux, mid + 1, hi)
    _merge(arr, aux, lo, mid, hi)

def _insertion(arr, lo, hi):
    for i in range(lo + 1, hi + 1):
        x = arr[i]; j = i - 1
        while j >= lo and arr[j] > x:
            arr[j+1] = arr[j]; j -= 1
        arr[j+1] = x

Speedup: ~1.5-2× on random input.


3. Skip Merge if Already Sorted

Before

_sort(arr, aux, lo, mid)
_sort(arr, aux, mid + 1, hi)
_merge(arr, aux, lo, mid, hi)  # always

After

_sort(arr, aux, lo, mid)
_sort(arr, aux, mid + 1, hi)
if arr[mid] <= arr[mid + 1]:
    return  # already in order
_merge(arr, aux, lo, mid, hi)

Win: Already-sorted input becomes O(n) instead of O(n log n).


4. Bottom-Up Iterative Merge Sort

Before — recursive (stack-bound)

def merge_sort(arr): ...  # recursive, may overflow on huge n

After — iterative

def merge_sort_bottom_up(arr):
    n = len(arr)
    aux = [0] * n
    size = 1
    while size < n:
        for lo in range(0, n - size, size * 2):
            mid = lo + size - 1
            hi = min(lo + size * 2 - 1, n - 1)
            _merge(arr, aux, lo, mid, hi)
        size *= 2

Win: No stack overflow on huge n. Slightly faster cache behavior.


5. Use Insertion Sort to Build Initial Runs (Bottom-Up)

Before — start with size=1

size = 1
while size < n: ...

After — start with size=16 using insertion

RUN = 16
for i in range(0, n, RUN):
    _insertion(arr, i, min(i + RUN - 1, n - 1))
size = RUN
while size < n: ...

Win: Skip first 4 levels of merging (log₂(16) = 4). Real speedup on random data.


6. Galloping Merge (TimSort Optimization)

When one run consistently wins the merge, exponential search to skip ahead.

After — galloping mode

def merge_galloping(l, r):
    out = []
    i = j = 0
    streak = 0
    while i < len(l) and j < len(r):
        if l[i] <= r[j]:
            out.append(l[i]); i += 1
            streak += 1
        else:
            out.append(r[j]); j += 1
            streak = 0
        if streak >= 7:  # MIN_GALLOP
            # binary-search how many more from l fit before r[j]
            import bisect
            k = bisect.bisect_right(l, r[j], i)
            out.extend(l[i:k]); i = k
            streak = 0
    out.extend(l[i:]); out.extend(r[j:])
    return out

Win: When runs are very imbalanced (e.g., one all-small + one all-large), 2-5× faster.


7. Eliminate Sentinel Comparison Overhead

Before (CLRS sentinel form):

def merge(arr, lo, mid, hi):
    L = arr[lo:mid+1] + [float('inf')]
    R = arr[mid+1:hi+1] + [float('inf')]
    # ... main loop without bounds checks

After (no-sentinel)

def merge(arr, aux, lo, mid, hi):
    for k in range(lo, hi+1): aux[k] = arr[k]
    i, j = lo, mid + 1
    for k in range(lo, hi + 1):
        if   i > mid: arr[k] = aux[j]; j += 1
        elif j > hi:  arr[k] = aux[i]; i += 1
        elif aux[i] <= aux[j]: arr[k] = aux[i]; i += 1
        else: arr[k] = aux[j]; j += 1

Win: No sentinel allocation; explicit bounds may JIT/optimize better.


8. Branch-Free Merge (SIMD-Friendly)

Before — branchy

if l[i] <= r[j]: out.append(l[i]); i += 1
else: out.append(r[j]); j += 1

After — branchless (Python doesn't gain; C/Go do)

// In C-like syntax
int take_left = l[i] <= r[j];
out[k] = take_left ? l[i] : r[j];
i += take_left;
j += !take_left;

Win: Eliminates branch mispredictions on random data. Modest gain (~10%) in compiled languages.


9. Parallel Merge Sort (Multi-Core)

After — fork halves to threads

Go

func ParallelMergeSort(arr []int) []int {
    if len(arr) < 10000 { return SequentialMergeSort(arr) }
    mid := len(arr) / 2
    var l, r []int
    var wg sync.WaitGroup
    wg.Add(2)
    go func() { defer wg.Done(); l = ParallelMergeSort(arr[:mid]) }()
    go func() { defer wg.Done(); r = ParallelMergeSort(arr[mid:]) }()
    wg.Wait()
    return merge(l, r)
}

Win: 3-5× speedup on 8 cores for n > 10⁶.


10. External Merge Sort (Data > RAM)

After — chunk to disk + k-way merge

import heapq, tempfile, os

def external_sort(input_iter, output_path, chunk=100_000):
    runs = []
    buf = []
    for x in input_iter:
        buf.append(x)
        if len(buf) >= chunk:
            buf.sort()
            f = tempfile.NamedTemporaryFile(mode='w', delete=False)
            f.write('\n'.join(map(str, buf)) + '\n'); f.close()
            runs.append(f.name)
            buf = []
    if buf:
        buf.sort()
        f = tempfile.NamedTemporaryFile(mode='w', delete=False)
        f.write('\n'.join(map(str, buf)) + '\n'); f.close()
        runs.append(f.name)
    its = [(int(x) for x in open(p)) for p in runs]
    with open(output_path, 'w') as out:
        for v in heapq.merge(*its):
            out.write(f"{v}\n")
    for p in runs: os.unlink(p)

Win: Sorts terabytes with megabytes of RAM. Standard for databases.


11. Replace arr[:mid] with Index Recursion (Python Specific)

Before — slice copies

def merge_sort(arr):
    if len(arr) <= 1: return arr
    mid = len(arr) // 2
    return merge(merge_sort(arr[:mid]), merge_sort(arr[mid:]))

After — index-only

def merge_sort(arr):
    aux = [0] * len(arr)
    _sort(arr, aux, 0, len(arr) - 1)

Win: Eliminates O(n log n) extra slice allocations.


12. Use Built-in heapq.merge for k-Way

Before — manual heap

def k_way_merge(streams):
    heap = []
    iters = [iter(s) for s in streams]
    for i, it in enumerate(iters):
        try: heapq.heappush(heap, (next(it), i))
        except StopIteration: pass
    while heap:
        val, i = heapq.heappop(heap)
        yield val
        try: heapq.heappush(heap, (next(iters[i]), i))
        except StopIteration: pass

After — built-in

import heapq
def k_way_merge(streams):
    yield from heapq.merge(*streams)

Win: Same performance, fewer lines, well-tested.


Summary

# Optimization Impact
1 One shared aux buffer 3× speedup
2 Insertion cutoff 1.5-2×
3 Skip merge if sorted O(n) on sorted
4 Bottom-up iterative No stack overflow
5 Initial Insertion runs Skip 4 merge levels
6 Galloping merge 2-5× on skewed runs
7 No sentinels Cleaner JIT
8 Branch-free merge ~10% in C/Go
9 Parallel 3-5× on 8 cores
10 External sort Sort > RAM data
11 Index-based recursion (Py) Eliminate slice copies
12 Built-in heapq.merge Less code

Final benchmark (Go n=100k random):

Vanilla Merge Sort         : 18 ms
+ shared aux               :  9 ms
+ insertion cutoff         :  6 ms
+ skip-if-sorted           :  6 ms (random — no win)
Bottom-up                  :  9 ms
TimSort (Java equivalent)  :  8 ms
Pdqsort (sort.Ints)        :  4 ms ← winner

Lesson: Even fully-optimized Merge Sort is ~50% slower than Pdqsort for random in-memory numeric data. Merge Sort wins on stability, worst-case guarantee, and external/parallel scenarios.