Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions virtaccl/PyORBIT_Model/pyorbit_va_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,17 @@ def getYAvg(self):
# Class for wire scanners. This class simply returns histograms of the vertical and horizontal positions.
class WSclass(BaseLinacNode):
node_type = "WireScanner"
parameter_list = ['x_histogram', 'y_histogram', 'x_avg', 'y_avg', 'x_sigma', 'y_sigma']
parameter_list = ['x_histogram', 'y_histogram', 'x_avg', 'y_avg', 'x_sigma', 'y_sigma', 'bin_number']

def __init__(self, node_name: str, bin_number: int = 50):
parameters = {'x_histogram': np.array([[-10, 0], [10, 0]]), 'y_histogram': np.array([[-10, 0], [10, 0]]),
'x_avg': 0.0, 'y_avg': 0.0, 'x_sigma': 0.0, 'y_sigma': 0.0}
default_histogram = np.column_stack((np.linspace(-10, 10, bin_number), np.zeros(bin_number)))
parameters = {'x_histogram': default_histogram, 'y_histogram': default_histogram,
'x_avg': 0.0, 'y_avg': 0.0, 'x_sigma': 0.0, 'y_sigma': 0.0, 'bin_number': bin_number}
BaseLinacNode.__init__(self, node_name)
for key, value in parameters.items():
self.addParam(key, value)
self.node_name = node_name
self.setType(WSclass.node_type)
self.bin_number = bin_number

def track(self, paramsDict):
if "bunch" not in paramsDict:
Expand All @@ -165,6 +165,7 @@ def track(self, paramsDict):
part_num = bunch.getSizeGlobal()
x_array = np.zeros(part_num)
y_array = np.zeros(part_num)
bin_number = self.getParam('bin_number')
if part_num > 0:
sync_part = bunch.getSyncParticle()
sync_beta = sync_part.beta()
Expand All @@ -179,13 +180,13 @@ def track(self, paramsDict):
y_avg += y

x_limits = np.array([np.min(x_array), np.max(x_array)]) * 1.1
x_bin_edges = np.linspace(x_limits[0], x_limits[1], self.bin_number + 1)
x_bin_edges = np.linspace(x_limits[0], x_limits[1], bin_number + 1)
x_hist, x_bins = np.histogram(x_array, bins=x_bin_edges)
x_positions = (x_bins[:-1] + x_bins[1:]) / 2
x_out = np.column_stack((x_positions, x_hist))

y_limits = np.array([np.min(y_array), np.max(y_array)]) * 1.1
y_bin_edges = np.linspace(y_limits[0], y_limits[1], self.bin_number + 1)
y_bin_edges = np.linspace(y_limits[0], y_limits[1], bin_number + 1)
y_hist, y_bins = np.histogram(y_array, bins=y_bin_edges)
y_positions = (y_bins[:-1] + y_bins[1:]) / 2
y_out = np.column_stack((y_positions, y_hist))
Expand All @@ -204,8 +205,9 @@ def track(self, paramsDict):
self.setParam('y_sigma', y_sigma)

else:
self.setParam('x_histogram', np.array([[-10, 0], [10, 0]]))
self.setParam('y_histogram', np.array([[-10, 0], [10, 0]]))
default_histogram = np.column_stack((np.linspace(-10, 10, bin_number), np.zeros(bin_number)))
self.setParam('x_histogram', default_histogram)
self.setParam('y_histogram', default_histogram)
self.setParam('x_avg', 0)
self.setParam('y_avg', 0)
self.setParam('x_sigma', 0)
Expand All @@ -223,6 +225,18 @@ def getXAvg(self):
def getYAvg(self):
return self.getParam('y_avg')

def getXSigma(self):
return self.getParam('x_sigma')

def getYSigma(self):
return self.getParam('y_sigma')

def getBinNumber(self):
return self.getParam('bin_number')

def setBinNumber(self, new_bin_number):
self.setParam('bin_number', new_bin_number)


