diff --git a/src/sort/dutch_flag_sort.py b/src/sort/dutch_flag_sort.py new file mode 100644 index 0000000..1948b26 --- /dev/null +++ b/src/sort/dutch_flag_sort.py @@ -0,0 +1,49 @@ +import random + +def get_partition(array, start_idx, end_idx): + # works well on duplicate elements, where regular quicksort degrades to quadratic performance due to bad partitions + lt = start_idx + eq = start_idx + gt = end_idx + pivot = array[random.randint(start_idx, end_idx)] + while eq <= gt: + if array[eq] == pivot: + eq += 1 + elif array[eq] > pivot: + array[eq], array[gt] = array[gt], array[eq] + gt -= 1 + else: + array[eq], array[lt] = array[lt], array[eq] + eq += 1 + lt += 1 + return lt - 1, gt + 1 + +def sort(array): + def quick_sort(array, start_idx, end_idx): + if start_idx < end_idx: + l, r = get_partition(array, start_idx, end_idx) + quick_sort(array, start_idx, l) + quick_sort(array, r, end_idx) + quick_sort(array, 0, len(array) - 1) + +def select(array, k): + """Return the k-th smallest element (0-indexed) using quickselect.""" + if k < 0 or k >= len(array): + return None + + l, r = 0, len(array) - 1 + + while l <= r: + left_end, right_start = get_partition(array, l, r) + + # equal region is [left_end + 1 .. right_start - 1] + eq_l = left_end + 1 + eq_r = right_start - 1 + + if k < eq_l: + r = left_end # search in "< pivot" + elif k > eq_r: + l = right_start # search in "> pivot" + else: + # k is inside the "== pivot" block + return array[k] \ No newline at end of file diff --git a/tst/sort/test_dutch_flag_sort.py b/tst/sort/test_dutch_flag_sort.py new file mode 100644 index 0000000..2119265 --- /dev/null +++ b/tst/sort/test_dutch_flag_sort.py @@ -0,0 +1,51 @@ +from src.sort.dutch_flag_sort import select, sort + +class TestQuickSort: + def test_empty_sort(self): + array = [] + sort(array) + assert array == [] + + def test_sorted_and_reverse_inputs(self): + sorted_array = [1, 2, 3, 4, 5] + reverse_array = list(reversed(sorted_array)) + sort(sorted_array) + sort(reverse_array) + assert sorted_array == [1, 2, 3, 4, 5] + assert reverse_array == [1, 2, 3, 4, 5] + + def test_duplicates_and_negatives(self): + array = [3, -1, 0, -1, 5, 3, 2] + sort(array) + assert array == [-1, -1, 0, 2, 3, 3, 5] + + def test_random_pivot_usage(self, monkeypatch): + calls = [] + + def fake_randint(start_idx, end_idx): + calls.append((start_idx, end_idx)) + return start_idx # pick first element to force swaps with pivot + + monkeypatch.setattr("src.sort.quick_sort.random.randint", fake_randint) + array = [10, 5, 8, 3, 2] + sort(array) + assert array == [2, 3, 5, 8, 10] + assert calls[0] == (0, 4) + assert all(start <= end for start, end in calls) + + def test_select_basic(self): + array = [7, 1, 5, 3, 9, 2] + assert select(array[:], 0) == 1 + assert select(array[:], 3) == 5 + assert select(array[:], 5) == 9 + + def test_select_with_duplicates(self): + array = [4, 2, 5, 2, 3, 4, 1] + assert select(array[:], 0) == 1 + assert select(array[:], 2) == 2 + assert select(array[:], 4) == 4 + + def test_select_invalid_index(self): + array = [10, 20, 30] + assert select(array[:], -1) is None + assert select(array[:], 3) is None