from visual import *
from random import random

from Numeric import *

from Hamiltonian import *



class NumericLattice:
    def __init__(self, nSpinsX, nSpinsY, stripeSpacingX, stripeSpacingY):
        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY
        
        self.stripeSpacingX = stripeSpacingX
        self.stripeSpacingY = stripeSpacingY

        self.EBu = 0.0
        self.EBd = 0.0
        self.Baxis = vector(0.0,0.0,1.0)
        self.temp = 0.0        
                                                                            # to move to array language, make array, then operate with array
        self.couplingXArray = zeros((nSpinsX,nSpinsX),Float32)              # coupling array for X direction (Up and Down), scalars
        self.couplingYArray = zeros((nSpinsY,nSpinsY),Float32)              # coupling array for Y direction (Left and Right), scalars
      
        #  Initializing Arrays
        self.sigmaArray = zeros((stripeSpacingX,stripeSpacingY),Float32)    # an array of sigma values, scalars
        self.phaseArray = zeros((stripeSpacingX,stripeSpacingY),Float32)    # scalars
        
        self.Ham = Hamiltonian(stripeSpacingX,stripeSpacingY)
        self.Ham.setzArray()

        #  Computation Arrays
        self.spinArray = zeros((nSpinsX,nSpinsY,1,3),Float32)               # spin vectors in a matrix, vectors
        self.meanFieldArray = zeros((nSpinsX,nSpinsY,1,3),Float32)          # vectors
        self.torqueArray = zeros((nSpinsX,nSpinsY,1,3),Float32)             # torque vectors in a matrix, vectors

        # Output Arrays
        self.energyArray = zeros((nSpinsX,nSpinsY),Float32)                 # scalars



    def setCouplings(self, Ja, Jbx, Jby):           # sets the couplings of neighboring spins and creates two coupling arrays.
        self.Ja = Ja                                # Unit Cell interior region coupling. Unit Cell Object specific. Currently a constant.
        self.Jbx = Jbx                              # Unit Cell boundary coupling., between unit cells, in the x direction.
        self.Jby = Jby                              # in the y direction...

        self.couplingXArray = zeros((self.nSpinsX,self.nSpinsX),Float32)
        self.couplingYArray = zeros((self.nSpinsY,self.nSpinsY),Float32)

        for i in range(self.nSpinsX):
            for j in range(self.nSpinsX):
                if          i - j == 1 or j - i == 1:
                    if      i%self.stripeSpacingX == 0 and j%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += self.Jbx
                    elif    j%self.stripeSpacingX == 0 and i%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += self.Jbx
                    else:                                                                                       self.couplingXArray[i][j] += self.Ja                   
                if          i - j == self.nSpinsX - 1 or j - i == self.nSpinsX - 1:                             self.couplingXArray[i][j] += self.Jbx

        for i in range(self.nSpinsY):
            for j in range(self.nSpinsY):
                if          i - j == 1 or j - i == 1:
                    if      i%self.stripeSpacingY == 0 and j%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += self.Jby
                    elif    j%self.stripeSpacingY == 0 and i%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += self.Jby
                    else:                                                                                       self.couplingYArray[i][j] += self.Ja 
                if          i - j == self.nSpinsY - 1 or j - i == self.nSpinsY - 1:                             self.couplingYArray[i][j] += self.Jby

        self.Ham.setCouplings(self.Ja, self.Jbx, self.Jby)


    def setStateFromH(self, k, baseSigma, n):
        self.Ham.makeHamiltonian(k)

        if n == -1:     state = self.Ham.getGroundState()
        else:           state = self.Ham.getState(n)

        xComponentUCArray = state.real
        yComponentUCArray = state.imaginary

        for x in range(self.stripeSpacingX):
            for y in range(self.stripeSpacingY):
                z = (-1.0)**(x%2 + y%2)
                self.phaseArray[x][y] = k.x*x + k.y*y + (1-z)*pi/2

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY
                xR = x - xModXSpacing
                yR = y - yModYSpacing

                z = (-1.0)**(xModXSpacing%2 + yModYSpacing%2)

                self.sigmaArray[xModXSpacing][yModYSpacing] = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]**2 + yComponentUCArray[xModXSpacing][yModYSpacing]**2)**0.5
                self.phaseArray[xModXSpacing][yModYSpacing] = k.x*xModXSpacing + k.y*yModYSpacing + (1-z)*pi/2

                tempSigma = self.sigmaArray[xModXSpacing][yModYSpacing]
                tempAngle = self.phaseArray[xModXSpacing][yModYSpacing] + k.x*xR + k.y*yR

                tempSpinX = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]*cos(tempAngle) - yComponentUCArray[xModXSpacing][yModYSpacing]*sin(tempAngle))
