//
// Programmer:    Craig Stuart Sapp <craig@ccrma.stanford.edu>
// Creation Date: Sun Jan  6 10:03:34 PST 2008
// Last Modified: Sun Jan  6 10:03:36 PST 2008
// Filename:      ...sig/examples/all/keyboundary.cpp
// Web Address:   http://sig.sapp.org/examples/museinfo/humdrum/keyboundary.cpp
// Syntax:        C++; museinfo
//
//  Description: Optimizing program for Krumhansl-Schmuckler key-finding 
//               algorithm at modulation points.

#include <time.h>
#include <math.h>
#include "humdrum.h"

#ifndef OLDCPP
   #include <iostream>
#else
   #include <iostream.h>
#endif


// function declarations:
void      checkOptions(Options& opts, int argc, char** argv);
void      example(void);
void      usage(const char* command);
void      fillBaseHistogram(Array<Array<double> >& beathist, 
                                 HumdrumFile& infile);
void      printHistograms(Array<Array<Array<double> > >& histograms);
void      printBaseHistograms(Array<Array<double> >& histograms);
void      addToHistogram(Array<Array<double> >& histogram, int pc, 
                                 double start, double dur);
void      fillAllHistograms(Array<Array<Array<double> > >& histograms);
void      fillCorrectData(Array<Array<int> >& correct, 
                                 HumdrumFile& infile);
double    bareCorrelation(int size, double* x, double* y);
double    pearsonCorrelation(int size, double* x, double* y);
void      printWeights(Array<double> weights);
void      runSearch(Array<double>& weights, 
                                 Array<Array<Array<Array<double> > > >& 
                                       histograms,
                                 Array<Array<Array<int> > >& correct, 
                                 int trialCount, 
                                 int stopLimit, double amplitude, 
                                 double decay);
double    getScore(Array<Array<int> >& correct, 
                                 Array<Array<int> >& answers);
double    getDistance(Array<double>& a, Array<double>& b);
void      makeAnswers(Array<Array<int> >& answers, 
                                 Array<Array<Array<double> > >& histograms, 
                                 Array<double>& weights);
int       findBestKey(Array<double>& histogram, 
                                 Array<double>& weights);
void      makeRandomWeights(Array<double>& weights, 
		                 Array<double>& initweights, double range);
void      normalizeWeights(Array<double>& weights);
void      printAnswerHistogram(Array<Array<int> >& answers,
                                 Array<Array<Array<double> > >& histograms, 
                                 Array<double>& weights);
void      printPPM(Array<Array<int> >& matrix);
void      generatePicture(Array<Array<int> >& answers);
void      printCorrect(Array<Array<int> >& correct);
void      getTwoKeys(int& lefthalf, int& righthalf, 
                                 HumdrumFile& infile);
double    getStandardDeviation(double mean, Array<double>& data);
double    getMean(int size, double* a);
double    getBoundaryLeft(int correct, Array<int>& answers);
double    getBoundaryRight(int correct, Array<int>& answers);

// User interface variables:
Options   options;
int       beatHistogramQ     = 0;    // used with -b option
int       allHistogramsQ     = 0;    // used with -a option
int       trialCount         = 100;  // used with -c option
double    decay              = 0.95; // used with -d option
double    amplitude          = 1.0;  // used with -s option
int       stopLimit          = 10;   // used with -l option
int       seed               = 0;   
int       simpleQ            = 0;    // used with --simple
int       aardenQ            = 0;    // used with --aarden
int       krumhanslQ         = 0;    // used with --krumhansl
int       temperleyQ         = 0;    // used with --temperley
int       hybridQ            = 0;    // used with --hybrid
int       pictureQ           = 0;    // used with --picture
Array<double> inputweights;          // used with -w


//////////////////////////////////////////////////////////////////////////

