from visual import *
from numpy import *
from scipy import *
import time


class waveFunction:
    def __init__(self,m,radius):
        self.m = m
        self.radius = radius
        self.c = 3e8

        self.tonorm = 1.0
        self.speed = .5
        self.vSign = +1.0
    
        self.tolerance = 1e-6
        self.maxIterations = 50

        self.gamma0 = array([[0.0,0.0,0.0,-1j],[0.0,0.0,1j,0.0],[0.0,-1j,0.0,0.0],[1j,0.0,0.0,0.0]])
        self.gamma1 = array([[1j,0.0,0.0,0.0],[0.0,-1j,0.0,0.0],[0.0,0.0,1j,0.0],[0.0,0.0,0.0,-1j]])
        self.gamma2 = array([[0.0,0.0,0.0,1j],[0.0,0.0,-1j,0.0],[0.0,-1j,0.0,0.0],[1j,0.0,0.0,0.0]])
        self.gamma3 = array([[0.0,-1j,0.0,0.0],[-1j,0.0,0.0,0.0],[0.0,0.0,0.0,-1j],[0.0,0.0,-1j,0.0]])

    def returnWaveFunction(self,theta):
        return 1/sqrt(2*pi*self.radius)*e**(-1j*self.m*theta)

    def returnSpinor(self,theta):
        waveFunction = self.returnWaveFunction(theta)
        
        arg = abs(self.m)*theta/2.0
        coeffs = [-sign(self.m)*(-sin(arg)+1j*cos(arg)),-sign(self.m)*(-sin(arg)-1j*cos(arg))]
        
        return 1/sqrt(4.0)*array([coeffs[0]*waveFunction,coeffs[1]*waveFunction,coeffs[0]*waveFunction,coeffs[1]*waveFunction])

    def getProbCurrent(self,theta):
        derivTheta = self.tonorm*derivative(self.returnWaveFunction,theta,n=1,dx=self.tolerance,order=5)
        
        part1theta = conjugate(self.returnWaveFunction(theta))*derivTheta
        part2theta = self.returnWaveFunction(theta)*conjugate(derivTheta)

        thetaCurrent = (1/(2j)*(part1theta-part2theta)).real

        return vector(-thetaCurrent*self.radius*sin(abs(self.m)*theta),thetaCurrent*self.radius*cos(abs(self.m)*theta),0.0)

    def getDensity(self,theta):
        return (dot(conjugate(self.returnSpinor(theta)),self.returnSpinor(theta))).real

    def getProbVel(self,theta):
        return self.getProbCurrent(theta)/self.getDensity(theta)

    def getSpeed(self):
        speed = self.getProbVel(0.0)
        self.speed = mag(speed)
        self.vSign = sign(speed)       

    def getRelVel(self,theta):
        vel = self.vSign*self.speed*vector(-sin(abs(self.m)*theta),cos(abs(self.m)*theta),0.0)
        return vel/sqrt(1+self.speed**2/self.c**2)

    def getMeanSpeed(self):
        return mag(self.getRelVel(0.0))

    def getTransformation(self,theta):
        vel = self.getRelVel(theta)
        speedOverc = mag(vel)/self.c
        nVector = norm(vel)
        
        phi = arctanh(speedOverc)
        coshExp = cosh(phi/2)
        sinhCon = sinh(phi/2)

        return array([[coshExp,0.0,-nVector[2]*sinhCon,-(nVector[0]-1j*nVector[1])*sinhCon],[0.0,coshExp,-(nVector[0]+1j*nVector[1])*sinhCon,nVector[2]*sinhCon],[-nVector[2]*sinhCon,-(nVector[0]-1j*nVector[1])*sinhCon,coshExp,0.0],[-(nVector[0]+1j*nVector[1])*sinhCon,nVector[2]*sinhCon,0.0,coshExp]])

    def getTransformedSpinor(self,theta):
        return matrixmultiply(self.getTransformation(theta),self.returnSpinor(theta))

    def integrateTheta(self):
        return self.radius*integrate.quad(lambda theta: dot(conjugate(self.getTransformedSpinor(theta)),self.getTransformedSpinor(theta)), 0, 2*pi, epsrel=self.tolerance, limit=self.maxIterations)[0]

    def normalize(self):
        norm = self.tonorm*self.integrateTheta()
        self.tonorm = 1/sqrt(norm)

    def returnDifferentialSpinor(self,theta):
        spinor = self.getTransformedSpinor

        thetaDifferential = derivative(spinor,theta,n=1,dx=self.tolerance,order=5)
        self.gammaT = - self.gamma1*sin(abs(self.m)*theta) + self.gamma2*cos(abs(self.m)*theta)
        differentialSpinor = array(matrixmultiply(self.gammaT,thetaDifferential))

