package weka.classifiers.custom;

import weka.core.Instance;
import weka.core.Instances;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import java.io.*;
import java.util.*;
import weka.core.*;
import newCode.*;
import riso.numerical.*;

public class Myloss extends Classifier {

	double[] finalW;
	/**
	 * @param args
	 */
	public void buildClassifier(Instances instances) throws Exception {
		
		
		finalW=null;
			SLoss sq=new SLoss();
			double[] x;
			double[] w;
			double C;
			double f;
			double[] g;
			int n=0;
			double y;
			int noofrec=0;
		
			finalW=new double[n+1];
			for(int i=0;i<n+1;i++) finalW[i]=0;
						
		Enumeration enu = instances.enumerateInstances();
	    while (enu.hasMoreElements()) {
	      Instance instance = (Instance) enu.nextElement();
	      noofrec++;
	      n=instance.numAttributes()-1;
	      x=new double[n+1];
		  x[0]=1;
		w=new double[n+1];
		for(int i=0;i<(n+1);i++) w[i]=0;
	      for(int i=0; i<instance.numAttributes()-1; i++)
	      {	    	
			x[i+1]=instance.value(i);
	      }
	    y=instance.value(instance.numAttributes()-1);
	    
	    int[] iprint=new int[n+1];
		double[] diag = new double[n+1];
		int[] iflag=new int[n+1];
		for(int i=0;i<=n;i++)
		{
			iprint[i]=0;
			iflag[i]=0;
			diag[i]=0;
		}
		
		do{
			f=sq.f(y,x,w);
				g=sq.g(y,x,w);
			
			LBFGS.lbfgs(n+1,3,w,f,g,false,diag,iprint,0.00001,.0000000000001,iflag);
			
		}while(allNotZero(iflag));
	    
		for(int i=0;i<n+1;i++) finalW[i] = (finalW[i]*(noofrec-1) + w[i])/noofrec;
	      //System.out.println();
	    }
	
	    }	
				
				
				
				//reading every record of the file 
//				int jl=0;
//					jl++;
//					
//					noofrec++;
					
					//initializing x
										
					//initializing w
			
					
			
					
				
//				bw.write("Weights without regularization : ");
//				for(int i=0;i<finalW.length;i++) bw.write(finalW[i]+" ");
//				bw.write("\n");
//				bw.close();
//				fw.close();
				
				//System.out.println("");
//			}catch(Exception ex)
//			{
//				System.out.println(ex);
//			}
			
//			return finalW;
//		}
		
		private double dotproduct(double[] x, double[] w)
		{
			double res=0;
			
			for(int i=0;i<x.length;i++)
			{
				res+=x[i]*w[i];
			}
			
			return res;
		}
		
		private double mod(double[] arr)
		{
			double sum=0;
			for(int i=0;i<arr.length;i++) sum+=arr[i]*arr[i];
			
			return Math.sqrt(sum);
		}
		
		private boolean allNotZero(int[] arr)
		{
			for(int i=0;i<arr.length;i++)
			{
				if(arr[i]!=0)
					return true;
			}	
			return false;
		}
			
	
	public double classifyInstance(Instance instance) {
		double[] x;
		double y=0;//w[0];
		int n=instance.numAttributes()-1;
		 x=new double[n+1];
		  x[0]=1;
		 for(int i=0; i<instance.numAttributes()-1; i++)
	      {	    	
			x[i+1]=instance.value(i);
			y+=finalW[i+1]*x[i+1];
	      }
		 
	    return y;
	}
	
//	public double [] transpose(double[]){
//		double[] T;
//		T=new double[];
//		return T;
//	}
	
}
