from numpy import *
from scipy import *
from visual import *
import time



class waveFunction:
    def __init__(self,n,l,m,tonorm):
        self.n = n
        self.l = l
        self.m = m
        self.tonorm = tonorm

        self.limSize = 50.0
        self.maxIterations = 32
        self.tolerance = 1e-3

    def radialWaveFunction(self,r):
        return e**(-r/self.n)*(2.0*r/self.n)**self.l

    def thetaWaveFunction(self,theta):
        return e**(-1j*self.m*theta)

    def returnWaveFunction(self,r,theta):
        return self.tonorm*self.radialWaveFunction(r)*self.thetaWaveFunction(theta)

    def integrateThetaSeparately(self):
        return integrate.quad(lambda theta: (conjugate(self.thetaWaveFunction(theta))*self.thetaWaveFunction(theta)).real, 0, 2*pi)[0]

    def integrateRadialSeparately(self):
        return integrate.quad(lambda r: conjugate(self.radialWaveFunction(r))*r*self.radialWaveFunction(r), 0, Inf)[0]
    
    def normalize(self):
        self.tonorm = 1/sqrt(self.integrateRadialSeparately()*self.integrateThetaSeparately())

    def getProbCurrent(self,r,theta):
        derivR = derivative(self.radialWaveFunction,r,n=1)
        derivT = derivative(self.thetaWaveFunction,theta,n=1)

        part1r = conjugate(self.returnWaveFunction(r,theta))*self.tonorm*derivR*self.thetaWaveFunction(theta)
        part2r = self.returnWaveFunction(r,theta)*conjugate(self.tonorm*derivR*self.thetaWaveFunction(theta))

        if r != 0.0:
            part1theta = conjugate(self.returnWaveFunction(r,theta))*self.tonorm*derivT*self.radialWaveFunction(r)/r
            part2theta = self.returnWaveFunction(r,theta)*conjugate(self.tonorm*derivT*self.radialWaveFunction(r))/r
        else:
            part1theta = 0
            part2theta = 0

        current = vector((1/(2j)*(part1r-part2r)).real,(1/(2j)*(part1theta-part2theta)).real,0.0)
        return current



class basis:
    def __init__(self,maxn,coeffs):
        self.maxn = maxn
        self.coeffs = coeffs
        self.tonorm = 1

        count = 0
        for n in xrange(maxn):
            for l in xrange(n+1):
                for m in xrange(2*l+1):
                    print "Radial QN", n+1
                    print "Angular QN", l
                    print "Magnetic QN", m-l
                    count += 1
        print "Total modes:", count
        self.maxNbasis = count
        
        i = 0
        self.basis = [0]*self.maxNbasis
        for n in xrange(maxn):
            for l in xrange(n+1):
                for m in xrange(2*l+1):
                    self.basis[i] = waveFunction(n+1,l,m-l,1.0)
                    self.basis[i].normalize()
                    print "Normalization was:", self.basis[i].tonorm
                    i += 1
                    
        print "Basis finished initialization:", time.ctime()

    def getFullWaveFunction(self,r,theta):
        self.fullWavefunction = 0.0
        i = 0
        for n in xrange(maxn):
            for l in xrange(n+1):
                for m in xrange(2*l+1):
                    self.fullWavefunction += self.tonorm*self.coeffs[i]*self.basis[i].returnWaveFunction(r,theta)
                    i += 1
        return self.fullWavefunction

    def getTotalProbCurrent(self,r,theta):
        self.current = vector(0.0,0.0,0.0)
        i = 0
        for n in xrange(maxn):
            for l in xrange(n+1):
                for m in xrange(2*l+1):
                    self.current += self.tonorm*self.coeffs[i]*self.basis[i].getProbCurrent(r,theta)
                    i += 1
        return self.current       

    def normalize(self):
        norm = 0.0
        i = 0
        for n in xrange(maxn):
            for l in xrange(n+1):
                for m in xrange(2*l+1):
                    norm += (conjugate(self.coeffs[i])*self.coeffs[i])
                    i += 1
                    
        self.tonorm = 1/sqrt(norm)


       
