from Tkinter import *
from visual import *
from random import random
import time,sys,thread


class Simulation:
    def __init__(self,parent):
        self.parent = parent

        self.Ntotx = 32
        self.Ntoty = 4
        
        self.Ntot = self.Ntotx*self.Ntoty
        
        self.dx = 1.0
        self.dy = 1.0
        
        self.dt = 0.054
        
        self.pause = 1
        self.paused = 1

        self.Ja = -1.0                                                              # Default coupling
        self.Jb = -0.8                                                              # Alternate value of coupling

        self.mean_field_down = 0.35                                                # Mean field strength for down spins
        self.mean_field_up = 0.25                                                    # Mean field strength for up spins
        
        self.k = 2*pi*vector(3./(self.Ntotx*self.dx),2./(self.Ntoty*self.dy),0)     # 2pi/L * (kx integer,ky integer)
                                                                                    # ensures commensurability with lattice
        print 'wavenumbers per lattice length', self.k.x, self.k.y

        self.spin_length = 1.0
        self.spin_sigma = .5                                                       # sigma magnitude
        
        self.vis_spin_length = hypot(self.dx,self.dy)                            # need new methodology for illustrating


        # Toggle Variables
        self.tv_sigma_rep = BooleanVar()                # Sigma spin representation (xy only)
        self.tv_circles = BooleanVar()                  # Circles for spin representation  
        
        self.tv_box_rep = BooleanVar()                  # Box representation
        self.tv_flatten = BooleanVar()                  # Flatten box representation

        self.tv_sphere_rep = BooleanVar()               # Sphere representation
        self.tv_points = BooleanVar()                   # Represent spins as points instead of vectors

        self.tv_sublattices = BooleanVar()              # Sublattices
        self.tv_torques = BooleanVar()                  # Torques   
        self.tv_colors = BooleanVar()                   # Colors   
        self.tv_fm_filter = BooleanVar()                # Filter for FM nn coloring scheme 

        self.tv_tz_zero = BooleanVar()                  # Torque Z component = 0   
        self.tv_free_y_boundary = BooleanVar()          # Modified mean field


        self.setup_scene()
        self.set_spin_mag_ratio(self.k)
        
        self.spins = []
        self.spins = self.create_lattice(self.spins, self.Ntotx, self.Ntoty, self.dx, self.dy)
        self.spins = self.initialize_spins(self.spins, self.k)
        self.spins = self.initialize_circles(self.spins)
        thread.start_new_thread(self.animate_spins,(self.spins,))

        
    def setup_scene(self):
        scene.title = "Simulation"
        scene.width = 600
        scene.height = 540
        scene.x = 20
        scene.y = 20
        scene.autoscale = 0
        scene.up = (0,0,1)
        scene.forward = (0,2.0,-1.0)
        scene.lights = [vector(.5,.5,.5), vector(-.5,-.5,.5)]
        scene.background = (0,0,0)
        scene.uniform = 1
        scene.range = 1.3*self.Ntotx**(.8)
