from .error import *
from .element import *
from .rk4 import *
import numpy as np
import scipy.constants as const
from scipy.integrate import RK45
from scipy.optimize import brentq
import logging
logger = logging.getLogger(__name__)
__all__ = ['RayTrace']
[docs]class RayTrace(Element):
"""Non-relativistic ray-tracing of particles through electromagnetic fields.
Set up ray-tracing object with fields E and B. The planes
variable defines a list of planes: the starting plane,
possible intermediate planes and the exit plane. Each of the
planes is defined by a tuple (c,u,v,col) with three numpy
arrays: center point c0 and unit vectors u0 and v0; and a None
or a Collimator object if the plane is to collimate
particles. The u and v vectors are normalized internally.
Length of element is defined based on reference particle
trajectory or zero if reference particle not defined.
Intermediate planes not in use currently.
Step length steplen is used for defining initial step size of
integration in case of RK45. In case of fixed step size
integrator it is used to calculate the time step size using
initial particle velocity. The xsize and ysize define the size
of box representing the elements in plots."""
def __init__(self,refple,E,B,planes,steplen,xsize=0.1,ysize=0.1,iterator='RK45',maxsteps=1000):
super().__init__('RayTrace')
self.isLeaf=False
self.iterator=iterator
self.maxsteps=maxsteps
self.refple = refple
self.E = E
self.B = B
self.ds = 0
self.dr = 0
self.dx = 0
self.dy = 0
self.steplen = steplen
self.xsize = xsize
self.ysize = ysize
# Process planes
sqrt_eps = np.finfo(float).eps
logger.debug('Raytrace defined')
for i,p in enumerate(planes):
c = np.copy(p[0])
u = p[1]
v = p[2]
w = np.cross(u,v)
v = -np.cross(u,w)
if np.linalg.norm(u) <= sqrt_eps or np.linalg.norm(v) <= sqrt_eps or np.linalg.norm(w) <= sqrt_eps:
raise Exception('Invalid vectors defining plane {:d}'.format(i))
u = u/np.linalg.norm(u)
v = v/np.linalg.norm(v)
w = w/np.linalg.norm(w)
logger.debug('Plane {:d}:'.format(i))
logger.debug('u = {:s}'.format(np.array2string(u)))
logger.debug('v = {:s}'.format(np.array2string(v)))
logger.debug('w = {:s}'.format(np.array2string(w)))
col = None if p[3] is None else Collimator(p[3])
planes[i] = (c,u,v,w,col)
self.planes = planes
self.dr = self.planes[-1][0]-self.planes[0][0]
self.dx = np.array([np.dot(self.planes[-1][1],self.planes[0][1]),
np.dot(self.planes[-1][1],self.planes[0][2]),
np.dot(self.planes[-1][1],self.planes[0][3])])
self.dy = np.array([np.dot(self.planes[-1][2],self.planes[0][1]),
np.dot(self.planes[-1][2],self.planes[0][2]),
np.dot(self.planes[-1][2],self.planes[0][3])])
# Define optical axis length if reference particle defined
# (could also test here for 0-order terms)
if refple:
K0 = const.e*self.refple.energy
m0 = const.u*self.refple.mass
q0 = const.e*self.refple.charge
vz = np.sqrt(2*K0/m0)
x0 = np.array([0,0,0,0,0,vz])
x1 = self.calc(q0,m0,x0)
if type(x1) != np.ndarray:
raise Exception('RayTrace(): could not calculate reference trajectory')
logger.debug('Raytraced reference particle, length = {:f}'.format(self.s[-1]))
self.ds = self.s[-1]
else:
# Default length to straight path from entrance to exit
self.ds = np.linalg.norm(self.dr)
[docs] def dxdt(self,t,x):
""" Function providing right hand side for ray-tracer ODE solver.
"""
if self.E == None:
E = np.zeros(3)
else:
E = self.E.field(x[:3])
if self.B == None:
B = np.zeros(3)
else:
B = self.B.field(x[:3])
F = np.zeros(len(x))
F[0] = x[3]
F[1] = x[4]
F[2] = x[5]
F[3] = self.qm*(E[0]+x[4]*B[2]-x[5]*B[1])
F[4] = self.qm*(E[1]+x[5]*B[0]-x[3]*B[2])
F[5] = self.qm*(E[2]+x[3]*B[1]-x[4]*B[0])
return(F)
def intrproot(self,t,intrp):
return(np.dot(intrp(t)[:3]-self.planes[-1][0],self.planes[-1][3]))
[docs] def calc(self,q,m,xs0):
"""Calculate particle with defined starting plane coordinates (x1,x2,x3,v1,v2,v3).
Returns the particle coordinates (x1,x2,x3,v1,v2,v3) at the end or None if the
particle does not reach the end plane.
The calculation point time and coordinates can be found in
internal attributes time and traj after calculation in the
coordinate system of the field.
"""
# Transform coordinates
x0 = self.planes[0][0] + xs0[0]*self.planes[0][1] + xs0[1]*self.planes[0][2] + xs0[2]*self.planes[0][3]
v0 = xs0[3]*self.planes[0][1] + xs0[4]*self.planes[0][2] + xs0[5]*self.planes[0][3]
# Initialize solver
t0 = 0
traj0 = np.concatenate((x0,v0))
tmax = 1
self.qm = q/m
self.niter = 0
time = [t0]
traj = [traj0]
try:
if self.iterator=='RK45':
dtmax = self.steplen/np.linalg.norm(v0)
# Clear "noise" observable with tolerances 1e-3/1e-6
rtol = 1e-4
atol = 1e-7
solver = RK45(self.dxdt,t0,traj0,tmax,dtmax,rtol,atol,
False,self.steplen/np.linalg.norm(v0))
else:
dt = self.steplen/np.linalg.norm(v0)
solver = RK4(self.dxdt,t0,traj0,tmax,dt)
while solver.step() == None and self.niter < self.maxsteps:
time.append(solver.t)
traj.append(solver.y)
self.niter += 1
# Check if exit plane crossed
# Solving (\vec{x}(t) - \vec{c}) \cdot \hat{w} = 0 linearized as
# (\vec{x_1} + (\vec{x_2}-\vec{x_1}) K - \vec{c}) \cdot \hat{w} = 0.
div = np.dot(traj[-1][:3]-traj[-2][:3],self.planes[-1][3])
if div == 0.0:
continue
K = np.dot(self.planes[-1][0]-traj[-2][:3],self.planes[-1][3]) / div
if K >= 0.0 and K < 1.0:
# Solve crossing point using interpolator with brent
intrp = solver.dense_output()
tcol = brentq(self.intrproot,solver.t_old,solver.t,intrp)
time[-1] = tcol
traj[-1] = intrp(tcol)
self.time = np.array(time)
self.traj = np.array(traj)
self.s = np.concatenate(([0],np.cumsum(np.linalg.norm(np.diff(self.traj[:,:3],axis=0),axis=1))))
relx1 = traj[-1][:3] - self.planes[-1][0]
v1 = traj[-1][3:]
x1 = np.array([np.dot(relx1,self.planes[-1][1]), np.dot(relx1,self.planes[-1][2]), np.dot(relx1,self.planes[-1][3]),
np.dot(v1,self.planes[-1][1]), np.dot(v1,self.planes[-1][2]), np.dot(v1,self.planes[-1][3])])
return(x1)
except FieldError as err:
# Particle out of defined field
self.time = np.array(time)
self.traj = np.array(traj)
return(None)
# Calculation failed
self.time = np.array(time)
self.traj = np.array(traj)
logger.error('RayTrace calculation failed')
if self.niter >= self.maxsteps:
logger.error('Too many steps')
return(None)
[docs] def transfer(self, driver):
driver.onElementEnter(self)
pleset = driver.getParticlesAsCopy()
x = pleset['x']
a = pleset['a']
y = pleset['y']
b = pleset['b']
K = const.e*pleset['K']
m = const.u*pleset['m']
q = const.e*pleset['q']
mask = np.array(len(x)*[True])
niter_total = 0
nfailed = 0
for i in range(len(x)):
vz = np.sqrt(2*K[i]/(m[i]*(a[i]**2+b[i]**2+1)))
x0 = np.array([x[i],y[i],0,a[i]*vz,b[i]*vz,vz])
x1 = self.calc(q[i],m[i],x0)
niter_total += self.niter
if type(x1) != np.ndarray:
mask[i] = False
nfailed += 1
else:
pleset['x'][i] = x1[0]
pleset['a'][i] = x1[3]/x1[5]
pleset['y'][i] = x1[1]
pleset['b'][i] = x1[4]/x1[5]
pleset['K'][i] = 0.5*m[i]*(x1[3]**2+x1[4]**2+x1[5]**2)/const.e
logger.debug('Raytraced {:d} particles with {:d} steps total, average {:f} steps per particle, {:d} particles not transmitted'.format(len(x),niter_total,niter_total/len(x),nfailed))
# Mask out collimated particles
for c in pleset:
pleset[c] = pleset[c][mask]
driver.registerTransfer(pleset, self.ds, self.dr, self.dx, self.dy)
driver.onElementExit(self)
[docs] def opticalAxisExit(self):
return(self.exit_r,self.exit_x,self.exit_y)
[docs] def plot(self,plotter,view,s):
if view == 'sx':
plotter.polyline(s+np.array([0,0,self.ds,self.ds,0]),
np.array([-self.xsize,self.xsize,self.xsize,-self.xsize,-self.xsize]),view)
if view == 'sy':
plotter.polyline(s+np.array([0,0,self.ds,self.ds,0]),
np.array([-self.ysize,self.ysize,self.ysize,-self.ysize,-self.ysize]),view)
[docs] def length(self):
return(self.ds)
[docs] def printDebug(self,indent=''):
pass