"""
NOTE: You are only allowed to edit this file between the lines that say:
    # TODO
    # END TODO
"""

import numpy as np
from bernoulli_bandit import *

    
def give_pull(num_arms, successes, failures):
    '''
	num_arms: int: number of arms in the bandit
	successes: list: list of number of successes for each arm
	failures: list: list of number of failures for each arm

	Returns an integer representing the the index of the arm that it wants to pull (0-indexed).

    Algorithm:
    1. For each arm, sampe from the beta distribution with alpha = successes(a) + 1, beta = failures[a] + 1
	'''
    # TODO
    return np.random.randint(num_arms)
    # END TODO


########################################################################
#                       DO NOT EDIT THE FOLLWOING                      #
########################################################################


def update(arm_index, reward, successes, failures):
    '''
	arm_index: int: arm chosen to be pulled
	reward: int(0/1): reward obtained by pulling this arm
    successes: list: list of number of successes for each arm
	failures: list: list of number of failures for each arm

	Returns modified lists of successes and faailures 
	'''
    if reward == 1 :
        successes[arm_index] += 1
    else :
        failures[arm_index] += 1
    return successes, failures


def thompson(seed, PROBS, HORIZON):
    np.random.seed(seed)
    np.random.shuffle(PROBS)
    bandit = BernoulliBandit(probs=PROBS)

    num_arms = len(PROBS)
    successes = np.zeros(num_arms) # list of number of successes for each arm
    failures = np.zeros(num_arms)  # list of number of failures for each arm
    
    for t in range(HORIZON):
        arm_to_be_pulled = give_pull(num_arms, successes, failures)
        reward = bandit.pull(arm_to_be_pulled)
        successes, failures= update(arm_to_be_pulled, reward, successes, failures)
    return bandit.regret()