From f4b9405d585abdb2b2e891637309ec39ef821aac Mon Sep 17 00:00:00 2001 From: Bryan Yue Date: Mon, 5 Jan 2026 22:38:55 -0800 Subject: [PATCH] Quick Select --- src/sort/quick_sort.py | 38 ++++++++++++++++++++++++++----------- tst/sort/test_quick_sort.py | 19 ++++++++++++++++++- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/sort/quick_sort.py b/src/sort/quick_sort.py index a76c77c..725a43b 100644 --- a/src/sort/quick_sort.py +++ b/src/sort/quick_sort.py @@ -1,17 +1,18 @@ import random +def get_partition(array, start_idx, end_idx): + pivot_idx = random.randint(start_idx, end_idx) + pivot = array[pivot_idx] + array[pivot_idx], array[end_idx] = array[end_idx], array[pivot_idx] + i = start_idx - 1 + for j in range(start_idx, end_idx): + if array[j] <= pivot: + i += 1 + array[i], array[j] = array[j], array[i] + array[i + 1], array[end_idx] = array[end_idx], array[i + 1] + return i + 1 + def sort(array): - def get_partition(array, start_idx, end_idx): - pivot_idx = random.randint(start_idx, end_idx) - pivot = array[pivot_idx] - array[pivot_idx], array[end_idx] = array[end_idx], array[pivot_idx] - i = start_idx - 1 - for j in range(start_idx, end_idx): - if array[j] <= pivot: - i += 1 - array[i], array[j] = array[j], array[i] - array[i + 1], array[end_idx] = array[end_idx], array[i + 1] - return i + 1 def quick_sort(array, start_idx, end_idx): if start_idx < end_idx: @@ -20,3 +21,18 @@ def quick_sort(array, start_idx, end_idx): quick_sort(array, partition + 1, end_idx) quick_sort(array, 0, len(array) - 1) + +def select(array, k): + """Return the k-th smallest element (0-indexed) using quickselect.""" + if k < 0 or k >= len(array): + return None + + l, r = 0, len(array) - 1 + while l <= r: + pivot_idx = get_partition(array, l, r) + if pivot_idx == k: + return array[pivot_idx] + if pivot_idx < k: + l = pivot_idx + 1 + else: + r = pivot_idx - 1 \ No newline at end of file diff --git a/tst/sort/test_quick_sort.py b/tst/sort/test_quick_sort.py index ed4753a..2be876a 100644 --- a/tst/sort/test_quick_sort.py +++ b/tst/sort/test_quick_sort.py @@ -1,4 +1,4 @@ -from src.sort.quick_sort import sort +from src.sort.quick_sort import select, sort class TestQuickSort: def test_empty_sort(self): @@ -32,3 +32,20 @@ def fake_randint(start_idx, end_idx): assert array == [2, 3, 5, 8, 10] assert calls[0] == (0, 4) assert all(start <= end for start, end in calls) + + def test_select_basic(self): + array = [7, 1, 5, 3, 9, 2] + assert select(array[:], 0) == 1 + assert select(array[:], 3) == 5 + assert select(array[:], 5) == 9 + + def test_select_with_duplicates(self): + array = [4, 2, 5, 2, 3, 4, 1] + assert select(array[:], 0) == 1 + assert select(array[:], 2) == 2 + assert select(array[:], 4) == 4 + + def test_select_invalid_index(self): + array = [10, 20, 30] + assert select(array[:], -1) is None + assert select(array[:], 3) is None