##        print "differential spinor", matrixmultiply(self.gamma0,differentialSpinor)
##        print "conjugate spinor", conjugate(self.getTransformedSpinor(theta))
##        print "product", (dot(conjugate(self.getTransformedSpinor(theta)),matrixmultiply(self.gamma0,differentialSpinor))).real

        return differentialSpinor

    def setRadius(self,radius):
        self.radius = radius



class integrations:
    def __init__(self,spinorOne,spinorTwo):
        self.spinorOne = spinorOne
        self.spinorTwo = spinorTwo

        self.maxIterations = 50
        self.tolerance = 1e-6
        self.radius = 1.0
        self.c = 3e8
        
        self.gamma0 = array([[0.0,0.0,0.0,-1j],[0.0,0.0,1j,0.0],[0.0,-1j,0.0,0.0],[1j,0.0,0.0,0.0]])

    def integrateOverlap(self):
        return integrate.quad(lambda theta: (dot(conjugate(self.spinorOne.getTransformedSpinor(theta)),self.spinorTwo.getTransformedSpinor(theta))).real, 0, 2*pi, epsrel=self.tolerance, limit=self.maxIterations)[0]*self.radius

    def integrateNuclear(self):
        return integrate.quad(lambda theta: (dot(conjugate(self.spinorOne.getTransformedSpinor(theta)),self.spinorTwo.getTransformedSpinor(theta))).real, 0, 2*pi, epsrel=self.tolerance, limit=self.maxIterations)[0]

    def integrateKinetic(self):
        return integrate.quad(lambda theta: (dot(conjugate(self.spinorOne.getTransformedSpinor(theta)),matrixmultiply(self.gamma0,self.spinorTwo.returnDifferentialSpinor(theta)))).real, 0, 2*pi, epsrel=self.tolerance, limit=self.maxIterations)[0]*self.radius*4.0
 

    def getRelativisticDistance(self,theta):
        E1speed = self.spinorOne.speed
        gamma = 1/sqrt(1-E1speed**2/self.c**2)

        posE1relE2inL = self.radius*vector(cos(theta)-1.0,sin(theta)*gamma,0.0)
        dist = mag(posE1relE2inL)
        
        if dist == 0.0:
            return self.tolerance
        else:
            return dist

    def integrandForExchange(self,theta):
        return (dot(conjugate(self.spinorOne.getTransformedSpinor(0.0)),self.spinorOne.getTransformedSpinor(theta))/self.getRelativisticDistance(theta)*dot(conjugate(self.spinorTwo.getTransformedSpinor(theta)),self.spinorTwo.getTransformedSpinor(0.0))).real

    def integrateExchangeTheta(self):
        return integrate.quad(lambda theta: self.integrandForExchange(theta), 0, 2*pi, full_output=1, points=[0,2*pi], epsrel=self.tolerance, limit=self.maxIterations)[0]

    def integrateExchange(self):
        return 2*pi*self.radius**2*self.integrateExchangeTheta()
    

    def integrandForCoulombic(self,theta):
        return (dot(conjugate(self.spinorTwo.getTransformedSpinor(0.0)),self.spinorTwo.getTransformedSpinor(0.0))/self.getRelativisticDistance(theta)*dot(conjugate(self.spinorOne.getTransformedSpinor(theta)),self.spinorOne.getTransformedSpinor(theta))).real

    def integrateCoulombicTheta(self):
        return integrate.quad(lambda theta: self.integrandForCoulombic(theta), 0, 2*pi, full_output=1, points=[0,2*pi], epsrel=self.tolerance, limit=self.maxIterations)[0]   

    def integrateCoulombic(self):
        return 2*pi*self.radius**2*self.integrateCoulombicTheta()
    

    def setRadius(self,radius):
        self.radius = radius
        self.spinorOne.setRadius(radius)
        self.spinorTwo.setRadius(radius)





class basis:
    def __init__(self,maxn):
        self.maxn = maxn
        self.tonorm = 1

        self.maxNbasis = (maxn*2+1)*2
        
        print "Total modes:", self.maxNbasis
        
        i = 0
        self.basis = [0]*self.maxNbasis
        for elec in xrange(2):
            for n in xrange(2*maxn+1):
                print "Azimuthal QN", n-maxn
                print "Electron Number", elec
              
                self.basis[i] = waveFunction(n-maxn,1.0)
                self.basis[i].getSpeed()
                
