Source code for optable.optical_component

import scipy
from .base import *
from .ray import *
from .surfaces import *
from mpl_toolkits.mplot3d.art3d import Poly3DCollection


[docs] class OpticalComponent(Vector): """Base class for optical elements represented in a local coordinate frame."""
[docs] def __init__(self, origin, **kwargs): """Initialize a component at a lab-frame origin. Args: origin: Component origin in lab coordinates. **kwargs: Optional display metadata such as ``name``, ``label``, ``render_obj``, and ``render_comp_vec``. """ super().__init__(origin, **kwargs) self.transform_matrix = np.identity(3) self.surface = Plane() # Default surface is a plane self._bbox = ( None, None, None, None, None, None, ) # xmin, xmax, ymin, ymax, zmin, zmax self.render_obj = kwargs.get("render_obj", True) self.render_comp_vec = kwargs.get("render_comp_vec", False) self.name = kwargs.get("name", None) self.label = kwargs.get("label", None) self.label_position = kwargs.get( "label_position", [1, 0, 0] ) # [dx,dy] or [dx,dy,dz] self._interact_count = {} self.max_interact_count = kwargs.get("max_interact_count", None)
def __repr__(self): return f"OpticalComponent(origin={self.origin}, transform_matrix=\n{self.transform_matrix})" # normal in local fram is always x-axis @property def normal(self): """Return the surface normal in lab coordinates.""" return self.transform_matrix @ np.array([1, 0, 0]) @property def tangent_Y(self): """Return local +Y axis expressed in lab coordinates.""" return self.transform_matrix @ np.array([0, 1, 0]) @property def tangent_Z(self): """Return local +Z axis expressed in lab coordinates.""" return self.transform_matrix @ np.array([0, 0, 1])
[docs] def get_bbox_local(self): raise NotImplementedError("Subclasses must implement get_bbox_local method")
@property def bbox(self): """Return cached axis-aligned bounding box in lab coordinates.""" if self._bbox == (None, None, None, None, None, None): self._bbox = tuple(self.get_bbox()) return tuple(self._bbox)
[docs] def get_bbox(self) -> tuple: """Compute axis-aligned bounding box in lab frame. Returns: Tuple ``(xmin, xmax, ymin, ymax, zmin, zmax)`` in lab coordinates. """ bbox_local = self.get_bbox_local() corners_local = np.array( [ [bbox_local[0], bbox_local[2], bbox_local[4]], # [xmin, ymin, zmin] [bbox_local[1], bbox_local[2], bbox_local[4]], # [xmax, ymin, zmin] [bbox_local[0], bbox_local[3], bbox_local[4]], # [xmin, ymax, zmin] [bbox_local[1], bbox_local[3], bbox_local[4]], # [xmax, ymax, zmin] [bbox_local[0], bbox_local[2], bbox_local[5]], # [xmin, ymin, zmax] [bbox_local[1], bbox_local[2], bbox_local[5]], # [xmax, ymin, zmax] [bbox_local[0], bbox_local[3], bbox_local[5]], # [xmin, ymax, zmax] [bbox_local[1], bbox_local[3], bbox_local[5]], # [xmax, ymax, zmax] ] ).T # shape (3, 8) corners_global = (self.transform_matrix @ corners_local) + self.origin.reshape( -1, 1 ) # shape (3, 8) xmin, xmax = np.min(corners_global[0, :]), np.max(corners_global[0, :]) ymin, ymax = np.min(corners_global[1, :]), np.max(corners_global[1, :]) zmin, zmax = np.min(corners_global[2, :]), np.max(corners_global[2, :]) # print( # f"Component {self.name}, bbox: {(xmin, xmax, ymin, ymax, zmin, zmax)},origin: {self.origin}" # ) return (xmin, xmax, ymin, ymax, zmin, zmax)
def _RotAroundLocal(self, axis, localpoint, theta): R = self.R(axis, theta) localpoint = np.array(localpoint) self.transform_matrix = np.dot(R, self.transform_matrix) self.origin = self.origin + np.dot(R, -localpoint) + localpoint return self
[docs] def ray_to_local_coordinates(self, ray: Ray) -> Ray: """Transform a ray from lab frame into this component's local frame.""" R = np.linalg.inv(self.transform_matrix) local_origin = np.dot(R, ray.origin - self.origin) local_direction = np.dot(R, ray.direction) return ray.copy(origin=local_origin, direction=local_direction)
[docs] def point_to_lab_coordinates(self, point_local: np.ndarray) -> np.ndarray: """Transform a point from local frame back to the lab frame.""" R = self.transform_matrix global_point = np.dot(R, point_local) + self.origin return global_point
[docs] def ray_to_lab_coordinates(self, ray: Ray) -> Ray: """Transform a ray from local frame back to the lab frame.""" R = self.transform_matrix global_origin = np.dot(R, ray.origin) + self.origin global_direction = np.dot(R, ray.direction) return ray.copy(origin=global_origin, direction=global_direction)
def _find_zero(self, f, a, b, num_intervals: int = 10) -> np.ndarray: t_values = np.linspace(a, b, int(num_intervals)) sols = [] for i in range(len(t_values) - 1): t_start, t_end = t_values[i], t_values[i + 1] if f(t_start) * f(t_end) < 0: t = scipy.optimize.brentq(f, t_start, t_end) sols.append(t) return np.array(sols)
[docs] def get_interact_count(self, ray_id): count = self._interact_count.get(ray_id, 0) return count
[docs] def should_interact(self, ray_id): if self.max_interact_count is None: return True else: count = self._interact_count.get(ray_id, 0) return count < self.max_interact_count
[docs] def increase_interact_count(self, ray_id): count = self._interact_count.get(ray_id, 0) self._interact_count[ray_id] = count + 1
[docs] def intersect_point_local( self, ray: Ray ) -> Union[Tuple[np.ndarray, float], Tuple[None, None]]: """Find first valid intersection with the component in local coordinates. Args: ray: Input ray already expressed in the component local frame. Returns: ``(P, t)`` where ``P`` is the local intersection point and ``t`` is the ray parameter. Returns ``(None, None)`` if no valid hit exists. """ EPS = 1e-9 # if self.surface.planar: # rays intesection with the plane, solve the equation # (P - P0) . n = 0, where P is the intersection point, P0 is the origin of optical component in local frame (0, 0, 0) # P = Ps + t * v, where v is the direction of the ray, Ps is the origin of the ray # t = (P0 - Ps) . n / v . n numerator = np.dot(-ray.origin, np.array([1, 0, 0])) denominator = np.dot(ray.direction, np.array([1, 0, 0])) if denominator == 0: if numerator != 0: return None, None else: t = 0 else: t = numerator / denominator # # Case 0: t=0, the just outgoing ray interact with the optical component # Case 1: t>ray.length, has been absorbed # Case 2: t<0, backward ray if ( (t is None) or (np.abs(t) < EPS) or (t < 0) or ((ray.length is not None) and (t > ray.length)) ): return None, None else: P = ray.origin + t * ray.direction if not (self.surface.within_boundary(P)): return None, None else: return P, t else: # P = (x, y, z) # f(P) = 0 # P = Ps + t * v # -> f(Ps + t * v) = 0, numerically solve t def f(t): Pt = ray.origin + t * ray.direction return self.surface.f(Pt) t1, t2, hits = self.surface.solve_crosssection_ray_bbox_local( ray.origin, ray.direction ) t1, t2 = float(t1[0]), float(t2[0]) # print("t1, t2", t1, t2) if t2 + EPS < t1: return None, None t1 = max(t1, 0) t2 = min(t2, 100) tList = self._find_zero(f, t1 - EPS, t2 + EPS) # if len(tList) == 0: return None, None else: valid_mask_0 = (tList >= 0) & (np.abs(tList) >= EPS) if ray.length is not None: valid_mask_1 = tList <= ray.length valid_mask = valid_mask_0 & valid_mask_1 else: valid_mask = valid_mask_0 tvalid = tList[valid_mask] # for t in np.sort(tvalid): P = ray.origin + t * ray.direction if self.surface.within_boundary(P): return P, t return None, None
[docs] def interact_local(self, ray: Ray) -> List[Ray]: """Compute outgoing rays in local coordinates. Subclasses must implement optical behavior using local-frame geometry. """ raise NotImplementedError("Subclasses must implement interact_local method")
def _get_boundary_points(self, type: str): t = np.linspace(0, 1, 100) points_local = self.surface.parametric_boundary(t, type) points_global = (self.transform_matrix @ points_local) + self.origin.reshape( -1, 1 ) # print(points_global) return points_global
[docs] def render(self, ax, type: str, **kwargs): """Render the component boundary in 2D (``"Z"``) or 3D (``"3D"``).""" if not self.render_obj: return # Get edge color and line width from kwargs color = kwargs.get("color", "black") linewidth = kwargs.get("linewidth", 2) global_x, global_y, global_z = self._get_boundary_points(type) global_xyz = [global_x, global_y, global_z] detailed_render = kwargs.get("detailed_render", False) label = kwargs.get("label", None) label_fontsize = kwargs.get("label_fontsize", 10) # if type in ["X", "Y", "Z"]: dimension_dict = {"Z": [0, 1], "X": [1, 2], "Y": [2, 0]} switch_axis = kwargs.get("switch_axis", False) if switch_axis: dimension_dict[type] = dimension_dict[type][::-1] ax.plot( global_xyz[dimension_dict[type][0]], global_xyz[dimension_dict[type][1]], color=color, linewidth=linewidth, ) if self.render_comp_vec: # add a component vector (red) normal = self.normal ax.quiver( self.origin[dimension_dict[type][0]], self.origin[dimension_dict[type][1]], normal[dimension_dict[type][0]], normal[dimension_dict[type][1]], color=color, # scale=2, # scale_units="xy", ) if label and self.label: if self.label_position is not None: ax.text( self.origin[dimension_dict[type][0]] + self.label_position[dimension_dict[type][0]], self.origin[dimension_dict[type][1]] + self.label_position[dimension_dict[type][1]], self.label, color=color, fontsize=label_fontsize, ) elif type == "3D": ax.plot( global_x, global_y, global_z, color=color, linewidth=linewidth, ) if self.render_comp_vec: # add a component vector (red) normal = self.normal ax.quiver( self.origin[0], self.origin[1], self.origin[2], normal[0], normal[1], normal[2], color=color, # scale=2, # scale_units="xy", ) if detailed_render: # add a 3D polygon poly = Poly3DCollection( [list(zip(global_x, global_y, global_z))], color=color, linewidths=linewidth, edgecolors=color, alpha=0.25, ) ax.add_collection3d(poly) else: raise ValueError(f"render: Invalid type: {type}")
[docs] def interact(self, ray: Ray) -> Union[Tuple[float, List[Ray]], Tuple[None, None]]: """Apply component interaction in local frame and return lab-frame rays. Args: ray: Input ray in lab coordinates. Returns: Tuple ``(t, rays)`` where ``t`` is the hit distance and ``rays`` contains the truncated incoming segment plus newly generated outgoing rays in lab coordinates. Returns ``(None, None)`` when no interaction occurs. """ if not ray.alive: return None, None else: local_ray = self.ray_to_local_coordinates(ray) P, t = self.intersect_point_local(local_ray) # if P is None: return None, None else: if self.should_interact(ray._id): self.increase_interact_count(ray._id) else: return None, None # truncated_ray = ray.copy(length=t, alive=False) # new rays after interaction local_rays_after_interaction = self.interact_local( local_ray ) # List[Ray] lab_rays_after_interaction = [ self.ray_to_lab_coordinates(local_ray) for local_ray in local_rays_after_interaction ] # global_point = self.point_to_lab_coordinates(P) # print( # f"Component {self.name} interacted with ray {ray._id}, intensity {ray.intensity} at global point {global_point} and distance {t}, generated {len(lab_rays_after_interaction)} rays" # ) return t, [truncated_ray] + lab_rays_after_interaction
[docs] def patch_block(self, width, height): """Create a ``Block`` patch sharing this component pose and aperture.""" obj = Block(self.origin, hole=self.surface, width=width, height=height) obj.transform_matrix = self.transform_matrix return obj
[docs] def gather_components( self, avoid_flatten_classname: List = [], ignore_classname: List = [] ) -> List: """Return a flattened metadata list for this component hierarchy. Args: avoid_flatten_classname: Class names whose children are not expanded. ignore_classname: Class names excluded from output. Returns: A list of serializable dictionaries describing components. """ def _repr_self_dict(): d = { "name": get_attr_str(self, "name", "None"), "class": self.__class__.__name__, "origin": to_mathematical_str(str(self.origin.tolist())), "transform_matrix": to_mathematical_str( str(self.transform_matrix.tolist()) ), "radius": get_attr_str(self, "radius", "None"), "width": get_attr_str(self, "width", "None"), "height": get_attr_str(self, "height", "None"), "focal_length": get_attr_str(self, "focal_length", "None"), } return d components = [] if self.__class__.__name__ not in ignore_classname: components.append(_repr_self_dict()) if hasattr(self, "components"): for component in self.components: if self.__class__.__name__ not in avoid_flatten_classname: components.extend( component.gather_components( avoid_flatten_classname=avoid_flatten_classname, ignore_classname=ignore_classname, ) ) return components
[docs] class PointObj(OpticalComponent): """Degenerate point-like marker component that transmits all rays."""
[docs] def __init__(self, origin, **kwargs): """Initialize a point object at ``origin``.""" super().__init__(origin, **kwargs) self.surface = Point() self._edge_color = "orange"
[docs] def interact_local(self, ray): """Return the same ray to model perfect transmission.""" return [ray] # every ray is transmitted
[docs] def render(self, ax, type: str, **kwargs): """Render the point marker in 2D or 3D.""" if type in ["X", "Y", "Z"]: dimension_dict = {"Z": [0, 1], "X": [1, 2], "Y": [2, 0]} switch_axis = kwargs.get("switch_axis", False) if switch_axis: dimension_dict[type] = dimension_dict[type][::-1] ax.scatter( self.origin[dimension_dict[type][0]], self.origin[dimension_dict[type][1]], color=self._edge_color, marker="+", s=20, ) elif type == "3D": ax.scatter( self.origin[0], self.origin[1], self.origin[2], color=self._edge_color, marker="+", s=20, )
[docs] def get_bbox_local(self): """Return local bounding box from point surface geometry.""" return self.surface.get_bbox_local()
[docs] class Block(OpticalComponent): """Opaque rectangular blocker with an optional hole aperture."""
[docs] def __init__( self, origin, hole: Union[Surface | None] = None, width: float = 1.0, height: float = 1.0, **kwargs, ): """Initialize a blocking rectangle. Args: origin: Component origin in lab coordinates. hole: Optional surface removed from the rectangle aperture. width: Rectangle width along local Y. height: Rectangle height along local Z. """ super().__init__(origin, **kwargs) self.width = width self.height = height self.surface = ( Rectangle(width, height).subtract(hole) if hole is not None else Rectangle(width, height) ) self._edge_color = "black"
[docs] def interact_local(self, ray): """Absorb all incident rays.""" return [] # every ray is absorbed
[docs] def render(self, ax, type: str, **kwargs): """Render block outline.""" super().render(ax, type, color=self._edge_color, **kwargs)
[docs] def get_bbox_local(self): """Return local bounding box from rectangle geometry.""" return self.surface.get_bbox_local()
[docs] class BaseMirror(OpticalComponent): """Base reflective surface"""
[docs] def __init__( self, origin, reflectivity: float = 1.0, transmission: float = 0.0, **kwargs, ): """ Args: origin: Component origin in lab coordinates. reflectivity: Fraction of incoming intensity sent to reflected ray. transmission: Fraction of incoming intensity sent to transmitted ray. """ super().__init__(origin, **kwargs) self.reflectivity = reflectivity self.transmission = transmission # self._edge_color = "green"
[docs] def interact_local(self, ray): """Compute reflected/transmitted rays in local coordinates. The input ``ray`` must already be in this component local frame. A reflected branch is generated when ``reflectivity > 0`` and a straight-through branch when ``transmission > 0``. """ P, t = self.intersect_point_local(ray) normal = self.surface.normal(P) # rays = [] qo = None if ray.qo is None else ray.q_at_z(t) if self.reflectivity > 0: reflected_direction = ( ray.direction - 2 * np.dot(ray.direction, normal) * normal ) reflected_ray = ray.copy( origin=P, direction=reflected_direction, intensity=ray.intensity * self.reflectivity, qo=qo, _pathlength=ray.pathlength(float(t)), ) rays.append(reflected_ray) if self.transmission > 0: transmitted_ray = ray.copy( origin=P, direction=ray.direction, intensity=ray.intensity * self.transmission, qo=qo, _pathlength=ray.pathlength(float(t)), ) rays.append(transmitted_ray) # return rays
[docs] def render(self, ax, type: str, **kwargs): """Render mirror outline.""" super().render(ax, type, color=self._edge_color, **kwargs)
[docs] def get_bbox_local(self): """Return local bounding box from the active surface.""" return self.surface.get_bbox_local()
[docs] class BaseRefraciveSurface(OpticalComponent): """Base interface between two refractive media.""" _n1 = RefractiveIndex("_n1") _n2 = RefractiveIndex("_n2")
[docs] def __init__( self, origin, n1: Union[float, Material] = 1.0, n2: Union[float, Material] = 1.0, reflectivity: float = 0.0, transmission: float = 1.0, **kwargs, ): """ Args: origin: Component origin in lab coordinates. n1: Refractive index for local ``x > 0`` side. n2: Refractive index for local ``x < 0`` side. reflectivity: Additional reflected branch coefficient. transmission: Transmitted branch coefficient. **kwargs: Optional ``surface`` override and rendering options. """ super().__init__(origin, **kwargs) self._n1 = n1 self._n2 = n2 # self.reflectivity = reflectivity self.transmission = transmission # self._edge_color = "gray" self.surface = kwargs.get("surface", Plane()) self.roc = self.surface.roc if hasattr(self.surface, "roc") else np.inf
[docs] def interact_local(self, ray): """Apply Snell refraction and optional reflection in local frame. The method determines incident side from the local surface normal, computes transmitted direction using Snell's law, handles total internal reflection, and updates Gaussian beam ``q`` with ABCD matrices when available. """ P, t = self.intersect_point_local(ray) normal = self.surface.normal(P) n1 = self._n1(ray.wavelength * ray.unit) n2 = self._n2(ray.wavelength * ray.unit) # rays = [] # # ROC is positive if center of curvature is toward nout side ROC = np.inf if hasattr(self, "roc"): # if the roc is a callable function if callable(self.roc): ROC = self.roc(P) else: ROC = self.roc # print("ROC:", ROC) if np.dot(ray.direction, normal) < 0: # incident from n1 to n2 nin, nout = n1, n2 else: # incident from n2 to n1 nin, nout = n2, n1 ROC = -ROC # change the sign of ROC # ABCD_refraction = np.array([[1, 0], [(nin - nout) / (ROC * nout), nin / nout]]) ABCD_reflection = np.array([[1, 0], [2 / ROC, 1]]) # if ray.qo is not None: qin = ray.q_at_z(t) qo_trans = (ABCD_refraction[0, 0] * qin + ABCD_refraction[0, 1]) / ( ABCD_refraction[1, 0] * qin + ABCD_refraction[1, 1] ) else: qo_trans = None # if ray.qo is not None: qin = ray.q_at_z(t) qo_refl = (ABCD_reflection[0, 0] * qin + ABCD_reflection[0, 1]) / ( ABCD_reflection[1, 0] * qin + ABCD_reflection[1, 1] ) else: qo_refl = None r_n = np.dot(ray.direction, normal) r_t = ray.direction - r_n * normal if r_n > 0: transmitted_normal = normal else: transmitted_normal = -normal cos_theta_i = r_n cos_theta_i = np.clip(cos_theta_i, -1, 1) # avoid numerical issues sin_theta_i = np.sqrt(1 - cos_theta_i**2) sin_theta_t = (nin * sin_theta_i) / nout if sin_theta_t < 1: if self.transmission > 0: cos_theta_t = np.sqrt(1 - sin_theta_t**2) transmitted_direction = ( nin / nout ) * r_t + cos_theta_t * transmitted_normal transmitted_ray = ray.copy( origin=P, direction=transmitted_direction, intensity=ray.intensity * self.transmission, qo=qo_trans, _n=nout, _pathlength=ray.pathlength(float(t)), ) rays.append(transmitted_ray) else: # total internal reflection reflected_direction = ray.direction + 2 * cos_theta_i * (-normal) reflected_ray = ray.copy( origin=P, direction=reflected_direction, intensity=ray.intensity, qo=qo_refl, _pathlength=ray.pathlength(float(t)), ) rays.append(reflected_ray) # if self.reflectivity > 0: reflected_direction = ray.direction + 2 * cos_theta_i * (-normal) reflected_ray = ray.copy( origin=P, direction=reflected_direction, intensity=ray.intensity * self.reflectivity, qo=qo_refl, _pathlength=ray.pathlength(float(t)), ) rays.append(reflected_ray) # return rays
[docs] def render(self, ax, type: str, **kwargs): """Render refractive surface outline.""" super().render(ax, type, color=self._edge_color, **kwargs)
[docs] def get_bbox_local(self): """Return local bounding box from the active surface.""" return self.surface.get_bbox_local()
[docs] class Mirror(BaseMirror): """Circular mirror."""
[docs] def __init__( self, origin, radius: float = 0.5, reflectivity: float = 1.0, transmission: float = 0.0, **kwargs, ): """Initialize a circular mirror.""" super().__init__( origin, reflectivity=reflectivity, transmission=transmission, **kwargs ) self.radius = radius self.surface = Circle(radius)
[docs] class SquareMirror(BaseMirror): """Rectangular mirror."""
[docs] def __init__( self, origin, width: float = 1.0, height: float = 1.0, reflectivity: float = 1.0, transmission: float = 0.0, **kwargs, ): """Initialize a rectangular mirror.""" super().__init__( origin, reflectivity=reflectivity, transmission=transmission, **kwargs ) self.width = width self.height = height self.surface = Rectangle(width, height)
[docs] class SquareRefractive(BaseRefraciveSurface): """Rectangular refractive interface."""
[docs] def __init__( self, origin, width: float = 1.0, height: float = 1.0, n1: Union[float, Material] = 1.0, n2: Union[float, Material] = 1.0, reflectivity: float = 0.0, transmission: float = 1.0, **kwargs, ): """Initialize a rectangular refractive surface.""" super().__init__( origin, n1=n1, n2=n2, reflectivity=reflectivity, transmission=transmission, **kwargs, ) self.width = width self.height = height self.surface = Rectangle(width, height)
[docs] class CircleRefractive(BaseRefraciveSurface): """Circular refractive interface."""
[docs] def __init__( self, origin, radius: float = 0.5, n1: Union[float, Material] = 1.0, n2: Union[float, Material] = 1.0, reflectivity: float = 0.0, transmission: float = 1.0, **kwargs, ): """Initialize a circular refractive surface.""" super().__init__( origin, n1=n1, n2=n2, reflectivity=reflectivity, transmission=transmission, **kwargs, ) self.radius = radius self.surface = Circle(radius)
[docs] class SphereRefractive(BaseRefraciveSurface): """Spherical-cap refractive interface."""
[docs] def __init__( self, origin, radius: float = 0.5, height: float = 0.5, n1: Union[float, Material] = 1.0, n2: Union[float, Material] = 1.0, reflectivity: float = 0.0, transmission: float = 1.0, **kwargs, ): """Initialize a spherical refractive surface.""" super().__init__( origin, n1=n1, n2=n2, reflectivity=reflectivity, transmission=transmission, **kwargs, ) self.radius = radius self.height = height self.roc = radius self.surface = Sphere(radius, height)
[docs] class BeamSplitter(SquareMirror): """Rectangular beamsplitter modeled as a partially reflective mirror."""
[docs] def __init__(self, origin, width=1.0, height=1.0, eta: float = 0.5, **kwargs): """Initialize beamsplitter from splitting ratio ``eta``. Args: origin: Component origin in lab coordinates. width: Aperture width. height: Aperture height. eta: Power ratio sent to reflected branch. """ super().__init__( origin, width=width, height=height, reflectivity=np.sqrt(eta), transmission=np.sqrt(1 - eta), **kwargs, ) edgecolor = kwargs.get("edgecolor", Color.SCIENCE_BLUE_DARK) self._edge_color = edgecolor
[docs] def render(self, ax, type, **kwargs): """Render the beamsplitter boundary and optional filled face in 2D.""" super().render(ax, type, **kwargs) facecolor = kwargs.get("facecolor", Color.SCIENCE_BLUE_LIGHT) linewidth = kwargs.get("linewidth", 2) # draw the outer frame with edgecolor, inner cube with facecolor if type == "Z": rect_pts = [ [-self.width / 2, 0, 0], [0, -self.width / 2, 0], [self.width / 2, 0, 0], [0, self.width / 2, 0], ] # print(np.array(rect_pts).shape) rect_pts = self.transform_matrix @ np.transpose( np.array(rect_pts) ) + np.array(self.origin).reshape(-1, 1) # print(rect_pts.shape) ax.add_patch( plt.Polygon( np.transpose(rect_pts)[:, :2], facecolor=facecolor, edgecolor=self._edge_color, linewidth=linewidth, ) )
[docs] class Lens(OpticalComponent): """Thin ideal lens with circular aperture."""
[docs] def __init__( self, origin, focal_length: float, radius: float = 0.5, transmission: float = 1.0, **kwargs, ): """ Args: origin: Component origin in lab coordinates. focal_length: focal length of the lens. radius: Circular aperture radius. transmission: Intensity scaling applied to transmitted ray. """ super().__init__(origin, **kwargs) self.focal_length = focal_length self.transmission = transmission # self.radius = radius self.surface = Circle(radius) # self._edge_color = "purple"
[docs] def interact_local(self, ray): """Apply thin-lens deflection and Gaussian ``q`` propagation locally.""" normal = np.array([1, 0, 0]) # normal in local frame is always x-axis P, t = self.intersect_point_local(ray) if ray.qo is None: qo = None else: q1 = ray.q_at_z(t) qo = q1 / (1 - (q1 / self.focal_length)) # v0 = ray.direction f = self.focal_length # lens equation: v' = v - P/f v = v0 - P / f deflected_ray = ray.copy( origin=P, direction=v, intensity=ray.intensity * self.transmission, qo=qo ) rays = [deflected_ray] return rays
[docs] def render(self, ax, type: str, **kwargs): """Render lens outline.""" return super().render(ax, type, color=self._edge_color, **kwargs)
[docs] def get_bbox_local(self): """Return local bounding box from circular aperture geometry.""" return self.surface.get_bbox_local()
[docs] class CylMirror(BaseMirror): """Cylindrical mirror segment."""
[docs] def __init__( self, origin, radius: float = 0.5, height: float = 1.0, theta_range=(-np.pi, np.pi), **kwargs, ): """Initialize a cylindrical mirror.""" super().__init__(origin, **kwargs) self.radius = radius self.height = height self.surface = Cylinder(radius, height, theta_range)