Source code for optable.optical_table

from .optical_component import *
from .component_group import *
from .monitor import *
import matplotlib.pyplot as plt
import numpy as np, time, csv
from copy import deepcopy


[docs] class OpticalTable:
[docs] def __init__(self, **kwargs): self.components = [] self.rays = [] self.monitors = [] self.norender_set = set() self._bbox = ( None, None, None, None, None, None, ) # (xmin, xmax, ymin, ymax, zmin, zmax) self.unit = kwargs.get("unit", 1e-2) # default unit is cm
[docs] def add_components(self, component: Union[OpticalComponent, List]): if isinstance(component, OpticalComponent): self.components.append(component) elif isinstance(component, list): for c in component: if isinstance(c, OpticalComponent): self.components.append(c) elif isinstance(c, list): self.components.extend(c)
[docs] def add_monitors(self, monitor: Union[Monitor, List]): if isinstance(monitor, Monitor): self.monitors.append(monitor) elif isinstance(monitor, list): for m in monitor: if isinstance(m, Monitor): self.monitors.append(m) elif isinstance(m, list): self.monitors.extend(m)
@property def bbox(self): if self._bbox[0] is None: self.get_bbox() return self._bbox
[docs] def get_bbox(self): bboxes = [c.bbox for c in self.components] bbox = base_merge_bboxs(bboxes) self._bbox = bbox return bbox
[docs] def ray_tracing(self, rays: Union[Ray, List[Ray]], perfomance_limit=None): """ Perform ray tracing simulation. Parameters: rays: Ray or list of Ray objects. """ if isinstance(rays, Ray): rays = [rays] for ray in rays: rays_traced = self._single_ray_tracing( ray, perfomance_limit=perfomance_limit ) self.rays.extend(rays_traced) # return copy.deepcopy(self.rays)
def _single_ray_tracing(self, ray: Ray, perfomance_limit=None): """ Perform ray tracing simulation. """ alive_rays = [ray] dead_rays = [] exit_flag = False t_start = time.time() last_hint_time = t_start trace_num = 0 # MAX_TRACEING_TIME = 0.5 MIN_HINTING_TIME = 1 MAX_TRACEING_TIME = 600 MAX_TRACE_NUM = 2000 if perfomance_limit is not None: if "max_trace_time" in perfomance_limit: MAX_TRACEING_TIME = perfomance_limit["max_trace_time"] if "max_trace_num" in perfomance_limit: MAX_TRACE_NUM = perfomance_limit["max_trace_num"] while ( (not exit_flag) and (time.time() - t_start < MAX_TRACEING_TIME) and (trace_num < MAX_TRACE_NUM) ): trace_num += 1 # print every second, how much alive traces left, how many traces have been done tnow = time.time() if (tnow - t_start > MIN_HINTING_TIME) and ( tnow - last_hint_time > MIN_HINTING_TIME ): num_alive = len(alive_rays) num_dead = len(dead_rays) print( "Tracing... Time elapsed: {:.2f} s, Trace num: {}, Alive rays: {}, Dead rays: {}".format( tnow - t_start, trace_num, num_alive, num_dead ) ) last_hint_time = tnow # print(len(rays)) # print(rays) if len(alive_rays) > 0: ray = alive_rays.pop(0) # 1. find the component that the ray will intersect the first t_min = None rays_min = None for component in self.components: t, new_rays = component.interact(ray) if t is not None and (t_min is None or t < t_min): t_min = t rays_min = new_rays # 2. intersection if rays_min is not None: for r in rays_min: if r.alive: alive_rays.append(r) else: dead_rays.append(r) # 3. no intersection else: # ray.alive = False dead_rays.append(ray) else: exit_flag = True # if not exit_flag: print( "Ray tracing time exceeds the maximum tracing time after {} traces.".format( trace_num ) ) rays = dead_rays.copy() for monitor in self.monitors: monitor.record(rays) return rays
[docs] def render(self, ax=None, type: str = "Z", roi=None, **kwargs): label_dict = { "Z": ["X", "Y"], "X": ["Y", "Z"], "Y": ["Z", "X"], "3D": ["X", "Y"], } if type in ["X", "Y", "Z"]: if ax is None: fig, ax = plt.subplots(figsize=(10, 8)) plt.subplots_adjust(left=0.1, right=0.7) switch_axis = kwargs.get("switch_axis", False) if switch_axis: label_dict[type] = label_dict[type][::-1] ax.set_xlabel(label_dict[type][0]) ax.set_ylabel(label_dict[type][1]) # print(self.norender_set) for ray in self.rays: # print(ray._id) if ray._id not in self.norender_set: ray.render(ax, type=type, **kwargs) for component in self.components: if component._id not in self.norender_set: component.render(ax, type=type, **kwargs) for monitor in self.monitors: if monitor._id not in self.norender_set: monitor.render(ax, type=type, **kwargs) # ax.set_aspect("equal") ax.set_aspect("auto") elif type == "3D": if ax is None: fig = plt.figure(figsize=(10, 8)) ax = fig.add_subplot(111, projection="3d") ax.set_xlabel("X") ax.set_ylabel("Y") # ax.set_zlabel("Z") for ray in self.rays: if ray._id not in self.norender_set: ray.render(ax, type="3D", **kwargs) for component in self.components: if component._id not in self.norender_set: component.render(ax, type="3D", **kwargs) for monitor in self.monitors: if monitor._id not in self.norender_set: monitor.render(ax, type="3D", **kwargs) ax.set_aspect("equal") else: raise ValueError(f"render: Invalid type: {type}") # if roi is not None: axis2roiidx = {"X": (0, 1), "Y": (2, 3), "Z": (4, 5)} x_axislabel, y_axislabel = label_dict[type] xidx0, xidx1 = axis2roiidx[x_axislabel] yidx0, yidx1 = axis2roiidx[y_axislabel] ax.set_xlim(roi[xidx0], roi[xidx1]) ax.set_ylim(roi[yidx0], roi[yidx1]) if ax.name == "3d": ax.set_zlim(roi[4], roi[5]) aspect = kwargs.get("aspect", "equal") ax.set_aspect(aspect)
# >>> ABCD MATRIX CALCULATION
[docs] def calculate_abcd_matrix( self, mon0: Monitor, mon1: Monitor, rays: List[Ray], disp=1e-5, rot=1e-5, debugaxs=None, ) -> np.ndarray: """Calculate the ABCD matrix between two monitors mon0 and mon1 (mon0 to mon1). default principal axis is "Y". raysids are the ids of the rays used to calculate the matrix. If None, use all rays that intersect with mon0. The principal axis for each monitor is specified by pax0 and pax1, it can be either "Y" or "Z" or a custom vector lying on the plane of the monitor. delta is the small perturbation applied to calculate the matrix elements. Returns: A Nray x 2 x 2 numpy array representing the ABCD matrix for each ray. """ # pax0_dispvec = mon0.tangent_Y pax0_rotvec = mon0.tangent_Z def _simulate(rays): self.rays = [] mon0.clear() mon1.clear() self.ray_tracing(rays) if debugaxs is not None: self.render(ax=debugaxs, type="Z") # # VALIDATIONS assert len(rays) > 0, "No rays to trace in ABCD calculation." Nrays = len(rays) raysids = [ray._id for ray in rays] assert len(set(raysids)) == Nrays, "Redundant ray ids in ABCD calculation." # sort rays by their ids rays_idx_sortbyids = np.argsort(raysids) rays = [deepcopy(rays[i]) for i in rays_idx_sortbyids] # # RAY TRACING _simulate(rays) # get rays intersecting with monitors rays_mon0 = mon0.get_rays(sort="ID") raysid_mon0 = [ray._id for ray in rays_mon0] assert set(raysids) == set( raysid_mon0 ), "Rays at mon0 do not match the input rays." PListmon0 = mon0.get_PList(sort="ID") rays_mon1 = mon1.get_rays(sort="ID") raysid_mon1 = [ray._id for ray in rays_mon1] assert set(raysids) == set( raysid_mon1 ), "Rays at mon1 do not match the input rays." yList00 = mon1.get_yList(sort="ID") tYList00 = mon1.get_tYList(sort="ID") # Ms = np.zeros((Nrays, 2, 2)) # store each ray's ABCD matrix # now bias rays in principal axis, and perform ray tracing rays_biased = [] # first bias in position for idx in range(Nrays): rays_biased.append(rays[idx]._Translate(pax0_dispvec * disp)) _simulate(rays_biased) yList10 = mon1.get_yList(sort="ID") tYList10 = mon1.get_tYList(sort="ID") A = (yList10 - yList00) / disp C = (tYList10 - tYList00) / disp # then bias in angle rays_biased = [] for idx in range(Nrays): rays_biased.append( rays[idx]._RotAround(pax0_rotvec, PListmon0[idx], rot) ) # rotate around the intersection point on mon0 _simulate(rays_biased) yList01 = mon1.get_yList(sort="ID") tYList01 = mon1.get_tYList(sort="ID") B = (yList01 - yList00) / rot D = (tYList01 - tYList00) / rot # Ms[:, 0, 0] = A Ms[:, 0, 1] = B Ms[:, 1, 0] = C Ms[:, 1, 1] = D # return Ms
[docs] @staticmethod def calibrate_symmetric_4f( lens: Union[OpticalComponent, ComponentGroup], rays: List[Ray], F10: float, F20: float, criterion: str = "M=-I", debugaxs=None, optimize=True, display_M=False, ): """Calibrate a symmetric 4f system formed by two lenses and two monitors. The distances are mon0 - F1 - lens - 2*F2 - lens(rotated 180) - F1 - mon1. The function will adjust the distances F1 and F2 to achieve the desired 4f imaging condition. Parameters: lens: OpticalComponent, the lens to be calibrated. rays: List[Ray], the rays to be traced. F10,F20: float, the initial distances. debugaxs: matplotlib axes, optional, for debugging visualization. Returns: F1, F2: calibrated distances. """ from scipy.optimize import minimize from tqdm import tqdm _cnt = 0 # create a copy of lens def simulate(F1, F2): nonlocal _cnt _cnt += 1 dr0 = np.array([F1, 0, 0]) - lens.origin l0 = lens.copy()._Translate(dr0) dr1 = np.array([F1 + 2 * F2, 0, 0]) - lens.origin l1 = lens.copy()._Translate(dr1).RotZ(np.pi) Mon0 = Monitor(origin=[0, 0, 0], width=5, height=5) Mon1 = Monitor(origin=[2 * F1 + 2 * F2, 0, 0], width=5, height=5) table = OpticalTable() table.add_components([l0, l1]) table.add_monitors([Mon0, Mon1]) table.ray_tracing(rays) yList = Mon1.get_yList() tyList = Mon1.get_tYList() Ms = table.calculate_abcd_matrix(Mon0, Mon1, rays) # if display_M: for M in Ms: print(M) # if not optimize: # if True: # if _cnt % 10 == 0: # fig, ax = plt.subplots(figsize=(8, 6)) table.rays = [] table.ray_tracing(rays) print( "Rendering 4f system with F1={:.4f}, F2={:.4f} ...".format(F1, F2) ) # table.render(ax=ax, type="Z") table.render(ax=debugaxs, type="Z" if debugaxs is not None else None) # plt.show() return Ms, yList, tyList def _cost_func_M_equal_mI(F1, F2): Ms, yList, tYList = simulate(F1, F2) cost = 0 for M in Ms: cost += np.linalg.norm(M + np.eye(2)) cost /= len(Ms) pbar.set_description(f"Testing F1={F1:.4f}, F2={F2:.4f}, cost={cost:.6f}") return cost def _cost_func_flat_field(F1, F2): Ms, yList, tYList = simulate(F1, F2) cost = 0 for M in Ms: A = M[0, 0] B = M[0, 1] C = M[1, 0] D = M[1, 1] d0 = 1.5 ds = (d0 * A - B) / (D - C * d0) cost += np.abs(ds - d0) cost /= len(Ms) pbar.set_description(f"Testing F1={F1:.4f}, F2={F2:.4f}, cost={cost:.6f}") return cost def _cost_func_min_stdtY(F1, F2): Ms, yList, tYList = simulate(F1, F2) # print(tYList) cost = 0 for M in Ms: cost += np.std(tYList) cost /= len(Ms) pbar.set_description(f"Testing F1={F1:.4f}, F2={F2:.4f}, cost={cost:.6f}") return cost def cost_function(): # if "pbar" in locals(): pbar.update(1) if criterion == "M=-I": return _cost_func_M_equal_mI elif criterion == "flat_field": return _cost_func_flat_field elif criterion == "min_stdtY": return _cost_func_min_stdtY else: raise ValueError(f"Unknown criterion: {criterion}") if not optimize: Ms, yList, tYList = simulate(F10, F20) return Ms, yList, tYList else: pbar = tqdm(total=500, desc="Optimizing 4f system") res = minimize( lambda x: cost_function()(x[0], x[1]), x0=[F10, F20], method="Nelder-Mead", options={"disp": True, "xatol": 1e-5, "maxiter": 50}, ) pbar.close() F1_opt, F2_opt = res.x return F1_opt, F2_opt
# >>> VISUALIZATION FUNCTIONS
[docs] def add_wavelength_legend(self, ax, wavelengths): """ Adds a custom color legend to an existing ax for specified wavelengths. wavelengths: list of floats/ints in nm """ from matplotlib.lines import Line2D # Create the 'Proxy' line objects for the legend # We use the wavelength_to_rgb function defined earlier proxies = [] for wl in wavelengths: wl_m = wl * self.unit # convert to nm color = wavelength_to_rgb(wl_m) line = Line2D([0], [0], color=color, lw=3, label=f"{int(wl_m*1e9)} nm") proxies.append(line) # Optional: Merge with existing legend items if the plot already has them existing_handles, _ = ax.get_legend_handles_labels() # Update the legend ax.legend(handles=existing_handles + proxies, loc="best")
# >>> EXPORTING FUNCTIONS
[docs] def gather_rays_csv(self): """ Gather all rays in the optical table. """ def _repr_self_dict(ray): d = { "origin": to_mathematical_str(str(ray.origin.tolist())), "transform_matrix": to_mathematical_str( str(ray.transform_matrix.tolist()) ), "intensity": get_attr_str(ray, "intensity", "None"), "length": get_attr_str(ray, "length", "None"), "qo": to_mathematical_str(str(get_attr_str(ray, "qo", "None"))), "n": to_mathematical_str(str(get_attr_str(ray, "n", "None"))), } return d rays = [] for ray in self.rays: rays.append(_repr_self_dict(ray)) return rays
[docs] def gather_components( self, avoid_flatten_classname: List = [], ignore_classname: List = [] ) -> List[dict]: """ Gather all components in the optical table, avoiding flattening if specified. """ components = [] for component in self.components: components.extend( component.gather_components( avoid_flatten_classname=avoid_flatten_classname, ignore_classname=ignore_classname, ) ) return components
[docs] def export_rays_csv(self, filename: str): """ Export rays to a file. """ rays_traced = self.gather_rays_csv() print(f"Exporting rays to {filename} ...") with open(filename, "w", newline="") as csvfile: writer = csv.writer(csvfile) keys = rays_traced[0].keys() if rays_traced else [] writer.writerow(keys) for ray in rays_traced: writer.writerow(ray.values())
[docs] def export_components_csv( self, filename: str, avoid_flatten_classname: List = [], ignore_classname: List = [], ): """ Export components to a file. Including its class, origin and normal vector """ components = self.gather_components( avoid_flatten_classname=avoid_flatten_classname, ignore_classname=ignore_classname, ) print(f"Exporting components to {filename} ...") with open(filename, "w", newline="") as csvfile: writer = csv.writer(csvfile) keys = components[0].keys() if components else [] writer.writerow(keys) for component in components: writer.writerow(component.values())