int main(int argc, char** argv) {
   inputweights.setSize(0);
   checkOptions(options, argc, argv); // process the command-line options

   int i, j;

   Array<double> weights(24);
   weights.zero();
   // set the starting position of the key-position weightings
   // if not using all zeros:
   if (inputweights.getSize() == 24) {
      weights = inputweights;
   } else if (simpleQ) {
      weights[0]  = 2;  // C	(major)
      weights[1]  = 0;  // C#	(major)
      weights[2]  = 1;  // D	(major)
      weights[3]  = 0;  // E-	(major)
      weights[4]  = 1;  // E	(major)
      weights[5]  = 1;  // F	(major)
      weights[6]  = 0;  // F#	(major)
      weights[7]  = 2;  // G	(major)
      weights[8]  = 0;  // A-	(major)
      weights[9]  = 1;  // A	(major)
      weights[10] = 0;  // B-	(major)
      weights[11] = 1;  // B	(major)
      weights[12] = 2;  // c	(minor)
      weights[13] = 0;  // c#	(minor)
      weights[14] = 1;  // d	(minor)
      weights[15] = 1;  // e-	(minor)
      weights[16] = 0;  // e	(minor)
      weights[17] = 1;  // f	(minor)
      weights[18] = 0;  // f#	(minor)
      weights[19] = 2;  // g	(minor)
      weights[20] = 1;  // a-	(minor)
      weights[21] = 0;  // a	(minor)
      weights[22] = 1;  // b-	(minor)
      weights[23] = 0;  // b	(minor)
   } else if (aardenQ) {
      weights[0]  = 17.7661;	// C	(major)
      weights[1]  =  0.145624;	// C#	(major)
      weights[2]  = 14.9265;	// D	(major)
      weights[3]  =  0.160186;	// D#	(major)
      weights[4]  = 19.8049;	// E	(major)
      weights[5]  = 11.3587;	// F	(major)
      weights[6]  =  0.291248;	// F#	(major)
      weights[7]  = 22.062;	// G	(major)
      weights[8]  =  0.145624;	// G#	(major)
      weights[9]  =  8.15494;	// A	(major)
      weights[10] =  0.232998;	// A#	(major)
      weights[11] =  4.95122;	// B	(major)
      weights[12] = 18.2648;	// c	(minor)
      weights[13] =  0.737619;	// c#	(minor)
      weights[14] = 14.0499;	// d	(minor)
      weights[15] = 16.8599;	// e-	(minor)
      weights[16] =  0.702494;	// e	(minor)
      weights[17] = 14.4362;	// f	(minor)
      weights[18] =  0.702494;	// f#	(minor)
      weights[19] = 18.6161;	// g	(minor)
      weights[20] =  4.56621;	// a-	(minor)
      weights[21] =  1.93186;	// a	(minor)
      weights[22] =  7.37619;	// b-	(minor)
      weights[23] =  1.75623;	// b	(minor)
   } else if (krumhanslQ) {
      weights[0]  = 6.35;	// C	(major)
      weights[1]  = 2.23;	// C#	(major)
      weights[2]  = 3.48;	// D	(major)
      weights[3]  = 2.33;	// D#	(major)
      weights[4]  = 4.38;	// E	(major)
      weights[5]  = 4.09;	// F	(major)
      weights[6]  = 2.52;	// F#	(major)
      weights[7]  = 5.19;	// G	(major)
      weights[8]  = 2.39;	// G#	(major)
      weights[9]  = 3.66;	// A	(major)
      weights[10] = 2.29;	// A#	(major)
      weights[11] = 2.88;	// B	(major)
      weights[12] = 6.33;	// c	(minor)
      weights[13] = 2.68;	// c#	(minor)
      weights[14] = 3.52;	// d	(minor)
      weights[15] = 5.38;	// e-	(minor)
      weights[16] = 2.60;	// e	(minor)
      weights[17] = 3.53;	// f	(minor)
      weights[18] = 2.54;	// f#	(minor)
      weights[19] = 4.75;	// g	(minor)
      weights[20] = 3.98;	// a-	(minor)
      weights[21] = 2.69;	// a	(minor)
      weights[22] = 3.34;	// b-	(minor)
      weights[23] = 3.17;	// b	(minor)
   } else if (temperleyQ) {
      weights[0]  = 5.0;	// C	(major)
      weights[1]  = 2.0;	// C#	(major)
      weights[2]  = 3.5;	// D	(major)
      weights[3]  = 2.0;	// D#	(major)
      weights[4]  = 4.5;	// E	(major)
      weights[5]  = 4.0;	// F	(major)
      weights[6]  = 2.0;	// F#	(major)
      weights[7]  = 4.5;	// G	(major)
      weights[8]  = 2.0;	// G#	(major)
      weights[9]  = 3.5;	// A	(major)
      weights[10] = 1.5;	// A#	(major)
      weights[11] = 4.0;	// B	(major)
      weights[12] = 5.0;	// c	(minor)
      weights[13] = 2.0;	// c#	(minor)
      weights[14] = 3.5;	// d	(minor)
      weights[15] = 4.5;	// e-	(minor)
      weights[16] = 2.0;	// e	(minor)
      weights[17] = 4.0;	// f	(minor)
      weights[18] = 2.0;	// f#	(minor)
      weights[19] = 4.5;	// g	(minor)
      weights[20] = 3.5;	// a-	(minor)
      weights[21] = 2.0;	// a	(minor)
      weights[22] = 1.5;	// b-	(minor)
      weights[23] = 4.0;	// b	(minor)
   } else if (hybridQ) {
      weights[0]  = 11.38305;	// C    (major)
      weights[1]  =  1.072812;	// C#   (major)
      weights[2]  =  9.21325;	// D    (major)
      weights[3]  =  1.080093;	// E-   (major)
      weights[4]  = 12.15245;	// E    (major)
      weights[5]  =  7.67935;	// F    (major)
      weights[6]  =  1.145624;	// F#   (major)
      weights[7]  = 13.281;	// G    (major)
      weights[8]  =  1.072812;	// A-   (major)
      weights[9]  =  5.82747;	// A    (major)
      weights[10] =  0.866499;	// B-   (major)
      weights[11] =  4.47561;	// B    (major)
      weights[12] = 11.6324;	// c    (minor)
      weights[13] =  1.3688095;	// c#   (minor)
      weights[14] =  8.77495;	// d    (minor)
      weights[15] = 10.67995;	// e-   (minor)
      weights[16] =  1.351247;	// e    (minor)
      weights[17] =  9.2181;	// f    (minor)
      weights[18] =  1.351247;	// f#   (minor)
      weights[19] = 11.55805;	// g    (minor)
      weights[20] =  4.033105;	// a-   (minor)
      weights[21] =  1.96593;	// a    (minor)
      weights[22] =  4.438095;	// b-   (minor)
      weights[23] =  2.878115;	// b    (minor)
   }

   // allow user starting weights here later


   Array<HumdrumFile> infiles;
   if (options.getArgCount() == 0) {
      // if no command-line arguments read data file from standard input
      infiles.setSize(1);
      infiles.allowGrowth(0);
      infiles[0].clear();
      infiles[0].read(cin);
   } else {
      infiles.setSize(options.getArgCount());
      infiles.allowGrowth(0);
      for (i=0; i<options.getArgCount(); i++) {
         infiles[i].clear();
         infiles[i].read(options.getArg(i+1));
      }
   }

   Array<Array<Array<Array<double> > > > histograms;
   Array<Array<Array<int> > > correct;

   histograms.setSize(infiles.getSize());
   correct.setSize(infiles.getSize());
   histograms.allowGrowth(0);
   correct.allowGrowth(0);

   int totalbeats;
   int kk;
   for (kk=0; kk<infiles.getSize(); kk++) {
      infiles[kk].analyzeRhythm("4");
      totalbeats = (int)(infiles[kk].getTotalDuration()+0.9);
      histograms[kk].setSize(totalbeats);


      for (i=0; i<totalbeats; i++) {
         histograms[kk][i].setSize(i+1);
         histograms[kk][i].allowGrowth(0);
         for (j=0; j<i+1; j++) {
            histograms[kk][i][j].setSize(12);
            histograms[kk][i][j].allowGrowth(0);
            histograms[kk][i][j].zero();
         }
      }

      fillBaseHistogram(histograms[kk][histograms[kk].getSize()-1], 
         infiles[kk]);
      if (beatHistogramQ) {
         cout << "==========================================\n";
         printBaseHistograms(histograms[kk][histograms[kk].getSize()-1]);
      }

      fillAllHistograms(histograms[kk]);
      if (allHistogramsQ) {
         cout << "------------------------------------------\n";
         printHistograms(histograms[kk]);
      }

      fillCorrectData(correct[kk], infiles[kk]);
      //printCorrect(correct[kk]);
      //cout << "====================" <<endl;
      //exit(1);
   }

   if (beatHistogramQ || allHistogramsQ) {
      exit(0);
   }

   runSearch(weights, histograms, correct, trialCount, stopLimit, 
         amplitude, decay);

   printWeights(weights);

   return 0;
}