##                tempSpinX = baseSigma*xComponentUCArray[xModXSpacing][yModYSpacing]
                tempSpinY = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]*sin(tempAngle) + yComponentUCArray[xModXSpacing][yModYSpacing]*cos(tempAngle))
##                tempSpinY = baseSigma*yComponentUCArray[xModXSpacing][yModYSpacing]
                tempSpinZ = z*(1.0-tempSigma**2)**.5
               
                self.spinArray[x][y] = (tempSpinX,tempSpinY,tempSpinZ)


    def setStateManually(self, k, baseSigma, zArray, sigmaArray, phaseArray):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY
                xR = x - xModXSpacing
                yR = y - yModYSpacing

                z = zArray[xModXSpacing][yModYSpacing]

                tempSigma = baseSigma*sigmaArray[xModXSpacing][yModYSpacing]
                tempAngle = phaseArray[xModXSpacing][yModYSpacing] + k.x*xR + k.y*yR
                
                tempSpinX = tempSigma*cos(tempAngle)
                tempSpinY = tempSigma*sin(tempAngle)
                tempSpinZ = z*(1.0-tempSigma**2)**.5
               
                self.spinArray[x][y] = (tempSpinX,tempSpinY,tempSpinZ)


    def setStateManuallyPhi(self, k, scalePhi, zArray, phiArray, phaseArray):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY
                xR = x - xModXSpacing
                yR = y - yModYSpacing

                z = zArray[xModXSpacing][yModYSpacing]

                tempSigma = sin(scalePhi*phiArray[xModXSpacing][yModYSpacing]*(pi/180.0))
                tempAngle = phaseArray[xModXSpacing][yModYSpacing] + k.x*xR + k.y*yR
                
                tempSpinX = tempSigma*cos(tempAngle)
                tempSpinY = tempSigma*sin(tempAngle)
                tempSpinZ = z*cos(scalePhi*phiArray[xModXSpacing][yModYSpacing]*(pi/180.0))
               
                self.spinArray[x][y] = (tempSpinX,tempSpinY,tempSpinZ)

    def moveSpin(self, spherePoint):
        maxDot = -1.0
        selectX = 0
        selectY = 0

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                newDot = dot(spherePoint, self.spinArray[x][y][0])
                if newDot > maxDot:
                    maxDot = newDot
                    selectX = x
                    selectY = y

        self.spinArray[selectX][selectY] = spherePoint


 

    def getTorques(self, spinArray, t):
        flippedLatticeArray = transpose(spinArray)                                             # array preparation for coupling computation
        twiceflippedLatticeArray = transpose(transpose(spinArray,(1,0,2,3)))                   #

        meanFieldXArray = matrixmultiply(flippedLatticeArray,self.couplingXArray)                   # computation of mean field
        meanFieldYArray = matrixmultiply(twiceflippedLatticeArray,self.couplingYArray)              #
      
        self.meanFieldArray = transpose(meanFieldXArray) + transpose(meanFieldYArray, (2,3,1,0))    # sum of x and y directions in correct shape

        w = (6.5/10.0)*2.0*pi
        A = .7
        arg = A*sin(w*t)

        EField = self.EBu*vector(cos(arg),sin(arg),0.0)

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                
                if self.temp != 0.0:                            # adds a random torque <--> kick to each spin, with a scale set by temp.
                    randomU = random()                          # perturbed by unit ball
                    randomV = random()                          # we produce a uniform distribution on the unit ball
                    theta = 2.0*3.1416*randomU
                    phi = arccos(2.0*randomV-1.0)
                    kick = self.temp*norm(vector(sin(theta)*cos(phi),cos(theta)*sin(phi),cos(phi)))
                    
                    self.meanFieldArray[x][y] += kick        # this kick is constant in magnitude but random in direction

                if self.EBu != 0.0 or self.EBd != 0.0:
                    self.meanFieldArray[x][y] += EField

                self.torqueArray[x][y] = cross(spinArray[x][y][0],self.meanFieldArray[x][y][0])


    def timeEvolve(self, t, dt, EBu, EBd, temp):           # Takes us one step forward in time by adding the torque to our spins and renormalizing.
        self.EBu = EBu
        self.EBd = EBd
        self.temp = temp

        self.getTorques(self.spinArray, t)                 # for time dep solns: set t = tn
        k1 = dt*self.torqueArray

        self.getTorques(self.spinArray + k1/2.0, t + dt/2.0)        # for time dep solns: set t = tn + dt/2.0
        k2 = dt*self.torqueArray

        self.getTorques(self.spinArray + k2/2.0, t + dt/2.0)        # for time dep solns: set t = tn + dt/2.0
        k3 = dt*self.torqueArray

        self.getTorques(self.spinArray + k3, t + dt)            # for time dep solns: set t = tn + dt
        k4 = dt*self.torqueArray

        self.spinArray = self.spinArray + k1/6.0 + k2/3.0 + k3/3.0 + k4/6.0

        for x in range(self.nSpinsX):           # normalizes the state
            for y in range(self.nSpinsY):
                self.spinArray[x][y][0] = norm(self.spinArray[x][y][0])



    def randomizeState(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                randomU = random()
                randomV = random()
                
                theta = 2.0*3.1416*randomU
                phi = arccos(2.0*randomV-1.0)
                
                self.spinArray[x][y][0] = norm(vector(sin(theta)*cos(phi),cos(theta)*sin(phi),cos(phi)))

    def couplingDistortion(self, scale):
        for i in range(self.nSpinsX):
            for j in range(self.nSpinsX):
                randomV = random()
                distortion = 2.0*scale*(randomV-.5)
                if          i - j == 1 or j - i == 1:
                    if      i%self.stripeSpacingX == 0 and j%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += distortion
                    elif    j%self.stripeSpacingX == 0 and i%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += distortion
                    else:                                                                                       self.couplingXArray[i][j] += distortion                  
                if          i - j == self.nSpinsX - 1 or j - i == self.nSpinsX - 1:                             self.couplingXArray[i][j] += distortion

        for i in range(self.nSpinsY):
            for j in range(self.nSpinsY):
                randomV = random()
                distortion = 2.0*scale*(randomV-.5)
                if          i - j == 1 or j - i == 1:
                    if      i%self.stripeSpacingY == 0 and j%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += distortion
                    elif    j%self.stripeSpacingY == 0 and i%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += distortion
                    else:                                                                                       self.couplingYArray[i][j] += distortion
                if          i - j == self.nSpinsY - 1 or j - i == self.nSpinsY - 1:                             self.couplingYArray[i][j] += distortion




    def getEnergies(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.energyArray[x][y] = -dot(self.spinArray[x][y][0],self.meanFieldArray[x][y][0])

    def returnTotalEnergy(self):
        energy = 0.0
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                energy = energy + self.energyArray[x][y]
        return energy
               
    def returnSpinSum(self):
        spinSum = vector(0.0,0.0,0.0)
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                spinSum += self.spinArray[x][y][0]
        return spinSum

    def returnCommonAxis(self):
        commonAxis = vector(0.0,0.0,0.0)
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                commonAxis += self.spinArray[x][y][0]
        return norm(commonAxis)
    
    def returnState(self):
        return self.spinArray














##    def measurek(self):
##        commonAxis = self.getCommonAxis()
##
##        projectedUL = norm(self.spinArray[0][0][0] - dot(self.spinArray[0][0][0],commonAxis)*commonAxis)
##        projectedUR = norm(self.spinArray[self.nSpinsX-1][0][0] - dot(self.spinArray[self.nSpinsX-1][0][0],commonAxis)*commonAxis)
##        projectedBL = norm(self.spinArray[0][self.nSpinsY-1][0] - dot(self.spinArray[0][self.nSpinsY-1][0],commonAxis)*commonAxis)
##        
##        dotx = dot(projectedUL,((-1.0)**(self.nSpinsX-1))*projectedUR)
##        doty = dot(projectedUL,((1.0)**(self.nSpinsY-1))*projectedBL)
##       
##        if self.nSpinsX > 1:
##            if abs(dotx) < 1.0:     kx = arccos(dotx)
##            else:                   kx = pi
##        else:                       kx = 0.0
##
##        if self.nSpinsY > 1:
##            if abs(doty) < 1.0:     ky = arccos(doty)
##            else:                   ky = pi
##        else:                       ky = 0.0
##
##        return vector((kx%(2.0*pi)),(ky%(2.0*pi)),0.0)




##############################################################
##  TESTING CODE:                                           ##

##dt = .001
##
##
##Nx = 3
##Ny = 2
##uc = NumericUnitCell(Nx,Ny)
##
##Ja = -1.0
##Jbx = 1.8
##Jby = -1.0
##uc.setCouplings(Ja,Jbx,Jby)
##
##k = vector(1.0,.5,0.0)
##baseSigma = 0.001
##uc.setState(k,baseSigma,-1)


##
##print '\nspinArray:'
##print uc.getState()
##
##
##uc.getTorques()
##print '\nmeanFieldArray:'
##print uc.meanFieldArray
##print '\ntorqueArray:'
##print uc.torqueArray
##
##uc.timeEvolve(dt)
##print '\ntimeEvolved State:'
##print uc.getState()
##


##print '\n-----------------------------'
##print "Calculation Instance Started"
##print time.ctime()
##
##
##
##k = 0
##loops = 100000
##while k <= loops:
##    k+=1
##
##    uc.getTorques()
##    uc.timeEvolve(dt)
##
##
##print "\nCalculation Instance Ended For",loops,"Loops"
##print time.ctime()
##print '-----------------------------'
##
##print uc.getState()


#  NumericUnitCell:
##  63 seconds for 100,000 iterations with a 3 by 2 unit cell = 6 spins
##   6 seconds for 10,000 iterations with a 3 by 2 unit cell = 6 spins

##   4 seconds for 1,000 iterations with an 8 by 4 unit cell = 32 spins
##  33 seconds for 100 iterations with an 80 by 40 unit cell = 3,200 spins

## 301 seconds for 10 iterations with an 800 by 400 unit cell = 320,000 spins



#   newUnitCell:
##  43 seconds for 100,000 iterations with a 3 by 2 unit cell = 6 spins
##   4 seconds for 10,000 iterations with a 3 by 2 unit cell = 6 spins

##   2 seconds for 1,000 iterations with an 8 by 4 unit cell = 32 spins
##  18 seconds for 100 iterations with an 80 by 40 unit cell = 3,200 spins
## 343 seconds for 10 iterations with an 800 by 400 unit cell = 320,000 spins





## Notes:
## I would expect the new method (numeric matrices) will be faster any time we can actually utilize a matrix to do an operation.
## This makes sense and seems obvious, but since this is the difining quality, we must now ask what places we can use the matrix method.
## Matrix method works for:

## Coupling:  Since the coupling can be generated as a single matrix and does not need to be 'recreated' with each iteration.


## But not Torques???:  This should be more straightforward than how it is currently implemented, where I iterate through the spin entries




