#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Fri Oct  4 11:48:03 2019

@author: swaprava

plot for difference in W* and W^ERM for symmetric graders
"""
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['text.usetex'] = True

mu = 1
gamma = 16

trueScoreRepeat = 100
observedScoreRepeat = 100

numOfGraders = 5

biases = np.linspace(0,1,11)

reliabilities = np.linspace(6,15,10)

def computeR(observedYvector, bias, reliability, k):
    
    temp1 = gamma**k * mu + reliability**k * sum(np.array(observedYvector) - bias)
    temp2 = gamma**k + len(observedYvector) * reliability**k
    
    return temp1/temp2

def computeW(r, y, kind="quadratic"):
    
    if kind == "quadratic":
        
        return -(r - y)**2
    


rel = np.mean(reliabilities)
trueScores = np.random.normal(mu, 1.0/np.sqrt(gamma), size=trueScoreRepeat)

# PLOT WRT BIAS
avgError = [] 
stdError = []

for bias in biases:
    
      
    errorVector = []
    intermediateError = []
    
    for trueScore in trueScores:

        errorIntermediate = []
        
        for observedScoreCount in range(observedScoreRepeat):
            
            observedYvector = np.random.normal(trueScore+bias, 1.0/np.sqrt(rel), size=numOfGraders)
            rTrupeqa = computeR(observedYvector, bias, rel, k=0.5)
            rERM = computeR(observedYvector, bias, rel, k=1.0)
            
            welfareERM = round(computeW(rERM, trueScore), 6)
            welfareTrupeqa = round(computeW(rTrupeqa, trueScore), 6)
            # welfareERM = computeW(rERM, trueScore)
            # welfareTrupeqa = computeW(rTrupeqa, trueScore)
            # print("welfareERM =", welfareERM)
            # print("welfareTrupeqa =", welfareTrupeqa)
            # error = computeW(rERM, trueScore) - computeW(rTrupeqa, trueScore)
            # print("error =", error)
            if abs(welfareERM) >= 1e-5:
                error = (welfareERM - welfareTrupeqa) / abs(welfareERM)
                # error = (welfareERM - welfareTrupeqa)
                # error = (welfareTrupeqa / welfareERM)
                errorVector.append(error)
                errorIntermediate.append(error)

        intermediateError.append(np.std(errorIntermediate))
            
    avgError.append(np.mean(errorVector))
    stdError.append(np.mean(intermediateError))
    

plt.figure("Error vs Bias")
plt.errorbar(biases, avgError, yerr=stdError, fmt='bx-', elinewidth=2, capsize=2, capthick=2, label='Error')
plt.grid()
plt.xticks(biases)
plt.xlabel('Bias (symmetric)', fontsize=16)
# plt.title(r'$(W^{{ERM}} - W^{{ISWDM}}) / |W^{{ERM}}|$ plot', fontsize=16)
# plt.axis('tight')
plt.tight_layout()
plt.tick_params(labelsize=16)
plt.show()



bias = np.mean(biases)
trueScores = np.random.normal(mu, 1.0/np.sqrt(gamma), size=trueScoreRepeat)

# PLOT WRT RELIABILITY
avgError = [] 
stdError = []

for rel in reliabilities:
    
    errorVector = []
    intermediateError = []
    
    for trueScore in trueScores:

        errorIntermediate = []
        
        for observedScoreCount in range(observedScoreRepeat):
            
            observedYvector = np.random.normal(trueScore+bias, 1.0/np.sqrt(rel), size=numOfGraders)
            # print("observedYvector =", observedYvector)

            rTrupeqa = computeR(observedYvector, bias, rel, k=0.5)
            rERM = computeR(observedYvector, bias, rel, k=1.0)
            
            welfareERM = round(computeW(rERM, trueScore), 6)
            welfareTrupeqa = round(computeW(rTrupeqa, trueScore), 6)
            # print("welfareERM =", welfareERM)
            # print("welfareTrupeqa =", welfareTrupeqa)
            # error = computeW(rERM, trueScore) - computeW(rTrupeqa, trueScore)
            if abs(welfareERM) >= 1e-5:
                error = (welfareERM - welfareTrupeqa) / abs(welfareERM)
                # error = (welfareTrupeqa / welfareERM)
                errorVector.append(error)
                errorIntermediate.append(error)

        intermediateError.append(np.std(errorIntermediate))
            
    avgError.append(np.mean(errorVector))
    stdError.append(np.mean(intermediateError))
    

plt.figure("Error vs Reliability")
plt.errorbar(reliabilities, avgError, yerr=stdError, fmt='rs-', elinewidth=2, capsize=2, capthick=2, label='Error')
plt.grid()
plt.xticks(reliabilities)
plt.xlabel('Reliability (symmetric)', fontsize=16)
# plt.title(r'$(W^{{ERM}} - W^{{ISWDM}}) / |W^{{ERM}}|$ plot', fontsize=16)
# plt.axis('tight')
plt.tight_layout()
plt.tick_params(labelsize=16)
plt.show()