"""
NOTE: You are only allowed to edit this file between the lines that say:
    # TODO
    # END TODO
"""

import numpy as np
from bernoulli_bandit import *

    
def give_pull(num_arms, eps, values):
    '''
	num_arms: int: number of arms in the bandit
	eps: float
	values: list: means of the arms

	Returns an integer representing the index of the arm that it wants to pull (0-indexed).

    Algorithm:
    1. With probability 1-epsilon, choose the action that has the highest mean(exploitation).
    2. With probability epsilon, choose a random action (exploration).
	'''
    # TODO
    return np.random.randint(num_arms)
    # END TODO


########################################################################
#                       DO NOT EDIT THE FOLLWOING                      #
########################################################################


def update(arm_index, reward, counts, values):
    '''
	arm_index: int: arm chosen to be pulled
	reward: int(0/1): reward obtained by pulling this arm
    counts: list: list of number of times an arm is pulled
	values: list: means of the arms

	Returns modified lists of counts and values 
	'''
    counts[arm_index] += 1
    n = counts[arm_index]
    value = values[arm_index]
    new_value = ((n - 1) / n) * value + (1 / n) * reward
    values[arm_index] = new_value
    return counts, values


def eps_greedy(seed, PROBS, HORIZON):
    np.random.seed(seed)
    np.random.shuffle(PROBS)
    bandit = BernoulliBandit(probs=PROBS)

    num_arms = len(PROBS)
    eps = 0.1
    counts = np.zeros(num_arms) # number of times an arm is pulled  
    values = np.zeros(num_arms) # means of the arms
    
    for t in range(HORIZON):
        arm_to_be_pulled = give_pull(num_arms, eps, values)
        reward = bandit.pull(arm_to_be_pulled)
        counts, values = update(arm_to_be_pulled, reward, counts, values)
    return bandit.regret()