常见排序算法C++实现


C++ 实现

#pragma once

void swap(int& a, int& b)
{
    if (a != b) {
        a ^= b;
        b ^= a;
        a ^= b;
    }
}

void get_minmax(int* arr, int n, int& mi, int& ma)
{
    mi = ma = arr[0];
    for (int i = 1; i < n; ++i) {
        if (arr[i] < mi) {
            mi = arr[i];
        } else if (arr[i] > ma) {
            ma = arr[i];
        }
    }
}

// https://www.bigocheatsheet.com/

// unstable, time: O(n^2) <= O(n^2) <= O(n^2), space: O(1)
void selection_sort(int* arr, int n)
{
    for (int i = 0; i < n; ++i) {
        int p = i;
        for (int j = i + 1; j < n; ++j) {
            if (arr[j] < arr[p]) {
                p = j;
            }
        }
        swap(arr[i], arr[p]);
    }
}

// stable, time: O(n) <= O(n^2) <= O(n^2), space: O(1)
void bubble_sort(int* arr, int n)
{
    for (; n > 1; --n) {
        bool swapped = false;
        for (int i = 1; i < n; ++i) {
            if (arr[i - 1] > arr[i]) {
                swap(arr[i - 1], arr[i]);
                swapped = true;
            }
        }
        if (!swapped) {
            break;
        }
    }
}

// stable, time: O(n) <= O(n^2) <= O(n^2), space: O(1)
void insertion_sort(int* arr, int n)
{
    for (int i = 1; i < n; ++i) {
        int tmp = arr[i];
        int j = i;
        for (; j > 0 && arr[j - 1] > tmp; --j) {
            arr[j] = arr[j - 1];
        }
        arr[j] = tmp;
    }
}

// unstable, time: O(nlogn) <= O((nlogn)^2) <= O((nlogn)^2), space: O(1)
void shell_sort(int* arr, int n)
{
    for (int gap = n / 2; gap > 0; gap /= 2) {
        for (int i = gap; i < n; ++i) {
            int tmp = arr[i];
            int j = i;
            for (; j >= gap && arr[j - gap] > tmp; j -= gap) {
                arr[j] = arr[j - gap];
            }
            arr[j] = tmp;
        }
    }
}

// unstable, time: O(nlogn) <= O(nlogn) <= O(n^2), space: O(logn)
void quick_sort(int* arr, int n)
{
    struct helper {
        static int partition(int* arr, int low, int high)
        {
            int pivot = arr[high];
            int i = low - 1;
            for (int j = low; j < high; ++j) {
                if (arr[j] < pivot) {
                    swap(arr[++i], arr[j]);
                }
            }
            swap(arr[++i], arr[high]);
            return i;
        }

        static void quick_sort(int* arr, int low, int high)
        {
            if (low < high) {
                int pi = partition(arr, low, high);
                quick_sort(arr, low, pi - 1);
                quick_sort(arr, pi + 1, high);
            }
        }
    };

    helper::quick_sort(arr, 0, n - 1);
}

// stable, time: O(nlogn) <= O(nlogn) <= O(nlogn), space: O(n)
void merge_sort(int* arr, int n)
{
    struct helper {
        int* buf;

        helper(int n) { buf = new int[n]; }
        ~helper() { delete[] buf; }

        void merge_sort(int* arr, int low, int high)
        {
            if (low < high) {
                int mid = low + (high - low) / 2;
                merge_sort(arr, low, mid);
                merge_sort(arr, mid + 1, high);
                merge(arr, low, mid, high);
            }
        }

        void merge(int* arr, int low, int mid, int high)
        {
            for (int i = 0, j = low; j <= high; ++i, ++j) {
                buf[i] = arr[j];
            }

            int left = mid - low + 1, right = high - low + 1;
            int i = 0, j = left, k = low;
            for (; i < left && j < right; ++k) {
                if (buf[i] <= buf[j]) {
                    arr[k] = buf[i++];
                } else {
                    arr[k] = buf[j++];
                }
            }
            for (; i < left; ++i, ++k) {
                arr[k] = buf[i];
            }
            for (; j < right; ++j, ++k) {
                arr[k] = buf[j];
            }
        }
    };

    if (n < 2) {
        return;
    }
    helper h(n);
    h.merge_sort(arr, 0, n - 1);
}

// unstable, time: O(nlogn) <= O(nlogn) <= O(nlogn), space: O(1)
void heap_sort(int* arr, int n)
{
    struct helper {
        static void heapify(int* arr, int n, int root)
        {
            int largest = root, l = root * 2 + 1, r = root * 2 + 2;
            if (l < n && arr[l] > arr[largest]) {
                largest = l;
            }
            if (r < n && arr[r] > arr[largest]) {
                largest = r;
            }
            if (largest != root) {
                swap(arr[root], arr[largest]);
                heapify(arr, n, largest);
            }
        }
    };

    for (int i = n / 2 - 1; i >= 0; --i) {
        helper::heapify(arr, n, i);
    }
    for (int i = n - 1; i >= 0; --i) {
        swap(arr[0], arr[i]);
        helper::heapify(arr, i, 0);
    }
}

