import argparse, time, numpy as np, os
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, CheckButtons, RadioButtons, Button
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
from scipy.optimize import minimize
[docs]
class InteractiveOpticalTable:
"""GUI/CLI helper around an experimental setup description file (Python).
``vars``
``dict`` mapping *parameter name* → ``[value, min, max]``.
If both ``min`` and ``max`` are ``None`` the parameter is treated as a
Boolean and rendered as a one‑box ``CheckButtons`` widget.
Optionally it can expose
``presets``
``dict[str, dict]`` mapping *preset name* to a *vars‑like* dictionary.
Each preset entry may list only the parameters it wants to override;
any parameter omitted from a preset keeps the current slider state.
``PLOT_TYPE``
*Optional*, either ``"Z"`` (default) or ``"3D"``.
"""
# ---------------------------------------------------------------------
# Construction helpers
# ---------------------------------------------------------------------
[docs]
def __init__(self, fileName: str, FPS: int = 20):
self.fileName = fileName
self.last_update_time = time.time()
self.FPS = FPS # target frames per second for slider drags
self.render = True # live‑render flag
self.changing_preset = False # suppress spurious update() loops
self.optimization_running = False # flag for optimization state
self.optimization_interrupted = False # flag for optimization interruption
# ---- load experimental‑setup file into its own namespace ----------
self.namespace: dict[str, object] = {}
self._create_axes()
# — load user experiment file into *self.namespace*
with open(self.fileName, "r", encoding="utf-8") as f:
self.expsetup_code = f.read()
exec(self.expsetup_code, self.namespace)
self.tunable_vars_setting: dict[str, list] = self.namespace["vars"]
self.presets: dict[str, dict] | None = self.namespace.get("presets")
self._plot_type = self.namespace.get("PLOT_TYPE", "Z").upper()
# print(f"Plot type requested by setup: {self._plot_type}")
# Re‑create main axes in 3‑D if requested
if self._plot_type == "3D":
self.ax0.remove()
self.ax0 = plt.subplot(self.gs[0], projection="3d")
self.namespace["ax0"] = self.ax0
# widgets will be created later
self.sliders: dict[str, Slider | CheckButtons] = {}
self.opt_boxes: dict[str, CheckButtons] = {}
self.param_order = list(self.tunable_vars_setting.keys())
# detect available cost functions
self.cost_funcs = self._discover_cost_functions()
# choose first one by default
if self.cost_funcs:
self.selected_cost_name = next(iter(self.cost_funcs))
self._display_optimization = False
# ------------------------------------------------------------------ #
# ─── FIGURE / AXES HELPERS ─── #
# ------------------------------------------------------------------ #
def _create_axes(self):
"""Initialise *fig*, *ax0* (main view) and *ax1* side panels."""
self.fig = plt.figure(figsize=(12, 6))
self.gs = GridSpec(1, 2, width_ratios=[2.5, 1])
self.ax0 = plt.subplot(self.gs[0])
self.gs1 = GridSpecFromSubplotSpec(3, 1, subplot_spec=self.gs[1], hspace=0.3)
self.ax1 = [plt.subplot(self.gs1[i]) for i in range(3)]
plt.subplots_adjust(left=0.1, right=0.7)
# expose to experiment file
self.namespace.update(
dict(fig=self.fig, ax0=self.ax0, gs1=self.gs1, ax1=self.ax1)
)
self.slider_axes_top = 0.95
# ------------------------------------------------------------------ #
# HELPERS #
# ------------------------------------------------------------------ #
def _clear_axes(self):
self.ax0.clear()
for ax in self.ax1:
ax.clear()
# ---- slider & checkbox placement helpers -------------------------
def _get_slider_ax(self, idx: int):
return self.fig.add_axes(
[0.8, self.slider_axes_top * (1 - idx / 30), 0.1, 0.03]
)
def _get_checkbox_ax(self, idx: int):
# tiny square to the left of the slider
return self.fig.add_axes(
[0.97, self.slider_axes_top * (1 - idx / 30) + 0.005, 0.03, 0.03]
)
# ---- cost-function discovery -------------------------------------
def _discover_cost_functions(self):
funcs = {}
for name, val in self.namespace.items():
if not name.lower().startswith("cost_"):
continue
_name = name[5:] # strip "cost_" prefix
if callable(val):
# user supplied a function → call with namespace
funcs[_name] = lambda ns, f=val: float(f())
else:
# numeric variable – wrap in lambda so we always fetch updated value
funcs[_name] = lambda ns, key=name: float(ns[key])
return funcs
# ---- utility conversions -----------------------------------------
def _sliders_to_params(self):
params = {}
for name, slider in self.sliders.items():
if isinstance(slider, Slider):
params[name] = slider.val
else: # CheckButtons
params[name] = slider.get_status()[0]
return params
def _optimisable_param_names(self):
"""Return list of parameter names whose 'Opt.' box is ticked."""
return [
name
for name, box in self.opt_boxes.items()
if box.get_status()[0]
and not (
self.tunable_vars_setting[name][1] is None
and self.tunable_vars_setting[name][2] is None
)
]
# ------------------------------------------------------------------ #
# ─── CORE GUI ─── #
# ------------------------------------------------------------------ #
[docs]
def create_sliders(self):
"""Slider + "Opt." checkbox for every parameter."""
for i, name in enumerate(self.param_order):
val, vmin, vmax = self.tunable_vars_setting[name]
# --- optimisation enable box ---------------------------------
cb_ax = self._get_checkbox_ax(i)
# 检查是否为布尔变量(复选框)
is_boolean_param = vmin is None and vmax is None
self.opt_boxes[name] = CheckButtons(cb_ax, [""], [False])
if is_boolean_param:
# 对于布尔变量,禁用优化选择框
self.opt_boxes[name].set_active(0)
cb_ax.set_visible(False)
for side in ("top", "right", "left", "bottom"):
cb_ax.spines[side].set_visible(False)
# --- actual slider / bool toggle ----------------------------
sl_ax = self._get_slider_ax(i)
if vmin is None and vmax is None:
self.sliders[name] = CheckButtons(sl_ax, [name], [bool(val)])
else:
self.sliders[name] = Slider(sl_ax, name, vmin, vmax, valinit=val)
def _update_cost_function_val(self):
"""Update the cost function value display."""
if hasattr(self, "cost_value_text"):
cost_value = self._current_cost_value()
self.cost_value_text.set_text(f"Value: {cost_value:.6g}")
# ------------------------------------------------------------------ #
[docs]
def slider_interactive(self):
self.create_sliders()
# ---------- redraw helper --------------------------------------
def _redraw():
# print("Redrawing...")
self._clear_axes()
self.update_table(**self._sliders_to_params())
if self.render:
plt.draw()
def _on_slider_changed(_):
if not self.changing_preset:
now = time.time()
if now - self.last_update_time < 1 / self.FPS:
return
self.last_update_time = now
_redraw()
# attach callbacks
for w in list(self.sliders.values()) + list(self.opt_boxes.values()):
if isinstance(w, Slider):
w.on_changed(_on_slider_changed)
else:
w.on_clicked(_on_slider_changed)
# BLOCKS
# --------------------------------------------------------------
# Finetune
# --------------------------------------------------------------
FINETUNE_X, FINETUNE_Y = 0.02, 0.01
FINTUNE_DX, FINTUNE_DY = 0.15, 0.11
finetune_ax = self.fig.add_axes(
[FINETUNE_X, FINETUNE_Y, FINTUNE_DX, FINTUNE_DY]
)
finetune = RadioButtons(finetune_ax, ["x1", "x10", "x100", "x1000"], active=0)
for side in ("top", "right", "left", "bottom"):
finetune_ax.spines[side].set_visible(False)
self.fig.text(
FINETUNE_X,
FINETUNE_Y + 0.11,
"Finetune Multiplier",
fontsize=12,
weight="bold",
)
def _retune(label):
for idx, name in enumerate(self.param_order):
slider = self.sliders[name]
if not isinstance(slider, Slider):
continue
cur = slider.val
vmin, vmax = self.tunable_vars_setting[name][1:]
span = vmax - vmin
factor = dict(x1=1, x10=10, x100=100, x1000=1000)[label]
new_min = max(cur - span / factor, vmin)
new_max = min(cur + span / factor, vmax)
slider.disconnect_events()
slider.ax.remove()
new_ax = self._get_slider_ax(idx)
self.sliders[name] = Slider(new_ax, name, new_min, new_max, valinit=cur)
self.sliders[name].on_changed(_on_slider_changed)
self.fig.canvas.draw_idle()
finetune.on_clicked(_retune)
# --------------------------------------------------------------
# Presets
# --------------------------------------------------------------
if self.presets:
preset_ax = self.fig.add_axes([0.01, 0.8, 0.1, 0.15])
for side in ("top", "right", "left", "bottom"):
preset_ax.spines[side].set_visible(False)
# Find the current preset
preset_idx = 0
for i, (name, preset) in enumerate(self.presets.items()):
if preset == self.tunable_vars_setting:
preset_idx = i
break
preset_rb = RadioButtons(
preset_ax, list(self.presets.keys()), active=preset_idx
)
self.fig.text(0.02, 0.95, "Presets", fontsize=12, weight="bold")
def _load_preset(label):
preset = self.presets[label]
self.changing_preset = True
for name, values in preset.items():
if name not in self.sliders:
continue
target_val, pmin, pmax = values
widget = self.sliders[name]
if isinstance(widget, Slider):
# Re‑create slider if bounds differ
if (widget.valmin != pmin) or (widget.valmax != pmax):
idx = self.param_order.index(name)
widget.disconnect_events()
widget.ax.remove()
new_ax = self._get_slider_ax(idx)
self.sliders[name] = Slider(
new_ax, name, pmin, pmax, valinit=target_val
)
self.sliders[name].on_changed(_on_slider_changed)
else:
widget.set_val(target_val)
else: # CheckButtons (single box)
current_state = widget.get_status()[0]
desired_state = bool(target_val)
if current_state != desired_state:
widget.set_active(0)
self.changing_preset = False
_on_slider_changed(None)
preset_rb.on_clicked(_load_preset)
# --------------------------------------------------------------
# Cost function selector + OPTIMISE button
# --------------------------------------------------------------
if self._display_optimization:
COST_X, COST_Y = 0.03, 0.38
cost_rb_ax = self.fig.add_axes([COST_X, COST_Y, 0.17, 0.1])
self.fig.text(
COST_X, COST_Y + 0.12, "Cost Function", fontsize=12, weight="bold"
)
BTN_X, BTN_Y = COST_X + 0.12, COST_Y + 0.10
btn_ax = self.fig.add_axes([BTN_X, BTN_Y, 0.06, 0.04])
btn = Button(btn_ax, "OPTIMISE", hovercolor="lightgray")
for s in cost_rb_ax.spines.values():
s.set_visible(False)
def _pick_cost(label):
self.selected_cost_name = label
self._update_cost_function_val
self.fig.canvas.draw_idle()
# # add radio buttons for cost functions
# if self.cost_funcs:
# cost_rb = RadioButtons(
# cost_rb_ax, list(self.cost_funcs.keys()), active=0
# )
# cost_rb.on_clicked(_pick_cost)
# add current cost value display
COST_VAL_X, COST_VAL_Y = COST_X, COST_Y + 0.10
cost_value_text_ax = self.fig.add_axes([COST_VAL_X, COST_VAL_Y, 0.17, 0.02])
cost_value_text_ax.axis("off")
self.cost_value_text = cost_value_text_ax.text(
0, 0, "Value: --", fontsize=10
)
#
COST_PROGRESS_X, COST_PROGRESS_Y = COST_X, COST_Y - 0.18
self.opt_progress_ax = self.fig.add_axes(
[COST_PROGRESS_X, COST_PROGRESS_Y, 0.17, 0.15]
)
self.opt_progress_ax.set_xlabel("Iter Num", fontsize=8)
self.opt_progress_ax.set_ylabel("Cost Function", fontsize=8)
self.opt_progress_ax.set_yscale("log")
(self.opt_progress_line,) = self.opt_progress_ax.plot([], [], "b-")
self.opt_progress_ax.grid(True)
def _on_slider_changed_with_cost_update(_):
if not self.changing_preset:
now = time.time()
if now - self.last_update_time < 1 / self.FPS:
return
self.last_update_time = now
_redraw()
# 更新代价函数值显示
if hasattr(self, "cost_value_text"):
cost_value = self._current_cost_value()
self.cost_value_text.set_text(f"Value: {cost_value:.6g}")
# 替换之前的回调函数
for w in list(self.sliders.values()) + list(self.opt_boxes.values()):
if isinstance(w, Slider):
w.on_changed(_on_slider_changed_with_cost_update)
else:
w.on_clicked(_on_slider_changed_with_cost_update)
if self.cost_funcs:
cost_rb = RadioButtons(
cost_rb_ax, list(self.cost_funcs.keys()), active=0
)
cost_rb.on_clicked(_pick_cost)
cost_value = self._current_cost_value()
self.cost_value_text.set_text(f"Value: {cost_value:.6g}")
# 修改按钮点击处理函数
def _run(_):
# click to stop optimization if running
if self.optimization_running:
self.optimization_interrupted = True
else: # click to start optimization
self.optimization_running = True
self.optimization_interrupted = False
original_color = btn.color
original_hovercolor = btn.hovercolor
original_label = btn.label.get_text()
btn.color = "lightgreen"
btn.hovercolor = "red"
btn.label.set_text("Optimizing...")
btn._hover_fill_color = "red" # 确保悬停颜色更新
# 添加悬停时的tooltip文本
orig_tooltip = btn.ax.get_title()
btn.ax.set_title("Click to abort")
self.fig.canvas.draw_idle()
plt.draw()
plt.pause(0.01)
try:
self.optimize_selected()
finally:
# 恢复按钮原始状态
self.optimization_running = False
btn.color = original_color
btn.hovercolor = original_hovercolor
btn.label.set_text(original_label)
btn._hover_fill_color = original_hovercolor
btn.ax.set_title(orig_tooltip)
self.fig.canvas.draw_idle()
btn.on_clicked(_run)
# --------------------------------------------------------------
# END of blocks
# --------------------------------------------------------------
_redraw()
plt.show()
# Dump final values ------------------------------------------------
print("Final slider values:")
for k, v in self._sliders_to_params().items():
print(f" {k} = {v}")
# Print final slider values for easy copy, for example:
# sol0 = {
# "V2dX": [0.26746, -2, 2],
# "V2dY": [0, -0.02, 0.7],
# "V2dXMLA": [3.30596, 1, 15],
# "V2ADD_MLA": [1, None, None], # for Boolean parameters, use None for min/max
print("\nFinal parameters for copy-paste:")
print("sol0 = {")
for name, slider in self.sliders.items():
if isinstance(slider, Slider):
val = slider.val
vmin, vmax = slider.valmin, slider.valmax
else:
val = slider.get_status()[0]
vmin, vmax = None, None
print(f' "{name}": [{val}, {vmin}, {vmax}],')
print("}")
# ─── BACK-END LOGIC ─── #
# ------------------------------------------------------------------ #
[docs]
def update_table(self, **params):
self.namespace.update(params)
exec(self.expsetup_code, self.namespace)
# ------------ optimisation (selected subset) -----------------------
def _current_cost_value(self):
val = self.cost_funcs[self.selected_cost_name](self.namespace)
return float(val)
[docs]
def optimize_selected(self, maximise=False):
opt_names = self._optimisable_param_names()
if not opt_names:
print("No parameter selected for optimisation.")
return
# separate vectors for varying vs fixed parameters
fixed = {}
initials, bounds = [], []
for n in self.param_order:
w = self.sliders[n]
if n in opt_names:
initials.append(w.val)
if isinstance(w, Slider):
bounds.append((w.valmin, w.valmax))
else:
fixed[n] = w.val if isinstance(w, Slider) else w.get_status()[0]
# initialise optimisation progress plot
iterations = []
costs = []
self.opt_progress_line.set_data(iterations, costs)
# logplot in y axis
# self.opt_progress_ax.set_yscale("log")
self.opt_progress_ax.relim()
self.opt_progress_ax.autoscale_view()
iteration_count = [0]
# -------------------------------------------------------------- #
def _cost(x):
if self.optimization_interrupted:
raise InterruptedError("Optimization aborted by user")
iteration_count[0] += 1
params = fixed | {n: v for n, v in zip(opt_names, x)}
self._clear_axes()
self.update_table(**params)
c = self._current_cost_value()
print(f"Iter {iteration_count[0]}: cost={c: .6g}")
iterations.append(iteration_count[0])
costs.append(c if not maximise else -c)
# update main drawing every 10 iterations
if iteration_count[0] % 10 == 0 and self.render:
self._update_cost_function_val()
self.opt_progress_line.set_data(iterations, costs)
self.opt_progress_ax.relim()
self.opt_progress_ax.autoscale_view()
self.fig.canvas.draw_idle()
plt.draw()
plt.pause(0.01)
return -c if maximise else c
print(f"Optimising over: {opt_names}")
# plt.ion()
try:
res = minimize(
_cost,
initials,
bounds=bounds,
method="Nelder-Mead",
options={
"maxiter": 150, # Larger max number of iterations
"maxfev": 150, # Larger max number of function evaluations
"xatol": 1e-5, # Smaller tolerance on parameter change
"fatol": 1e-7, # Smaller tolerance on function value change
# "adaptive": True, # Enable adaptive step size (optional, good for badly scaled problems)
},
)
except InterruptedError:
print("Optimization interrupted by user.")
res = None
# 清除中断标志
self.optimization_interrupted = False
if res is not None:
# plt.ioff()
print("Result:")
for n, v in zip(opt_names, res.x):
print(f" {n} = {v}")
# push final values back to sliders for visual feedback
for n, v in zip(opt_names, res.x):
w = self.sliders[n]
if isinstance(w, Slider):
w.set_val(v)
else:
desired = bool(round(v))
if w.get_status()[0] != desired:
w.set_active(0)
# update the cost function value display
self._update_cost_function_val()
# one last redraw with finished state
self._clear_axes()
self.update_table(**self._sliders_to_params())
plt.draw()
[docs]
def optimize(self, maximize=False):
initials = [v[0] for v in self.tunable_vars_setting.values()]
bounds = [(v[1], v[2]) for v in self.tunable_vars_setting.values()]
names = self.param_order
n = 0
def cost(vals):
nonlocal n
n += 1
params = {name: val for name, val in zip(names, vals)}
self._clear_axes()
self.update_table(**params)
c = float(self.namespace["cost_func"])
print(f"Iter {n}: cost={c:.5g}; vals={vals}")
return -c if maximize else c
result = minimize(cost, initials, bounds=bounds, method="Nelder-Mead", tol=1e-6)
print("Optimized parameters:")
for name, val in zip(names, result.x):
print(f" {name} = {val}")
return result
[docs]
def scan(
self,
param_x: str,
values_x,
param_y: str,
values_y,
*,
cmap="viridis",
show=True,
):
Nx, Ny = len(values_x), len(values_y)
cost_map = np.empty((Ny, Nx), dtype=float)
original_render = self.render
self.render = False
for j, y in enumerate(values_y):
for i, x in enumerate(values_x):
self.update_table(**{param_x: x, param_y: y})
cost_map[j, i] = self._current_cost_value()
self.render = original_render
self._clear_axes()
im = self.ax0.imshow(
cost_map,
origin="lower",
aspect="auto",
extent=[values_x[0], values_x[-1], values_y[0], values_y[-1]],
cmap=cmap,
)
self.fig.colorbar(im, ax=self.ax0).set_label(self.selected_cost_name)
self.ax0.set_xlabel(param_x)
self.ax0.set_ylabel(param_y)
self.ax0.set_title("Parameter scan – cost map")
if show:
plt.show()
return cost_map
# ---------------------------------------------------------------------
# Small CLI entry‑point ------------------------------------------------
# ---------------------------------------------------------------------
if __name__ == "__main__":
FILE_NAME = os.path.join(
os.path.dirname(__file__),
# "../examples",
"../demo",
# "ripa_gen2_lensless.py",
"ripa_gen2_2nd.py",
# "prism_refl.py",
)
MODE = "interact" # 'interact' | 'optimize' | 'scan'
table = InteractiveOpticalTable(fileName=FILE_NAME)
# table._display_optimization = False # enable cost function display
table._display_optimization = True # enable cost function display
match MODE:
case "interact":
table.slider_interactive()
case "optimize":
table.optimize(maximize=True)
case "scan":
table.scan(
"V2d4F",
np.linspace(-3.3, 2.7, 11),
"V2dXMLA",
np.linspace(0.7, 6.7, 11),
)