##        scene.stereo = 'passive'
        scene.select()


    def create_lattice(self, _lattice=[], _Ntotx=8, _Ntoty=8, _deltax=1.0, _deltay=1.0):
        self.xmin = -(_Ntotx-1)*_deltax/2.
        self.ymin = -(_Ntoty-1)*_deltay/2.
        self.nz = 0

        for ny in range(_Ntoty):
            y = self.ymin + ny*_deltay
            for nx in range(_Ntotx):
                x = self.xmin + nx*_deltax

                _lattice.append(frame())
                _lattice[-1].pos = vector(x,y,0.0)
                _lattice[-1].spin = vector()
                _lattice[-1].spinvec = arrow(pos=(x,y,0.0), axis=(0,0,1), color=(.8,.8,.8), shaftwidth=.25)
                
                _lattice[-1].box = box(pos=(x,y,0.0), length=_deltax, width=0.1, height=_deltay, color=(.9,.9,.9))
                _lattice[-1].box.visible = 0
                
                _lattice[-1].point = sphere(pos=(x,y,0.0), radius=0.03, color=(.8,.8,.8))
                _lattice[-1].point.visible = 0
              
                _lattice[-1].torque = vector()
                _lattice[-1].torquevec = arrow(pos=(x,y,0), axis=(0,0,1), color=(.5,.5,1), shaftwidth=0.1)
                _lattice[-1].torquevec.visible = 0

                _lattice[-1].omega = 0.0

                _lattice[-1].circle = curve()

                _lattice[-1].near = range(4)
                _lattice[-1].couplings = range(4)
                _lattice[-1].indices = (nx,ny,1-(nx+ny)%2)
              
        for s in _lattice:
            nx, ny, nz = s.indices
            if nx == 0:                             # leftmost spin in a row
                nspinl = _Ntotx*ny + _Ntotx-1       # wrap around to spin on right side
                s.near[0] = nspinl                  # reference this by its list element number, given by the order in which the spins are added to the list...
            else:
                nspinl = _Ntotx*ny + nx-1
                s.near[0] = nspinl
            if nx == _Ntotx-1:                      # rightmost spin in a row
                nspinr = _Ntotx*ny
                s.near[1] = nspinr
            else:
                nspinr = _Ntotx*ny + nx+1 
                s.near[1] = nspinr
            if ny == 0:                             # bottom spin in a column
                nspind = _Ntotx*(_Ntoty-1) + nx 
                s.near[2] = nspind
            else:
                nspind = _Ntotx*(ny-1) + nx
                s.near[2] = nspind
            if ny == _Ntoty-1:                      # top spin in a column
                nspinu = nx
                s.near[3] = nspinu
            else:
                nspinu = _Ntotx*(ny+1) + nx
                s.near[3] = nspinu
        return _lattice


    def set_spin_mag_ratio(self, _k=(1/8.,1/8.,1/8.)):
        self.chi = (cos(_k.x*self.dx)+cos(_k.y*self.dy))/2
        self.spin_mag_ratio = .5*(1+1/self.chi-((1/self.chi-1)*(1/self.chi+3))**(.5))

        if abs(self.spin_mag_ratio) > 1:
            self.spin_mag_ratio = 1/self.spin_mag_ratio
        print 'new spin up to down magnitude ratio', self.spin_mag_ratio


    def initialize_spins(self, _lattice=[], _k=(1/8.,1/8.,1/8.,)):
        self.pause = 1
        while self.paused == 0:
            pass
             
        was_split = 0
        if self.tv_sublattices.get() == 1:
            was_split = 1
            self.t_sublattices()
      
        if abs(self.spin_mag_ratio) > 1:
            self.spin_sigma = (1/self.spin_mag_ratio)*self.spin_sigma
            self.spinB_sigma = self.spin_sigma
        else:
            self.spin_sigma = self.spin_sigma
            self.spinB_sigma = self.spin_mag_ratio*self.spin_sigma

        spinA_zmag = (self.spin_length**2-self.spin_sigma**2)**(.5)
        spinB_zmag = (self.spin_length**2-self.spinB_sigma**2)**(.5)

        if self.k.x != 0:
            k_angle = arctan(self.k.y/self.k.x)
        else:
            k_angle = pi/2.

        for s in _lattice:
            pos = vector(s.indices[0]*self.dx, s.indices[1]*self.dy,0.0)
            
            if s.indices[2] == 1:
                spinx = self.spin_sigma*cos(dot(_k, pos)+k_angle)
                spiny = self.spin_sigma*sin(dot(_k, pos)+k_angle)
                spinz = (1-self.spin_sigma**2)**.5
            else:
                spinx = -self.spinB_sigma*cos(dot(_k, pos)+k_angle)
                spiny = -self.spinB_sigma*sin(dot(_k, pos)+k_angle)
                spinz = (-1)*(1-self.spinB_sigma**2)**.5

            for cc in range(4):
                s.couplings[cc] = self.Ja
            s.spin = vector(spinx,spiny,spinz)
            s.spin = norm(s.spin)

        if was_split == 1:
            was_split = 0
            self.t_sublattices()
        return _lattice

    
    def initialize_circles(self, _lattice=[]):
        self.pause = 1
        while self.paused == 0:
            pass
        
        self.equator_line = curve(x=(self.vis_spin_length)*cos(arange(0,2.1*pi,.1565)), y=(self.vis_spin_length)*sin(arange(0,2.1*pi,.1565)), color=(.3,.3,.3))
        self.equator_line.visible = 0
        
        for s in _lattice:
            if s.indices[2] == 1:
                s.circle = curve(x=s.pos.x+(self.vis_spin_length/2)*cos(arange(0,2*pi,.31)), y=s.pos.y+(self.vis_spin_length/2)*sin(arange(0,2*pi,.31)), color=(.3,.3,.3))
            else:
                s.circle = curve(x=s.pos.x+(self.vis_spin_length/2)*self.spin_mag_ratio*cos(arange(0,2*pi,.31)), y=s.pos.y+(self.vis_spin_length/2)*self.spin_mag_ratio*sin(arange(0,2*pi,.31)), color=(.3,.3,.3))
            s.circle.visible = 0
            
        return _lattice


    def animate_spins(self, _lattice=[]):
        while 1:
            self.paused = self.pause
                
            if self.paused != 1:
                for s in _lattice:
                    s.torque = vector(0.0,0.0,0.0)
                    
                    for nn in range(4):
                        s.torque = s.torque + s.couplings[nn]*cross(s.spin,_lattice[s.near[nn]].spin)

                    s.torque = s.torque + vector(0.0,0.0,(s.indices[2]-1)*self.mean_field_down + s.indices[2]*self.mean_field_up)
                        
                if self.tv_tz_zero.get() == 1:
                    for s in _lattice:
                        s.torque.z = 0.0

                for s in _lattice:
                    torque_perp_to_s = s.torque - dot(s.torque,s.spin)*s.spin
                    theta = mag(torque_perp_to_s)*self.dt
                    
                    s.spin = s.spin*cos(theta) + norm(cross(s.spin, cross(s.torque, s.spin)))*sin(theta)