# Class for wire scanners. This class simply returns histograms of the vertical and horizontal positions.
class ScreenClass(BaseLinacNode):
Expand Down
4 changes: 3 additions & 1 deletion virtaccl/site/SNS_Linac/virtual_SNS_linac.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ def build_sns(**kwargs):
beam_line.add_device(bend_device)

wire_scanners = devices_dict["Wire_Scanner"]
bin_number = 50
for name, model_name in wire_scanners.items():
if model_name in element_list:
ws_device = WireScanner(name, model_name)
model.get_element_controller(model_name).get_element().setBinNumber(bin_number)
ws_device = WireScanner(name, model_name, {'bin_number': bin_number})
beam_line.add_device(ws_device)

bpms = devices_dict["BPM"]
Expand Down
74 changes: 56 additions & 18 deletions virtaccl/site/SNS_Linac/virtual_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,21 +164,28 @@ class Cavity(Device):
phase_key = 'phase' # [radians]
amp_key = 'amp' # [arb. units]

# Device Defaults
default_initial_phase = 0 # [radians]
default_initial_amp = 1.0 # [arb. units]

def __init__(self, name: str, model_name: str = None, initial_dict: Dict[str, Any] = None, phase_offset=0,
design_amp=15):
if model_name is None:
self.model_name = name
else:
self.model_name = model_name
super().__init__(name, self.model_name)
initial_dict = {} if initial_dict is None else initial_dict

# Sets initial values for parameters.
if initial_dict is not None:
if Cavity.phase_key in initial_dict:
initial_phase = initial_dict[Cavity.phase_key]
else:
initial_phase = Cavity.default_initial_phase
if Cavity.amp_key in initial_dict:
initial_amp = initial_dict[Cavity.amp_key]
else:
initial_phase = 0
initial_amp = 1.0
initial_amp = Cavity.default_initial_amp

self.design_amp = design_amp # [MV]

Expand Down Expand Up @@ -297,42 +304,60 @@ class WireScanner(Device):
speed_pv = 'Speed_Set' # [mm/s]
x_avg_pv = 'Hor_Mean_gs' # [mm]
y_avg_pv = 'Ver_Mean_gs' # [mm]
x_sigma_pv = 'Hor_Sigma_gs'
y_sigma_pv = 'Ver_Sigma_gs'
x_sigma_pv = 'Hor_Sigma_gs' # [mm]
y_sigma_pv = 'Ver_Sigma_gs' # [mm]
x_profile_pv = 'Hor_Profile' # [arb. units]
x_axis_pv = 'Hor_Axis' # [mm]
y_profile_pv = 'Ver_Profile' # [arb. units]
y_axis_pv = 'Ver_Axis' # [mm]

# PyORBIT parameter keys
x_hist_key = 'x_histogram' # [arb. units]
y_hist_key = 'y_histogram' # [arb. units]
x_hist_key = 'x_histogram' # [m, arb. units]
y_hist_key = 'y_histogram' # [m, arb. units]
x_avg_key = 'x_avg' # [m]
y_avg_key = 'y_avg' # [m]
x_sigma_key = 'x_sigma'
y_sigma_key = 'y_sigma'
x_sigma_key = 'x_sigma' # [m]
y_sigma_key = 'y_sigma' # [m]
bin_number_key = 'bin_number' # [number]

# Device keys
position_key = 'wire_position' # [m]
speed_key = 'wire_speed' # [m]
speed_key = 'wire_speed' # [m/s]

# Device Constants
x_offset = -0.01 # [m]
y_offset = 0.01 # [m]
wire_coeff = 1 / math.sqrt(2)

# Device Defaults
default_initial_position = -0.05 # [m]
default_initial_speed = 1 # [m/s]
default_bin_number = 50 # number

def __init__(self, name: str, model_name: str = None, initial_dict: Dict[str, Any] = None):
if model_name is None:
self.model_name = name
else:
self.model_name = model_name
super().__init__(name, self.model_name)
initial_dict = {} if initial_dict is None else initial_dict

# Changes the units from meters to millimeters for associated PVs.
self.milli_units = LinearTInv(scaler=1e3)

