from numpy import *
from scipy import *
from visual import *
import time


class initialWavefunction:
    def __init__(self, center, sigma):
        self.center = center
        self.sigma = sigma

    def returnFullWavefunction(self,x,y):
        self.radialWavefunction = 1/self.sigma*e**(-((x-center[0])**2+(y-center[1])**2)/self.sigma**2)
        self.thetaWavefunction = e**(-1j*2*atan2(y-center[1],x-center[0]))
        
        return self.radialWavefunction*self.thetaWavefunction

    def setSigma(self,sigma):
        self.sigma = sigma

    

class electronPiece:
    def __init__(self, probabilityAmplitude):
        self.tonorm = 1.0
        self.probabilityAmplitude = self.tonorm*probabilityAmplitude
        self.spinor = [probabilityAmplitude,0,probabilityAmplitude,0]
        
    def getDensity(self):
        self.probability = (dot(conjugate(self.spinor),self.spinor)).real
        return self.probability

    def getPhase(self):
        self.phase = atan2(self.spinor[0].imag,self.spinor[0].real)
        return self.phase

    def normalize(self,tonorm):
        self.tonorm = tonorm
        self.probabilityAmplitude = self.tonorm*self.probabilityAmplitude
        self.spinor = [self.probabilityAmplitude,0,self.probabilityAmplitude,0]


class numericLattice:
    def __init__(self, sizeOfLattice, initialWavefunction, center):
        self.sizeOfLattice = sizeOfLattice
        self.initialWavefunction = initialWavefunction
        self.center = center

        self.numericLattice = [0]*sizeOfLattice
        self.probabilityAmplitudes = [0]*sizeOfLattice
        self.differential = [0]*sizeOfLattice
        self.laplacian = [0]*sizeOfLattice
        for x in xrange(sizeOfLattice):
            self.numericLattice[x] = [0]*sizeOfLattice
            self.probabilityAmplitudes[x] = [0]*sizeOfLattice
            self.differential[x] = [0]*sizeOfLattice
            self.laplacian[x] = [0]*sizeOfLattice
            for y in xrange(sizeOfLattice):
                self.probabilityAmplitudes[x][y] = initialWavefunction.returnFullWavefunction(x,y)
                self.numericLattice[x][y] = electronPiece(self.probabilityAmplitudes[x][y])
                self.differential[x][y] = vector(0.0,0.0,0.0)
                self.laplacian[x][y] = 0.0

    def setWavefunction(self):
        for x in xrange(sizeOfLattice):
            for y in xrange(sizeOfLattice):
                self.probabilityAmplitudes[x][y] = initialWavefunction.returnFullWavefunction(x,y)
                self.numericLattice[x][y] = electronPiece(self.probabilityAmplitudes[x][y])
                self.differential[x][y] = vector(0.0,0.0,0.0)
                self.laplacian[x][y] = 0.0
        

    def normalizeDensity(self):
        density = 0.0
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                density += self.numericLattice[x][y].getDensity()
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                self.numericLattice[x][y].normalize(1/sqrt(density))

    def getIonicPotential(self,Z):
        self.ionicPotential = 0.0
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                if x != self.center[0] or y != self.center[1]:
                    self.ionicPotential += -Z*self.numericLattice[x][y].getDensity()/mag(vector(x,y,0)-self.center)
        return self.ionicPotential

    def getDifferential(self,lattice):
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                differentialX = 0.0
                differentialY = 0.0
                if x+1 == self.sizeOfLattice:
                    differentialX += (lattice[-1][y]-lattice[x-1][y])/2.0
                else:
                    differentialX += (lattice[x+1][y]-lattice[x-1][y])/2.0
                if y+1 == self.sizeOfLattice:
                    differentialY += (lattice[x][-1]-lattice[x][y-1])/2.0
                else:
                    differentialY += (lattice[x][y+1]-lattice[x][y-1])/2.0
                self.differential[x][y] = (differentialX, differentialY, 0)
        return self.differential

    def getLaplacian(self,lattice):
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                differentialX = 0.0
                differentialY = 0.0
                if x+1 == self.sizeOfLattice:
                    differentialX += (lattice[-1][y][0]-lattice[x-1][y][0])/2.0
                else:
                    differentialX += (lattice[x+1][y][0]-lattice[x-1][y][0])/2.0
                if y+1 == self.sizeOfLattice:
                    differentialY += (lattice[x][-1][1]-lattice[x][y-1][1])/2.0
                else:
                    differentialY += (lattice[x][y+1][1]-lattice[x][y-1][1])/2.0
                self.laplacian[x][y] = (differentialX + differentialY)
        return self.laplacian

    def getKineticEnergy(self):
        self.kineticEnergy = 0.0
        differential = self.getDifferential(self.probabilityAmplitudes)
        laplacian = self.getLaplacian(differential)
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                self.kineticEnergy += -(conjugate(self.probabilityAmplitudes[x][y])*laplacian[x][y]).real/2.0
        return self.kineticEnergy
        


