import java.util.Arrays;

public class MergeSort
{
  public static void mergeSort(int a[], int n)
  {
System.out.println("mergeSort called for:");
System.out.println(Arrays.toString(a));
    if (n < 2)
    {
      return;
    }
    if (n == 2)
    {
      if(a[0] > a[1])
      {
        int temp = a[0];
        a[0] = a[1];
        a[1] = temp;
      }
      return;
    }
    int ssize = n/2;  //size of left half
    int dsize = n - ssize; //size of right half
    int s[] = new int[ssize];
    int d[] = new int[dsize];
   
//copy the data into the half arrays
    int sindex = 0;
    int dindex = 0;

    for(int i = 0; i < n; i++)
    {
      if(i < ssize)
      {
        s[sindex] = a[i];
        sindex++;
      }
      else
      {
        d[dindex] = a[i];
        dindex++;
      }
    }
    
    mergeSort(s, ssize);
System.out.println("Returned left half");
System.out.println(Arrays.toString(s));
    mergeSort(d, dsize);
System.out.println("Returned right half");
System.out.println(Arrays.toString(d));

// merge
    sindex = dindex = 0;
    int i = 0;
    while(sindex < ssize && dindex < dsize)
    {
      if(s[sindex] < d[dindex])
      {
        a[i] = s[sindex];
        sindex++;
      }
      else
      {
        a[i] = d[dindex];
        dindex++;
      }
      i++;
    }
    while(sindex < ssize)
    {
      a[i] = s[sindex];
      sindex++;
      i++;
    }
    while(dindex < dsize)
    {
      a[i] = d[dindex];
      dindex++;
      i++;
    }
  }   
  
  public static void main(String args[])
  {
    int vals[] = {3, 12, 9, 7, 8, 5, 4, 1};
    
    System.out.println(Arrays.toString(vals));
    mergeSort(vals, 8);
    System.out.println("sortedL");
    System.out.println(Arrays.toString(vals));
  }
}    