##                self.basis[i].normalize()
##                print "Normalization was:", self.basis[i].tonorm
##                print "Average speed was:", self.basis[i].getMeanSpeed()
                
                print ""
                i += 1

        self.integrations = [0]*self.maxNbasis
        for i in xrange(self.maxNbasis):
            self.integrations[i] = [0]*self.maxNbasis
            for j in xrange(self.maxNbasis):       
                self.integrations[i][j] = integrations(self.basis[i],self.basis[j])

        self.overlaps = ones((self.maxNbasis,self.maxNbasis),Float)
        self.Noverlaps = ones((self.maxNbasis,self.maxNbasis),Float)
        self.KEoverlaps = ones((self.maxNbasis,self.maxNbasis),Float)
        self.XCoverlaps = ones((self.maxNbasis,self.maxNbasis),Float)
        self.COoverlaps = ones((self.maxNbasis,self.maxNbasis),Float)

        print "Basis and Integrations Set Up and Normalized at:", time.ctime()

    def fillOverlaps(self):
        for i in xrange(self.maxNbasis):
            for j in xrange(i+1):
                overlap = self.integrations[i][j].integrateOverlap()
                if abs(overlap) < 1e-15:
                    overlap = 0.0
                self.overlaps[i][j] = overlap
                self.overlaps[j][i] = overlap
                print "Overlap",i,j,"is:",overlap,"at:",time.ctime()

    def fillNoverlaps(self):
        for i in xrange(self.maxNbasis):
            for j in xrange(i+1):
                Noverlap = self.integrations[i][j].integrateNuclear()
                if abs(Noverlap) < 1e-15:
                    Noverlap = 0.0
                self.Noverlaps[i][j] = Noverlap
                self.Noverlaps[j][i] = Noverlap
                print "Nuclear",i,j,"is:",Noverlap,"at:",time.ctime()

    def fillKEoverlaps(self):
        for i in xrange(self.maxNbasis):
            for j in xrange(i+1):
                KEoverlap = self.integrations[i][j].integrateKinetic()
                if abs(KEoverlap) < 1e-15:
                    KEoverlap = 0.0
                self.KEoverlaps[i][j] = KEoverlap
                self.KEoverlaps[j][i] = KEoverlap
                print "Kinetic",i,j,"is:",KEoverlap,"at:",time.ctime()

    def fillXCEnergies(self):
        for i in xrange(self.maxNbasis):
            for j in xrange(i+1):
                XCoverlap = self.integrations[i][j].integrateExchange()
                if abs(XCoverlap) < 1e-15:
                    XCoverlap = 0.0
                self.XCoverlaps[i][j] = XCoverlap
                self.XCoverlaps[j][i] = XCoverlap
                print "Exchange Overlap",i,j,"is:",XCoverlap,"at:",time.ctime()

    def fillCOEnergies(self):
        for i in xrange(self.maxNbasis):
            for j in xrange(i+1):
                COoverlap = self.integrations[i][j].integrateCoulombic()
                if abs(COoverlap) < 1e-15:
                    COoverlap = 0.0
                self.COoverlaps[i][j] = COoverlap
                self.COoverlaps[j][i] = COoverlap
                print "Coulombic Overlap",i,j,"is:",COoverlap,"at:",time.ctime()

    def short(self):
        for n in xrange(self.maxNbasis):
            print "time is:", time.ctime()
            print n, n, "overlap is:", self.integrations[n][n].integrateOverlap()
            print n, n, "nuclear overlap is:",self.integrations[n][n].integrateNuclear()
            print n, n, "kinetic overlap is:",self.integrations[n][n].integrateKinetic()
            print n, n, "exchange overlap is:",self.integrations[n][n].integrateExchange()
            print n, n, "coulombic overlap is:",self.integrations[n][n].integrateCoulombic()

    def converge(self,Z):
        Z = Z
        radius = 1.0
        maxIter = 64
        step = radius/maxIter
        convergenceTolerance = .005

        energyOld = 5.0
        energyNew = 0.0

        i = 0
        while abs(energyOld-energyNew) > convergenceTolerance and i < maxIter-2:
            energyOld = energyNew

            radius -= step

            self.integrations[0][4].setRadius(radius)

            ionicEnergy = -Z*(self.integrations[0][0].integrateNuclear()+self.integrations[4][4].integrateNuclear())
            print ionicEnergy
            kineticEnergy = 137.036*(self.integrations[0][0].integrateKinetic()+self.integrations[4][4].integrateKinetic())
            print kineticEnergy
            exchangeEnergy = 2.0*self.integrations[0][4].integrateExchange()
            print exchangeEnergy
            coulombicEnergy = self.integrations[0][4].integrateCoulombic()
            print coulombicEnergy
            
            totalEnergy = ionicEnergy + kineticEnergy + exchangeEnergy + coulombicEnergy

            energyNew = totalEnergy
            
            print "New Energy:", energyNew, "and trial number:", i, "with radius:", radius
            print "At time:", time.ctime(), "difference was:", energyNew-energyOld
            i += 1


maxn = 2
Z = 150*2.0

Basis = basis(maxn)

##Basis.short()

Basis.converge(Z)