//////////////////////////////////////////////////////////////////////////



//////////////////////////////
//
// printCorrect -- for debugging
//

void printCorrect(Array >& correct) {
   int i,j;
   for (i=0; i<correct.getSize(); i++) {
      cout << i << ":";
      for (j=0; j<correct[i].getSize(); j++) {
         cout << "\t" << correct[i][j];
      }
      cout << "\n";
   }
}



//////////////////////////////
//
// runSearch --
//

void runSearch(Array<double>& initweights, 
      Array<Array<Array<Array<double> > > >& histograms,
      Array<Array<Array<int> > >& correct, int trialCount, 
      int stopLimit, double initrange, double decay) {
 
   double range               = initrange;
   Array<double> bestweights  = initweights;
   Array<Array<Array<int> > > answers;

   double distance;
   int i, j, kk;

   Array<double> bestscore;
   bestscore.setSize(histograms.getSize());
   bestscore.allowGrowth(0);
   bestscore.setAll(-1000000);

   Array<double>  stepweights;
   Array<double>  nextbestweights;

   Array<double> nextbestscore;
   nextbestscore.setSize(histograms.getSize());
   nextbestscore.allowGrowth(0);
   nextbestscore = bestscore;

   Array<double> stepscore;
   stepscore.setSize(histograms.getSize());
   stepscore.allowGrowth(0);

   answers.setSize(histograms.getSize());
   answers.allowGrowth(0);
   for (kk=0; kk<histograms.getSize(); kk++) {
      answers[kk].setSize(correct[kk].getSize());
      for (i=0; i<correct[kk].getSize(); i++) {
         answers[kk][i].setSize(correct[kk][i].getSize());
         answers[kk][i].allowGrowth(0);
         for (j=0; j<answers[kk][i].getSize(); j++) {
            answers[kk][i][j] = -1;
         }
      }

      makeAnswers(answers[kk], histograms[kk], initweights);

      bestscore[kk] = getScore(correct[kk], answers[kk]);
   
      printAnswerHistogram(answers[kk], histograms[kk], initweights);
   }

   int    currentLimit  = 0;
   int    step          = 0;
   int    betterQ       = 0;
   while (currentLimit < stopLimit) { 
      step++;
      stepweights     = bestweights;
      nextbestweights = bestweights;
      nextbestscore   = bestscore;
      for (i=0; i<trialCount; i++) {
         makeRandomWeights(stepweights, bestweights, range);
         //normalizeWeights(stepweights);
         for (kk=0; kk<histograms.getSize(); kk++) {
            makeAnswers(answers[kk], histograms[kk], stepweights);
            stepscore[kk] = getScore(correct[kk], answers[kk]);
         }
         betterQ = 1;
         for (kk=0; kk<stepscore.getSize(); kk++) {
            if (stepscore[kk] < nextbestscore[kk]) {
               betterQ = 0;
               break;
            }
         }

         if (betterQ) {
            nextbestscore = stepscore;
            nextbestweights = stepweights;
         }
      }


      distance = getDistance(bestweights, nextbestweights);

      if (distance > 0.0) {
         currentLimit = 0;
      } else {
         currentLimit++;
      }

      /*      if (nextbestscore <= bestscore) {
         currentLimit++;
      } else {
         currentLimit = 0;
      }
      */

      betterQ = 1;
      for (kk=0; kk<bestscore.getSize(); kk++) {
         if (nextbestscore[kk] < bestscore[kk]) {
            betterQ = 0;
            break;
         }
      }

      if (betterQ) {
         bestscore    = nextbestscore;
         bestweights  = nextbestweights;
      } 

      cout << "\n" 
           << "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" << "\n";
      cout << "!! Step:\t"          << step           << "\n";
      cout << "!! Trials:\t"        << trialCount     << "\n";
      cout << "!! Stop Limit:\t"    << stopLimit      << "\n";
      cout << "!! Iteration:\t"     << currentLimit   << "\n";
      cout << "!! Start Range:\t"   << initrange      << "\n";
      cout << "!! Range:\t"         << range          << "\n";
      cout << "!! Decay:\t"         << decay          << "\n";
      cout << "!! Best Score";
      if (bestscore.getSize() > 1) {
         cout << "s";
      }
      cout << ":\t";   
      for (kk=0; kk<bestscore.getSize(); kk++) {
         cout << bestscore[kk];
         if (kk < bestscore.getSize() - 1) {
            cout << " ";
         }
      }
      cout << "\n";
      cout << "!! Distance:\t"      << distance       << "\n";
      cout << "!! Random Seed:\t"   << seed << "\n";

      // normalizeWeights(bestweights);
      printWeights(bestweights);
      cout << endl;

      for (kk=0; kk<histograms.getSize(); kk++) {
         cout << "!! File == " << kk << "\n";
         printAnswerHistogram(answers[kk], histograms[kk], bestweights);
         //printCorrect(answers);
      }

      range *= decay;

   }

   initweights = bestweights;
}