class electron:
    def __init__(self,basis,maxn):
        self.basis = basis
        self.maxn = maxn
        self.velScale = .2

    def returnSpinor(self,r,theta):
        return [self.basis.getFullWaveFunction(r,theta), 0, 0, self.basis.getFullWaveFunction(r,theta)]
        
    def getDensity(self,r,theta):
        self.density = (dot(conjugate(self.returnSpinor(r,theta)),self.returnSpinor(r,theta))).real
        return self.density

    def getPhase(self,r,theta):
        self.phase = atan2(self.returnSpinor(r,theta)[0].imag,self.returnSpinor(r,theta)[0].real)
        return self.phase

    def getElecProbCurrent(self,r,theta):
        self.current = self.basis.getTotalProbCurrent(r,theta)
        return self.current

    def getVelocity(self,r,theta):
        if self.getDensity(r,theta) != 0.0:
            return self.velScale*vector(-self.getElecProbCurrent(r,theta)[1]*r**2*sin(theta),self.getElecProbCurrent(r,theta)[1]*r**2*cos(theta),0.0)/self.getDensity(r,theta)
        else:
            return vector(0.0,0.0,0.0)

    def getVelocityRelLattice(self,r,theta):
        velocityMag = mag(self.getVelocity(r,theta))
        self.relVelocity = self.getVelocity(r,theta)/sqrt(1+velocityMag**2)
        return self.relVelocity
        

class densityPoint:
    def __init__(self):
        self.visibility = 1
        self.point = sphere(visible=self.visibility, radius=.001)

    def setAttributes(self,position,size,color):
        self.point.pos = position
        self.point.radius = size
        self.point.color = color

    def toggleVisibility(self):
        self.visibility = (1+self.visibility)%2

class currentVector:
    def __init__(self):
        self.visibility = 1
        self.currentArrow = arrow(visible=self.visibility)

    def setAttributes(self,position,direction):
        self.currentArrow.pos = position
        self.currentArrow.axis = direction

    def toggleVisibility(self):
        self.visibility = (1+self.visibility)%2

        

class visualLattice:
    def __init__(self,electronOne,sizeOfLattice,center):
        self.electronOne = electronOne
        self.sizeOfLattice = sizeOfLattice
        self.center = center
        self.scale = sizeOfLattice

        self.pointLattice=[None]*sizeOfLattice
        self.currentLattice=[None]*sizeOfLattice
        for x in xrange(sizeOfLattice):
            self.pointLattice[x]=[None]*sizeOfLattice
            self.currentLattice[x]=[None]*sizeOfLattice
            for y in xrange(sizeOfLattice):
                position = vector(x-sizeOfLattice/2.0,y-sizeOfLattice/2.0,0.0)
                r = mag(position)
                theta = atan2(position[1],position[0])
                currentDirection = self.electronOne.getVelocityRelLattice(r,theta)
                
                density = self.electronOne.getDensity(r,theta)
                size = self.scale*density
                phase = self.electronOne.getPhase(r,theta)
                colorArg = phase/pi
                color = ((1+abs(colorArg))/2.0,0.0,(1-abs(colorArg))/2.0)
                
                self.pointLattice[x][y] = densityPoint()
                self.currentLattice[x][y] = currentVector()
                self.pointLattice[x][y].setAttributes(position,size,color)
                self.currentLattice[x][y].setAttributes(position,currentDirection)

    def updateLattice(self,electronOne):
        for x in xrange(self.sizeOfLattice):
            for y in xrange(self.sizeOfLattice):
                position = vector(x-sizeOfLattice/2.0,y-sizeOfLattice/2.0,0.0)
                r = mag(position)
                theta = atan2(position[1],position[0])
                currentDirection = self.electronOne.getVelocityRelLattice(r,theta)
                
                density = self.electronOne.getDensity(r,theta)
                size = self.scale*density
                phase = self.electronOne.getPhase(r,theta)
                colorArg = phase/pi
                color = ((1+abs(colorArg))/2.0,0.0,(1-abs(colorArg))/2.0)
                
                self.pointLattice[x][y].setAttributes(position,size,color)
                self.currentLattice[x][y].setAttributes(position,currentDirection)



maxn = 3
scale = 1.0

count = 0
for n in xrange(maxn):
    for l in xrange(n+1):
        for m in xrange(2*l+1):
            count += 1

coeffsOne = [None]*count
coeffsTwo = [None]*count
i = 0
for n in xrange(maxn):
    for l in xrange(n+1):
        for m in xrange(2*l+1):
            if m-l==-2:
                coeffsOne[i] = 1.0
            else:
                coeffsOne[i] = 0.0
            if m-l==2:
                coeffsTwo[i] = 1.0
            else:
                coeffsTwo[i] = 0.0
            i += 1

print "Electron One Coefficients", coeffsOne
print "Electron Two Coefficients", coeffsTwo

BasisOne = basis(maxn, coeffsOne)
BasisTwo = basis(maxn, coeffsOne)

electronOne = electron(BasisOne, maxn)
electronTwo = electron(BasisTwo, maxn)

electronOne.basis.normalize()
electronTwo.basis.normalize()

sizeOfLattice = 24
center = (0.0,0.0,0.0)
scene = display(title='Density Plot', x=0, y=0, width=600, height=600, center=center, background=(0,0,0))
visualLattice = visualLattice(electronOne, sizeOfLattice, center)
visualLattice.updateLattice(electronOne)




