package iitb.cfilt.cpost.crfpp;

import iitb.cfilt.cpost.*;
import iitb.cfilt.cpost.test.TestVGIdm;
import iitb.cfilt.cpost.dmstemmer.MAResult;
import iitb.cfilt.cpost.dmstemmer.NewStemmer;
import iitb.cfilt.cpost.ngi.NounGroupIdentifier3;
import iitb.cfilt.cpost.ngi.NgiResult;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;

public class Tester_VGI_pref
{
	public static Vector<String> distinctTags;
	private static boolean VGI = true; // It is true when VGI is to be applied. By Nikhilesh 
	private static boolean NGI = true;
	private static TestVGIdm tvgi;
	private static NounGroupIdentifier3 ngi;
	
	public static int getTagNum(String tag)
	{
		return(distinctTags.indexOf(tag));
	}
	
	public static String getNumTag(int tagnum)
	{
		return(distinctTags.get(tagnum));
	}
	
	/**
	 * This function is used to calculate accuracy and generate the confusion matrix from the output file generated by the CRF++.
	 * @param testfile This is the file which is generated by CRF++ after testing/decoding.
	 * @param resultfile This is the file which contains a word on each line with the tag assigned by the CRF and also the correct tag is mentioned in case of error made by the CRF.
	 * 
	 * 
	 */
	
/*
	public static void testAccuracy(String testfile, String resultfile, String flagVGI)
	{
		int total = 0;
		int correct = 0;
		int confMatrix[][];
		distinctTags = new Vector<String>();
		if(flagVGI.trim().equalsIgnoreCase("true"))
		{
			VGI = true;
		}
		else if(flagVGI.trim().equalsIgnoreCase("false"))
		{
			VGI = false;
		}
		try
		{
			File outfile = new File(resultfile);
			if(outfile.exists())
			{
				outfile.delete();
				outfile = new File(resultfile);
			}
			UTFWriter ob = new UTFWriter(outfile);
			
			BufferedReader bf = new BufferedReader(new InputStreamReader(new FileInputStream(ConfigReader.get("CRF.tagfile")), "UTF8"));
			String line = "";
			while((line = bf.readLine()) != null)
			{
				distinctTags.add(line.trim());
			}
			bf.close();
			
			confMatrix = new int[distinctTags.size()][distinctTags.size()];
			for(int i=0;i<distinctTags.size();i++)
				for(int j=0;j<distinctTags.size();j++)
				{
					confMatrix[i][j] = 0;
				}
					
			bf = new BufferedReader(new InputStreamReader(new FileInputStream(testfile), "UTF8"));
			line = "";
			Vector<String> words = new Vector<String>();
			Vector<String> oldtags = new Vector<String>();
			Vector<String> latesttags = new Vector<String>();
			String sentence = "";
			while((line = bf.readLine()) != null)
			{
				if(line.length() != 0)
				{
					String[] cols = line.split("\t");					
					int l = cols.length;
					
					String word = cols[0].trim();
					String oldtag = cols[l-2].trim();
					String newtag = cols[l-1].trim();
					words.add(word);
					oldtags.add(oldtag);
					latesttags.add(newtag);
					sentence = sentence + word + " ";
				}
				else
				{
					int l = sentence.length();
					int [] vga = new int[words.size()]; 
					if(VGI){						
						sentence = sentence.substring(0, l-1);
						tvgi = new TestVGIdm();
						vga = tvgi.doVGI_forSentence(sentence, false);
					}
					for(int i=0; i<words.size(); i++){
						if(VGI){
							if(vga[i]>0){
								if(vga[i]>1){
									if(vga[i]==100){
										if(!latesttags.get(i).equals("NN")){
											System.out.println("Changing tag for " + words.get(i) + " from " + latesttags.get(i) +" to " + "NN \n"); 
											latesttags.set(i, "NN");
										}
									}
									else{
										if(!latesttags.get(i).equals("VAUX")){
											System.out.println("Changing tag for " + words.get(i) + " from " + latesttags.get(i) +" to " + "VAUX \n"); 
											latesttags.set(i, "VAUX");
										}
									}
								}
								else if(!latesttags.get(i).equals("VM")){
									System.out.println("Changing tag for " + words.get(i) + " from " + latesttags.get(i) +" to " + "VM \n");
									latesttags.set(i, "VM");
								}
							}
						}
						total++;
						if(latesttags.get(i).equals(oldtags.get(i)))
						{
							correct++;
							ob.writeUTF(words.get(i) + " : " + latesttags.get(i) + "\n");
							confMatrix[getTagNum(oldtags.get(i))][getTagNum(oldtags.get(i))]++;
						}
						else
						{
							ob.writeUTF(words.get(i) + " : " + latesttags.get(i) + " Correct tag : " + oldtags.get(i) + "\n");
							System.out.println("Tag : " + latesttags.get(i) + ",Correct tag : " + oldtags.get(i) + "\tWord : " + words.get(i) + "\t" + total +"\n");
							confMatrix[getTagNum(latesttags.get(i))][getTagNum(oldtags.get(i))]++;
						}
					}
					ob.writeUTF("\n");
					sentence = "";
					words.clear();
					oldtags.clear();
					latesttags.clear();
				}
			}
			System.out.println("Total words : " + total);
			System.out.println("Correct words : " + correct);
			System.out.println("Accuracy : " + (double)correct*100/(double)total);
			
			System.out.println("\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				int incorrect = 0;
				for(int j=0;j<distinctTags.size();j++)
				{
					if(i != j)
						incorrect = incorrect + confMatrix[i][j];
				}
				double accuracy = 100*(double)(confMatrix[i][i])/(double)(confMatrix[i][i]+incorrect);
				System.out.println("Correct:" + confMatrix[i][i] + "\tErrors:" + incorrect + "\tTotal:" + (confMatrix[i][i]+incorrect) + "\tAccuracy:" + accuracy);
			}
			
			System.out.println("\nConfusion Matrix :\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					if(confMatrix[i][j] != 0)
						System.out.print(getNumTag(j) + ":" + confMatrix[i][j] + " ");
				}
				System.out.println();
			}
			
			System.out.print("\nConfusion Matrix (Formated) :\n\n\t");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
			}
			System.out.println("");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					//if(confMatrix[i][j] != 0)
						System.out.print(confMatrix[i][j] + "\t");
				}
				System.out.println();
			}
			
			System.out.println("Done!");
			ob.close();
			bf.close();
		}
		catch(Exception e)
		{
			System.out.println(e.toString());
			e.printStackTrace();
		}
	}
*/ 	