// stable, time: O(n+1) <= O(n+k) <= O(n+k), space: O(n+k)
void counting_sort(int* arr, int n)
{
    if (n < 2) {
        return;
    }

    int mi = 0, ma = 0;
    get_minmax(arr, n, mi, ma);

    int range = ma - mi + 1;

    int* count = new int[range]();
    for (int i = 0; i < n; ++i) {
        ++count[arr[i] - mi];
    }
    for (int i = 1; i < range; ++i) {
        count[i] += count[i - 1];
    }

    int* tmp = new int[n];
    for (int i = n - 1; i >= 0; --i) {
        tmp[--count[arr[i] - mi]] = arr[i];
    }
    for (int i = 0; i < n; ++i) {
        arr[i] = tmp[i];
    }

    delete[] count;
    delete[] tmp;
}

// stable, time: O(nk) <= O(nk) <= O(nk), space: O(n+k)
void radix_sort(int* arr, int n)
{
    struct helper {
        int* buf;
        helper(int n) { buf = new int[n]; }
        ~helper() { delete[] buf; }

        void counting_sort(int* arr, int n, int mi, int exp) {
            int count[10] = { 0 };

            for (int i = 0; i < n; ++i) {
                ++count[((arr[i] - mi) / exp) % 10];
            }
            for (int i = 1; i < 10; ++i) {
                count[i] += count[i - 1];
            }
            for (int i = 0; i < n; ++i) {
                buf[i] = 0;
            }
            for (int i = n - 1; i >= 0; --i) {
                buf[--count[((arr[i] - mi) / exp) % 10]] = arr[i];
            }
            for (int i = 0; i < n; ++i) {
                arr[i] = buf[i];
            }
        }
    };

    if (n < 2) {
        return;
    }

    int mi = 0, ma = 0;
    get_minmax(arr, n, mi, ma);
    int range = ma - mi + 1;
    helper h(n);
    for (int exp = 1; range / exp > 0; exp *= 10) {
        h.counting_sort(arr, n, mi, exp);
    }
}

// stable, time: O(n+k) <= O(n+k) <= O(n^2), space: O(n)
void bucket_sort(int* arr, int n)
{
    int n_bucket = 16;
    int** buckets = new int*[n_bucket];
    int* nums = new int[n_bucket]();
    for (int i = 0; i < n_bucket; ++i) {
        buckets[i] = new int[n]();
    }

    int mi = 0, ma = 0;
    get_minmax(arr, n, mi, ma);
    int range = ma - mi + 1;
    for (int i = 0; i < n; ++i) {
        float v = (arr[i] - mi) * 1.0f / range;
        int index = int(v * n_bucket);
        buckets[index][nums[index]++] = arr[i];
    }
    for (int i = 0; i < n_bucket; ++i) {
        quick_sort(buckets[i], nums[i]);
    }
    int p = 0;
    for (int i = 0; i < n_bucket; ++i) {
        for (int j = 0; j < nums[i]; ++j) {
            arr[p++] = buckets[i][j];
        }
    }

    for (int i = 0; i < n_bucket; ++i) {
        delete[] buckets[i];
    }
    delete[] nums;
    delete[] buckets;
}

测试代码

#include "sort.h"
#include <algorithm>
#include <cassert>
#include <iostream>
#include <iterator>
#include <numeric>
#include <random>
#include <vector>
using namespace std;

void test_sort(int n, void (*f)(int*, int), const char* name)
{
    vector<int> v(n);
    iota(v.begin(), v.end(), -n / 2);
    shuffle(v.begin(), v.end(), mt19937 { random_device {}() });

    cout << name << ":" << endl;
    copy(v.begin(), v.end(), ostream_iterator<int>(cout, " "));
    cout << endl;

    f(v.data(), n);

    copy(v.begin(), v.end(), ostream_iterator<int>(cout, " "));
    cout << endl
         << endl;

    assert(is_sorted(v.begin(), v.end()));
    for (int i = 1; i < n; ++i) {
        assert(v[i - 1] + 1 == v[i]);
    }
}

int main()
{
    test_sort(9, selection_sort, "selection_sort");
    test_sort(9, insertion_sort, "insertion_sort");
    test_sort(9, bubble_sort, "bubble_sort");
    test_sort(9, shell_sort, "shell_sort");
    test_sort(9, quick_sort, "quick_sort");
    test_sort(9, merge_sort, "merge_sort");
    test_sort(9, heap_sort, "heap_sort");
    test_sort(9, counting_sort, "counting_sort");
    test_sort(9, radix_sort, "radix_sort");
    test_sort(9, bucket_sort, "bucket_sort");
    return 0;
}

算法复杂度

https://www.bigocheatsheet.com/


文章作者: Kiba Amor
版权声明: 本博客所有文章除特別声明外,均采用 CC BY-NC-ND 4.0 许可协议。转载请注明来源 Kiba Amor !
  目录