//////////////////////////////
//
// printAnswerHistogram --
//

void printAnswerHistogram(Array<Array<int> >& answers,
      Array<Array<Array<double> > >& histograms, 
      Array<double>& weights) {

   Array<double> anshist;
   anshist.setSize(25);
   anshist.allowGrowth(0);
   anshist.zero();

   int i, j;
   int best;
   for (i=0; i<histograms.getSize(); i++) {
      for (j=0; j<histograms[i].getSize(); j++) {
	 best = findBestKey(histograms[i][j], weights);
	 answers[i][j] = best;
         // cout << i << "," << j << ":" << best << "\n";
         anshist[best]++;
      }
   }

   //cout << "!!MajWeights:\t";
   //for (i=0; i<12; i++) {
   //   cout << weights[i] << " ";
   //}
   //cout << "\n!!MinWeights:\t";
   //for (i=0; i<12; i++) {
   //   cout << weights[i+12] << " ";
   //}
   //cout << "\n";

   cout << "!!MajKeyHist:\t";
   for (i=0; i<12; i++) {
      cout << anshist[i] << " ";
   }
   cout << "\n!!MinKeyHist:\t";
   for (i=12; i<24; i++) {
      cout << anshist[i] << " ";
   }
   cout << "\n!!MinMajTies: " << anshist[24];
   cout << endl;

   if (pictureQ) {
      cout << "+++++++++++++++++++++++++++\n";
      generatePicture(answers);
      cout << "+++++++++++++++++++++++++++\n";
   }
   //exit(1);
}



//////////////////////////////
//
//makeAnswers --
//

void makeAnswers(Array<Array<int> >& answers, 
      Array<Array<Array<double> > >& histograms, 
      Array<double>& weights) {
   int i, j;

   for (i=0; i<histograms.getSize(); i++) {
      for (j=0; j<histograms[i].getSize(); j++) {
         answers[i][j] = findBestKey(histograms[i][j], weights);
      }
   }
}



//////////////////////////////
//
// findBestKey --
//

#define COMPARATOR bareCorrelation
//#define COMPARATOR pearsonCorrelation

int findBestKey(Array& histogram, Array& weights) {
   int i;
   double* maj = weights.getBase();
   double* min = weights.getBase() + 12;

   double hist[24];
   for (i=0; i<12; i++) {
      hist[i]    = histogram[i];
      hist[i+12] = histogram[i];
   }

   int bestmaji = 0;
   double bestmajscore = COMPARATOR(12, maj, hist);
   double testmajscore;

   int bestmini = 0;
   double bestminscore = COMPARATOR(12, min, hist);
   double testminscore;

   for (i=1; i<12; i++) {

      testmajscore = COMPARATOR(12, maj, hist+i);
      if (testmajscore > bestmajscore) {
         bestmajscore = testmajscore;
         bestmaji = i;
      }

      testminscore = COMPARATOR(12, min, hist+i);
      if (testminscore > bestminscore) {
         bestminscore = testminscore;
         bestmini = i;
      }
   }

   double bestmajpearson = pearsonCorrelation(12, maj, hist+bestmaji);
   double bestminpearson = pearsonCorrelation(12, min, hist+bestmini);

   int result = -1;
   if (bestmajpearson > bestminpearson) {
      result = bestmaji;
   } else if (bestmajpearson < bestminpearson) {
      result = bestmini + 12;
   } else {
      result = 24;
   }

   //cout << "\n\n";
   //cout << "RESULT = \t" << result << "\n";
   //cout << "weight=\t";
   //for (i=0; i<weights.getSize(); i++) {
   //   cout << weights[i] << "\t";
   //}
   //cout << "\nhistogram=\t";
   //for (i=0; i<histogram.getSize(); i++) {
   //   cout << histogram[i] << "\t";
   //}
   //cout << "\n\n";

   return result;
}