##                    s.spin = s.spin + self.dt*s.torque
                    s.spin = norm(s.spin)

                if self.tv_sigma_rep.get() == 1:
                    if self.tv_colors.get() == 1:
                        for s in _lattice:
                            s.spinvec.axis = .7*vector(s.spin.x,s.spin.y,0)*self.vis_spin_length
                            s.spinvec.pos = s.pos - s.spinvec.axis/2.0
                            color = .9*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
                            s.spinvec.color = (.3+color, .0 , .2-color)
                    else:
                        for s in _lattice:
                            s.spinvec.axis = .7*vector(s.spin.x,s.spin.y,0)*self.vis_spin_length
                            s.spinvec.pos = s.pos - s.spinvec.axis/2.0
                            
                elif self.tv_box_rep.get() == 1:
                    if self.tv_colors.get() == 1:
                        for s in _lattice:
                            if self.tv_flatten.get() == 1:
                                s.box.pos = (s.box.pos.x,s.box.pos.y,0.0)
                            else:
                                s.box.pos = (s.box.pos.x,s.box.pos.y,s.spin.z)
                            color = .6*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
                            s.box.color = (.3+color, .0 , .2-color)
                    else:
                        for s in _lattice:
                            if self.tv_flatten.get() == 1:
                                s.box.pos = (s.box.pos.x,s.box.pos.y,0.0)
                            else:
                                s.box.pos = (s.box.pos.x,s.box.pos.y,s.spin.z)

                elif self.tv_sphere_rep.get() == 1:
                    if self.tv_points.get() == 1:
                        if self.tv_colors.get() == 1:
                            for s in _lattice:
                                s.point.pos = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                                color = .6*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
                                s.point.color = (.3+color, .0 , .2-color)
                        else:
                            for s in _lattice:
                                s.point.pos = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                    else:
                        if self.tv_colors.get() == 1:
                            for s in _lattice:
                                s.spinvec.axis = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                                color = .6*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
                                s.spinvec.color = (.3+color, .0 , .2-color)
                        else:
                            for s in _lattice:
                                s.spinvec.axis = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                else:
                    if self.tv_colors.get() == 1:
                        for s in _lattice:
                            s.spinvec.axis = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                            s.spinvec.pos = s.pos - s.spinvec.axis/2
                            color = .6*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
                            s.spinvec.color = (.3+color, .0 , .2-color)
                    else:
                        for s in _lattice:
                            s.spinvec.axis = vector(s.spin.x,s.spin.y,s.spin.z)*self.vis_spin_length
                            s.spinvec.pos = s.pos - s.spinvec.axis/2

                if self.tv_torques.get() == 1:
                    if self.tv_points.get() == 1:
                        
                        for s in _lattice:
                            s.torquevec.axis = .1*s.torque
                            s.torquevec.pos = s.point.pos
                    else:
                        for s in _lattice:
                            s.torquevec.axis = .4*s.torque
                            s.torquevec.pos = s.spinvec.pos + s.spinvec.axis
        rate(100)





        


