"""
NOTE: You are only allowed to edit this file between the lines that say:
    # TODO
    # END TODO
"""

import numpy as np
import math
from bernoulli_bandit import *

    
def give_pull(num_arms, counts, values, time):
    '''
	num_arms: int: number of arms in the bandit
    counts: list: list of number of times an arm is pulled
	values: list: means of the arms
    time: int: number of episodes

	Returns an integer representing the index of the arm that it wants to pull (0-indexed).

    Algorithm:
    1. Until the number of pulls is less than number of arms, pull arm = time
    Afterwards
    2. For each arm, compute the upper confidence bound for each action as follows:
        UCB(a) = mean(a) + sqrt(log(total number of episodes) / N(a) 
        where N(a) is the number of times the action a has been taken, and log(total number of episodes) is the natural logarithm of the total number of episodes so far.
        Choose the action with the highest upper confidence bound for the current state.
	'''
    # TODO
    return np.random.randint(num_arms)    
    # END TODO


########################################################################
#                       DO NOT EDIT THE FOLLWOING                      #
########################################################################


def update(arm_index, reward, counts, values, time):
    '''
	arm_index: int: arm chosen to be pulled
	reward: int(0/1): reward obtained by pulling this arm
    counts: list: list of number of times an arm is pulled
	values: list: means of the arms
    time: int: number of episodes

	Returns modified lists of counts, values and time
	'''
    time += 1
    counts[arm_index] += 1
    n = counts[arm_index]
    value = values[arm_index]
    new_value = ((n - 1) / n) * value + (1 / n) * reward
    values[arm_index] = new_value
    return counts, values, time


def ucb(seed, PROBS, HORIZON):
    np.random.seed(seed)
    np.random.shuffle(PROBS)
    bandit = BernoulliBandit(probs=PROBS)

    num_arms = len(PROBS)
    time = 0                    # variable to store number of episodes
    counts = np.zeros(num_arms) # number of times an arm is pulled  
    values = np.zeros(num_arms) # means of the arms
    
    for t in range(HORIZON):
        arm_to_be_pulled = give_pull(num_arms, counts, values, time)
        reward = bandit.pull(arm_to_be_pulled)
        counts, values, time = update(arm_to_be_pulled, reward, counts, values, time)
    return bandit.regret()