//////////////////////////////
//
// getDistance -- find the Euclidian distance between two vectors.
//

double getDistance(Array& a, Array& b)  {
   int i;
   double sum = 0.0;
   double value;
   for (i=0; i<a.getSize(); i++) {
      value = a[i] - b[i];
      sum += value * value;
   }
   
   return sqrt(sum);
}



/////////////////////////////
//
// makeRandomWeights --
//

void makeRandomWeights(Array<double>& weights, Array<double>& initweights, 
      double range) {
   int i;
   double delta = 0.0;
   for (i=0; i<weights.getSize(); i++) {
      delta = drand48() * 2 * range - range;
      weights[i] = initweights[i] + delta;
   }
}



//////////////////////////////
//
// getScore --
//

double getScore(Array >& correct, Array >& answers) {
   int size = correct.getSize();
   Array<double> leftboundary;
   Array<double> rightboundary;
   leftboundary.setSize(size);
   rightboundary.setSize(size);
   leftboundary.allowGrowth(0);
   rightboundary.allowGrowth(0);
  
   int leftside  = correct[size-1][0];
   int rightside = correct[size-1][correct[size-1].getSize()-1];

   int i;
   for (i=0; i<size; i++) {
      leftboundary[i] = getBoundaryLeft(leftside, answers[i]);
      rightboundary[i] = getBoundaryRight(rightside, answers[i]);
   }

   double meanl = getMean(leftboundary.getSize(), leftboundary.getBase());
   double meanr = getMean(rightboundary.getSize(), rightboundary.getBase());

   double sdleft  = getStandardDeviation(meanl, leftboundary);
   double sdright = getStandardDeviation(meanr, rightboundary);

   return -(sdleft + sdright);
}



//////////////////////////////
//
// getBoundaryLeft --
//

double getBoundaryLeft(int correct, Array& answers) {
   int i;
   int size = answers.getSize();
   int firsterror = -1;
   for (i=0; i<size; i++) {
      if (answers[i] != correct) {
         firsterror = i;
         break;
      }
   }
  
   if (firsterror < 0) {
      return 0.0 - size / 2.0;
   }

   return firsterror - size / 2.0;
}



//////////////////////////////
//
// getBoundaryRight --
//

double getBoundaryRight(int correct, Array& answers) {
   int i;
   int size = answers.getSize();
   int firsterror = -1;
   for (i=size-1; i>=0; i--) {
      if (answers[i] != correct) {
         firsterror = i;
         break;
      }
   }
  
   if (firsterror < 0) {
      return 0.0 - size / 2.0;
   }

   return firsterror - size / 2.0;
}



//////////////////////////////
//
// getScoreOld -- maximizes the area of the two correct answers
//

double getScoreOld(Array >& correct, Array >& answers) {
   int output = 0;
   int size = correct.getSize();
   int leftside  = correct[size-1][0];
   int rightside = correct[size-1][correct[size-1].getSize()-1];
   int i, j;
   int ans, cor;
   int leftcount = 0;
   int rightcount = 0;
   for (i=0; i<size; i++) {
      for (j=0; j<correct[i].getSize(); j++) {
         cor = correct[i][j];
         ans = answers[i][j];
	 //if (cor < 0) {
         //   if (drand48() < 0.5) {
         //      cor = leftside;
         //   } else {
         //      cor = rightside;
         //   }
         //}
         if (cor == ans) {
            output += 2;
	    if (cor == leftside) {
               leftcount++;
            } else {
               rightcount++;
            }
         } else if (cor == leftside && ans == rightside) {
            output += -1;
         } else if (cor == rightside && ans == leftside) {
            output += -1;
         } else if (ans > 23 || cor < 0) { // ignore boundaries
            output += 0;
         } else {
            output += -2;
         }
		    
      }
   }

   //int difference = abs(leftcount - rightcount);
   int difference = 0;

   return (double)(output - difference);
}



//////////////////////////////
//
// fillCorrectData -- infile is currently not really used, add flexibility
//     later.
//
// statically set to f minor in the first half and A major in the second
// half.
//

void fillCorrectData(Array >& correct, HumdrumFile& infile) {
   int totalbeats = (int)(infile.getTotalDuration()+0.9);
   int i, j;
   int lefthalf = -1;
   int righthalf = -1;
   getTwoKeys(lefthalf, righthalf, infile);
   double halfj;
   int rowsize;
   correct.setSize(totalbeats);
   for (i=0; i<totalbeats; i++) {
      correct[i].setSize(i+1);
      correct[i].allowGrowth(0);
      halfj = (i+1)/2 + 0.000001;
      rowsize = i+1;
      for (j=0; j<rowsize; j++) {
         if (j < (int)halfj) {
            correct[i][j] = lefthalf;
         } else { 
            correct[i][j] = righthalf;
         } 
      }
      if (rowsize % 2 == 1) {  // borderline case exactly 50/50
         //if (rowsize % 4 == 1) {
         //   correct[i][rowsize/2] = lefthalf;
         //} else {
         //   correct[i][rowsize/2] = righthalf;
         //}
         correct[i][rowsize/2] = -1;   

	 // also blank out the region just off-center
	 /*         if (rowsize/2 - 1 >= 0) {
            correct[i][rowsize/2-1] = -1;   
         }
         if (rowsize/2 + 1 < rowsize) {
            correct[i][rowsize/2+1] = -1;   
         }
	 */
      } else {
	 // blank out the region just off-center
	 /*         if (rowsize/2 - 1 >= 0) {
            correct[i][rowsize/2-1] = -1;   
         }
         if (rowsize/2 < rowsize) {
            correct[i][rowsize/2] = -1;   
         }
	 */

      }
   }
}



