from visual import *
from FFT import *
import time,sys,thread


class Box:
    def __init__(self):
        self.box = box(visible=0)



class FFTPlot:
   
    def __init__(self, nSpinsX, nSpinsY, quality):
        self.window = display(title='FFT', width=300, height=300, x=20, y=420, center=(0,0,0), background=(.0,.0,.0))
        self.window.fov = pi/90.0
        self.window.lights = [vector(.5,.5,.5), vector(-.5,-.5,-.5)]

        self.window.up = (0,0.001,1)
        
        self.window.userzoom = 1
        self.window.userspin = 0
        
        self.window.visible = 1
        self.window.select()



        self.tvWindow = 1

        self.visualArray = []
        self.spinArray = []

        self.quality = quality
        
        self.xTransformFFT = []
        self.yTransformFFT = []
        self.zTransformFFT = []

        
        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY
        self.quality = quality
        
        self.xComp = zeros((nSpinsX*quality,nSpinsY*quality),Float32)
        self.yComp = zeros((nSpinsX*quality,nSpinsY*quality),Float32)
        self.zComp = zeros((nSpinsX*quality,nSpinsY*quality),Float32)

        dimX = quality*nSpinsX
        dimY = quality*nSpinsY

        visualArray = [ None ] * dimX                             # a vector of null elements, nSpinsX long
        for x in range(dimX):
            visualArray[x] = [0] * dimY                             # a vector of zeros nSpinsY long
            for y in range(dimY):
                visualArray[x][y] = Box()

        self.visualArray = visualArray
        self.window.range = 1.4*self.quality*(self.nSpinsX*self.nSpinsY)**0.5




        self.axisOrigin = vector(-(self.quality*(self.nSpinsX+1))/2.0-0.1,-(self.quality*(self.nSpinsY+1))/2.0-0.1,0.0)
        
        self.xAxis = []
        self.yAxis = []
        
        for x in arange(0.0,self.quality*self.nSpinsX+1.1,1.0):     self.xAxis.append(vector(x,0,0)+self.axisOrigin)
        for y in arange(0.0,self.quality*self.nSpinsY+1.1,1.0):     self.yAxis.append(vector(0,y,0)+self.axisOrigin)
        
        self.xAxisLine = curve(pos=self.xAxis, color=(.9,.9,.9))
        self.yAxisLine = curve(pos=self.yAxis, color=(.9,.9,.9))

        self.xAxislabel = label(pos=vector(self.nSpinsX,-0.1,0) + self.axisOrigin, text='kx', xoffset=0.0, yoffset=-.001, space=.5, height=10, border=4)
        self.yAxislabel = label(pos=vector(-0.1,self.nSpinsY,0) + self.axisOrigin, text='ky', xoffset=-.001, yoffset=0.0, space=.5, height=10, border=4)











    def addBoxes(self,spinArray):
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                boxPos = (x-(self.quality*self.nSpinsX)/2.0,y-(self.quality*self.nSpinsY)/2.0)
                self.visualArray[x][y].box = box(pos=boxPos, length=1.0, width=1.0, height=1.0, color=(.2,.2,.2), visible=1)


    def updateBoxes(self,array):
        maxVal = 0.0
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                if abs(array[x][y]) > maxVal: maxVal = abs(array[x][y])
                
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                boxColor = array[x][y]/maxVal
                self.visualArray[x][y].box.color = (boxColor,boxColor,boxColor)


    


                
    def plotXYtransform(self,spinArray):
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                self.xComp[x][y] = spinArray[x%self.nSpinsX][y%self.nSpinsY][0][0]
                self.yComp[x][y] = spinArray[x%self.nSpinsX][y%self.nSpinsY][0][1]               
                
        xTransformFFT = fft(self.xComp).real
        yTransformFFT = fft(self.yComp).real
        
        self.updateBoxes(xTransformFFT + yTransformFFT)

    
    def plotXtransform(self,spinArray):
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                self.xComp[x][y] = spinArray[x%self.nSpinsX][y%self.nSpinsY][0][0]
                
        xTransformFFT = fft(self.xComp).real
        self.updateBoxes(xTransformFFT)
    

    def plotYtransform(self,spinArray):
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                self.yComp[x][y] = spinArray[x%self.nSpinsX][y%self.nSpinsY][0][1] 
           
        yTransformFFT = fft(self.yComp).real
        self.updateBoxes(yTransformFFT)


    def plotZtransform(self,spinArray):
        for x in range(self.quality*self.nSpinsX):
            for y in range(self.quality*self.nSpinsY):
                self.zComp[x][y] = spinArray[x%self.nSpinsX][y%self.nSpinsY][0][2] 
    
        zTransformFFT = fft(self.zComp).real
        self.updateBoxes(zTransformFFT)








    def toggle(self):
        self.tvWindow = (self.tvWindow+1)%2

        if self.tvWindow == 1:
            self.window.visible = self.tvWindow
            self.xAxisLine.visible = self.tvWindow
            self.yAxisLine.visible = self.tvWindow
            self.xAxislabel.visible = self.tvWindow
            self.yAxislabel.visible = self.tvWindow

            for x in range(self.quality*self.nSpinsX):
                for y in range(self.quality*self.nSpinsY):
                    self.visualArray[x][y].box.visible = self.tvWindow
        else:
            for x in range(self.quality*self.nSpinsX):
                for y in range(self.quality*self.nSpinsY):
                    self.visualArray[x][y].box.visible = self.tvWindow
            
            self.xAxisLine.visible = self.tvWindow
            self.yAxisLine.visible = self.tvWindow
            self.xAxislabel.visible = self.tvWindow
            self.yAxislabel.visible = self.tvWindow
            self.window.visible = self.tvWindow












    


####Debug Code
##nSpinsX = 3
##nSpinsY = 2
##
##stripeSpacingX = 3
##stripeSpacingY = 2
##
##Ja = -1.0
##Jbx = 2.0
##Jby = -1.0
##
##zArray = [[1.0,-1.0],[-1.0,1.0],[1.0,-1.0]]
##sigmaArray = [[.95,.39],[.03,.03],[.39,.95]]
##phaseArray = [[0.0,pi],[pi,pi],[pi,0.0]]
##
##
##
##k = 2.0*pi*vector(0.0,0.0,0.0)
##n = -1
##
##baseSigma = .99
##dt = .001
##
##new = Interface(nSpinsX, nSpinsY, stripeSpacingX, stripeSpacingY, Ja, Jbx, Jby, k, baseSigma, zArray, sigmaArray, phaseArray, dt)
##
##
##
##
##
##spinArray = new.simu.nLattice.returnState()
##
##print 'yes' 
##print spinArray
##        
##fft = FFTPlot(spinArray)
##
##print fft.returnXtransform()
##    
##
##
##
##