	public static void copyTags(Vector<String> latesttags, Vector<String> oldtags){
		assert (latesttags.size() == oldtags.size()) : "Fatal Error - latesttags.size != OldTags.size";

		for (int i=0; i<latesttags.size(); i++){
			oldtags.set(i,latesttags.get(i));
		}
	}

	public static void updateConfusionMatrix(Vector<String> words, Vector<String> correcttags, Vector<String> answertags, int confMatrix[][])
	{
		for(int i=0; i<words.size(); i++){						
		  confMatrix[getTagNum(correcttags.get(i))][getTagNum(answertags.get(i))]++;
		}
	}	

	public static void initConfMatrix(int confMatrix[][], int size)
	{
		for(int i=0;i<size;i++){
		  for(int j=0;j<size;j++){
			confMatrix[i][j] = 0;
		  }
		}
	}


	public static int incorrectPerTag(int confMatrix[][], int numTags, int tag)
	{
		int incorrect = 0;
		for(int j=0;j<numTags;j++)
		{
			if(tag != j){
				incorrect += confMatrix[tag][j];
			} 
		}
		return incorrect;
	}

	public static double accuracyPerTag(int confMatrix[][], int numTags, int tag)
	{
		int incorrectForTag = incorrectPerTag(confMatrix, numTags, tag);
		int correctForTag = confMatrix[tag][tag];
		int totalForTag = correctForTag + incorrectForTag;

		double accuracy = 100*(double)(correctForTag)/(double)(totalForTag);
		return accuracy;
	}

	public static void printAccuracyReport(int confMatrix[][], int numTags, boolean printFormattedMatrix, String name)
	{	
			int total=0, correct=0;

			System.out.println("\n" + name + " Per Tag Accuracy :\n");
			for(int i=0;i<numTags;i++)
			{
				System.out.print(getNumTag(i) + "\t");

				int incorrectForTag = incorrectPerTag(confMatrix, numTags, i);
				int correctForTag = confMatrix[i][i];
				int totalForTag = correctForTag + incorrectForTag;

				double accuracy = 100*(double)(correctForTag)/(double)(totalForTag);
				System.out.println("Correct:" + correctForTag + "\tErrors:" + incorrectForTag + "\tTotal:" + totalForTag + "\tAccuracy:" + accuracy);

				correct += correctForTag;
				total += totalForTag;
			}
			
			System.out.println("\n" + name + " Confusion Matrix :\n");			
			for(int i=0;i<numTags;i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<numTags;j++)
				{
					if(confMatrix[i][j] != 0)
						System.out.print(getNumTag(j) + ":" + confMatrix[i][j] + " ");
				}
				System.out.println();
			}
		