//////////////////////////////
//
// getTwoKeys --
//

void getTwoKeys(int& lefthalf, int& righthalf, HumdrumFile& infile) {
   int i;
   int counter = 0;
   int keyz[100] = {0};
   keyz[0] = -1;
   keyz[1] = -1;
   int key;
   int length;
   for (i=0; i<infile.getNumLines(); i++) {
      if (infile[i].getType() != E_humrec_interpretation) {
         continue;
      }
      length = strlen(infile[i][0]);
      if (length < 3) {
         continue;
      }
      if (infile[i][0][length-1] != ':') {
         continue;
      }
      key =  Convert::kernToMidiNoteNumber(infile[i][0]);
      key = key % 12;
      if (islower(infile[i][0][1])) {
         key += 12;
      }
      
      keyz[counter++] = key;
      // cout << "!! Found Key: " << key << endl;
      if (counter >= 2) {
         break;
      } 
   }

    lefthalf  = keyz[0];
    righthalf = keyz[1];
}



//////////////////////////////
//
// printWeights --
//

void printWeights(Array weights) {
   char buffer[1024] = {0};
   cout << "**kern\t**weight\n";
   cout << "!!major weights (using C as the tonic)\n";
   int i;
   for (i=0; i<12; i++) {
      cout << Convert::base12ToKern(buffer, i+48) 
           << "\t" << fixed << weights[i] << "\n";
      cout.unsetf(ios_base::floatfield);
   }
   cout << "!!minor weights (using c as the tonic)\n";
   for (i=0; i<12; i++) {
      cout << Convert::base12ToKern(buffer, i+60) 
           << "\t" << fixed << weights[i+12] << "\n";
      cout.unsetf(ios_base::floatfield);
   }
   cout << "*-\t*-\n";
}



//////////////////////////////
//
// fillAllHistograms -- 
//

void fillAllHistograms(Array > >& histograms) {
   int size = histograms.getSize();
   int i, j, k;
   for (i=size-2; i>=0; i--) {
      for (j=0; j<histograms[i].getSize(); j++) {
         for (k=0; k<12; k++) {
            histograms[i][j][k] = histograms[i+1][j][k] + 
                                  histograms[size-1][size-1-i+j][k];
         }
      }
   }
}



//////////////////////////////
//
// fillBaseHistogram -- get pitch durations for all beat segmentations
//      of the music.
//

void fillBaseHistogram(Array >& beathist, HumdrumFile& infile) {
   int    i;
   int    j;
   int    k;
   int    tokencount   = 0;
   int    pitch        = 0;
   char   buffer[1024] = {0};
   double start        = 0.0;
   double duration     = 0.0;

   for (i=0; i<infile.getNumLines(); i++) {
      if (!infile[i].isData()) {
         continue;
      }
      start = infile[i].getAbsBeat();
      for (j=0; j<infile[i].getFieldCount(); j++) {
         if (strcmp(infile[i].getExInterp(j), "**kern") != 0) {
            continue;
         }
         tokencount = infile[i].getTokenCount(j);
         for (k=0; k<tokencount; k++) {
            infile[i].getToken(buffer, j, k);
            if (strcmp(buffer, ".") == 0) {
               continue;  // ignore null tokens
            }
            pitch = Convert::kernToMidiNoteNumber(buffer);
            if (pitch < 0) {
              continue;  // ignore rests or funny objects
            } 
            pitch = pitch % 12;
            duration = Convert::kernToDuration(buffer);
            if (duration <= 0.0) {
               continue;  // ignore grace notes and funny objects
            }
            addToHistogram(beathist, pitch, start, duration);
         }
      }
   }
}



//////////////////////////////
//
// addToHistogram --
//

void addToHistogram(Array<Array<double> >& histogram, int pc, 
      double start, double dur) {
   int    starti = (int)start;
   double startf = start - starti;
   
   if (dur <= 1.0 - startf) {
      histogram[starti][pc] += dur;
      return;
   } else if (1.0 - startf > 0.0) {
      histogram[starti][pc] += 1.0 - startf;
      dur -= (1.0 - startf);
   }

   int i = starti + 1;
   while(dur > 0.0) {
      if (dur < 1.0) {
         histogram[i][pc] += dur;
         dur = 0.0;
      } else {
         histogram[i][pc] += 1.0;
         dur -= 1.0;
      }
      i++;
   }
}



//////////////////////////////
//
// printHistograms -- print the measured pitch histogram array
//

void printHistograms(Array > >& histograms) {

   cout << "range\tC\tC#\tD\tEb\tE\tF\tF#\tG\tAb\tA\tBb\tB\n";
   int i, j, k;
   for (i=0; i<histograms.getSize(); i++) {
      for (j=0; j<histograms[i].getSize(); j++) {
         cout << i << ',' << j << ':';
         for (k=0; k<histograms[i][j].getSize(); k++) {
            cout << '\t' << histograms[i][j][k];
         }
         cout << '\n';
      }
   }
}