class DensityPoint:
    def __init__(self):
        self.visibility = 1
        self.point = sphere(visible=self.visibility, radius=.001)

    def setAttributes(self,position,size,color):
        self.point.pos = position
        self.point.radius = size
        self.point.color = color

    def toggleVisibility(self):
        self.visibility = (1+self.visibility)%2



class visualLattice:
    def __init__(self, numericLattice, sizeOfLattice, center):
        self.numericLattice = numericLattice
        self.sizeOfLattice = sizeOfLattice
        self.center = center
        self.scale = sizeOfLattice*.1

        self.pointLattice=[None]*sizeOfLattice
        for x in xrange(sizeOfLattice):
            self.pointLattice[x]=[None]*sizeOfLattice
            for y in xrange(sizeOfLattice):
                position = vector(x,y,0.0)
                
                density = self.numericLattice.numericLattice[x][y].getDensity()
                size = self.scale*density
                
                phase = self.numericLattice.numericLattice[x][y].getPhase()
                colorArg = phase/pi
                                
                color = ((1+abs(colorArg))/2.0,0.0,(1-abs(colorArg))/2.0)
                
                self.pointLattice[x][y] = DensityPoint()
                self.pointLattice[x][y].setAttributes(position,size,color)

    def updateLattice(self,numericLattice):
        for x in xrange(sizeOfLattice):
            for y in xrange(sizeOfLattice):
                position = vector(x,y,0.0)
                
                density = self.numericLattice.numericLattice[x][y].getDensity()
                size = self.scale*density
                
                phase = self.numericLattice.numericLattice[x][y].getPhase()
                colorArg = phase/pi
                                
                color = ((1+abs(colorArg))/2.0,0.0,(1-abs(colorArg))/2.0)
                
                self.pointLattice[x][y].setAttributes(position,size,color)


            
print '\n-----------------------------'
print "Calculation Instance Started"
print time.ctime()

sizeOfLattice = 16
sigma = 2.0
center = vector(sizeOfLattice/2.0,sizeOfLattice/2.0,0.0)
Z = 2.0

scene = display(title='Density Plot', x=0, y=0, width=600, height=600, center=center, background=(0,0,0))

initialWavefunction = initialWavefunction(center, sigma)
numericLattice = numericLattice(sizeOfLattice, initialWavefunction, center)
numericLattice.normalizeDensity()
visualLattice = visualLattice(numericLattice, sizeOfLattice, center)
print "Initialization finalized at:", time.ctime()
oldEnergy = numericLattice.getIonicPotential(Z) + numericLattice.getKineticEnergy()
print "Old Energy:", oldEnergy

maxIter = 64
step = sigma/maxIter
tolerance = 1.5*sigma/maxIter**2
print "Tolerance:", tolerance
oldEnergy = -.5
newEnergy = 0.0

i = 0
while tolerance < abs(newEnergy-oldEnergy) and i < maxIter-1:
    oldEnergy = newEnergy

    sigma -= step
    
    initialWavefunction.setSigma(sigma)
    numericLattice.setWavefunction()
    numericLattice.normalizeDensity()
    visualLattice.updateLattice(numericLattice)
    
    newEnergy = numericLattice.getIonicPotential(Z) + numericLattice.getKineticEnergy()
    print "New Energy:", newEnergy, "and trial number:", i, "with sigma:", sigma
    i += 1

print "\nCalculation Instance Ended"
print time.ctime()
print '-----------------------------'

##
##
