package iitb.cfilt.cpost.crfpp;

import iitb.cfilt.cpost.*;
import iitb.cfilt.cpost.test.TestVGIdm;
import iitb.cfilt.cpost.dmstemmer.MAResult;
import iitb.cfilt.cpost.dmstemmer.NewStemmer;
import iitb.cfilt.cpost.ngi.NounGroupIdentifier3;
import iitb.cfilt.cpost.ngi.NgiResult;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.List;
import java.util.Vector;

public class Testerdm
{
	public static Vector<String> distinctTags;
	private static boolean VGI = true; // It is true when VGI is to be applied. By Nikhilesh 
	private static boolean NGI = true;
	private static TestVGIdm tvgi;
	private static NounGroupIdentifier3 ngi;
	
	public static int getTagNum(String tag)
	{
		return(distinctTags.indexOf(tag));
	}
	
	public static String getNumTag(int tagnum)
	{
		return(distinctTags.get(tagnum));
	}
	
	/**
	 * This function is used to calculate accuracy and generate the confusion matrix from the output file generated by the CRF++.
	 * @param testfile This is the file which is generated by CRF++ after testing/decoding.
	 * @param resultfile This is the file which contains a word on each line with the tag assigned by the CRF and also the correct tag is mentioned in case of error made by the CRF.
	 * 
	 * 
	 */
	