//////////////////////////////
//
// printBaseHistograms -- print the measured pitch histogram array
//

void printBaseHistograms(Array >& histograms) {

   cout << "beat\tC\tC#\tD\tEb\tE\tF\tF#\tG\tAb\tA\tBb\tB\n";
   int i, j;
   for (i=0; i<histograms.getSize(); i++) {
      cout << i << ':';
      for (j=0; j<histograms[i].getSize(); j++) {
         cout << '\t' << histograms[i][j];
      }
      cout << '\n';
   }
}



//////////////////////////////
//
// bareCorrelation --
//

double bareCorrelation(int size, double* x, double* y) {
   double sum = 0.0;
   for (int i=0; i<size; i++) {
      sum += (*x) * (*y);
      x++;
      y++;
   }

   return sum;
}



//////////////////////////////
//
// pearsonCorrelation --
//

double pearsonCorrelation(int size, double* x, double* y) {

   double sumx  = 0.0;
   double sumy  = 0.0;
   double sumco = 0.0;
   double meanx = x[0];   
   double meany = y[0];   
   double sweep;
   double deltax;
   double deltay;

   int i;
   for (i=2; i<=size; i++) {
      sweep = (i-1.0) / i;
      deltax = x[i-1] - meanx;
      deltay = y[i-1] - meany;
      sumx  += deltax * deltax * sweep;
      sumy  += deltay * deltay * sweep;
      sumco += deltax * deltay * sweep;
      meanx += deltax / i;
      meany += deltay / i;
   }

   double popsdx = sqrt(sumx / size);
   double popsdy = sqrt(sumy / size);
   double covxy  = sumco / size;

   return covxy / (popsdx * popsdy);
}



//////////////////////////////
//
// normalizeWeights --
//

void normalizeWeights(Array& weights) {
   double minmaj = weights[0];
   double minmin = weights[12];
   double maxmaj = weights[0];
   double maxmin = weights[12];
   int i;
   for (i=1; i<12; i++) {
      if (weights[i] < minmaj)    { minmaj = weights[i]; }
      if (weights[i+12] < minmin) { minmin = weights[i+12]; }
      if (weights[i] > maxmaj)    { maxmaj = weights[i]; }
      if (weights[i+12] > maxmin) { maxmin = weights[i+12]; }
   }

   maxmaj += minmaj;
   maxmin += minmin;

   for (i=0; i<12; i++) {
      weights[i] = (weights[i] + minmaj) / maxmaj;
      weights[i+12] = (weights[i+12] + minmin) / maxmin;
   }

}



//////////////////////////////
//
// checkOptions -- 
//

void checkOptions(Options& opts, int argc, char* argv[]) {
   opts.define("a|all=b", "all segmentation histograms display");
   opts.define("b|beat=b", "beat histogram display");
   opts.define("c|count=i:100", "number of trials per step");
   opts.define("d|decay=d:0.95", "number of trials per step");
   opts.define("s|amplitude=d:1.0", "starting random amplitude");
   opts.define("l|stoplimit=i:10", "number of trials with same error rate");
   opts.define("seed=i:0", "random number generator seed");
   opts.define("picture=b",    "generate a picture of the current state"); 
   opts.define("w|weights=s",  "input starting weights");

   opts.define("simple=b",    "use simple weights"); 
   opts.define("aarden=b",    "use aarden weights"); 
   opts.define("krumhansl=b", "use krumhansl weights"); 
   opts.define("temperley=b", "use temperley weights"); 
   opts.define("hybrid=b",    "use hybrid (1/2 aarden, 1/2 temperley) weights"); 

   opts.define("author=b",  "author of program"); 
   opts.define("version=b", "compilation info");
   opts.define("example=b", "example usages");   
   opts.define("help=b",  "short description");
   opts.process(argc, argv);
   
   // handle basic options:
   if (opts.getBoolean("author")) {
      cout << "Written by Craig Stuart Sapp, "
           << "craig@ccrma.stanford.edu, Jan 2008" << endl;
      exit(0);
   } else if (opts.getBoolean("version")) {
      cout << argv[0] << ", version: Jan 2008" << endl;
      cout << "compiled: " << __DATE__ << endl;
      cout << MUSEINFO_VERSION << endl;
      exit(0);
   } else if (opts.getBoolean("help")) {
      usage(opts.getCommand());
      exit(0);
   } else if (opts.getBoolean("example")) {
      example();
      exit(0);
   }
   
   beatHistogramQ = opts.getBoolean("beat");
   allHistogramsQ = opts.getBoolean("all");
   trialCount     = opts.getInteger("count");
   stopLimit      = opts.getInteger("stoplimit");
   decay          = opts.getDouble("decay");
   amplitude      = opts.getDouble("amplitude");
   seed           = opts.getInteger("seed");

   simpleQ        = opts.getBoolean("simple");
   aardenQ        = opts.getBoolean("aarden");
   krumhanslQ     = opts.getBoolean("krumhansl");
   temperleyQ     = opts.getBoolean("temperley");
   hybridQ        = opts.getBoolean("hybrid");
   pictureQ       = opts.getBoolean("picture");

   if (seed == 0) {
     seed = time(NULL);
   }

   srand48(seed);


   int i;
   int j;
   int key;
   double value;
   int weightsQ = opts.getBoolean("weights");
   if (weightsQ) {
      inputweights.setSize(24);
      inputweights.allowGrowth(0);
      inputweights.zero();
      HumdrumFile wfile;
      wfile.read(opts.getString("weights"));
      for (i=0; i<wfile.getNumLines(); i++) {
         if (wfile[i].getType() != E_humrec_data) {
            continue;
         }
         key = -1;
         value = -1000000.0;
         for (j=0; j<wfile[i].getFieldCount(); j++) {
            if (strcmp(wfile[i].getExInterp(j), "**kern") == 0) {
               key = Convert::kernToMidiNoteNumber(wfile[i][j]) % 12;
               if (islower(wfile[i][j][0])) {
                  key += 12;
               }
            } else if (strcmp(wfile[i].getExInterp(j), "**weight") == 0) {
               sscanf(wfile[i][j], "%lf", &value);
            }
         }
         if ((key >= 0) && (key < 24) && (value != -1000000.0)) {
            inputweights[key] = value;
         }
      }
   }


}