# Sets initial values for parameters.
if initial_dict is not None:
# Use defaults for any unspecified parameters
if WireScanner.position_key in initial_dict:
initial_position = initial_dict[WireScanner.position_key]
else:
initial_position = WireScanner.default_initial_position
if WireScanner.speed_key in initial_dict:
initial_speed = initial_dict[WireScanner.speed_key]
else:
initial_position = -0.05 # [mm]
initial_speed = 1 # [mm/s]
initial_speed = WireScanner.default_initial_speed
if WireScanner.bin_number_key in initial_dict:
bin_number = initial_dict[WireScanner.bin_number_key]
else:
bin_number = WireScanner.default_bin_number

# Defines internal parameters to keep track of the wire position.
self.last_wire_pos = initial_position
Expand All @@ -348,8 +373,12 @@ def __init__(self, name: str, model_name: str = None, initial_dict: Dict[str, An
self.register_measurement(WireScanner.y_charge_pv, noise=xy_noise)
self.register_measurement(WireScanner.x_avg_pv, noise=xy_noise, transform=self.milli_units)
self.register_measurement(WireScanner.y_avg_pv, noise=xy_noise, transform=self.milli_units)
self.register_measurement(WireScanner.x_sigma_pv, transform=self.milli_units)
self.register_measurement(WireScanner.y_sigma_pv, transform=self.milli_units)
self.register_measurement(WireScanner.x_sigma_pv, noise=xy_noise, transform=self.milli_units)
self.register_measurement(WireScanner.y_sigma_pv, noise=xy_noise, transform=self.milli_units)
self.register_measurement(WireScanner.x_profile_pv, definition={'count': bin_number})
self.register_measurement(WireScanner.x_axis_pv, transform=self.milli_units, definition={'count': bin_number})
self.register_measurement(WireScanner.y_profile_pv, definition={'count': bin_number})
self.register_measurement(WireScanner.y_axis_pv, transform=self.milli_units, definition={'count': bin_number})

self.register_setting(WireScanner.speed_pv, default=initial_speed, transform=self.milli_units)
self.register_setting(WireScanner.position_pv, default=initial_position, transform=self.milli_units)
Expand Down Expand Up @@ -394,17 +423,26 @@ def update_measurements(self, new_params: Dict[str, Dict[str, Any]] = None):

ws_params = new_params[self.model_name]
x_hist = ws_params[WireScanner.x_hist_key]
x_axis = x_hist[:, 0]
x_profile = x_hist[:, 1]
y_hist = ws_params[WireScanner.y_hist_key]
y_axis = y_hist[:, 0]
y_profile = y_hist[:, 1]

# Find the location of the vertical wire. Then interpolate the histogram from the model at that value.
x_pos = WireScanner.wire_coeff * wire_pos + WireScanner.x_offset
x_value = np.interp(x_pos, x_hist[:, 0], x_hist[:, 1], left=0, right=0)
x_value = np.interp(x_pos, x_axis, x_profile, left=0, right=0)
self.update_measurement(WireScanner.x_charge_pv, x_value)

y_pos = WireScanner.wire_coeff * wire_pos + WireScanner.y_offset
y_value = np.interp(y_pos, y_hist[:, 0], y_hist[:, 1], left=0, right=0)
y_value = np.interp(y_pos, y_axis, y_profile, left=0, right=0)
self.update_measurement(WireScanner.y_charge_pv, y_value)

self.update_measurement(WireScanner.x_profile_pv, x_profile)
self.update_measurement(WireScanner.x_axis_pv, x_axis)
self.update_measurement(WireScanner.y_profile_pv, y_profile)
self.update_measurement(WireScanner.y_axis_pv, y_axis)

self.update_measurement(WireScanner.x_avg_pv, ws_params[WireScanner.x_avg_key])
self.update_measurement(WireScanner.y_avg_pv, ws_params[WireScanner.y_avg_key])
self.update_measurement(WireScanner.x_sigma_pv, ws_params[WireScanner.x_sigma_key])
Expand Down
Loading