import io
import numpy

# S = {0, 1}
# A = {0, 1}
T = []
R = [] 
for s in range(0, 2):
    Ts = []
    Rs = [] 
    for a in range(0, 2):
        Ta = []
        Ra = [] 
        for sPrime in range(0, 2):
            TsasPrime = 0
            Ta.append(TsasPrime)
            RsasPrime = 0
            Ra.append(RsasPrime)
            
        Ts.append(Ta)
        Rs.append(Ra)

    T.append(Ts)
    R.append(Rs)

print(T)
print(R)

T[0][0][0] = 0.25
T[0][0][1] = 1 - T[0][0][0]
T[0][1][0] = 1
T[0][1][1] = 1 - T[0][1][0]
T[1][0][0] = 0.5
T[1][0][1] = 1 - T[1][0][0]
T[1][1][0] = 1
T[1][1][1] = 1 - T[1][1][0]

R[0][0][0] = 1
R[0][0][1] = 0
R[0][1][0] = 2
R[0][1][1] = 0#Dummy
R[1][0][0] = -1
R[1][0][1] = 10
R[1][1][0] = 0
R[1][1][1] = 0

gamma = 0.99

#def bellmanOptimalityOperator(V):
    #Write this in class

V = []
V.append(0)
V.append(0)

# Value iteration
epsilon = 1.0e-10
converged = False
while(converged == False):
    #Write this in class
    newV = []
    newV.append(0)
    newV.append(0)

    converged = True

    optimalPolicy = []
    
    for s in range(0, 2):
        maxA = 0
        for a in range(0, 2):
            sumA = 0
            for sPrime in range(0, 2):
                sumA += T[s][a][sPrime] * (R[s][a][sPrime] + gamma * V[sPrime])
            if(sumA > newV[s] or a == 0):
                newV[s] = sumA
                maxA = a

        optimalPolicy.append(maxA)
        
        if(abs(newV[s] - V[s]) > epsilon):
            converged = False
        
    for s in range(0, 2):
        V[s] = newV[s]
    print(V)
    print(optimalPolicy)
    print("***********************")

                

    