//////////////////////////////
//
// example --
//

void example(void) {


}



//////////////////////////////
//
// usage --
//

void usage(const char* command) {

}


//////////////////////////////
//
// printPPM -- 
//

void printPPM(Array >& matrix) {

   // Pitch to Color translations
   int red[40] = {
	0, 9, 18, 0, 63, 63, 63, 73, 82, 0, 218, 237, 255, 255, 255,
	255, 255, 255, 218, 182, 0, 45, 54, 63, 63, 63, 0, 109, 118,
	127, 145, 164, 0, 255, 255, 255, 255, 255, 73, 36
   };
   int green[40] = {
	255, 246, 237, 0, 123, 109, 95, 86, 77, 0, 9, 4, 0, 18, 36, 218,
	237, 255, 255, 255, 0, 209, 200, 191, 177, 164, 0, 50, 41, 31,
	27, 22, 0, 91, 109, 127, 145, 164, 255, 255
   };
   int blue[40] = {
	0, 36, 73, 0, 255, 255, 255, 255, 255, 0, 73, 36, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 182, 218, 255, 255, 255, 0, 255, 255, 255, 219,
	182, 0, 0, 0, 0, 0, 0, 0, 0
   };

   double minorscale = 0.75;   // darkening of color for minor keys
   const char* background   = "255 255 255";

   int row, column;
   int maxrow = matrix.getSize();
   int maxcolumn = matrix[0].getSize();
   int i, j;
   int root;
   double modescale = 1.0;

   cout << "P3\n";
   cout << maxcolumn << " " << maxrow << "\n";
   cout << "255\n";

   for (row=0; row<maxrow; row++) {
   // for (row=maxrow-1; row>=0; row--) {
      i = row;
      for (column=0; column<maxcolumn; column++) {
         j = column;
         root = matrix[i][j];
         if (root >= 40) {
            root -= 40;
            modescale = minorscale;
         } else {
            modescale = 1.0;
         }
         if (root >= 0) {
            cout << (int)(red[root]   * modescale + 0.5) << " "
                 << (int)(green[root] * modescale + 0.5) << " "
                 << (int)(blue[root]  * modescale + 0.5) << "  ";
         } else {
            cout << " " << background << "  ";
         }
      }
      cout << "\n";
   }
}



//////////////////////////////
//
// generatePicture --
//

void generatePicture(Array >& answers) {
   int rows = answers.getSize() * 2;
   int columns = answers[answers.getSize()-1].getSize() * 2;

   
   Array<Array<int> > matrix;
   int i, j;
   matrix.setSize(rows);
   for (i=0; i<rows; i++) {
      matrix[i].setSize(columns);
      matrix[i].setAll(-1);
   }

   int ii, jj;
   int value;
   int base12;
   int offset;
   for (i=0; i<answers.getSize(); i++) {
      for (j=0; j<answers[i].getSize(); j++) {
         ii = i*2;
         offset = answers[answers.getSize()-1].getSize() - i - 1;
         jj = j*2 + offset;
         base12 = answers[i][j];
         if (base12 >= 24 || base12 < 0) {
            value = -1;
         } else if (base12 < 12) { // major key
            value = Convert::base12ToBase40(base12) % 40 - 2;
         } else { // minor key
            value = Convert::base12ToBase40(base12) % 40 + 40 - 2;
         }
	 // cout << "BASE12 = " << base12 << "\tBASE40 = " << value << "\n";
         matrix[ii][jj]     = value;
         matrix[ii+1][jj]   = value;
         matrix[ii][jj+1]   = value;
         matrix[ii+1][jj+1] = value;
      }
   }

   printPPM(matrix);
}




//////////////////////////////
//
// getMean --
//

double getMean(int size, double* a) {
   if (size <= 0) {
      return 0.0;
   }

   int i;
   double sum = 0.0;
   for (i=0; i<size; i++) {
      sum += a[i];
   }

   return sum / size;
}



//////////////////////////////
//
// getStandardDeviation --
//

double getStandardDeviation(double mean, Array& data) {
   double sum = 0.0;
   double value;
   int i;
   for (i=0; i<data.getSize(); i++) {
      value = data[i] - mean;
      sum += value * value;
   }
   
   return sqrt(sum / data.getSize());
}


// md5sum: adfa05847e89b3fb64ae83fb9b5c7d01 keyboundary.cpp [20080518]