Scripts to produce publication-ready figures.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

318 lines
14 KiB

  1. import numpy, imageio, os
  2. import matplotlib.pyplot as plt
  3. from qtutils import *
  4. from labscript_utils.connections import ConnectionTable
  5. from labscript_utils import device_registry
  6. from labscript_c_extensions.runviewer.resample import resample as _resample
  7. import h5py
  8. class Shot(object):
  9. def __init__(self, path):
  10. self.path = path
  11. # Store list of traces
  12. self._traces = None
  13. # store list of channels
  14. self._channels = None
  15. # store list of markers
  16. self._markers = None
  17. # store list of shutter changes and callibrations
  18. self._shutter_times = None
  19. self._shutter_calibrations = {}
  20. # Load connection table
  21. self.connection_table = ConnectionTable(path)
  22. # open h5 file
  23. with h5py.File(path, 'r') as file:
  24. # Get master pseudoclock
  25. self.master_pseudoclock_name = file['connection table'].attrs['master_pseudoclock']
  26. if isinstance(self.master_pseudoclock_name, bytes):
  27. self.master_pseudoclock_name = self.master_pseudoclock_name.decode('utf8')
  28. else:
  29. self.master_pseudoclock_name = str(self.master_pseudoclock_name)
  30. # get stop time
  31. self.stop_time = file['devices'][self.master_pseudoclock_name].attrs['stop_time']
  32. self.device_names = list(file['devices'].keys())
  33. # Get Shutter Calibrations
  34. if 'calibrations' in file and 'Shutter' in file['calibrations']:
  35. for name, open_delay, close_delay in numpy.array(file['calibrations']['Shutter']):
  36. name = name.decode('utf8') if isinstance(name, bytes) else str(name)
  37. self._shutter_calibrations[name] = [open_delay, close_delay]
  38. def _load(self):
  39. if self._channels is None:
  40. self._channels = {}
  41. if self._traces is None:
  42. self._traces = {}
  43. if self._markers is None:
  44. self._markers = {}
  45. if self._shutter_times is None:
  46. self._shutter_times = {}
  47. self._load_markers()
  48. # Let's walk the connection table, starting with the master pseudoclock
  49. master_pseudoclock_device = self.connection_table.find_by_name(self.master_pseudoclock_name)
  50. self._load_device(master_pseudoclock_device)
  51. def _load_markers(self):
  52. with h5py.File(self.path, 'r') as file:
  53. if "time_markers" in file:
  54. for row in file["time_markers"]:
  55. self._markers[row['time']] = {'color': row['color'].tolist()[0], 'label': row['label']}
  56. elif "runviewer" in file:
  57. for time, val in file["runviewer"]["markers"].attrs.items():
  58. props = val.strip('{}}').rsplit(",", 1)
  59. color = list(map(int, props[0].split(":")[1].strip(" ()").split(",")))
  60. label = props[1].split(":")[1]
  61. self._markers[float(time)] = {'color': color, 'label': label}
  62. if 0 not in self._markers:
  63. self._markers[0] = {'color': [0,0,0], 'label': 'Start'}
  64. if self.stop_time not in self._markers:
  65. self._markers[self.stop_time] = {'color': [0,0,0], 'label' : 'End'}
  66. def add_trace(self, name, trace, parent_device_name, connection):
  67. name = str(name)
  68. self._channels[name] = {'device_name': parent_device_name, 'port': connection}
  69. self._traces[name] = trace
  70. # add shutter times
  71. con = self.connection_table.find_by_name(name)
  72. if con.device_class == "Shutter" and 'open_state' in con.properties:
  73. self.add_shutter_times([(name, con.properties['open_state'])])
  74. # Temporary solution to physical shutter times
  75. def add_shutter_times(self, shutters):
  76. for name, open_state in shutters:
  77. x_values, y_values = self._traces[name]
  78. if len(x_values) > 0:
  79. change_indices = numpy.where(y_values[:-1] != y_values[1:])[0]
  80. change_indices += 1 # use the index of the value that is changed to
  81. change_values = list(zip(x_values[change_indices], y_values[change_indices]))
  82. change_values.insert(0, (x_values[0], y_values[0])) # insert first value
  83. self._shutter_times[name] = {x_value + (self._shutter_calibrations[name][0] if y_value == open_state else self._shutter_calibrations[name][1]): 1 if y_value == open_state else 0 for x_value, y_value in change_values}
  84. def _load_device(self, device, clock=None):
  85. try:
  86. module = device.device_class
  87. device_class = device_registry.get_runviewer_parser(module)
  88. device_instance = device_class(self.path, device)
  89. clocklines_and_triggers = device_instance.get_traces(self.add_trace, clock)
  90. for name, trace in clocklines_and_triggers.items():
  91. child_device = self.connection_table.find_by_name(name)
  92. for grandchild_device_name, grandchild_device in child_device.child_list.items():
  93. self._load_device(grandchild_device, trace)
  94. except Exception:
  95. pass
  96. def resample(self, data_x, data_y, xmin, xmax, stop_time, num_pixels):
  97. """This is a function for downsampling the data before plotting
  98. it. Unlike using nearest neighbour interpolation, this method
  99. preserves the features of the plot. It chooses what value to
  100. use based on what values within a region are most different
  101. from the values it's already chosen. This way, spikes of a short
  102. duration won't just be skipped over as they would with any sort
  103. of interpolation."""
  104. # TODO: Only finely sample the currently visible region. Coarsely sample the rest
  105. # x_out = numpy.float32(numpy.linspace(data_x[0], data_x[-1], 4000*(data_x[-1]-data_x[0])/(xmax-xmin)))
  106. x_out = numpy.float64(numpy.linspace(xmin, xmax, 3 * 2000 + 2))
  107. y_out = numpy.empty(len(x_out) - 1, dtype=numpy.float64)
  108. data_x = numpy.float64(data_x)
  109. data_y = numpy.float64(data_y)
  110. # TODO: investigate only resampling when necessary.
  111. # Currently pyqtgraph sometimes has trouble rendering things
  112. # if you don't resample. If a point is far off the graph,
  113. # and this point is the first that should be drawn for stepMode,
  114. # because there is a long gap before the next point (which is
  115. # visible) then there is a problem.
  116. # Also need to explicitly handle cases where none of the data
  117. # is visible (which resampling does by setting NaNs)
  118. #
  119. # x_data_slice = data_x[(data_x>=xmin)&(data_x<=xmax)]
  120. # print len(data_x)
  121. # if len(x_data_slice) < 3*2000+2:
  122. # x_out = x_data_slice
  123. # y_out = data_y[(data_x>=xmin)&(data_x<=xmax)][:-1]
  124. # logger.info('skipping resampling')
  125. # else:
  126. resampling = True
  127. if resampling:
  128. _resample(data_x, data_y, x_out, y_out, numpy.float64(stop_time))
  129. # self.__resample4(data_x, data_y, x_out, y_out, numpy.float32(stop_time))
  130. else:
  131. x_out, y_out = data_x, data_y
  132. return x_out, y_out
  133. def find_nearest(self, array, value):
  134. array = numpy.asarray(array)
  135. idx = (numpy.abs(array - value)).argmin()
  136. return idx, array[idx]
  137. def generate_ylabel(self, channel_name):
  138. if channel_name == 'AO_MOT_Grad_Coil_current':
  139. label = '$ \\nabla B_{AHH}$'
  140. elif channel_name == 'AO_MOT_CompZ_Coil_current':
  141. label = '$B_{HH}$'
  142. elif channel_name == 'AO_MOT_3D_freq':
  143. label = '$\Delta \\nu$'
  144. elif channel_name == 'AO_MOT_3D_amp':
  145. label = '$P_{3D}$'
  146. elif channel_name == 'AO_Red_Push_amp':
  147. label = '$P_{Push}$'
  148. elif channel_name == 'MOT_2D_Shutter':
  149. label = '$P_{2D}$'
  150. elif channel_name == 'AO_ODT1_Pow':
  151. label = '$P_{cODT1}$'
  152. elif channel_name == 'Imaging_RF_Switch':
  153. label = '$P_{img}$'
  154. elif channel_name == 'MOT_3D_Camera_Trigger':
  155. label = '$Cam$'
  156. return label
  157. def plotSequence(self, Channels, Switches, PlotRange, animate = False, idx = 0):
  158. Traces = self.traces
  159. axs = plt.subplots(len(Channels), figsize = (10, 6), constrained_layout=True, sharex=True)[1]
  160. for i, channel_name in enumerate(Channels):
  161. channel_time = numpy.asarray(Traces[channel_name])[0]
  162. channel_trace = numpy.asarray(Traces[channel_name])[1]
  163. xmin = 0
  164. xmax = self.stop_time
  165. dx = 1e-9
  166. resampled_channel_trace = self.resample(channel_time, channel_trace, xmin, xmax, self.stop_time, dx)[1]
  167. time = numpy.arange(xmin, xmax, (xmax-xmin)/len(resampled_channel_trace))
  168. switch_time = numpy.asarray(Traces[Switches[i]])[0]
  169. switch_trace = numpy.asarray(Traces[Switches[i]])[1]
  170. xmin = 0
  171. xmax = self.stop_time
  172. dx = 1e-9
  173. resampled_switch_trace = self.resample(switch_time, switch_trace, xmin, xmax, self.stop_time, dx)[1]
  174. trace = numpy.multiply(resampled_channel_trace, resampled_switch_trace)
  175. TrimRange = [self.find_nearest(time, PlotRange[0])[0], self.find_nearest(time, PlotRange[1])[0]]
  176. trimmed_time = time[TrimRange[0]:TrimRange[1]]
  177. trimmed_trace = trace[TrimRange[0]:TrimRange[1]]
  178. if not animate:
  179. axs[i].plot(trimmed_time, trimmed_trace, '-b') #'-ob'
  180. axs[i].fill_between(trimmed_time, trimmed_trace, alpha=0.4)
  181. else:
  182. axs[i].plot(time[0:TrimRange[0] + idx], trace[0:TrimRange[0] + idx], '-b') #'-ob'
  183. axs[i].fill_between(time[0:TrimRange[0] + idx], trace[0:TrimRange[0] + idx], alpha=0.4)
  184. axs[i].axvline(x=0, color = 'b', linestyle = '--')
  185. axs[i].axvline(x=4, color = 'b', linestyle = '--')
  186. axs[i].axvline(x=4.315, color = 'b', linestyle = '--')
  187. axs[i].axvline(x=self.stop_time, color = 'b', linestyle = '--')
  188. axs[i].set_xlim(0, self.stop_time)
  189. axs[i].set_ylim(0, max(resampled_channel_trace) + 0.2)
  190. if i == len(Channels)-1:
  191. axs[i].set_xlabel('Time (s)', fontsize = 16)
  192. axs[i].set_ylabel(self.generate_ylabel(channel_name), fontsize = 16)
  193. if not animate:
  194. plt.savefig(f'seq.png', format='png', bbox_inches = "tight")
  195. plt.show()
  196. else:
  197. plt.savefig(f'seq-{idx}.png')
  198. plt.close()
  199. def animateSequence(self, Channels, Switches, PlotRange):
  200. SIZE = 6000
  201. STEP = 58
  202. for i in range(2, SIZE, STEP):
  203. self.plotSequence(Channels, Switches, PlotRange, animate = True, idx = i)
  204. with imageio.get_writer('seq_animated.gif', mode='i', fps = 24, loop = 1) as writer:
  205. for i in range(2, SIZE, STEP):
  206. image = imageio.imread(f'seq-{i}.png')
  207. writer.append_data(image)
  208. for i in range(2, SIZE, STEP):
  209. os.remove(f'seq-{i}.png')
  210. @property
  211. def channels(self):
  212. if self._channels is None:
  213. self._load()
  214. return self._channels.keys()
  215. @property
  216. def markers(self):
  217. if self._markers is None:
  218. self._load()
  219. return self._markers
  220. @property
  221. def traces(self):
  222. if self._traces is None:
  223. self._load()
  224. return self._traces
  225. @property
  226. def shutter_times(self):
  227. if self._shutter_times is None:
  228. self._load()
  229. return self._shutter_times
  230. if __name__ == "__main__":
  231. filepath = 'C:/Users/Karthik/Desktop/2023-09-27_0003_Phase_Transition_0.h5'
  232. shotObj = Shot(filepath)
  233. shotObj._load()
  234. Channels = list(shotObj.channels)
  235. """
  236. 'prawn_clock_line_0', 'prawn_clock_line_1', 'Dummy_1', 'Imaging_RF_Switch', 'Imaging_Shutter', 'MOT_2D_Shutter', 'MOT_3D_RF_Switch', 'MOT_3D_Shutter', 'Push_Beam_Blue_Shutter',
  237. 'Push_Beam_Blue_Switch', 'Push_Beam_Red_Shutter', 'Push_Beam_Red_Switch', 'CDT1_Switch', 'CDT2_Switch', 'MOT_3D_Camera_Trigger', 'MOT_3D_Camera_trigger', 'MOT_CompX_Coil_Switch',
  238. 'MOT_CompY_Coil_Switch', 'MOT_CompZ_Coil_Switch', 'MOT_Grad_Coil_Switch', 'ODT_Axis_Camera_Trigger', 'ODT_Axis_Camera_trigger', 'AO_Blue_Push_amp', 'AO_Blue_Push_freq', 'AO_Imaging_amp',
  239. 'AO_Imaging_freq', 'AO_MOT_3D_amp', 'AO_MOT_3D_freq', 'AO_MOT_CompX_Coil_current', 'AO_MOT_CompX_Coil_voltage', 'AO_MOT_CompY_Coil_current', 'AO_MOT_CompY_Coil_voltage', 'AO_MOT_CompZ_Coil_current',
  240. 'AO_MOT_CompZ_Coil_voltage', 'AO_MOT_Grad_Coil_current', 'AO_MOT_Grad_Coil_voltage', 'AO_Red_Push_amp', 'AO_Red_Push_freq', 'AO_Dummy', 'AO_ODT1_Mod', 'AO_ODT1_Pow', 'AO_ODT2_Pow'
  241. """
  242. """ Without Imaging """
  243. Channels = ['MOT_2D_Shutter', 'AO_Red_Push_amp', 'AO_MOT_3D_amp', 'AO_MOT_3D_freq', 'AO_MOT_Grad_Coil_current', 'AO_MOT_CompZ_Coil_current', 'AO_ODT1_Pow']
  244. Switches = ['MOT_2D_Shutter', 'Push_Beam_Red_Switch', 'MOT_3D_RF_Switch', 'MOT_3D_RF_Switch', 'MOT_Grad_Coil_Switch', 'MOT_CompZ_Coil_Switch', 'CDT1_Switch']
  245. """ With Imaging """
  246. # Channels = ['MOT_2D_Shutter', 'AO_Red_Push_amp', 'AO_MOT_3D_amp', 'AO_MOT_3D_freq', 'AO_MOT_Grad_Coil_current', 'AO_MOT_CompZ_Coil_current', 'AO_ODT1_Pow', 'Imaging_RF_Switch', 'MOT_3D_Camera_Trigger']
  247. # Switches = ['MOT_2D_Shutter', 'Push_Beam_Red_Switch', 'MOT_3D_RF_Switch', 'MOT_3D_RF_Switch', 'MOT_Grad_Coil_Switch', 'MOT_CompZ_Coil_Switch', 'CDT1_Switch', 'Imaging_RF_Switch', 'MOT_3D_Camera_Trigger']
  248. """ Plot Full Sequence """
  249. # TimeRange = [0.0, shotObj.stop_time]
  250. """ Plot till loading of MOT """
  251. # TimeRange = [0.0, 4.0]
  252. """ Plot from loading of MOT till end of sequence"""
  253. TimeRange = [4.0, shotObj.stop_time]
  254. """ Plot sequence """
  255. shotObj.plotSequence(Channels, Switches, PlotRange = TimeRange)
  256. """ Animate sequence """
  257. # shotObj.animateSequence(Channels, Switches, PlotRange = TimeRange)