from visual import *
from FFT import *


class Box:
    def __init__(self):
        self.box = box(visible=0)



class FFTPlot:
    def __init__(self, nSpinsX, nSpinsY, nLattice):
        self.window = display(title='FFT', width=240, height=240, x=440, y=500, center=(-0.5,-0.5,-0.5), background=(0.0,0.0,0.0))
        self.window.fov = pi/90.0
        self.window.lights = [vector(.5,.5,.5), vector(-.5,-.5,-.5)]
        self.window.up = (0,0.001,1)
        self.window.userzoom = 1
        self.window.userspin = 0
        self.window.select()

        self.tvWindow = 0

        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY

        self.nLattice = nLattice


        self.visualArray = [ None ] * self.nSpinsX                             # a vector of null elements, nSpinsX long
        for x in range(self.nSpinsX):
            self.visualArray[x] = [0] * self.nSpinsY                             # a vector of zeros nSpinsY long
            for y in range(self.nSpinsY):
                self.visualArray[x][y] = Box()


        self.window.range = .9*(self.nSpinsX*self.nSpinsY)**0.5
        self.window.visible = self.tvWindow


        self.axisOrigin = vector(-(self.nSpinsX+1)/2.0-.1,-(self.nSpinsY+1)/2.0-.1,0.0)
        
        self.xAxis = []
        for x in arange(0.0,self.nSpinsX+.1,self.nSpinsX):     self.xAxis.append(vector(x,0.0,0.0)+self.axisOrigin)
        self.xAxisLine = curve(pos=self.xAxis, color=(1.0,1.0,1.0), visible=self.tvWindow)
        self.xAxisLabel = label(pos=vector(self.nSpinsX,-0.1,0) + self.axisOrigin, text='kx', xoffset=0.0, yoffset=-.001, space=.05, height=12, border=2, box=0, opacity=0, line=0, color=(1.0,1.0,1.0), visible=self.tvWindow)
                
        self.yAxis = []
        for y in arange(0.0,self.nSpinsY+.1,self.nSpinsY):     self.yAxis.append(vector(0.0,y,0.0)+self.axisOrigin)
        self.yAxisLine = curve(pos=self.yAxis, color=(1.0,1.0,1.0), visible=self.tvWindow)
        self.yAxisLabel = label(pos=vector(-0.1,self.nSpinsY,0) + self.axisOrigin, text='ky', xoffset=-.001, yoffset=0.0, space=.05, height=12, border=2, box=0, opacity=0, line=0, color=(1.0,1.0,1.0), visible=self.tvWindow)

        self.addBoxes()

        

    def addBoxes(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                boxPos = (x-(self.nSpinsX)/2.0,y-(self.nSpinsY)/2.0)
                self.visualArray[x][y].box = box(pos=boxPos, length=1.0, width=1.0, height=1.0, color=(0.0,0.0,0.0), visible=self.tvWindow)


    def updateBoxes(self,array):
        maxVal = 0.0
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                if abs(array[x][y]) > maxVal: maxVal = abs(array[x][y])
        
        array = absolute(array)/maxVal
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.visualArray[x][y].box.color = (array[x][y],array[x][y],array[x][y])


    
   
    def plotXtransform(self):
        spinArray = self.nLattice.returnState()
        xComp = spinArray[:,:,0,0]
        xTransformFFT = fft2d(xComp, axes=(-2,-1)).real
        self.updateBoxes(xTransformFFT)
    
    def plotYtransform(self):
        spinArray = self.nLattice.returnState()
        yComp = spinArray[:,:,0,1]
        yTransformFFT = fft2d(yComp, axes=(-2,-1)).real
        self.updateBoxes(yTransformFFT)

    def plotZtransform(self):
        spinArray = self.nLattice.returnState()
        zComp = spinArray[:,:,0,2]
        zTransformFFT = fft2d(zComp, axes=(-2,-1)).real
        self.updateBoxes(zTransformFFT)

    def plotXYtransform(self):
        spinArray = self.nLattice.returnState()
        xComp = spinArray[:,:,0,0]
        yComp = spinArray[:,:,0,1]   
        xTransformFFT = fft2d(xComp, axes=(-2,-1)).real
        yTransformFFT = fft2d(yComp, axes=(-2,-1)).real
        self.updateBoxes(xTransformFFT + yTransformFFT)




    def toggleWindow(self):
        self.tvWindow = (self.tvWindow+1)%2

        if self.tvWindow == 1:
            self.window.visible = self.tvWindow
            self.xAxisLine.visible = self.tvWindow
            self.yAxisLine.visible = self.tvWindow
            self.xAxisLabel.visible = self.tvWindow
            self.yAxisLabel.visible = self.tvWindow

            for x in range(self.nSpinsX):
                for y in range(self.nSpinsY):
                    self.visualArray[x][y].box.visible = self.tvWindow
        else:
            for x in range(self.nSpinsX):
                for y in range(self.nSpinsY):
                    self.visualArray[x][y].box.visible = self.tvWindow
            
            self.xAxisLine.visible = self.tvWindow
            self.yAxisLine.visible = self.tvWindow
            self.xAxisLabel.visible = self.tvWindow
            self.yAxisLabel.visible = self.tvWindow
            self.window.visible = self.tvWindow












    


####Debug Code
##nSpinsX = 3
##nSpinsY = 2
##
##stripeSpacingX = 3
##stripeSpacingY = 2
##
##Ja = -1.0
##Jbx = 2.0
##Jby = -1.0
##
##zArray = [[1.0,-1.0],[-1.0,1.0],[1.0,-1.0]]
##sigmaArray = [[.95,.39],[.03,.03],[.39,.95]]
##phaseArray = [[0.0,pi],[pi,pi],[pi,0.0]]
##
##
##
##k = 2.0*pi*vector(0.0,0.0,0.0)
##n = -1
##
##baseSigma = .99
##dt = .001
##
##new = Interface(nSpinsX, nSpinsY, stripeSpacingX, stripeSpacingY, Ja, Jbx, Jby, k, baseSigma, zArray, sigmaArray, phaseArray, dt)
##
##
##
##
##
##spinArray = new.simu.nLattice.returnState()
##
##print 'yes' 
##print spinArray
##        
##fft = FFTPlot(spinArray)
##
##print fft.returnXtransform()
##    
##
##
##
##