####        Scale Values         ####
        
    def set_dt(self, dt_string):
        self.dt = string.atoi(dt_string)/1000.0
        

    def set_kx(self, kx_string):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if (self.k.y == 0.0 and string.atoi(kx_string) == 0.0):
            pass
        else:
            self.k.x = float(kx_string)*(2*pi)/(self.Ntotx*self.dx)

        self.set_spin_mag_ratio(self.k)
        gammaratio_widget.set(self.spin_mag_ratio*1000.0)
        self.spins = self.initialize_spins(self.spins, self.k)
        self.pause = 0

    def set_ky(self, ky_string):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if (self.k.x == 0.0 and string.atoi(ky_string) == 0.0):
            pass
        else:
            self.k.y = float(ky_string)*(2*pi)/(self.Ntoty*self.dy)

        self.set_spin_mag_ratio(self.k)
        gammaratio_widget.set(self.spin_mag_ratio*1000.0)
        self.spins = self.initialize_spins(self.spins, self.k)
        self.pause = 0



    def set_gammaratio(self, gammaratio_string):
        if string.atoi(gammaratio_string) != 0:
            self.spin_mag_ratio = float(gammaratio_string)/1000.0
          
            self.spins = self.initialize_spins(self.spins, self.k)
            self.pause = 0

    def set_J_ratio(self, J_ratio_string):
        self.pause = 1
        while self.paused == 0:
            pass
        
        self.Jb = self.Ja*float(J_ratio_string)
        self.t_free_y_boundary()
        self.pause = 0



    def set_mean_field_up(self, mean_field_up_string):
        self.pause = 1
        while self.paused == 0:
            pass
        
        self.mean_field_up = float(mean_field_up_string)
        self.pause = 0

    def set_mean_field_down(self, mean_field_down_string):
        self.pause = 1
        while self.paused == 0:
            pass
        
        self.mean_field_down = float(mean_field_down_string)
        self.pause = 0

#### ----------------------------- ####




####        Representations        ####