	public static void testAccuracy(String testfile, String resultfile, String flagVGI)
	{
		int total = 0;
		int correct = 0;
		int confMatrix[][];
		distinctTags = new Vector<String>();
		if(flagVGI.trim().equalsIgnoreCase("true"))
		{
			VGI = true;
		}
		else if(flagVGI.trim().equalsIgnoreCase("false"))
		{
			VGI = false;
		}
		try
		{
			File outfile = new File(resultfile);
			if(outfile.exists())
			{
				outfile.delete();
				outfile = new File(resultfile);
			}
			UTFWriter ob = new UTFWriter(outfile);
			
			BufferedReader bf = new BufferedReader(new InputStreamReader(new FileInputStream(ConfigReader.get("CRF.tagfile")), "UTF8"));
			String line = "";
			while((line = bf.readLine()) != null)
			{
				distinctTags.add(line.trim());
			}
			bf.close();
			
			confMatrix = new int[distinctTags.size()][distinctTags.size()];
			for(int i=0;i<distinctTags.size();i++)
				for(int j=0;j<distinctTags.size();j++)
				{
					confMatrix[i][j] = 0;
				}
					
			bf = new BufferedReader(new InputStreamReader(new FileInputStream(testfile), "UTF8"));
			line = "";
			Vector<String> words = new Vector<String>();
			Vector<String> oldtags = new Vector<String>();
			Vector<String> newtags = new Vector<String>();
			String sentence = "";
			while((line = bf.readLine()) != null)
			{
				if(line.length() != 0)
				{
					String[] cols = line.split("\t");					
					int l = cols.length;
					
					String word = cols[0].trim();
					String oldtag = cols[l-2].trim();
					String newtag = cols[l-1].trim();
					words.add(word);
					oldtags.add(oldtag);
					newtags.add(newtag);
					sentence = sentence + word + " ";
				}
				else
				{
					int l = sentence.length();
					int [] vga = new int[words.size()]; 
					if(VGI){						
						sentence = sentence.substring(0, l-1);
						tvgi = new TestVGIdm();
						vga = tvgi.doVGI_forSentence(sentence, false);
					}
					for(int i=0; i<words.size(); i++){
						if(VGI){
							if(vga[i]>0){
								if(vga[i]>1){
									if(vga[i]==100){
										if(!newtags.get(i).equals("NN")){
											System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "NN \n"); 
											newtags.set(i, "NN");
										}
									}
									else{
										if(!newtags.get(i).equals("VAUX")){
											System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "VAUX \n"); 
											newtags.set(i, "VAUX");
										}
									}
								}
								else if(!newtags.get(i).equals("VM")){
									System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "VM \n");
									newtags.set(i, "VM");
								}
							}
						}
						total++;
						if(newtags.get(i).equals(oldtags.get(i)))
						{
							correct++;
							ob.writeUTF(words.get(i) + " : " + newtags.get(i) + "\n");
							confMatrix[getTagNum(oldtags.get(i))][getTagNum(oldtags.get(i))]++;
						}
						else
						{
							ob.writeUTF(words.get(i) + " : " + newtags.get(i) + " Correct tag : " + oldtags.get(i) + "\n");
							System.out.println("Tag : " + newtags.get(i) + ",Correct tag : " + oldtags.get(i) + "\tWord : " + words.get(i) + "\t" + total +"\n");
							confMatrix[getTagNum(newtags.get(i))][getTagNum(oldtags.get(i))]++;
						}
					}
					ob.writeUTF("\n");
					sentence = "";
					words.clear();
					oldtags.clear();
					newtags.clear();
				}
			}
			System.out.println("Total words : " + total);
			System.out.println("Correct words : " + correct);
			System.out.println("Accuracy : " + (double)correct*100/(double)total);
			
			System.out.println("\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				int incorrect = 0;
				for(int j=0;j<distinctTags.size();j++)
				{
					if(i != j)
						incorrect = incorrect + confMatrix[i][j];
				}
				double accuracy = 100*(double)(confMatrix[i][i])/(double)(confMatrix[i][i]+incorrect);
				System.out.println("Correct:" + confMatrix[i][i] + "\tErrors:" + incorrect + "\tTotal:" + (confMatrix[i][i]+incorrect) + "\tAccuracy:" + accuracy);
			}
			
			System.out.println("\nConfusion Matrix :\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					if(confMatrix[i][j] != 0)
						System.out.print(getNumTag(j) + ":" + confMatrix[i][j] + " ");
				}
				System.out.println();
			}
			
			System.out.print("\nConfusion Matrix (Formated) :\n\n\t");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
			}
			System.out.println("");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					//if(confMatrix[i][j] != 0)
						System.out.print(confMatrix[i][j] + "\t");
				}
				System.out.println();
			}
			
			System.out.println("Done!");
			ob.close();
			bf.close();
		}
		catch(Exception e)
		{
			System.out.println(e.toString());
			e.printStackTrace();
		}
	}
	
	//************************************ Nikkhilesh *******************************************//
	@SuppressWarnings("unchecked")
	public static void testAccuracy(String testfile, String resultfile, String finalFile, String flagVGI, String ngiflag)
	{
		int total = 0;
		int correct = 0;
		int confMatrix[][];
		Vector <NgiResult> ngioutput = new Vector<NgiResult>();
		distinctTags = new Vector<String>();
		if(flagVGI.trim().equalsIgnoreCase("true"))
		{
			VGI = true;
		}
		else if(flagVGI.trim().equalsIgnoreCase("false"))
		{
			VGI = false;
		}
		if(ngiflag.trim().equalsIgnoreCase("true"))
		{
			NGI = true;
		}
		else if(ngiflag.trim().equalsIgnoreCase("false"))
		{
			NGI = false;
		}
		try
		{
			File outfile = new File(resultfile);
			if(outfile.exists())
			{
				outfile.delete();
				outfile = new File(resultfile);
			}
			UTFWriter ob = new UTFWriter(outfile);
			
			File outfile1 = new File(finalFile);
			if(outfile1.exists())
			{
				outfile1.delete();
				outfile1 = new File(finalFile);
			}
			UTFWriter obFinal = new UTFWriter(outfile1);
			
			BufferedReader bf = new BufferedReader(new InputStreamReader(new FileInputStream(ConfigReader.get("CRF.tagfile")), "UTF8"));
			String line = "";
			while((line = bf.readLine()) != null)
			{
				distinctTags.add(line.trim());
			}
			bf.close();
			
			confMatrix = new int[distinctTags.size()][distinctTags.size()];
			for(int i=0;i<distinctTags.size();i++)
				for(int j=0;j<distinctTags.size();j++)
				{
					confMatrix[i][j] = 0;
				}
					
			bf = new BufferedReader(new InputStreamReader(new FileInputStream(testfile), "UTF8"));
			line = "";
			Vector<String> words = new Vector<String>();
			Vector<String> oldtags = new Vector<String>();
			Vector<String> newtags = new Vector<String>();
			String sentence = "";
					
			while((line = bf.readLine()) != null)
			{
				if(line.length() != 0)
				{
					String[] cols = line.split("\t");					
					int l = cols.length;
					String word = cols[0].trim();
					String oldtag = cols[l-2].trim();
					String newtag = cols[l-1].trim();
					words.add(word);
					oldtags.add(oldtag);
					newtags.add(newtag);
					sentence = sentence + word + " ";
				}
				else
				{
					int l = sentence.length();
					int [] vga = new int[words.size()];
					NewStemmer dmstemmer = new NewStemmer();
					Vector<MAResult> mrv = new Vector<MAResult>();
					sentence = sentence.substring(0, l-1);
					String tokens[] = sentence.split(" ");
					List lis = Arrays.asList(tokens);
					Vector<String> tokenList = new Vector<String>(lis);
					for(int temp=0;temp<tokenList.size();temp++)
					{
						if(tokenList.get(temp)==null)
							continue;
						MAResult Mar = dmstemmer.stem(tokenList.get(temp));
						mrv.add(Mar);
					}

					if(VGI)
					{						
						tvgi = new TestVGIdm();
						vga = tvgi.doVGI_forSentence(mrv,newtags);
						//vga = tvgi.doVGI_forSentence(sentence, false);
					}
					if(NGI)
					{
						ngi = new NounGroupIdentifier3();
						ngi.nounGroupIdentify(mrv,newtags);
						ngi.printfinaltags();
						ngioutput = ngi.getngiresults();
					}
					for(int i=0; i<words.size(); i++){
						
						if(VGI)
						{
							if(vga[i]>0)
							{
								if(vga[i]>1)
								{
									if(vga[i]==100)
									{
										if(!newtags.get(i).equals("NN"))
										{
											System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "NN \n"); 
											newtags.set(i, "NN");
										}
									}
									else
									{
										if(!newtags.get(i).equals("VAUX"))
										{
											System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "VAUX \n"); 
											newtags.set(i, "VAUX");
										}
									}
								}
								else if(!newtags.get(i).equals("VM"))
								{
									System.out.println("Changing tag for " + words.get(i) + " from " + newtags.get(i) +" to " + "VM \n");
									newtags.set(i, "VM");
								}
							}
						}
						if(NGI)
						{
							if((ngioutput.get(i).getchanged()) && !(ngioutput.get(i).getfinaltag().startsWith("V")))
								newtags.set(i,ngioutput.get(i).getfinaltag());
						}
						//System.out.println(words.get(i) + "_[ " + newtags.get(i) + " ] ");
						obFinal.writeUTF(words.get(i) + "_[ " + newtags.get(i) + " ] ");
						total++;
						if(newtags.get(i).equals(oldtags.get(i)))
						{
							correct++;
							ob.writeUTF(words.get(i) + " : " + newtags.get(i) + "\n");
							confMatrix[getTagNum(oldtags.get(i))][getTagNum(oldtags.get(i))]++;
						}
						else
						{
							ob.writeUTF(words.get(i) + " : " + newtags.get(i) + " Correct tag : " + oldtags.get(i) + "\n");
							System.out.println("Tag : " + newtags.get(i) + ",Correct tag : " + oldtags.get(i) + "\tWord : " + words.get(i) + "\t" + total +"\n");
							confMatrix[getTagNum(newtags.get(i))][getTagNum(oldtags.get(i))]++;
						}
					}
					ob.writeUTF("\n");
					obFinal.writeUTF("\n");
					sentence = "";
					words.clear();
					oldtags.clear();
					newtags.clear();
				}
			}
			System.out.println("Total words : " + total);
			System.out.println("Correct words : " + correct);
			System.out.println("Accuracy : " + (double)correct*100/(double)total);
			
			System.out.println("\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				int incorrect = 0;
				for(int j=0;j<distinctTags.size();j++)
				{
					if(i != j)
						incorrect = incorrect + confMatrix[i][j];
				}
				double accuracy = 100*(double)(confMatrix[i][i])/(double)(confMatrix[i][i]+incorrect);
				System.out.println("Correct:" + confMatrix[i][i] + "\tErrors:" + incorrect + "\tTotal:" + (confMatrix[i][i]+incorrect) + "\tAccuracy:" + accuracy);
			}
			
			System.out.println("\nConfusion Matrix :\n");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					if(confMatrix[i][j] != 0)
						System.out.print(getNumTag(j) + ":" + confMatrix[i][j] + " ");
				}
				System.out.println();
			}
			
			System.out.print("\nConfusion Matrix (Formated) :\n\n\t");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
			}
			System.out.println("");
			
			for(int i=0;i<distinctTags.size();i++)
			{
				System.out.print(getNumTag(i) + "\t");
				for(int j=0;j<distinctTags.size();j++)
				{
					//if(confMatrix[i][j] != 0)
						System.out.print(confMatrix[i][j] + "\t");
				}
				System.out.println();
			}
			
			System.out.println("Done!");
			ob.close();
			obFinal.close();
			bf.close();
		}
		catch(Exception e)
		{
			System.out.println(e.toString());
			e.printStackTrace();
		}
	}
	
	//************************************ Nikkhilesh *******************************************//
	
	
	public static void main(String args[])
	{
		ConfigReader.read(args[0].trim());
		testAccuracy(args[1].trim(),args[2].trim(),args[3].trim(),args[4].trim(),args[5].trim());
	}
}
