Friday, May 28, 2010

Quick sort

UPDATE: I noticed a lot o traffic coming to this entry.. since it seems popular ill clean up and simplify the code some time later this week(currently its 2010/06/17). Then repost it in a new entry and note the new link here..


Quick sort..

How it works.
1) First choose a pivot value.
2) Then divide the list into 2 sections, do the division by finding the left most item the exceeds the pivot and swapping it with the right most item that is less than the pivot.
3) Then recurse on the 2 halves.

There are several hidden fail points
1) if a bad pivot value is chosen then the list wont divide into 2
2) if the data becomes to similar, or even completely identical then the chances of a bad pivot increase.

As a result of the fail point care needs to be taken to correctly find the right and left swap items and where that the the list is be split.

Speed:
Quick sort is O(nlog(n)) on average. and O(n^2) in the worst case.

//compile with g++
#include <iostream>
#include <iomanip>
using namespace std;

void print(int data[], int size);

void quicksort(int data[], int sectionLow, int sectionHigh)
{
  if(sectionLow >= sectionHigh-1)
    return;

  int low  = sectionLow;
  int high = sectionHigh-1;
  int mid  = (low+high)/2;

  int pivotValue = data[mid];

  while(low < high)
    {
      while(
            (low < high) &&
            (data[low]  < pivotValue) &&
            (data[high] > pivotValue)
            )
        {
          //O(N^2) avoidance.. move the pointers inorder or 1 will start to dominate 
          // when the data starts to get too similar
          low++;
          high--;
        }
      
      //one or both of these wont run.. the other will find its end...
      while((low < high) && (data[low]  < pivotValue)) low++;
      while((low < high) && (data[high] > pivotValue)) high--;

      cout << low << "<>" << high << " ";

      //do the swap
      if(low < high)
        {
          int temp = data[low];
          data[low]  = data[high];
          data[high] = temp;

          //step away from it otherwise it will get stuck in case that these are both pivotValues
          low++;
          high--;
        }
    }

  cout << endl;

  cout << low << " ";
  //quick sort boundary conditions are problem matic
  //make certian of where the end of the data was
  if(data[low] < pivotValue)   mid = low+1;
  else                         mid = low;

  cout << mid << " ";

  //make certain that this isnt the start or end edge or we will end in an infinite loop
  if(mid == sectionHigh) mid--;
  if(mid == sectionLow)  mid++;

  cout << mid << endl;

  cout << "stage result: " << sectionLow
       << "<->" << mid
       << "<->" << sectionHigh << " pivot:" << pivotValue << endl;
  print(&(data[sectionLow]), sectionHigh-sectionLow);

  quicksort(data, sectionLow, mid);
  quicksort(data, mid, sectionHigh);
}

//test
#define SIZE 20
#define SCRAMBLE(x, y) ((0xa57b & ~y) + ((0x3829 & x) << 1))

bool check(int data[], int size)
{
  for(int i = 1; i < size; i++)
    if(data[i] < data[i-1])
      {
        cout << "FAIL!" << endl;
        return false;
      }
  cout << "PASS" << endl;
  return true;
}

void print(int data[], int size)
{
  for(int i = 0; i < size; i++)
    cout << setw(5) << data[i] << " ";
  cout << endl;
}

bool test(int data[], int size)
{
  print(data, SIZE);
  quicksort(data,0, SIZE);
  print(data, SIZE);
  return check(data, SIZE);
}

int main()
{
  int data[SIZE];
  bool pass = true;

  //easy data
  data[0] = 1;
  for(int i = 0; i < SIZE; i++)
    data[i] = SIZE - i;
  pass &= test(data, SIZE);

  //semi repeated data
  data[0] = 1;
  for(int i = 0; i < SIZE; i++)
    data[i] = SCRAMBLE(i, data[i-1]);
  pass &= test(data, SIZE);

  //the sort killer!
  for(int i = 0; i < SIZE; i++)
    data[i] = 5;
  pass &= test(data, SIZE);

  //and some randoms to catch anything i missed
  srand ( time(NULL) );

  for(int j = 0; j < 100; j++)
    {
      for(int i = 0; i < SIZE; i++)
        data[i] = (int)((float)(j+1)*((float)rand()/(float)RAND_MAX));
      pass &= test(data, SIZE);
    }
  
  if(pass)
    cout << "ALL PASSED" << endl;
  else
    cout << "FAILED" << endl;
}

No comments:

Post a Comment