import random
import os

numStates = 2
TERMINAL = -1
p0 = 0.5
r0 = 1.0

p1 = 0.5
r1 = 1.0  #2.0

epsilon = 0.1

# We model an MDP with two non-terminal states, s0 and s1, and a single action. Both states transition to the terminal state with probability epsilon.
# s0 remains in s0 with probability p0(1 - epsilon), and transitions to s1 with probability (1 - p0)(1 - epsilon).
# s1 remains in s1 with probability p1(1 - epsilon), and transitions to s0 with probability (1 - p1)(1 - epsilon).
# There's a reward or r0 for any transition starting in s0, and a reward of r1 for any transition starting in s1.
# No discounting is used to compute values.

# An episode is started picking each state uniformly at random. It stores the sequence of states up to termination.
def generateEpisode():
    state = (int)((random.random() * numStates) % numStates)

    e = []
    e.append(state)

    while (state != TERMINAL):
        nextState = state
        if(random.random() < epsilon):
            nextState = TERMINAL
        elif(state == 0 and random.random() >= p0):
            nextState = 1
        elif(state == 1 and random.random() >= p1):
            nextState = 0
        
        if(nextState != TERMINAL):
            e.append(nextState)

        state = nextState

    return e

def getReward(s):
    if(s == 0):
        return r0
    elif(s == 1):
        return r1

# Returns the long-term undiscounted reward from an episode starting from each time step.
def getValues(e):
    values = []
    for i in range(len(e)):
        values.append(0)
    currentTotal = 0
    for i in range(len(e)):
        v = getReward(e[len(e) - 1 - i])
        currentTotal += v
        values[len(e) - 1 - i] = currentTotal

    return values

def average(x):
    sum = 0
    for i in range(len(x)):
        sum += x[i]

    return sum * 1.0 / len(x)

# Gets the list of entries going into a k-th visit MC calculation.
def getkthVisitValueEstimate(state, k, episodes, values):

    list = []
    for e in range(len(episodes)):
        ctr = 0
        for s in range(len(episodes[e])):
            if(episodes[e][s] == state):
                ctr += 1
                if(ctr == k):
                    list.append(values[e][s])

    return average(list)

# Gets the list of entries going into an every-visit MC calculation.
def getEveryVisitValueEstimate(state, episodes, values):

    list = []
    for e in range(len(episodes)):
        for s in range(len(episodes[e])):
            if(episodes[e][s] == state):
                list.append(values[e][s])

    return average(list)

# Gets the list of entries going into a last-visit MC calculation.
def getLastVisitValueEstimate(state, episodes, values):

    list = []
    for e in range(len(episodes)):
        lastFound = False
        for s in range(len(episodes[e])):
            if((not lastFound) and episodes[e][len(episodes[e]) - 1 - s] == state):
                list.append(values[e][len(episodes[e]) - 1 - s])
                lastFound = True

    return average(list)


randomSeed = 1
random.seed(randomSeed)

N = 10000
episodes = []
values = []
for j in range(N):
    e = generateEpisode()
    val = getValues(e)
    episodes.append(e)
    values.append(val)

est01 = getkthVisitValueEstimate(0, 1, episodes, values)
est02 = getkthVisitValueEstimate(0, 2, episodes, values)
est03 = getkthVisitValueEstimate(0, 3, episodes, values)

est11 = getkthVisitValueEstimate(1, 1, episodes, values)
est12 = getkthVisitValueEstimate(1, 2, episodes, values)
est13 = getkthVisitValueEstimate(1, 3, episodes, values)

est0l = getLastVisitValueEstimate(0, episodes, values)
est1l = getLastVisitValueEstimate(1, episodes, values)

est0e = getEveryVisitValueEstimate(0, episodes, values)
est1e = getEveryVisitValueEstimate(1, episodes, values)

print("Estimated value function")
print("[" + str(est01) + ", " + str(est11) + "]: first-visit MC")
print("[" + str(est02) + ", " + str(est12) + "]: second-visit MC")
print("[" + str(est03) + ", " + str(est13) + "]: third-visit MC")
print("[" + str(est0e) + ", " + str(est1e) + "]: every-visit MC")
print("[" + str(est0l) + ", " + str(est1l) + "]: last-visit MC")