### Sigma Representation
    def t_sigma_rep(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if self.tv_sigma_rep.get() == 1:
            self.tv_sphere_rep.set(0)
            self.t_sphere_rep()
            self.tv_box_rep.set(0)
            self.t_box_rep()

            self.tv_colors.set(1)
            self.t_colors()
            
        else:
            self.tv_circles.set(0)
            self.t_circles()
            
        self.pause = 0

    def t_circles(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        for s in self.spins:
            s.circle.visible = self.tv_circles.get()
            
        self.pause = 0
### ---------------------

### Box Representation
    def t_box_rep(self):
        self.pause = 1
        while self.paused == 0:
            pass

        if self.tv_box_rep.get() == 1:
            self.tv_sigma_rep.set(0)
            self.t_sigma_rep()
            self.tv_sphere_rep.set(0)
            self.t_sphere_rep()

            self.tv_flatten.set(1)
            self.t_flatten()
            self.tv_colors.set(1)
            self.t_colors()
            self.tv_torques.set(0)
            self.t_torques()

            for s in self.spins:
                s.spinvec.visible = 0
                s.box.visible = 1
        else:
            self.tv_flatten.set(0)
            self.t_flatten()
            
            for s in self.spins:
                s.spinvec.visible = 1
                s.box.visible = 0
                
        self.pause = 0

    def t_flatten(self):
        pass
### ---------------------

### Sphere Representation
    def t_sphere_rep(self):
        self.pause = 1
        while self.paused == 0:
            pass

        if self.tv_sphere_rep.get() == 1:
            self.tv_sigma_rep.set(0)
            self.t_sigma_rep()
            self.tv_box_rep.set(0)
            self.t_box_rep()

            self.tv_colors.set(0)
            self.t_colors()
            self.tv_torques.set(1)
            self.t_torques()
            self.tv_circles.set(0)
            self.t_circles()

            for s in self.spins:
                s.spinvec.pos = vector(0.0,0.0,0.0)
                s.spinvec.shaftwidth = .15

            self.equator_line.visible = 1

            scene.range = 3.0


        else:
            self.tv_points.set(0)
            self.t_points()
            self.tv_colors.set(1)
            self.t_colors()
                
            for s in self.spins:
                s.spinvec.shaftwidth = .25
                s.spinvec.pos = vector(s.indices[0]*self.dx - self.dx*self.Ntotx/2.0, s.indices[1]*self.dy - self.dy*self.Ntoty/2.0,0.0)

            self.equator_line.visible = 0

            scene.range = 1.3*self.Ntotx**(.8)
            
        self.pause = 0

    def t_points(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if self.tv_sphere_rep.get() == 1:
            if self.tv_points.get() == 1:
                for s in self.spins:
                    s.spinvec.visible = 0
                    s.point.visible = 1
                    s.torquevec.shaftwidth = .01
            else:
                for s in self.spins:
                    s.spinvec.visible = 1
                    s.point.visible = 0
                    s.torquevec.shaftwidth = .1
        else:
            for s in self.spins:
                s.torquevec.shaftwidth = .1
            pass
            
        self.pause = 0 
### ---------------------



####        Other Options        ####        
        
    def t_sublattices(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if self.tv_sublattices.get() == 0:
            for s in self.spins:
                if s.indices[2] == 1:
                    s.pos = s.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.box.pos = s.box.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.circle.pos = s.circle.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                else:
                    s.pos = s.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.box.pos = s.box.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.circle.pos = s.circle.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
        else:
            for s in self.spins:
                if s.indices[2] == 1:
                    s.pos = s.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.box.pos = s.box.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.circle.pos = s.circle.pos - vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                else:
                    s.pos = s.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.box.pos = s.box.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
                    s.circle.pos = s.circle.pos + vector(0,(self.Ntoty*self.dy)+self.dy*2.,0)/2.
        self.pause = 0
             
    def t_torques(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        for s in self.spins:
            s.torquevec.visible = self.tv_torques.get()
        self.pause = 0
        
    def t_colors(self):
        self.pause = 1
        while self.paused == 0:
            pass

        if self.tv_colors.get() == 0:
            for s in self.spins:
                s.spinvec.color = (.8,.8,.8)
                s.point.color = (.8,.8,.8)
        self.pause = 0

    def t_fm_filter(self):
        pass

    def t_tz_zero(self):
        pass

    def t_free_y_boundary(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        if self.tv_free_y_boundary.get() == 1:
            for s in self.spins:
                if s.indices[0] == 0:
                    s.couplings[0] = self.Ja
                if s.indices[0] == self.Ntotx - 1:
                    s.couplings[1] = self.Ja
                if s.indices[1] == 0:
                    s.couplings[2] = self.Jb
                if s.indices[1] == self.Ntoty - 1:
                    s.couplings[3] = self.Jb
        else:
            for s in self.spins:
                if s.indices[0] == 0:
                    s.couplings[0] = self.Ja
                if s.indices[0] == self.Ntotx - 1:
                    s.couplings[1] = self.Ja
                if s.indices[1] == 0:
                    s.couplings[2] = self.Ja
                if s.indices[1] == self.Ntoty - 1:
                    s.couplings[3] = self.Ja
        pass

        self.pause = 0

    def randomize(self):
        self.pause = 1
        while self.paused == 0:
            pass

        for s in self.spins:
            s.spin = norm(vector(random()-.5,random()-.5,random()-.5))
                
        self.pause = 0

    def reset(self):
        self.pause = 1
        while self.paused == 0:
            pass
        
        self.spins = self.initialize_spins(self.spins, self.k)
        self.t_free_y_boundary()
        self.pause = 0

#### ----------------------------- ####










######-------------------------------------------------------######
###                             TKR                             ###
        
##  Instance creation ##
tkr = Tk()
simu = Simulation(tkr)

##  TKR appearance and widget creation ##

########
tkr.wm_geometry(newGeometry="360x540+620+20")   
tkr.wm_title("Controls")

########
top_frame = Frame(tkr,relief=SUNKEN, borderwidth=0)
top_frame.grid(row=0, column=0, padx=0, pady=0)


######
toggle_frame = Frame(top_frame, relief=FLAT, borderwidth=0)
toggle_frame.grid(row=0, column=0, padx=4, pady=4)

####
modes_frame = Frame(toggle_frame, relief=FLAT, borderwidth=1)
modes_frame.grid(row=0, column=0, padx=4, pady=4)

#
modes_label = Label(modes_frame, text="Mode Options:").grid(row=0, column=0, sticky=W)

sigma_rep_label = Label(modes_frame, text="Sigmas").grid(row=1, column=0, sticky=E)
sigma_rep_widget = Checkbutton(modes_frame, text="", variable=simu.tv_sigma_rep, command=simu.t_sigma_rep).grid(row=1, column=1)

box_rep_label = Label(modes_frame, text="Boxes").grid(row=2, column=0, sticky=E)
box_rep_widget = Checkbutton(modes_frame, text="", variable=simu.tv_box_rep, command=simu.t_box_rep).grid(row=2, column=1)

sphere_rep_label = Label(modes_frame, text="Sphere").grid(row=3, column=0, sticky=E)
sphere_rep_widget = Checkbutton(modes_frame, text="", variable=simu.tv_sphere_rep, command=simu.t_sphere_rep).grid(row=3, column=1)

circles_label = Label(modes_frame, text="Circles").grid(row=1, column=2, sticky=E)
circles_widget = Checkbutton(modes_frame, text="", variable=simu.tv_circles, command=simu.t_circles).grid(row=1, column=3)

flatten_label = Label(modes_frame, text="Flatten").grid(row=2, column=2, sticky=E)
flatten_widget = Checkbutton(modes_frame, text="", variable=simu.tv_flatten, command=simu.t_flatten).grid(row=2, column=3)

points_label = Label(modes_frame, text="Points").grid(row=3, column=2, sticky=E)
points_widget = Checkbutton(modes_frame, text="", variable=simu.tv_points, command=simu.t_points).grid(row=3, column=3)


#

####
visuals_frame = Frame(toggle_frame, relief=FLAT, borderwidth=1)
visuals_frame.grid(row=0, column=1, padx=4, pady=4)

#
visuals_label = Label(visuals_frame, text="Visual Options:").grid(row=0, column=0, sticky=W)

sublattices_label = Label(visuals_frame, text="Sublattices").grid(row=1, column=0, sticky=E)
sublattices_widget = Checkbutton(visuals_frame, text="", variable=simu.tv_sublattices, command=simu.t_sublattices).grid(row=1, column=1)

torque_label = Label(visuals_frame, text="Torques").grid(row=2, column=0, sticky=E)
torque_widget = Checkbutton(visuals_frame, text="", variable=simu.tv_torques, command=simu.t_torques).grid(row=2, column=1)

color_label = Label(visuals_frame, text="Colors").grid(row=3, column=0, sticky=E)
color_widget = Checkbutton(visuals_frame, text="", variable=simu.tv_colors, command=simu.t_colors)
color_widget.grid(row=3, column=1)
color_widget.toggle()

fm_filter_label = Label(visuals_frame, text="FM Filter").grid(row=4, column=0, sticky=E)
fm_filter_widget = Checkbutton(visuals_frame, text="", variable=simu.tv_fm_filter, command=simu.t_fm_filter).grid(row=4, column=1)
#



####
simu_frame = Frame(toggle_frame, relief=FLAT, borderwidth=0)
simu_frame.grid(row=1, column=0, columnspan=2, padx=4, pady=4)

#
simu_label = Label(simu_frame, text="Simulation:").grid(row=0, column=0, sticky=W)

fix_sZ_label = Label(simu_frame, text="Tz Zero").grid(row=1, column=0, sticky=E)
fix_sZ_widget = Checkbutton(simu_frame, text="", variable=simu.tv_tz_zero, command=simu.t_tz_zero).grid(row=1, column=1)

free_y_boundary_label = Label(simu_frame, text="Free y Boundary").grid(row=1, column=2, sticky=E)
free_y_boundary_widget = Checkbutton(simu_frame, text="", variable=simu.tv_free_y_boundary, command=simu.t_free_y_boundary).grid(row=1, column=3)

randomize_widget = Button(simu_frame, text="Randomize", command=simu.randomize).grid(row=1, column=4, padx=4, pady=2)

reset_widget = Button(simu_frame, text="Reset", command=simu.reset).grid(row=1, column=5, padx=4, pady=2)
#



####
scales_frame = Frame(top_frame, relief=FLAT, borderwidth=1)
scales_frame.grid(row=1, column=0, columnspan=2, padx=10, pady=10)

#
scales_label = Label(scales_frame, text="Scales:").grid(row=0, column=0, sticky=W)

##kx_widget = Spinbox(scales_frame, width=5, from_=-5, to=5)
##kx_widget.set(3)
####kx_widget.insert(END, "An Item")
####for item in ["one", "two", "three", "four"]:
####    kx_widget.insert(END, item)
##kx_widget.grid(row=1,column=0)

kx_widget = Scale(scales_frame, orient=HORIZONTAL, from_=-6, to=6, resolution=1.0, label="kx/2PiL", command=lambda str: simu.set_kx(str))
kx_widget.set(simu.k.x*simu.Ntotx*simu.dx/(2.*pi))
kx_widget.grid(row=1, column=0)

ky_widget = Scale(scales_frame, orient=VERTICAL, from_=6, to=-6, resolution=1.0, label="ky/2PiL", command=lambda str: simu.set_ky(str))
ky_widget.set(simu.k.y*simu.Ntoty*simu.dy/(2.*pi))
ky_widget.grid(row=1, column=1)

dt_widget = Scale(scales_frame, orient=VERTICAL, from_=200, to=0, label="dt", command=lambda str: simu.set_dt(str))
dt_widget.set(simu.dt*1000.0)
dt_widget.grid(row=1, column=2)
#


####
scales2_frame = Frame(top_frame, relief=FLAT, borderwidth=1)
scales2_frame.grid(row=2, column=0, columnspan=2, padx=10, pady=10)

#
mean_field_down_widget = Scale(scales2_frame, orient=VERTICAL, from_=1.0, to=0.0, resolution=0.01, label="MF D", command=lambda str: simu.set_mean_field_down(str))
mean_field_down_widget.set(simu.mean_field_down)
mean_field_down_widget.grid(row=0, column=0)

mean_field_up_widget = Scale(scales2_frame, orient=VERTICAL, from_=1.0, to=0.0, resolution=0.01, label="MF U", command=lambda str: simu.set_mean_field_up(str))
mean_field_up_widget.set(simu.mean_field_up)
mean_field_up_widget.grid(row=0, column=1)

gammaratio_widget = Scale(scales2_frame, orient=VERTICAL, from_=1000, to=-1000, resolution=1, label="GR", command=lambda str: simu.set_gammaratio(str))
gammaratio_widget.set(simu.spin_mag_ratio*1000.0)
gammaratio_widget.grid(row=0, column=2)

J_ratio_widget = Scale(scales2_frame, orient=VERTICAL, from_=1.0, to=-1.0, resolution=.1, label="JR", command=lambda str: simu.set_J_ratio(str))
J_ratio_widget.set(simu.Jb/simu.Ja)
J_ratio_widget.grid(row=0, column=3)
#

##  Enter the main TKR loop, starting the interface and program ##
tkr.mainloop()




###### SCRATCH

##                if self.tv_colors.get() == 1:
##                    for s in _lattice:
##                        color = .9*((-1)**(s.indices[2]*(1+self.tv_fm_filter.get())))*(dot(norm(vector(s.spin.x,s.spin.y,0)),norm(self.k)))**1
##                        s.spinvec.color = (.3+color, .0 , .2-color)
##                    if self.tv_mf_shape.get() == 1:
                            
####                        mf_asy_steepness = -1.0
####                        mf_asy_offset = 1.2
####                        mf_shape = mf_asy_offset + mf_asy_steepness*((s.indices[1]+0.5-self.Ntoty/2.0)/(0.5-self.Ntoty/2.0))**2
####                        mf_shape = (2.0/pi)*arcsin(s.spin.z)
####                        mf_shape = s.indices[2]
####                        s.torque = s.torque + mf_shape*vector(0,0,(s.indices[2]-1)*self.mean_field_down + s.indices[2]*self.mean_field_up)
##                        pass
##                    else:

# switch 1 to 2d

##        self.spin_mag_ratio = (4-(16-4*(cos(_k.x*self.dx)+cos(_k.y*self.dy))**2)**(.5))/(2*(cos(_k.x*self.dx)+cos(_k.y*self.dy)))
##        print 'old spin up to down magnitude ratio', self.spin_mag_ratio
        
