pqcprep.plotting_tools

Collection of functions relating to plotting and visualisation.

  1"""
  2Collection of functions relating to plotting and visualisation. 
  3"""
  4
  5import numpy as np 
  6import matplotlib.pyplot as plt 
  7from matplotlib import rcParams, cm
  8import os
  9from .psi_tools import x_trans_arr, get_phase_target, psi, A 
 10from .file_tools import vars_to_name_str, vars_to_name_str_ampl
 11from .resource_tools import load_data_H23
 12from .phase_tools import full_encode, phase_from_state
 13from .binary_tools import dec_to_bin, bin_to_dec
 14
 15# general settings
 16rcParams['mathtext.fontset'] = 'stix' 
 17rcParams['font.family'] = 'STIXGeneral' 
 18width=0.75 
 19""" @private """
 20color='black' 
 21""" @private """
 22fontsize=28 
 23""" @private """
 24titlesize=32
 25""" @private """
 26ticksize=22
 27""" @private """
 28figsize=(10,10)
 29""" @private """
 30
 31def benchmark_plots(arg_dict,DIR, show=False, pdf=False):  
 32    """
 33    Generates plots visualising the outputs produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()`. 
 34
 35    Arguments:
 36    ---- 
 37    - **arg_dict** : *dict* 
 38
 39        Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args()`. 
 40
 41    - **show** : *boolean* 
 42
 43        If True, display plots. Default is False. 
 44
 45    - **pdf** : *boolean* 
 46
 47        If True, save plots in pdf format. If False, save plots in png format. Default is False. 
 48
 49    - **DIR** : *str*
 50
 51        Parent directory for output files.     
 52
 53    Returns:
 54    ---
 55
 56    Saves plots corresponding to each of the files produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()` (apart from `metrics_<NAME_STR>.npy`)
 57    in the directory `DIR/plots`. 
 58   
 59    """
 60
 61    name_str = vars_to_name_str(arg_dict)
 62    pdf_str = ".pdf" if pdf else ".png"
 63
 64    # data to plot 
 65    arrs =["loss", "mismatch", "grad", "vargrad"]
 66    labels=["Loss", "Mismatch", r"$|\nabla_\theta W|^2$",r"Var($\partial_\theta W$)" ]
 67
 68    for i in np.arange(len(arrs)):
 69        arr = np.load(os.path.join(DIR, "outputs", f"{arrs[i]}{name_str}.npy"))
 70
 71        plt.figure(figsize=figsize)
 72        plt.xlabel("Epoch", fontsize=fontsize)
 73        plt.ylabel(labels[i], fontsize=fontsize)
 74        plt.yscale('log')
 75        plt.tick_params(axis="both", labelsize=ticksize)
 76        plt.scatter(np.arange(len(arr))+1,arr,color="red")
 77
 78        plt.tight_layout()
 79        plt.savefig(os.path.join(DIR, "plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
 80
 81        if show:
 82            plt.show()
 83        plt.close()      
 84
 85    # plot mismatch by state 
 86    dic = np.load(os.path.join(DIR, "outputs", f"mismatch_by_state{name_str}.npy"),allow_pickle='TRUE').item()
 87    mismatch = list(dic.values())
 88    x_arr = x_trans_arr(arg_dict["n"])
 89
 90    plt.figure(figsize=figsize)
 91    plt.xlabel(r"$f$ (Hz)", fontsize=fontsize)
 92    plt.ylabel("Mismatch", fontsize=fontsize)
 93    plt.yscale('log')
 94    plt.tick_params(axis="both", labelsize=ticksize)
 95    plt.scatter(x_arr,mismatch,color="red")
 96
 97    plt.tight_layout()
 98    plt.savefig(os.path.join(DIR, "plots", f"mismatch_by_state{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
 99
100    if show:
101        plt.show()
102    plt.close()      
103
104    # plot extracted phase function   
105    mint = arg_dict["mint"] if arg_dict["mint"] != None else arg_dict["m"] 
106    if arg_dict["phase_reduce"]:
107        mint = 0      
108    phase = np.load(os.path.join(DIR, "outputs", f"phase{name_str}.npy"))
109    phase_target_rounded = get_phase_target(m=arg_dict["m"], psi_mode=arg_dict["func_str"], phase_reduce=arg_dict["phase_reduce"], mint=mint)
110    phase_target = psi(np.arange(2**arg_dict["n"]),mode=arg_dict["func_str"])
111
112    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
113
114    ax[0].plot(x_arr,phase_target, color="black")
115    ax[0].plot(x_arr,phase_target_rounded, color="gray", ls="--")
116    ax[0].scatter(x_arr,phase, color="red")
117
118    ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize)
119    ax[0].tick_params(axis="both", labelsize=ticksize)
120    ax[0].set_xticks([])
121
122    ax[1].scatter(x_arr,phase_target_rounded-phase, color="red")
123    ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize)
124    ax[1].tick_params(axis="both", labelsize=ticksize)
125    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
126
127    fig.tight_layout()
128    fig.savefig(os.path.join(DIR, "plots", f"phase{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
129
130    if show:
131        plt.show()
132    plt.close()      
133
134
135    return 0
136
137def benchmark_plots_ampl(arg_dict,DIR, show=False, pdf=False):    
138    """
139    Generates plots visualising the outputs produced by `pqcprep.training_tools.ampl_train_QNN()`.
140
141     Arguments:
142    ---- 
143    - **arg_dict** : *dict* 
144
145        Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args_ampl()`. 
146
147    - **show** : *boolean* 
148
149        If True, display plots. Default is False. 
150
151    - **pdf** : *boolean* 
152
153        If True, save plots in pdf format. Default is False. 
154
155    - **DIR** : *str*
156
157        Parent directory for output files.     
158
159    Returns:
160    ---
161
162    Saves plots corresponding to each of the files produced by `pqcprep.training_tools.ampl_train_QNN()` 
163    in the directory `DIR/ampl_plots`.
164    """
165    name_str = vars_to_name_str_ampl(arg_dict)
166    pdf_str = ".pdf" if pdf else ".png"
167
168    # data to plot 
169    arrs =["loss", "mismatch"]
170    labels=["Loss", "Mismatch"]
171
172    for i in np.arange(len(arrs)):
173        arr = np.load(os.path.join(DIR, "ampl_outputs", f"{arrs[i]}{name_str}.npy"))
174
175        plt.figure(figsize=figsize)
176        plt.xlabel("Epoch", fontsize=fontsize)
177        plt.ylabel(labels[i], fontsize=fontsize)
178        plt.yscale('log')
179        plt.tick_params(axis="both", labelsize=ticksize)
180        plt.scatter(np.arange(len(arr))+1,arr,color="red")
181
182        plt.tight_layout()
183        plt.savefig(os.path.join(DIR, "ampl_plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
184
185        if show:
186            plt.show()
187        plt.close()      
188
189    # plot amplitude 
190    x_arr = x_trans_arr(arg_dict["n"])
191
192    ampl_vec = np.abs(np.load(os.path.join(DIR, "ampl_outputs", f"statevec{name_str}.npy")))
193    ampl_target = np.array([A(i, mode=arg_dict["func_str"]) for i in x_arr])
194    ampl_target /= np.sqrt(np.sum(ampl_target**2))
195
196    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
197    ax[0].plot(x_arr,ampl_target, color="black")
198    ax[0].scatter(x_arr,ampl_vec,color="red")
199    ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize)
200    ax[0].tick_params(axis="both", labelsize=ticksize)
201    ax[0].set_xticks([])
202    ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red")
203    ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize)
204    ax[1].tick_params(axis="both", labelsize=ticksize)
205    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
206
207    fig.tight_layout()
208    fig.savefig(os.path.join(DIR, "ampl_plots", f"amplitude{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
209
210    if show:
211        plt.show()  
212    plt.close()          
213
214    return 0
215
216def waveform_plots(name_str_phase, name_str_ampl, in_dir, out_dir,comp, no_UA=False,show=False, pdf=False):
217    """
218    Plot the amplitude, phase, and full waveform of a state prepared using QCNNs. 
219
220    This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 
221
222    Arguments:
223    ----
224
225    - **name_str_phase** : *str* 
226
227        Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 
228
229    - **name_str_ampl** : *str* 
230
231        Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True.      
232
233    - **in_dir** : *str*
234
235        Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 
236
237    - **out_dir** : *str*
238
239        Directory for output files. 
240
241    - **comp** : *str*, *optional*
242
243        Compare outputs to results from Hayes 2023: if `'GR'` compare to results obtained using the Grover-Rudolph algorithm and if `'QGAN'` compare to results 
244        obtained using the QGAN.  
245
246    - **no_UA** : *boolean*
247
248        If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.      
249    
250    - **show** : *boolean* 
251
252        If True, display plots. Default is False. 
253
254    - **pdf** : *boolean* 
255
256        If True, save plots in pdf format. Default is False. 
257
258
259    Returns:
260    ----
261
262    Plots of the amplitude, phase, and full waveform saved in the directory `out_dir`.
263    """
264
265    pdf_str = ".pdf" if pdf else ".png"
266
267    # read in given files 
268    weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 
269    weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 
270
271    # extract information from name strings 
272    phase_reduce = '(PR)' in name_str_phase
273    real_p = '(r)' in name_str_phase
274    if '(CL)' in name_str_phase:
275        repeat_params="CL"
276    elif '(IL)' in name_str_phase:
277        repeat_params="IL" 
278    elif '(both)' in name_str_phase:
279        repeat_params="both"       
280    else:
281        repeat_params=None 
282
283    psi_mode=None 
284    func_A=None 
285    for i in ["psi", "linear", "quadratic", "sine"]:
286        if name_str_phase.count(i)==1:
287            psi_mode=i 
288            break 
289    if psi_mode==None:
290        raise ValueError("Name string could not be interpreted.")    
291    for i in ["x76", "linear", "uniform"]:
292        if name_str_ampl.count(i)==1:
293            func_A=i 
294            break     
295    if func_A==None:
296        raise ValueError("Name string could not be interpreted.")   
297    
298    n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3]
299    if int(n) < 10:
300        L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 
301        m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5]
302    else: 
303        L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 
304        m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6]   
305
306    mint=None 
307    for i in np.arange(int(m)+1):
308        if name_str_phase.count(f"({i})")==1:
309            mint=i 
310            break 
311    s = f"_{n}_{m}({mint})_"
312    if name_str_phase.count(s) != 1 :
313        raise ValueError("Name string could not be interpreted.")   
314    else:
315        L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2]
316    s = f"_{n}_{m}({mint})_{L_phase}_"
317    if name_str_phase.count(s) != 1 :
318        raise ValueError("Name string could not be interpreted.")
319
320    n= int(n)
321    m=int(m)
322    L_phase=int(L_phase)
323    L_A=int(L_A)
324    mint=int(mint)    
325
326    # generate x array 
327    x_arr = x_trans_arr(n)
328
329    # calculate target outputs 
330    ampl_target = np.array([A(i, mode=func_A) for i in x_arr])
331    ampl_target = ampl_target / np.sqrt(np.sum(ampl_target**2))
332
333    phase_rounded = get_phase_target(m=m, psi_mode=psi_mode, phase_reduce=phase_reduce, mint=mint)
334    phase_target = psi(np.arange(2**n),mode=psi_mode)
335
336    h_target = ampl_target * np.exp(2*1.j*np.pi* phase_target)
337    wave_real_target = np.real(h_target)
338    wave_im_target = np.imag(h_target)
339
340    h_target_rounded = ampl_target * np.exp(2*1.j*np.pi* phase_rounded)
341    wave_real_target_rounded = np.real(h_target_rounded)
342    wave_im_target_rounded = np.imag(h_target_rounded)
343
344    # load Hayes 2023 data for comparison  
345    if comp=="GR":
346        ampl_vec_comp=np.abs(load_data_H23("amp_state_GR"))
347        h_comp= load_data_H23("full_state_GR")
348        psi_LPF = load_data_H23("psi_LPF_processed")  
349        comp_label="GR"
350        comp=True
351    elif comp=="QGAN":
352        ampl_vec_comp=np.abs(load_data_H23("amp_state_QGAN"))
353        h_comp= load_data_H23("full_state_QGAN")
354        psi_LPF = load_data_H23("psi_LPF_processed")  
355        comp_label="QGAN"
356        comp=True
357    else:
358        comp=False 
359
360    if comp and no_UA:
361        comp=False    
362        print("Comparison to Hayes 2023 not shown due to incompatible amplitude function.") 
363    if comp and not (psi_mode=="psi" and func_A=="x76"):
364        comp=False
365        print("Comparison to Hayes 2023 not shown due to incompatible amplitude or phase function.")    
366
367    # get PQC state
368    state_vec = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA)
369    phase = phase_from_state(state_vec)
370
371    ampl_vec = np.abs(state_vec)
372    real_wave =np.real(state_vec)
373    im_wave = np.imag(state_vec) 
374
375    # plot amplitude 
376    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
377
378    ax[0].plot(x_arr,ampl_target, color="black")
379    if comp:
380        ax[0].scatter(x_arr,ampl_vec_comp,label=comp_label, color="blue")
381    ax[0].scatter(x_arr,ampl_vec,label="QCNN", color="red")
382
383    ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize)
384    if comp: ax[0].legend(fontsize=fontsize, loc='upper right')
385    ax[0].tick_params(axis="both", labelsize=ticksize)
386    ax[0].set_xticks([])
387
388    if comp:
389        ax[1].scatter(x_arr,ampl_target-ampl_vec_comp,label=comp_label, color="blue")
390    ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red")
391
392    ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize)
393    ax[1].tick_params(axis="both", labelsize=ticksize)
394    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
395
396    fig.tight_layout()
397    fig.savefig(os.path.join(out_dir,f"amplitude{pdf_str}"), bbox_inches='tight', dpi=500)
398
399    if show:
400        plt.show()
401    plt.close() 
402
403    # plot phase 
404    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
405
406    ax[0].plot(x_arr,phase_target, color="black")
407    ax[0].plot(x_arr,phase_rounded, color="gray", ls="--")
408    if comp:
409        ax[0].scatter(x_arr,psi_LPF,label="LPF", color="blue")
410    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
411
412    ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize)
413    if comp: ax[0].legend(fontsize=fontsize, loc='upper right')
414    ax[0].tick_params(axis="both", labelsize=ticksize)
415    ax[0].set_xticks([])
416
417    if comp:
418        ax[1].scatter(x_arr,phase_target-psi_LPF,label="LPF + ", color="blue")
419    ax[1].scatter(x_arr,phase_rounded-phase,label="QCNN", color="red")
420
421    ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize)
422    ax[1].tick_params(axis="both", labelsize=ticksize)
423    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
424
425    fig.tight_layout()
426    fig.savefig(os.path.join(out_dir,f"phase{pdf_str}"), bbox_inches='tight', dpi=500)
427
428    if show:
429        plt.show()
430    plt.close() 
431
432    # plot full waveform 
433    fig, ax = plt.subplots(2, 2, figsize=(2*figsize[0],figsize[1]), gridspec_kw={'height_ratios': [1.5, 1]})
434    fig.subplots_adjust()
435
436    ax[0,0].plot(x_arr,wave_real_target, color="black")
437    ax[0,0].plot(x_arr,wave_real_target_rounded, color="gray", ls="--")
438    if comp:
439        ax[0,0].scatter(x_arr,np.real(h_comp),label="LPF + "+comp_label, color="blue")
440    ax[0,0].scatter(x_arr,real_wave,label="QCNN", color="red")
441
442    ax[0,0].set_ylabel(r"$\Re[\tilde h(f)]$", fontsize=fontsize)
443    if comp: ax[0,0].legend(fontsize=fontsize, loc='upper right')
444    ax[0,0].tick_params(axis="both", labelsize=ticksize)
445    ax[0,0].set_xticks([])
446
447    if comp:
448        ax[1,0].scatter(x_arr,wave_real_target -np.real(h_comp),label="LPF"+comp_label, color="blue")
449    ax[1,0].scatter(x_arr,wave_real_target_rounded -real_wave,label="QCNN", color="red")
450
451    ax[1,0].set_ylabel(r"$\Delta \Re[\tilde h(f)]$", fontsize=fontsize)
452    ax[1,0].tick_params(axis="both", labelsize=ticksize)
453    ax[1,0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
454
455    ax[0,1].plot(x_arr,wave_im_target, color="black")
456    ax[0,1].plot(x_arr,wave_im_target_rounded, color="gray", ls="--")
457    if comp:
458        ax[0,1].scatter(x_arr,np.imag(h_comp),label="LPF + "+comp_label, color="blue")
459    ax[0,1].scatter(x_arr,im_wave,label="QCNN", color="red")
460
461    ax[0,1].set_ylabel(r"$\Im[\tilde h(f)]$", fontsize=fontsize)
462    if comp: ax[0,1].legend(fontsize=fontsize, loc='upper right')
463    ax[0,1].tick_params(axis="both", labelsize=ticksize)
464    ax[0,1].set_xticks([])
465
466    if comp:
467        ax[1,1].scatter(x_arr,wave_im_target -np.imag(h_comp),label="LPF + "+comp_label, color="blue")
468    ax[1,1].scatter(x_arr,wave_im_target_rounded -im_wave,label="QCNN", color="red")
469
470    ax[1,1].set_ylabel(r"$\Delta \Im[\tilde h(f)]$", fontsize=fontsize)
471    ax[1,1].tick_params(axis="both", labelsize=ticksize)
472    ax[1,1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
473
474    fig.tight_layout()
475    fig.savefig(os.path.join(out_dir,f"waveform{pdf_str}"), bbox_inches='tight', dpi=500)
476
477    if show:
478        plt.show()
479    plt.close()
480
481    return 0 
482
483
484def two_register_plots(name_str_phase, name_str_ampl,in_dir, out_dir, operators="QRQ",no_UA=False, show=False, pdf=False):
485    """
486    Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs. 
487
488    This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 
489
490    Arguments:
491    ----
492
493    - **name_str_phase** : *str* 
494
495        Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 
496
497    - **name_str_ampl** : *str* 
498
499        Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True.      
500
501    - **in_dir** : *str*
502
503        Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 
504
505    - **out_dir** : *str*
506
507        Directory for output files. 
508
509    - **operators** : *str*, 
510
511        Which operators to apply to the registers. Options are `'QRQ'` for full phase extraction, `'RQ'` for partial phase extraction, and `'Q'` for function 
512        evaluation only. Default is `'QRQ'`.   
513
514    - **no_UA** : *boolean*
515
516        If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.      
517    
518    - **show** : *boolean* 
519
520        If True, display plots. Default is False. 
521
522    - **pdf** : *boolean* 
523
524        If True, save plots in pdf format. Default is False. 
525
526
527    Returns:
528    ----
529
530    Plots of the amplitude and phase of the two-register state saved in the directory `out_dir`.
531    """
532    pdf_str = ".pdf" if pdf else ".png"
533
534    # read in given files 
535    weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 
536    weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 
537
538    # extract information from name strings 
539    real_p = '(r)' in name_str_phase
540    if '(CL)' in name_str_phase:
541        repeat_params="CL"
542    elif '(IL)' in name_str_phase:
543        repeat_params="IL" 
544    elif '(both)' in name_str_phase:
545        repeat_params="both"       
546    else:
547        repeat_params=None 
548
549    psi_mode=None 
550    func_A=None 
551    for i in ["psi", "linear", "quadratic", "sine"]:
552        if name_str_phase.count(i)==1:
553            psi_mode=i 
554            break 
555    if psi_mode==None:
556        raise ValueError("Name string could not be interpreted.")    
557    for i in ["x76", "linear", "uniform"]:
558        if name_str_ampl.count(i)==1:
559            func_A=i 
560            break     
561    if func_A==None:
562        raise ValueError("Name string could not be interpreted.")   
563    
564    n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3]
565    if int(n) < 10:
566        L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 
567        m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5]
568    else: 
569        L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 
570        m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6]   
571
572    mint=None 
573    for i in np.arange(int(m)+1):
574        if name_str_phase.count(f"({i})")==1:
575            mint=i 
576            break 
577    s = f"_{n}_{m}({mint})_"
578    if name_str_phase.count(s) != 1 :
579        raise ValueError("Name string could not be interpreted.")   
580    else:
581        L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2]
582    s = f"_{n}_{m}({mint})_{L_phase}_"
583    if name_str_phase.count(s) != 1 :
584        raise ValueError("Name string could not be interpreted.")
585
586    n= int(n)
587    m=int(m)
588    L_phase=int(L_phase)
589    L_A=int(L_A)
590    mint=int(mint)    
591
592    # generate x array 
593    x_arr = x_trans_arr(n)
594
595    # get PQC state
596    state_vec, state_vec_full = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA, full_state_vec=True, operators=operators)
597    phase = phase_from_state(state_vec)
598    full_phase = phase_from_state(state_vec_full)
599    
600    # plot amplitude 
601    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True)
602    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
603
604    ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize)
605    ax[0].tick_params(axis="both", labelsize=ticksize)
606    im=ax[0].pcolormesh(x_arr,np.arange(2**m), np.abs(state_vec_full))
607    im.set_clim(np.min(np.abs(state_vec_full)),np.max(np.abs(state_vec_full)))
608    locs = ax[0].get_yticks() 
609    ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
610    ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
611    ax[1].set_visible(False)    
612        
613    cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05)
614    cb.ax.tick_params(labelsize=ticksize)
615    
616    fig.savefig(os.path.join(out_dir,f"two_reg_ampl{pdf_str}"), bbox_inches='tight', dpi=500)
617
618    if show:
619        plt.show()
620    plt.close()    
621
622    # plot phase 
623    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True)
624    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
625
626    ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize)
627    ax[0].tick_params(axis="both", labelsize=ticksize)
628    im=ax[0].pcolormesh(x_arr,np.arange(2**m), full_phase, cmap="hsv")
629    im.set_clim(0, 2*np.pi)
630    locs = ax[0].get_yticks() 
631    ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
632    ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
633    ax[1].set_visible(False)    
634    
635    cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05)
636    cb.ax.tick_params(labelsize=ticksize)
637   
638    fig.savefig(os.path.join(out_dir,f"two_reg_phase{pdf_str}"), bbox_inches='tight', dpi=500)
639     
640    if show:
641        plt.show()
642    plt.close()    
643
644    # plot amplitude and phase  
645    X = x_arr
646    Y = np.arange(2**m)
647    X, Y = np.meshgrid(X, Y)
648    Z = np.abs(state_vec_full)
649    P_norm = (full_phase) / (2* np.pi)  # normalise to value in [0,1]
650
651    # Plot the surface
652    fig, ax = plt.subplots(figsize=(10,10), subplot_kw={"projection": "3d"}, constrained_layout=False)
653    surf = ax.plot_surface(X, Y, Z,facecolors=cm.hsv(P_norm),linewidth=0, antialiased=True, shade=False)
654    ax.set_xlabel(r"$f$ (Hz)", fontsize=fontsize, labelpad=20)
655    ax.set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize, labelpad=20)
656    locs = ax.get_yticks() 
657    ax.set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
658    ax.tick_params(axis="both", labelsize=ticksize)
659
660    cb=fig.colorbar(cm.ScalarMappable(cmap=cm.hsv, norm=surf.norm), ax = ax , shrink = 0.6, aspect = 5)
661    cb.ax.tick_params(labelsize=ticksize)
662    locs = cb.ax.get_yticks()
663    cb.ax.set_yticks( [0, 0.25,0.5,0.75,1], [r'$0$',r'$\frac{\pi}{2}$', r'$\pi$',r'$\frac{3\pi}{2}$' , r'$2\pi$'])
664    
665    fig.savefig(os.path.join(out_dir,f"two_reg_3D{pdf_str}"), bbox_inches='tight', dpi=500)
666
667    if show:
668        plt.show()
669    plt.close()   
670
671    return 0 
def benchmark_plots(arg_dict, DIR, show=False, pdf=False):
 32def benchmark_plots(arg_dict,DIR, show=False, pdf=False):  
 33    """
 34    Generates plots visualising the outputs produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()`. 
 35
 36    Arguments:
 37    ---- 
 38    - **arg_dict** : *dict* 
 39
 40        Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args()`. 
 41
 42    - **show** : *boolean* 
 43
 44        If True, display plots. Default is False. 
 45
 46    - **pdf** : *boolean* 
 47
 48        If True, save plots in pdf format. If False, save plots in png format. Default is False. 
 49
 50    - **DIR** : *str*
 51
 52        Parent directory for output files.     
 53
 54    Returns:
 55    ---
 56
 57    Saves plots corresponding to each of the files produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()` (apart from `metrics_<NAME_STR>.npy`)
 58    in the directory `DIR/plots`. 
 59   
 60    """
 61
 62    name_str = vars_to_name_str(arg_dict)
 63    pdf_str = ".pdf" if pdf else ".png"
 64
 65    # data to plot 
 66    arrs =["loss", "mismatch", "grad", "vargrad"]
 67    labels=["Loss", "Mismatch", r"$|\nabla_\theta W|^2$",r"Var($\partial_\theta W$)" ]
 68
 69    for i in np.arange(len(arrs)):
 70        arr = np.load(os.path.join(DIR, "outputs", f"{arrs[i]}{name_str}.npy"))
 71
 72        plt.figure(figsize=figsize)
 73        plt.xlabel("Epoch", fontsize=fontsize)
 74        plt.ylabel(labels[i], fontsize=fontsize)
 75        plt.yscale('log')
 76        plt.tick_params(axis="both", labelsize=ticksize)
 77        plt.scatter(np.arange(len(arr))+1,arr,color="red")
 78
 79        plt.tight_layout()
 80        plt.savefig(os.path.join(DIR, "plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
 81
 82        if show:
 83            plt.show()
 84        plt.close()      
 85
 86    # plot mismatch by state 
 87    dic = np.load(os.path.join(DIR, "outputs", f"mismatch_by_state{name_str}.npy"),allow_pickle='TRUE').item()
 88    mismatch = list(dic.values())
 89    x_arr = x_trans_arr(arg_dict["n"])
 90
 91    plt.figure(figsize=figsize)
 92    plt.xlabel(r"$f$ (Hz)", fontsize=fontsize)
 93    plt.ylabel("Mismatch", fontsize=fontsize)
 94    plt.yscale('log')
 95    plt.tick_params(axis="both", labelsize=ticksize)
 96    plt.scatter(x_arr,mismatch,color="red")
 97
 98    plt.tight_layout()
 99    plt.savefig(os.path.join(DIR, "plots", f"mismatch_by_state{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
100
101    if show:
102        plt.show()
103    plt.close()      
104
105    # plot extracted phase function   
106    mint = arg_dict["mint"] if arg_dict["mint"] != None else arg_dict["m"] 
107    if arg_dict["phase_reduce"]:
108        mint = 0      
109    phase = np.load(os.path.join(DIR, "outputs", f"phase{name_str}.npy"))
110    phase_target_rounded = get_phase_target(m=arg_dict["m"], psi_mode=arg_dict["func_str"], phase_reduce=arg_dict["phase_reduce"], mint=mint)
111    phase_target = psi(np.arange(2**arg_dict["n"]),mode=arg_dict["func_str"])
112
113    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
114
115    ax[0].plot(x_arr,phase_target, color="black")
116    ax[0].plot(x_arr,phase_target_rounded, color="gray", ls="--")
117    ax[0].scatter(x_arr,phase, color="red")
118
119    ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize)
120    ax[0].tick_params(axis="both", labelsize=ticksize)
121    ax[0].set_xticks([])
122
123    ax[1].scatter(x_arr,phase_target_rounded-phase, color="red")
124    ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize)
125    ax[1].tick_params(axis="both", labelsize=ticksize)
126    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
127
128    fig.tight_layout()
129    fig.savefig(os.path.join(DIR, "plots", f"phase{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
130
131    if show:
132        plt.show()
133    plt.close()      
134
135
136    return 0

Generates plots visualising the outputs produced by pqcprep.training_tools.train_QNN() and pqcprep.training_tools.test_QNN().

Arguments:

  • arg_dict : dict

    Dictionary containing varialbe information created using pqcprep.file_tools.compress_args().

  • show : boolean

    If True, display plots. Default is False.

  • pdf : boolean

    If True, save plots in pdf format. If False, save plots in png format. Default is False.

  • DIR : str

    Parent directory for output files.

Returns:

Saves plots corresponding to each of the files produced by pqcprep.training_tools.train_QNN() and pqcprep.training_tools.test_QNN() (apart from metrics_<NAME_STR>.npy) in the directory DIR/plots.

def benchmark_plots_ampl(arg_dict, DIR, show=False, pdf=False):
138def benchmark_plots_ampl(arg_dict,DIR, show=False, pdf=False):    
139    """
140    Generates plots visualising the outputs produced by `pqcprep.training_tools.ampl_train_QNN()`.
141
142     Arguments:
143    ---- 
144    - **arg_dict** : *dict* 
145
146        Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args_ampl()`. 
147
148    - **show** : *boolean* 
149
150        If True, display plots. Default is False. 
151
152    - **pdf** : *boolean* 
153
154        If True, save plots in pdf format. Default is False. 
155
156    - **DIR** : *str*
157
158        Parent directory for output files.     
159
160    Returns:
161    ---
162
163    Saves plots corresponding to each of the files produced by `pqcprep.training_tools.ampl_train_QNN()` 
164    in the directory `DIR/ampl_plots`.
165    """
166    name_str = vars_to_name_str_ampl(arg_dict)
167    pdf_str = ".pdf" if pdf else ".png"
168
169    # data to plot 
170    arrs =["loss", "mismatch"]
171    labels=["Loss", "Mismatch"]
172
173    for i in np.arange(len(arrs)):
174        arr = np.load(os.path.join(DIR, "ampl_outputs", f"{arrs[i]}{name_str}.npy"))
175
176        plt.figure(figsize=figsize)
177        plt.xlabel("Epoch", fontsize=fontsize)
178        plt.ylabel(labels[i], fontsize=fontsize)
179        plt.yscale('log')
180        plt.tick_params(axis="both", labelsize=ticksize)
181        plt.scatter(np.arange(len(arr))+1,arr,color="red")
182
183        plt.tight_layout()
184        plt.savefig(os.path.join(DIR, "ampl_plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
185
186        if show:
187            plt.show()
188        plt.close()      
189
190    # plot amplitude 
191    x_arr = x_trans_arr(arg_dict["n"])
192
193    ampl_vec = np.abs(np.load(os.path.join(DIR, "ampl_outputs", f"statevec{name_str}.npy")))
194    ampl_target = np.array([A(i, mode=arg_dict["func_str"]) for i in x_arr])
195    ampl_target /= np.sqrt(np.sum(ampl_target**2))
196
197    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
198    ax[0].plot(x_arr,ampl_target, color="black")
199    ax[0].scatter(x_arr,ampl_vec,color="red")
200    ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize)
201    ax[0].tick_params(axis="both", labelsize=ticksize)
202    ax[0].set_xticks([])
203    ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red")
204    ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize)
205    ax[1].tick_params(axis="both", labelsize=ticksize)
206    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
207
208    fig.tight_layout()
209    fig.savefig(os.path.join(DIR, "ampl_plots", f"amplitude{name_str}{pdf_str}"), bbox_inches='tight', dpi=500)
210
211    if show:
212        plt.show()  
213    plt.close()          
214
215    return 0

Generates plots visualising the outputs produced by pqcprep.training_tools.ampl_train_QNN().

Arguments:

  • arg_dict : dict

    Dictionary containing varialbe information created using pqcprep.file_tools.compress_args_ampl().

  • show : boolean

    If True, display plots. Default is False.

  • pdf : boolean

    If True, save plots in pdf format. Default is False.

  • DIR : str

    Parent directory for output files.

Returns:

Saves plots corresponding to each of the files produced by pqcprep.training_tools.ampl_train_QNN() in the directory DIR/ampl_plots.

def waveform_plots( name_str_phase, name_str_ampl, in_dir, out_dir, comp, no_UA=False, show=False, pdf=False):
217def waveform_plots(name_str_phase, name_str_ampl, in_dir, out_dir,comp, no_UA=False,show=False, pdf=False):
218    """
219    Plot the amplitude, phase, and full waveform of a state prepared using QCNNs. 
220
221    This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 
222
223    Arguments:
224    ----
225
226    - **name_str_phase** : *str* 
227
228        Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 
229
230    - **name_str_ampl** : *str* 
231
232        Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True.      
233
234    - **in_dir** : *str*
235
236        Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 
237
238    - **out_dir** : *str*
239
240        Directory for output files. 
241
242    - **comp** : *str*, *optional*
243
244        Compare outputs to results from Hayes 2023: if `'GR'` compare to results obtained using the Grover-Rudolph algorithm and if `'QGAN'` compare to results 
245        obtained using the QGAN.  
246
247    - **no_UA** : *boolean*
248
249        If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.      
250    
251    - **show** : *boolean* 
252
253        If True, display plots. Default is False. 
254
255    - **pdf** : *boolean* 
256
257        If True, save plots in pdf format. Default is False. 
258
259
260    Returns:
261    ----
262
263    Plots of the amplitude, phase, and full waveform saved in the directory `out_dir`.
264    """
265
266    pdf_str = ".pdf" if pdf else ".png"
267
268    # read in given files 
269    weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 
270    weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 
271
272    # extract information from name strings 
273    phase_reduce = '(PR)' in name_str_phase
274    real_p = '(r)' in name_str_phase
275    if '(CL)' in name_str_phase:
276        repeat_params="CL"
277    elif '(IL)' in name_str_phase:
278        repeat_params="IL" 
279    elif '(both)' in name_str_phase:
280        repeat_params="both"       
281    else:
282        repeat_params=None 
283
284    psi_mode=None 
285    func_A=None 
286    for i in ["psi", "linear", "quadratic", "sine"]:
287        if name_str_phase.count(i)==1:
288            psi_mode=i 
289            break 
290    if psi_mode==None:
291        raise ValueError("Name string could not be interpreted.")    
292    for i in ["x76", "linear", "uniform"]:
293        if name_str_ampl.count(i)==1:
294            func_A=i 
295            break     
296    if func_A==None:
297        raise ValueError("Name string could not be interpreted.")   
298    
299    n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3]
300    if int(n) < 10:
301        L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 
302        m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5]
303    else: 
304        L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 
305        m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6]   
306
307    mint=None 
308    for i in np.arange(int(m)+1):
309        if name_str_phase.count(f"({i})")==1:
310            mint=i 
311            break 
312    s = f"_{n}_{m}({mint})_"
313    if name_str_phase.count(s) != 1 :
314        raise ValueError("Name string could not be interpreted.")   
315    else:
316        L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2]
317    s = f"_{n}_{m}({mint})_{L_phase}_"
318    if name_str_phase.count(s) != 1 :
319        raise ValueError("Name string could not be interpreted.")
320
321    n= int(n)
322    m=int(m)
323    L_phase=int(L_phase)
324    L_A=int(L_A)
325    mint=int(mint)    
326
327    # generate x array 
328    x_arr = x_trans_arr(n)
329
330    # calculate target outputs 
331    ampl_target = np.array([A(i, mode=func_A) for i in x_arr])
332    ampl_target = ampl_target / np.sqrt(np.sum(ampl_target**2))
333
334    phase_rounded = get_phase_target(m=m, psi_mode=psi_mode, phase_reduce=phase_reduce, mint=mint)
335    phase_target = psi(np.arange(2**n),mode=psi_mode)
336
337    h_target = ampl_target * np.exp(2*1.j*np.pi* phase_target)
338    wave_real_target = np.real(h_target)
339    wave_im_target = np.imag(h_target)
340
341    h_target_rounded = ampl_target * np.exp(2*1.j*np.pi* phase_rounded)
342    wave_real_target_rounded = np.real(h_target_rounded)
343    wave_im_target_rounded = np.imag(h_target_rounded)
344
345    # load Hayes 2023 data for comparison  
346    if comp=="GR":
347        ampl_vec_comp=np.abs(load_data_H23("amp_state_GR"))
348        h_comp= load_data_H23("full_state_GR")
349        psi_LPF = load_data_H23("psi_LPF_processed")  
350        comp_label="GR"
351        comp=True
352    elif comp=="QGAN":
353        ampl_vec_comp=np.abs(load_data_H23("amp_state_QGAN"))
354        h_comp= load_data_H23("full_state_QGAN")
355        psi_LPF = load_data_H23("psi_LPF_processed")  
356        comp_label="QGAN"
357        comp=True
358    else:
359        comp=False 
360
361    if comp and no_UA:
362        comp=False    
363        print("Comparison to Hayes 2023 not shown due to incompatible amplitude function.") 
364    if comp and not (psi_mode=="psi" and func_A=="x76"):
365        comp=False
366        print("Comparison to Hayes 2023 not shown due to incompatible amplitude or phase function.")    
367
368    # get PQC state
369    state_vec = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA)
370    phase = phase_from_state(state_vec)
371
372    ampl_vec = np.abs(state_vec)
373    real_wave =np.real(state_vec)
374    im_wave = np.imag(state_vec) 
375
376    # plot amplitude 
377    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
378
379    ax[0].plot(x_arr,ampl_target, color="black")
380    if comp:
381        ax[0].scatter(x_arr,ampl_vec_comp,label=comp_label, color="blue")
382    ax[0].scatter(x_arr,ampl_vec,label="QCNN", color="red")
383
384    ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize)
385    if comp: ax[0].legend(fontsize=fontsize, loc='upper right')
386    ax[0].tick_params(axis="both", labelsize=ticksize)
387    ax[0].set_xticks([])
388
389    if comp:
390        ax[1].scatter(x_arr,ampl_target-ampl_vec_comp,label=comp_label, color="blue")
391    ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red")
392
393    ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize)
394    ax[1].tick_params(axis="both", labelsize=ticksize)
395    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
396
397    fig.tight_layout()
398    fig.savefig(os.path.join(out_dir,f"amplitude{pdf_str}"), bbox_inches='tight', dpi=500)
399
400    if show:
401        plt.show()
402    plt.close() 
403
404    # plot phase 
405    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]})
406
407    ax[0].plot(x_arr,phase_target, color="black")
408    ax[0].plot(x_arr,phase_rounded, color="gray", ls="--")
409    if comp:
410        ax[0].scatter(x_arr,psi_LPF,label="LPF", color="blue")
411    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
412
413    ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize)
414    if comp: ax[0].legend(fontsize=fontsize, loc='upper right')
415    ax[0].tick_params(axis="both", labelsize=ticksize)
416    ax[0].set_xticks([])
417
418    if comp:
419        ax[1].scatter(x_arr,phase_target-psi_LPF,label="LPF + ", color="blue")
420    ax[1].scatter(x_arr,phase_rounded-phase,label="QCNN", color="red")
421
422    ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize)
423    ax[1].tick_params(axis="both", labelsize=ticksize)
424    ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
425
426    fig.tight_layout()
427    fig.savefig(os.path.join(out_dir,f"phase{pdf_str}"), bbox_inches='tight', dpi=500)
428
429    if show:
430        plt.show()
431    plt.close() 
432
433    # plot full waveform 
434    fig, ax = plt.subplots(2, 2, figsize=(2*figsize[0],figsize[1]), gridspec_kw={'height_ratios': [1.5, 1]})
435    fig.subplots_adjust()
436
437    ax[0,0].plot(x_arr,wave_real_target, color="black")
438    ax[0,0].plot(x_arr,wave_real_target_rounded, color="gray", ls="--")
439    if comp:
440        ax[0,0].scatter(x_arr,np.real(h_comp),label="LPF + "+comp_label, color="blue")
441    ax[0,0].scatter(x_arr,real_wave,label="QCNN", color="red")
442
443    ax[0,0].set_ylabel(r"$\Re[\tilde h(f)]$", fontsize=fontsize)
444    if comp: ax[0,0].legend(fontsize=fontsize, loc='upper right')
445    ax[0,0].tick_params(axis="both", labelsize=ticksize)
446    ax[0,0].set_xticks([])
447
448    if comp:
449        ax[1,0].scatter(x_arr,wave_real_target -np.real(h_comp),label="LPF"+comp_label, color="blue")
450    ax[1,0].scatter(x_arr,wave_real_target_rounded -real_wave,label="QCNN", color="red")
451
452    ax[1,0].set_ylabel(r"$\Delta \Re[\tilde h(f)]$", fontsize=fontsize)
453    ax[1,0].tick_params(axis="both", labelsize=ticksize)
454    ax[1,0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
455
456    ax[0,1].plot(x_arr,wave_im_target, color="black")
457    ax[0,1].plot(x_arr,wave_im_target_rounded, color="gray", ls="--")
458    if comp:
459        ax[0,1].scatter(x_arr,np.imag(h_comp),label="LPF + "+comp_label, color="blue")
460    ax[0,1].scatter(x_arr,im_wave,label="QCNN", color="red")
461
462    ax[0,1].set_ylabel(r"$\Im[\tilde h(f)]$", fontsize=fontsize)
463    if comp: ax[0,1].legend(fontsize=fontsize, loc='upper right')
464    ax[0,1].tick_params(axis="both", labelsize=ticksize)
465    ax[0,1].set_xticks([])
466
467    if comp:
468        ax[1,1].scatter(x_arr,wave_im_target -np.imag(h_comp),label="LPF + "+comp_label, color="blue")
469    ax[1,1].scatter(x_arr,wave_im_target_rounded -im_wave,label="QCNN", color="red")
470
471    ax[1,1].set_ylabel(r"$\Delta \Im[\tilde h(f)]$", fontsize=fontsize)
472    ax[1,1].tick_params(axis="both", labelsize=ticksize)
473    ax[1,1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
474
475    fig.tight_layout()
476    fig.savefig(os.path.join(out_dir,f"waveform{pdf_str}"), bbox_inches='tight', dpi=500)
477
478    if show:
479        plt.show()
480    plt.close()
481
482    return 0 

Plot the amplitude, phase, and full waveform of a state prepared using QCNNs.

This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case.

Arguments:

  • name_str_phase : str

    Name string (produced using pqcprep.file_tools.vars_to_name_str()) corresponding to the weights of the QCNN used for phase encoding.

  • name_str_ampl : str

    Name string (produced using pqcprep.file_tools.vars_to_name_str_ampl()) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if no_UA is True.

  • in_dir : str

    Directory containing the input files (as specified by name_str_phase and name_str_ampl). The directory is expected to contain a weights file for both cases.

  • out_dir : str

    Directory for output files.

  • comp : str, optional

    Compare outputs to results from Hayes 2023: if 'GR' compare to results obtained using the Grover-Rudolph algorithm and if 'QGAN' compare to results obtained using the QGAN.

  • no_UA : boolean

    If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.

  • show : boolean

    If True, display plots. Default is False.

  • pdf : boolean

    If True, save plots in pdf format. Default is False.

Returns:

Plots of the amplitude, phase, and full waveform saved in the directory out_dir.

def two_register_plots( name_str_phase, name_str_ampl, in_dir, out_dir, operators='QRQ', no_UA=False, show=False, pdf=False):
485def two_register_plots(name_str_phase, name_str_ampl,in_dir, out_dir, operators="QRQ",no_UA=False, show=False, pdf=False):
486    """
487    Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs. 
488
489    This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 
490
491    Arguments:
492    ----
493
494    - **name_str_phase** : *str* 
495
496        Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 
497
498    - **name_str_ampl** : *str* 
499
500        Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True.      
501
502    - **in_dir** : *str*
503
504        Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 
505
506    - **out_dir** : *str*
507
508        Directory for output files. 
509
510    - **operators** : *str*, 
511
512        Which operators to apply to the registers. Options are `'QRQ'` for full phase extraction, `'RQ'` for partial phase extraction, and `'Q'` for function 
513        evaluation only. Default is `'QRQ'`.   
514
515    - **no_UA** : *boolean*
516
517        If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.      
518    
519    - **show** : *boolean* 
520
521        If True, display plots. Default is False. 
522
523    - **pdf** : *boolean* 
524
525        If True, save plots in pdf format. Default is False. 
526
527
528    Returns:
529    ----
530
531    Plots of the amplitude and phase of the two-register state saved in the directory `out_dir`.
532    """
533    pdf_str = ".pdf" if pdf else ".png"
534
535    # read in given files 
536    weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 
537    weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 
538
539    # extract information from name strings 
540    real_p = '(r)' in name_str_phase
541    if '(CL)' in name_str_phase:
542        repeat_params="CL"
543    elif '(IL)' in name_str_phase:
544        repeat_params="IL" 
545    elif '(both)' in name_str_phase:
546        repeat_params="both"       
547    else:
548        repeat_params=None 
549
550    psi_mode=None 
551    func_A=None 
552    for i in ["psi", "linear", "quadratic", "sine"]:
553        if name_str_phase.count(i)==1:
554            psi_mode=i 
555            break 
556    if psi_mode==None:
557        raise ValueError("Name string could not be interpreted.")    
558    for i in ["x76", "linear", "uniform"]:
559        if name_str_ampl.count(i)==1:
560            func_A=i 
561            break     
562    if func_A==None:
563        raise ValueError("Name string could not be interpreted.")   
564    
565    n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3]
566    if int(n) < 10:
567        L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 
568        m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5]
569    else: 
570        L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 
571        m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6]   
572
573    mint=None 
574    for i in np.arange(int(m)+1):
575        if name_str_phase.count(f"({i})")==1:
576            mint=i 
577            break 
578    s = f"_{n}_{m}({mint})_"
579    if name_str_phase.count(s) != 1 :
580        raise ValueError("Name string could not be interpreted.")   
581    else:
582        L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2]
583    s = f"_{n}_{m}({mint})_{L_phase}_"
584    if name_str_phase.count(s) != 1 :
585        raise ValueError("Name string could not be interpreted.")
586
587    n= int(n)
588    m=int(m)
589    L_phase=int(L_phase)
590    L_A=int(L_A)
591    mint=int(mint)    
592
593    # generate x array 
594    x_arr = x_trans_arr(n)
595
596    # get PQC state
597    state_vec, state_vec_full = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA, full_state_vec=True, operators=operators)
598    phase = phase_from_state(state_vec)
599    full_phase = phase_from_state(state_vec_full)
600    
601    # plot amplitude 
602    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True)
603    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
604
605    ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize)
606    ax[0].tick_params(axis="both", labelsize=ticksize)
607    im=ax[0].pcolormesh(x_arr,np.arange(2**m), np.abs(state_vec_full))
608    im.set_clim(np.min(np.abs(state_vec_full)),np.max(np.abs(state_vec_full)))
609    locs = ax[0].get_yticks() 
610    ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
611    ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
612    ax[1].set_visible(False)    
613        
614    cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05)
615    cb.ax.tick_params(labelsize=ticksize)
616    
617    fig.savefig(os.path.join(out_dir,f"two_reg_ampl{pdf_str}"), bbox_inches='tight', dpi=500)
618
619    if show:
620        plt.show()
621    plt.close()    
622
623    # plot phase 
624    fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True)
625    ax[0].scatter(x_arr,phase,label="QCNN", color="red")
626
627    ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize)
628    ax[0].tick_params(axis="both", labelsize=ticksize)
629    im=ax[0].pcolormesh(x_arr,np.arange(2**m), full_phase, cmap="hsv")
630    im.set_clim(0, 2*np.pi)
631    locs = ax[0].get_yticks() 
632    ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
633    ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize)
634    ax[1].set_visible(False)    
635    
636    cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05)
637    cb.ax.tick_params(labelsize=ticksize)
638   
639    fig.savefig(os.path.join(out_dir,f"two_reg_phase{pdf_str}"), bbox_inches='tight', dpi=500)
640     
641    if show:
642        plt.show()
643    plt.close()    
644
645    # plot amplitude and phase  
646    X = x_arr
647    Y = np.arange(2**m)
648    X, Y = np.meshgrid(X, Y)
649    Z = np.abs(state_vec_full)
650    P_norm = (full_phase) / (2* np.pi)  # normalise to value in [0,1]
651
652    # Plot the surface
653    fig, ax = plt.subplots(figsize=(10,10), subplot_kw={"projection": "3d"}, constrained_layout=False)
654    surf = ax.plot_surface(X, Y, Z,facecolors=cm.hsv(P_norm),linewidth=0, antialiased=True, shade=False)
655    ax.set_xlabel(r"$f$ (Hz)", fontsize=fontsize, labelpad=20)
656    ax.set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize, labelpad=20)
657    locs = ax.get_yticks() 
658    ax.set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)])
659    ax.tick_params(axis="both", labelsize=ticksize)
660
661    cb=fig.colorbar(cm.ScalarMappable(cmap=cm.hsv, norm=surf.norm), ax = ax , shrink = 0.6, aspect = 5)
662    cb.ax.tick_params(labelsize=ticksize)
663    locs = cb.ax.get_yticks()
664    cb.ax.set_yticks( [0, 0.25,0.5,0.75,1], [r'$0$',r'$\frac{\pi}{2}$', r'$\pi$',r'$\frac{3\pi}{2}$' , r'$2\pi$'])
665    
666    fig.savefig(os.path.join(out_dir,f"two_reg_3D{pdf_str}"), bbox_inches='tight', dpi=500)
667
668    if show:
669        plt.show()
670    plt.close()   
671
672    return 0 

Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs.

This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case.

Arguments:

  • name_str_phase : str

    Name string (produced using pqcprep.file_tools.vars_to_name_str()) corresponding to the weights of the QCNN used for phase encoding.

  • name_str_ampl : str

    Name string (produced using pqcprep.file_tools.vars_to_name_str_ampl()) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if no_UA is True.

  • in_dir : str

    Directory containing the input files (as specified by name_str_phase and name_str_ampl). The directory is expected to contain a weights file for both cases.

  • out_dir : str

    Directory for output files.

  • operators : str,

    Which operators to apply to the registers. Options are 'QRQ' for full phase extraction, 'RQ' for partial phase extraction, and 'Q' for function evaluation only. Default is 'QRQ'.

  • no_UA : boolean

    If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.

  • show : boolean

    If True, display plots. Default is False.

  • pdf : boolean

    If True, save plots in pdf format. Default is False.

Returns:

Plots of the amplitude and phase of the two-register state saved in the directory out_dir.