		if(printFormattedMatrix){
			System.out.print("\nConfusion Matrix (Formated) :\n\n\t");			
			for(int i=0;i<numTags;i++)
			{
				System.out.print(getNumTag(i) + "\t");
			}
			System.out.println("");
			
			for(int i=0;i<numTags;i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<numTags;j++)
				{
					System.out.print(confMatrix[i][j] + "\t");
				}
				System.out.println();
			}
		}

			System.out.println("Total words : " + total);
			System.out.println("Correct words : " + correct);
			System.out.println("Accuracy : " + (double)correct*100/(double)total);
			
			System.out.println("\n");

	} //end printAccuracyReport


	public static void printComparativeAccuracies(int crfConfMatrix[][], int ngiConfMatrix[][], int vgiConfMatrix[][], int numTags)
	{
		System.out.print("\n\t" + "CRF" + "\tNGI" + "\tVGI\n");
		for(int i=0;i<numTags;i++)
		{
			System.out.print(getNumTag(i) + "\t" +
				accuracyPerTag(crfConfMatrix, numTags, i) + "\t" +
				accuracyPerTag(ngiConfMatrix, numTags, i) + "\t" +
				accuracyPerTag(vgiConfMatrix, numTags, i)
			);
		}
	}
	//************************************ Nikkhilesh *******************************************//
	@SuppressWarnings("unchecked")
	public static void testAccuracy(String testfile, String resultfile, String finalFile, String flagVGI, String ngiflag)
	{
		int total = 0;
		int correct = 0;
		int confMatrix[][];
		int crfConfMatrix[][];
		int vgiConfMatrix[][];
		int ngiConfMatrix[][];
		Vector <NgiResult> ngioutput = new Vector<NgiResult>();
		distinctTags = new Vector<String>();

		VGI = flagVGI.trim().equalsIgnoreCase("true");
		NGI = ngiflag.trim().equalsIgnoreCase("true");

		try
		{
			File outfile = new File(resultfile);
			if(outfile.exists())
			{
				outfile.delete();
				outfile = new File(resultfile);
			}
			UTFWriter ob = new UTFWriter(outfile);
			
			File outfile1 = new File(finalFile);
			if(outfile1.exists())
			{
				outfile1.delete();
				outfile1 = new File(finalFile);
			}
			UTFWriter obFinal = new UTFWriter(outfile1);
			
			BufferedReader bf = new BufferedReader(new InputStreamReader(new FileInputStream(ConfigReader.get("CRF.tagfile")), "UTF8"));
			String line = "";
			while((line = bf.readLine()) != null)
			{
				distinctTags.add(line.trim());
			}
			bf.close();
			
			int numTags = distinctTags.size();
			confMatrix = new int[numTags][numTags];
			initConfMatrix(confMatrix, numTags);
			crfConfMatrix = new int[numTags][numTags];
			initConfMatrix(crfConfMatrix, numTags);
			ngiConfMatrix = new int[numTags][numTags];
			initConfMatrix(ngiConfMatrix, numTags);
			vgiConfMatrix = new int[numTags][numTags];
			initConfMatrix(vgiConfMatrix, numTags);
					
			bf = new BufferedReader(new InputStreamReader(new FileInputStream(testfile), "UTF8"));
			line = "";
			Vector<String> words = new Vector<String>();
			Vector<String> oldtags = new Vector<String>();
			Vector<String> latesttags = new Vector<String>();
			Vector<String> crftags = new Vector<String>();
			Vector<String> vgitags = new Vector<String>();
			Vector<String> ngitags = new Vector<String>();
			String sentence = "";
					
			while((line = bf.readLine()) != null)
			{
				if(line.length() != 0)
				{
					String[] cols = line.split("\t");					
					int l = cols.length;
					String word = cols[0].trim();
					String oldtag = cols[l-2].trim();
					String newtag = cols[l-1].trim();
					words.add(word);
					oldtags.add(oldtag);
					crftags.add(newtag);
					ngitags.add(newtag);
					vgitags.add(newtag);
					latesttags.add(newtag);
					sentence = sentence + word + " ";
				}
				else
				{
					int l = sentence.length();
					int [] vga = new int[words.size()];
					NewStemmer dmstemmer = new NewStemmer();
					Vector<MAResult> mrv = new Vector<MAResult>();
					sentence = sentence.substring(0, l-1);
					String tokens[] = sentence.split(" ");
					List lis = Arrays.asList(tokens);
					Vector<String> tokenList = new Vector<String>(lis);
					for(int temp=0;temp<tokenList.size();temp++)
					{
						if(tokenList.get(temp)==null)
							continue;
						MAResult Mar = dmstemmer.stem(tokenList.get(temp));
						mrv.add(Mar);
					}

					boolean changed = false;

					updateConfusionMatrix(words, oldtags, crftags, crfConfMatrix);

					if(NGI)
					{
					  ngi = new NounGroupIdentifier3();
					  ngi.nounGroupIdentify(mrv,latesttags);
					  ngi.printfinaltags();
					  ngioutput = ngi.getngiresults();

					  for(int i=0; i<words.size(); i++){					
						if((ngioutput.get(i).getchanged()) && !(ngioutput.get(i).getfinaltag().startsWith("V"))){
							System.out.println("NGI: Ch " + i + " " + words.get(i) + " from " + latesttags.get(i) +" to " +ngioutput.get(i).getfinaltag() + "\n");	
							latesttags.set(i,ngioutput.get(i).getfinaltag());
							changed = true;
						}
					  } // end for

					  copyTags(latesttags, ngitags);

					  updateConfusionMatrix(words, oldtags, ngitags, ngiConfMatrix);
					}

					if(VGI)
					{						
					  tvgi = new TestVGIdm();
					  vga = tvgi.doVGI_forSentence(mrv,latesttags);

					  for(int i=0; i<words.size(); i++){					
						if(vga[i]>0){
						    if(vga[i]>1){
							if(vga[i]==100){
							  if(!latesttags.get(i).equals("NN")){
								  System.out.println("Ch " + i + " " + words.get(i) + " from " + latesttags.get(i) +" to " + "NN \n"); 
								  latesttags.set(i, "NN");
								  changed = true;
							  }
							}else{
							  if(!latesttags.get(i).equals("VAUX")){
								System.out.println("Ch " + i + " " + words.get(i) + " from " + latesttags.get(i) +" to " + "VAUX \n"); 
								latesttags.set(i, "VAUX"); 
								changed = true;
							  }
							}
						    }else if(!latesttags.get(i).equals("VM")){
							System.out.println("Ch " + i + " " + words.get(i) + " from " + latesttags.get(i) +" to " + "VM \n");
							latesttags.set(i, "VM"); 
							changed = true;
						    }
						}
					  }

					  copyTags(latesttags, vgitags);

					  updateConfusionMatrix(words, oldtags, vgitags, vgiConfMatrix);
					}

					for(int i=0; i<words.size(); i++){						
						obFinal.writeUTF(words.get(i) + "_[ " + latesttags.get(i) + " ] ");
						total++;
						if(latesttags.get(i).equals(oldtags.get(i))){
							correct++;
							ob.writeUTF(words.get(i) + " : " + latesttags.get(i) + "\n");
							confMatrix[getTagNum(oldtags.get(i))][getTagNum(oldtags.get(i))]++;
						  if(!crftags.get(i).equals(oldtags.get(i)) || !ngitags.get(i).equals(oldtags.get(i)) || !vgitags.get(i).equals(oldtags.get(i)) ){
							System.out.println("Finally Correct: " + i + " " + words.get(i) + "\tCor: " + oldtags.get(i) + "\tcrf: " + crftags.get(i) + "\tngi: " + ngitags.get(i)  + "\tvgi: " + vgitags.get(i) + "\n");
						  }
						}else{
							ob.writeUTF(words.get(i) + " : " + latesttags.get(i) + " Cor : " + oldtags.get(i) + "\n");
							System.out.println("Finally Error: " + i + " "  + words.get(i) + "\tCor: " + oldtags.get(i) + "\tcrf: " + crftags.get(i) + "\tngi: " + ngitags.get(i)  + "\tvgi: " + vgitags.get(i) + "\n");
							confMatrix[getTagNum(latesttags.get(i))][getTagNum(oldtags.get(i))]++;
						}
					} //end For int i=0; i<words.size()


					if (changed){
					  System.out.println(sentence);
					}

					ob.writeUTF("\n");
					obFinal.writeUTF("\n");
					sentence = "";
					words.clear();
					oldtags.clear();
					crftags.clear();
					ngitags.clear();
					vgitags.clear();
					latesttags.clear();
				}
			}
			
			printAccuracyReport(crfConfMatrix, numTags, false, "CRF");
			printAccuracyReport(ngiConfMatrix, numTags, false, "NGI");
			printAccuracyReport(vgiConfMatrix, numTags, false, "VGI");

			printComparativeAccuracies(crfConfMatrix, ngiConfMatrix, vgiConfMatrix, numTags);

			System.out.println("Done!");
			ob.close();
			obFinal.close();
			bf.close();
		}
		catch(Exception e)
		{
			System.out.println(e.toString());
			e.printStackTrace();
		}
	}
	
	//************************************ Nikkhilesh *******************************************//
	
	
	public static void main(String args[])
	{
		ConfigReader.read(args[0].trim());
		testAccuracy(args[1].trim(),args[2].trim(),args[3].trim(),args[4].trim(),args[5].trim());
	}
}
