pqcprep.plotting_tools
Collection of functions relating to plotting and visualisation.
1""" 2Collection of functions relating to plotting and visualisation. 3""" 4 5import numpy as np 6import matplotlib.pyplot as plt 7from matplotlib import rcParams, cm 8import os 9from .psi_tools import x_trans_arr, get_phase_target, psi, A 10from .file_tools import vars_to_name_str, vars_to_name_str_ampl 11from .resource_tools import load_data_H23 12from .phase_tools import full_encode, phase_from_state 13from .binary_tools import dec_to_bin, bin_to_dec 14 15# general settings 16rcParams['mathtext.fontset'] = 'stix' 17rcParams['font.family'] = 'STIXGeneral' 18width=0.75 19""" @private """ 20color='black' 21""" @private """ 22fontsize=28 23""" @private """ 24titlesize=32 25""" @private """ 26ticksize=22 27""" @private """ 28figsize=(10,10) 29""" @private """ 30 31def benchmark_plots(arg_dict,DIR, show=False, pdf=False): 32 """ 33 Generates plots visualising the outputs produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()`. 34 35 Arguments: 36 ---- 37 - **arg_dict** : *dict* 38 39 Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args()`. 40 41 - **show** : *boolean* 42 43 If True, display plots. Default is False. 44 45 - **pdf** : *boolean* 46 47 If True, save plots in pdf format. If False, save plots in png format. Default is False. 48 49 - **DIR** : *str* 50 51 Parent directory for output files. 52 53 Returns: 54 --- 55 56 Saves plots corresponding to each of the files produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()` (apart from `metrics_<NAME_STR>.npy`) 57 in the directory `DIR/plots`. 58 59 """ 60 61 name_str = vars_to_name_str(arg_dict) 62 pdf_str = ".pdf" if pdf else ".png" 63 64 # data to plot 65 arrs =["loss", "mismatch", "grad", "vargrad"] 66 labels=["Loss", "Mismatch", r"$|\nabla_\theta W|^2$",r"Var($\partial_\theta W$)" ] 67 68 for i in np.arange(len(arrs)): 69 arr = np.load(os.path.join(DIR, "outputs", f"{arrs[i]}{name_str}.npy")) 70 71 plt.figure(figsize=figsize) 72 plt.xlabel("Epoch", fontsize=fontsize) 73 plt.ylabel(labels[i], fontsize=fontsize) 74 plt.yscale('log') 75 plt.tick_params(axis="both", labelsize=ticksize) 76 plt.scatter(np.arange(len(arr))+1,arr,color="red") 77 78 plt.tight_layout() 79 plt.savefig(os.path.join(DIR, "plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 80 81 if show: 82 plt.show() 83 plt.close() 84 85 # plot mismatch by state 86 dic = np.load(os.path.join(DIR, "outputs", f"mismatch_by_state{name_str}.npy"),allow_pickle='TRUE').item() 87 mismatch = list(dic.values()) 88 x_arr = x_trans_arr(arg_dict["n"]) 89 90 plt.figure(figsize=figsize) 91 plt.xlabel(r"$f$ (Hz)", fontsize=fontsize) 92 plt.ylabel("Mismatch", fontsize=fontsize) 93 plt.yscale('log') 94 plt.tick_params(axis="both", labelsize=ticksize) 95 plt.scatter(x_arr,mismatch,color="red") 96 97 plt.tight_layout() 98 plt.savefig(os.path.join(DIR, "plots", f"mismatch_by_state{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 99 100 if show: 101 plt.show() 102 plt.close() 103 104 # plot extracted phase function 105 mint = arg_dict["mint"] if arg_dict["mint"] != None else arg_dict["m"] 106 if arg_dict["phase_reduce"]: 107 mint = 0 108 phase = np.load(os.path.join(DIR, "outputs", f"phase{name_str}.npy")) 109 phase_target_rounded = get_phase_target(m=arg_dict["m"], psi_mode=arg_dict["func_str"], phase_reduce=arg_dict["phase_reduce"], mint=mint) 110 phase_target = psi(np.arange(2**arg_dict["n"]),mode=arg_dict["func_str"]) 111 112 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 113 114 ax[0].plot(x_arr,phase_target, color="black") 115 ax[0].plot(x_arr,phase_target_rounded, color="gray", ls="--") 116 ax[0].scatter(x_arr,phase, color="red") 117 118 ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize) 119 ax[0].tick_params(axis="both", labelsize=ticksize) 120 ax[0].set_xticks([]) 121 122 ax[1].scatter(x_arr,phase_target_rounded-phase, color="red") 123 ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize) 124 ax[1].tick_params(axis="both", labelsize=ticksize) 125 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 126 127 fig.tight_layout() 128 fig.savefig(os.path.join(DIR, "plots", f"phase{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 129 130 if show: 131 plt.show() 132 plt.close() 133 134 135 return 0 136 137def benchmark_plots_ampl(arg_dict,DIR, show=False, pdf=False): 138 """ 139 Generates plots visualising the outputs produced by `pqcprep.training_tools.ampl_train_QNN()`. 140 141 Arguments: 142 ---- 143 - **arg_dict** : *dict* 144 145 Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args_ampl()`. 146 147 - **show** : *boolean* 148 149 If True, display plots. Default is False. 150 151 - **pdf** : *boolean* 152 153 If True, save plots in pdf format. Default is False. 154 155 - **DIR** : *str* 156 157 Parent directory for output files. 158 159 Returns: 160 --- 161 162 Saves plots corresponding to each of the files produced by `pqcprep.training_tools.ampl_train_QNN()` 163 in the directory `DIR/ampl_plots`. 164 """ 165 name_str = vars_to_name_str_ampl(arg_dict) 166 pdf_str = ".pdf" if pdf else ".png" 167 168 # data to plot 169 arrs =["loss", "mismatch"] 170 labels=["Loss", "Mismatch"] 171 172 for i in np.arange(len(arrs)): 173 arr = np.load(os.path.join(DIR, "ampl_outputs", f"{arrs[i]}{name_str}.npy")) 174 175 plt.figure(figsize=figsize) 176 plt.xlabel("Epoch", fontsize=fontsize) 177 plt.ylabel(labels[i], fontsize=fontsize) 178 plt.yscale('log') 179 plt.tick_params(axis="both", labelsize=ticksize) 180 plt.scatter(np.arange(len(arr))+1,arr,color="red") 181 182 plt.tight_layout() 183 plt.savefig(os.path.join(DIR, "ampl_plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 184 185 if show: 186 plt.show() 187 plt.close() 188 189 # plot amplitude 190 x_arr = x_trans_arr(arg_dict["n"]) 191 192 ampl_vec = np.abs(np.load(os.path.join(DIR, "ampl_outputs", f"statevec{name_str}.npy"))) 193 ampl_target = np.array([A(i, mode=arg_dict["func_str"]) for i in x_arr]) 194 ampl_target /= np.sqrt(np.sum(ampl_target**2)) 195 196 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 197 ax[0].plot(x_arr,ampl_target, color="black") 198 ax[0].scatter(x_arr,ampl_vec,color="red") 199 ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize) 200 ax[0].tick_params(axis="both", labelsize=ticksize) 201 ax[0].set_xticks([]) 202 ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red") 203 ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize) 204 ax[1].tick_params(axis="both", labelsize=ticksize) 205 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 206 207 fig.tight_layout() 208 fig.savefig(os.path.join(DIR, "ampl_plots", f"amplitude{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 209 210 if show: 211 plt.show() 212 plt.close() 213 214 return 0 215 216def waveform_plots(name_str_phase, name_str_ampl, in_dir, out_dir,comp, no_UA=False,show=False, pdf=False): 217 """ 218 Plot the amplitude, phase, and full waveform of a state prepared using QCNNs. 219 220 This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 221 222 Arguments: 223 ---- 224 225 - **name_str_phase** : *str* 226 227 Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 228 229 - **name_str_ampl** : *str* 230 231 Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True. 232 233 - **in_dir** : *str* 234 235 Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 236 237 - **out_dir** : *str* 238 239 Directory for output files. 240 241 - **comp** : *str*, *optional* 242 243 Compare outputs to results from Hayes 2023: if `'GR'` compare to results obtained using the Grover-Rudolph algorithm and if `'QGAN'` compare to results 244 obtained using the QGAN. 245 246 - **no_UA** : *boolean* 247 248 If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False. 249 250 - **show** : *boolean* 251 252 If True, display plots. Default is False. 253 254 - **pdf** : *boolean* 255 256 If True, save plots in pdf format. Default is False. 257 258 259 Returns: 260 ---- 261 262 Plots of the amplitude, phase, and full waveform saved in the directory `out_dir`. 263 """ 264 265 pdf_str = ".pdf" if pdf else ".png" 266 267 # read in given files 268 weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 269 weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 270 271 # extract information from name strings 272 phase_reduce = '(PR)' in name_str_phase 273 real_p = '(r)' in name_str_phase 274 if '(CL)' in name_str_phase: 275 repeat_params="CL" 276 elif '(IL)' in name_str_phase: 277 repeat_params="IL" 278 elif '(both)' in name_str_phase: 279 repeat_params="both" 280 else: 281 repeat_params=None 282 283 psi_mode=None 284 func_A=None 285 for i in ["psi", "linear", "quadratic", "sine"]: 286 if name_str_phase.count(i)==1: 287 psi_mode=i 288 break 289 if psi_mode==None: 290 raise ValueError("Name string could not be interpreted.") 291 for i in ["x76", "linear", "uniform"]: 292 if name_str_ampl.count(i)==1: 293 func_A=i 294 break 295 if func_A==None: 296 raise ValueError("Name string could not be interpreted.") 297 298 n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3] 299 if int(n) < 10: 300 L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 301 m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5] 302 else: 303 L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 304 m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6] 305 306 mint=None 307 for i in np.arange(int(m)+1): 308 if name_str_phase.count(f"({i})")==1: 309 mint=i 310 break 311 s = f"_{n}_{m}({mint})_" 312 if name_str_phase.count(s) != 1 : 313 raise ValueError("Name string could not be interpreted.") 314 else: 315 L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2] 316 s = f"_{n}_{m}({mint})_{L_phase}_" 317 if name_str_phase.count(s) != 1 : 318 raise ValueError("Name string could not be interpreted.") 319 320 n= int(n) 321 m=int(m) 322 L_phase=int(L_phase) 323 L_A=int(L_A) 324 mint=int(mint) 325 326 # generate x array 327 x_arr = x_trans_arr(n) 328 329 # calculate target outputs 330 ampl_target = np.array([A(i, mode=func_A) for i in x_arr]) 331 ampl_target = ampl_target / np.sqrt(np.sum(ampl_target**2)) 332 333 phase_rounded = get_phase_target(m=m, psi_mode=psi_mode, phase_reduce=phase_reduce, mint=mint) 334 phase_target = psi(np.arange(2**n),mode=psi_mode) 335 336 h_target = ampl_target * np.exp(2*1.j*np.pi* phase_target) 337 wave_real_target = np.real(h_target) 338 wave_im_target = np.imag(h_target) 339 340 h_target_rounded = ampl_target * np.exp(2*1.j*np.pi* phase_rounded) 341 wave_real_target_rounded = np.real(h_target_rounded) 342 wave_im_target_rounded = np.imag(h_target_rounded) 343 344 # load Hayes 2023 data for comparison 345 if comp=="GR": 346 ampl_vec_comp=np.abs(load_data_H23("amp_state_GR")) 347 h_comp= load_data_H23("full_state_GR") 348 psi_LPF = load_data_H23("psi_LPF_processed") 349 comp_label="GR" 350 comp=True 351 elif comp=="QGAN": 352 ampl_vec_comp=np.abs(load_data_H23("amp_state_QGAN")) 353 h_comp= load_data_H23("full_state_QGAN") 354 psi_LPF = load_data_H23("psi_LPF_processed") 355 comp_label="QGAN" 356 comp=True 357 else: 358 comp=False 359 360 if comp and no_UA: 361 comp=False 362 print("Comparison to Hayes 2023 not shown due to incompatible amplitude function.") 363 if comp and not (psi_mode=="psi" and func_A=="x76"): 364 comp=False 365 print("Comparison to Hayes 2023 not shown due to incompatible amplitude or phase function.") 366 367 # get PQC state 368 state_vec = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA) 369 phase = phase_from_state(state_vec) 370 371 ampl_vec = np.abs(state_vec) 372 real_wave =np.real(state_vec) 373 im_wave = np.imag(state_vec) 374 375 # plot amplitude 376 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 377 378 ax[0].plot(x_arr,ampl_target, color="black") 379 if comp: 380 ax[0].scatter(x_arr,ampl_vec_comp,label=comp_label, color="blue") 381 ax[0].scatter(x_arr,ampl_vec,label="QCNN", color="red") 382 383 ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize) 384 if comp: ax[0].legend(fontsize=fontsize, loc='upper right') 385 ax[0].tick_params(axis="both", labelsize=ticksize) 386 ax[0].set_xticks([]) 387 388 if comp: 389 ax[1].scatter(x_arr,ampl_target-ampl_vec_comp,label=comp_label, color="blue") 390 ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red") 391 392 ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize) 393 ax[1].tick_params(axis="both", labelsize=ticksize) 394 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 395 396 fig.tight_layout() 397 fig.savefig(os.path.join(out_dir,f"amplitude{pdf_str}"), bbox_inches='tight', dpi=500) 398 399 if show: 400 plt.show() 401 plt.close() 402 403 # plot phase 404 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 405 406 ax[0].plot(x_arr,phase_target, color="black") 407 ax[0].plot(x_arr,phase_rounded, color="gray", ls="--") 408 if comp: 409 ax[0].scatter(x_arr,psi_LPF,label="LPF", color="blue") 410 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 411 412 ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize) 413 if comp: ax[0].legend(fontsize=fontsize, loc='upper right') 414 ax[0].tick_params(axis="both", labelsize=ticksize) 415 ax[0].set_xticks([]) 416 417 if comp: 418 ax[1].scatter(x_arr,phase_target-psi_LPF,label="LPF + ", color="blue") 419 ax[1].scatter(x_arr,phase_rounded-phase,label="QCNN", color="red") 420 421 ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize) 422 ax[1].tick_params(axis="both", labelsize=ticksize) 423 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 424 425 fig.tight_layout() 426 fig.savefig(os.path.join(out_dir,f"phase{pdf_str}"), bbox_inches='tight', dpi=500) 427 428 if show: 429 plt.show() 430 plt.close() 431 432 # plot full waveform 433 fig, ax = plt.subplots(2, 2, figsize=(2*figsize[0],figsize[1]), gridspec_kw={'height_ratios': [1.5, 1]}) 434 fig.subplots_adjust() 435 436 ax[0,0].plot(x_arr,wave_real_target, color="black") 437 ax[0,0].plot(x_arr,wave_real_target_rounded, color="gray", ls="--") 438 if comp: 439 ax[0,0].scatter(x_arr,np.real(h_comp),label="LPF + "+comp_label, color="blue") 440 ax[0,0].scatter(x_arr,real_wave,label="QCNN", color="red") 441 442 ax[0,0].set_ylabel(r"$\Re[\tilde h(f)]$", fontsize=fontsize) 443 if comp: ax[0,0].legend(fontsize=fontsize, loc='upper right') 444 ax[0,0].tick_params(axis="both", labelsize=ticksize) 445 ax[0,0].set_xticks([]) 446 447 if comp: 448 ax[1,0].scatter(x_arr,wave_real_target -np.real(h_comp),label="LPF"+comp_label, color="blue") 449 ax[1,0].scatter(x_arr,wave_real_target_rounded -real_wave,label="QCNN", color="red") 450 451 ax[1,0].set_ylabel(r"$\Delta \Re[\tilde h(f)]$", fontsize=fontsize) 452 ax[1,0].tick_params(axis="both", labelsize=ticksize) 453 ax[1,0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 454 455 ax[0,1].plot(x_arr,wave_im_target, color="black") 456 ax[0,1].plot(x_arr,wave_im_target_rounded, color="gray", ls="--") 457 if comp: 458 ax[0,1].scatter(x_arr,np.imag(h_comp),label="LPF + "+comp_label, color="blue") 459 ax[0,1].scatter(x_arr,im_wave,label="QCNN", color="red") 460 461 ax[0,1].set_ylabel(r"$\Im[\tilde h(f)]$", fontsize=fontsize) 462 if comp: ax[0,1].legend(fontsize=fontsize, loc='upper right') 463 ax[0,1].tick_params(axis="both", labelsize=ticksize) 464 ax[0,1].set_xticks([]) 465 466 if comp: 467 ax[1,1].scatter(x_arr,wave_im_target -np.imag(h_comp),label="LPF + "+comp_label, color="blue") 468 ax[1,1].scatter(x_arr,wave_im_target_rounded -im_wave,label="QCNN", color="red") 469 470 ax[1,1].set_ylabel(r"$\Delta \Im[\tilde h(f)]$", fontsize=fontsize) 471 ax[1,1].tick_params(axis="both", labelsize=ticksize) 472 ax[1,1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 473 474 fig.tight_layout() 475 fig.savefig(os.path.join(out_dir,f"waveform{pdf_str}"), bbox_inches='tight', dpi=500) 476 477 if show: 478 plt.show() 479 plt.close() 480 481 return 0 482 483 484def two_register_plots(name_str_phase, name_str_ampl,in_dir, out_dir, operators="QRQ",no_UA=False, show=False, pdf=False): 485 """ 486 Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs. 487 488 This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 489 490 Arguments: 491 ---- 492 493 - **name_str_phase** : *str* 494 495 Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 496 497 - **name_str_ampl** : *str* 498 499 Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True. 500 501 - **in_dir** : *str* 502 503 Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 504 505 - **out_dir** : *str* 506 507 Directory for output files. 508 509 - **operators** : *str*, 510 511 Which operators to apply to the registers. Options are `'QRQ'` for full phase extraction, `'RQ'` for partial phase extraction, and `'Q'` for function 512 evaluation only. Default is `'QRQ'`. 513 514 - **no_UA** : *boolean* 515 516 If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False. 517 518 - **show** : *boolean* 519 520 If True, display plots. Default is False. 521 522 - **pdf** : *boolean* 523 524 If True, save plots in pdf format. Default is False. 525 526 527 Returns: 528 ---- 529 530 Plots of the amplitude and phase of the two-register state saved in the directory `out_dir`. 531 """ 532 pdf_str = ".pdf" if pdf else ".png" 533 534 # read in given files 535 weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 536 weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 537 538 # extract information from name strings 539 real_p = '(r)' in name_str_phase 540 if '(CL)' in name_str_phase: 541 repeat_params="CL" 542 elif '(IL)' in name_str_phase: 543 repeat_params="IL" 544 elif '(both)' in name_str_phase: 545 repeat_params="both" 546 else: 547 repeat_params=None 548 549 psi_mode=None 550 func_A=None 551 for i in ["psi", "linear", "quadratic", "sine"]: 552 if name_str_phase.count(i)==1: 553 psi_mode=i 554 break 555 if psi_mode==None: 556 raise ValueError("Name string could not be interpreted.") 557 for i in ["x76", "linear", "uniform"]: 558 if name_str_ampl.count(i)==1: 559 func_A=i 560 break 561 if func_A==None: 562 raise ValueError("Name string could not be interpreted.") 563 564 n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3] 565 if int(n) < 10: 566 L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 567 m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5] 568 else: 569 L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 570 m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6] 571 572 mint=None 573 for i in np.arange(int(m)+1): 574 if name_str_phase.count(f"({i})")==1: 575 mint=i 576 break 577 s = f"_{n}_{m}({mint})_" 578 if name_str_phase.count(s) != 1 : 579 raise ValueError("Name string could not be interpreted.") 580 else: 581 L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2] 582 s = f"_{n}_{m}({mint})_{L_phase}_" 583 if name_str_phase.count(s) != 1 : 584 raise ValueError("Name string could not be interpreted.") 585 586 n= int(n) 587 m=int(m) 588 L_phase=int(L_phase) 589 L_A=int(L_A) 590 mint=int(mint) 591 592 # generate x array 593 x_arr = x_trans_arr(n) 594 595 # get PQC state 596 state_vec, state_vec_full = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA, full_state_vec=True, operators=operators) 597 phase = phase_from_state(state_vec) 598 full_phase = phase_from_state(state_vec_full) 599 600 # plot amplitude 601 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True) 602 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 603 604 ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize) 605 ax[0].tick_params(axis="both", labelsize=ticksize) 606 im=ax[0].pcolormesh(x_arr,np.arange(2**m), np.abs(state_vec_full)) 607 im.set_clim(np.min(np.abs(state_vec_full)),np.max(np.abs(state_vec_full))) 608 locs = ax[0].get_yticks() 609 ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 610 ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 611 ax[1].set_visible(False) 612 613 cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05) 614 cb.ax.tick_params(labelsize=ticksize) 615 616 fig.savefig(os.path.join(out_dir,f"two_reg_ampl{pdf_str}"), bbox_inches='tight', dpi=500) 617 618 if show: 619 plt.show() 620 plt.close() 621 622 # plot phase 623 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True) 624 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 625 626 ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize) 627 ax[0].tick_params(axis="both", labelsize=ticksize) 628 im=ax[0].pcolormesh(x_arr,np.arange(2**m), full_phase, cmap="hsv") 629 im.set_clim(0, 2*np.pi) 630 locs = ax[0].get_yticks() 631 ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 632 ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 633 ax[1].set_visible(False) 634 635 cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05) 636 cb.ax.tick_params(labelsize=ticksize) 637 638 fig.savefig(os.path.join(out_dir,f"two_reg_phase{pdf_str}"), bbox_inches='tight', dpi=500) 639 640 if show: 641 plt.show() 642 plt.close() 643 644 # plot amplitude and phase 645 X = x_arr 646 Y = np.arange(2**m) 647 X, Y = np.meshgrid(X, Y) 648 Z = np.abs(state_vec_full) 649 P_norm = (full_phase) / (2* np.pi) # normalise to value in [0,1] 650 651 # Plot the surface 652 fig, ax = plt.subplots(figsize=(10,10), subplot_kw={"projection": "3d"}, constrained_layout=False) 653 surf = ax.plot_surface(X, Y, Z,facecolors=cm.hsv(P_norm),linewidth=0, antialiased=True, shade=False) 654 ax.set_xlabel(r"$f$ (Hz)", fontsize=fontsize, labelpad=20) 655 ax.set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize, labelpad=20) 656 locs = ax.get_yticks() 657 ax.set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 658 ax.tick_params(axis="both", labelsize=ticksize) 659 660 cb=fig.colorbar(cm.ScalarMappable(cmap=cm.hsv, norm=surf.norm), ax = ax , shrink = 0.6, aspect = 5) 661 cb.ax.tick_params(labelsize=ticksize) 662 locs = cb.ax.get_yticks() 663 cb.ax.set_yticks( [0, 0.25,0.5,0.75,1], [r'$0$',r'$\frac{\pi}{2}$', r'$\pi$',r'$\frac{3\pi}{2}$' , r'$2\pi$']) 664 665 fig.savefig(os.path.join(out_dir,f"two_reg_3D{pdf_str}"), bbox_inches='tight', dpi=500) 666 667 if show: 668 plt.show() 669 plt.close() 670 671 return 0
32def benchmark_plots(arg_dict,DIR, show=False, pdf=False): 33 """ 34 Generates plots visualising the outputs produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()`. 35 36 Arguments: 37 ---- 38 - **arg_dict** : *dict* 39 40 Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args()`. 41 42 - **show** : *boolean* 43 44 If True, display plots. Default is False. 45 46 - **pdf** : *boolean* 47 48 If True, save plots in pdf format. If False, save plots in png format. Default is False. 49 50 - **DIR** : *str* 51 52 Parent directory for output files. 53 54 Returns: 55 --- 56 57 Saves plots corresponding to each of the files produced by `pqcprep.training_tools.train_QNN()` and `pqcprep.training_tools.test_QNN()` (apart from `metrics_<NAME_STR>.npy`) 58 in the directory `DIR/plots`. 59 60 """ 61 62 name_str = vars_to_name_str(arg_dict) 63 pdf_str = ".pdf" if pdf else ".png" 64 65 # data to plot 66 arrs =["loss", "mismatch", "grad", "vargrad"] 67 labels=["Loss", "Mismatch", r"$|\nabla_\theta W|^2$",r"Var($\partial_\theta W$)" ] 68 69 for i in np.arange(len(arrs)): 70 arr = np.load(os.path.join(DIR, "outputs", f"{arrs[i]}{name_str}.npy")) 71 72 plt.figure(figsize=figsize) 73 plt.xlabel("Epoch", fontsize=fontsize) 74 plt.ylabel(labels[i], fontsize=fontsize) 75 plt.yscale('log') 76 plt.tick_params(axis="both", labelsize=ticksize) 77 plt.scatter(np.arange(len(arr))+1,arr,color="red") 78 79 plt.tight_layout() 80 plt.savefig(os.path.join(DIR, "plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 81 82 if show: 83 plt.show() 84 plt.close() 85 86 # plot mismatch by state 87 dic = np.load(os.path.join(DIR, "outputs", f"mismatch_by_state{name_str}.npy"),allow_pickle='TRUE').item() 88 mismatch = list(dic.values()) 89 x_arr = x_trans_arr(arg_dict["n"]) 90 91 plt.figure(figsize=figsize) 92 plt.xlabel(r"$f$ (Hz)", fontsize=fontsize) 93 plt.ylabel("Mismatch", fontsize=fontsize) 94 plt.yscale('log') 95 plt.tick_params(axis="both", labelsize=ticksize) 96 plt.scatter(x_arr,mismatch,color="red") 97 98 plt.tight_layout() 99 plt.savefig(os.path.join(DIR, "plots", f"mismatch_by_state{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 100 101 if show: 102 plt.show() 103 plt.close() 104 105 # plot extracted phase function 106 mint = arg_dict["mint"] if arg_dict["mint"] != None else arg_dict["m"] 107 if arg_dict["phase_reduce"]: 108 mint = 0 109 phase = np.load(os.path.join(DIR, "outputs", f"phase{name_str}.npy")) 110 phase_target_rounded = get_phase_target(m=arg_dict["m"], psi_mode=arg_dict["func_str"], phase_reduce=arg_dict["phase_reduce"], mint=mint) 111 phase_target = psi(np.arange(2**arg_dict["n"]),mode=arg_dict["func_str"]) 112 113 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 114 115 ax[0].plot(x_arr,phase_target, color="black") 116 ax[0].plot(x_arr,phase_target_rounded, color="gray", ls="--") 117 ax[0].scatter(x_arr,phase, color="red") 118 119 ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize) 120 ax[0].tick_params(axis="both", labelsize=ticksize) 121 ax[0].set_xticks([]) 122 123 ax[1].scatter(x_arr,phase_target_rounded-phase, color="red") 124 ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize) 125 ax[1].tick_params(axis="both", labelsize=ticksize) 126 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 127 128 fig.tight_layout() 129 fig.savefig(os.path.join(DIR, "plots", f"phase{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 130 131 if show: 132 plt.show() 133 plt.close() 134 135 136 return 0
Generates plots visualising the outputs produced by pqcprep.training_tools.train_QNN() and pqcprep.training_tools.test_QNN().
Arguments:
arg_dict : dict
Dictionary containing varialbe information created using
pqcprep.file_tools.compress_args().show : boolean
If True, display plots. Default is False.
pdf : boolean
If True, save plots in pdf format. If False, save plots in png format. Default is False.
DIR : str
Parent directory for output files.
Returns:
Saves plots corresponding to each of the files produced by pqcprep.training_tools.train_QNN() and pqcprep.training_tools.test_QNN() (apart from metrics_<NAME_STR>.npy)
in the directory DIR/plots.
138def benchmark_plots_ampl(arg_dict,DIR, show=False, pdf=False): 139 """ 140 Generates plots visualising the outputs produced by `pqcprep.training_tools.ampl_train_QNN()`. 141 142 Arguments: 143 ---- 144 - **arg_dict** : *dict* 145 146 Dictionary containing varialbe information created using `pqcprep.file_tools.compress_args_ampl()`. 147 148 - **show** : *boolean* 149 150 If True, display plots. Default is False. 151 152 - **pdf** : *boolean* 153 154 If True, save plots in pdf format. Default is False. 155 156 - **DIR** : *str* 157 158 Parent directory for output files. 159 160 Returns: 161 --- 162 163 Saves plots corresponding to each of the files produced by `pqcprep.training_tools.ampl_train_QNN()` 164 in the directory `DIR/ampl_plots`. 165 """ 166 name_str = vars_to_name_str_ampl(arg_dict) 167 pdf_str = ".pdf" if pdf else ".png" 168 169 # data to plot 170 arrs =["loss", "mismatch"] 171 labels=["Loss", "Mismatch"] 172 173 for i in np.arange(len(arrs)): 174 arr = np.load(os.path.join(DIR, "ampl_outputs", f"{arrs[i]}{name_str}.npy")) 175 176 plt.figure(figsize=figsize) 177 plt.xlabel("Epoch", fontsize=fontsize) 178 plt.ylabel(labels[i], fontsize=fontsize) 179 plt.yscale('log') 180 plt.tick_params(axis="both", labelsize=ticksize) 181 plt.scatter(np.arange(len(arr))+1,arr,color="red") 182 183 plt.tight_layout() 184 plt.savefig(os.path.join(DIR, "ampl_plots", f"{arrs[i]}{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 185 186 if show: 187 plt.show() 188 plt.close() 189 190 # plot amplitude 191 x_arr = x_trans_arr(arg_dict["n"]) 192 193 ampl_vec = np.abs(np.load(os.path.join(DIR, "ampl_outputs", f"statevec{name_str}.npy"))) 194 ampl_target = np.array([A(i, mode=arg_dict["func_str"]) for i in x_arr]) 195 ampl_target /= np.sqrt(np.sum(ampl_target**2)) 196 197 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 198 ax[0].plot(x_arr,ampl_target, color="black") 199 ax[0].scatter(x_arr,ampl_vec,color="red") 200 ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize) 201 ax[0].tick_params(axis="both", labelsize=ticksize) 202 ax[0].set_xticks([]) 203 ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red") 204 ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize) 205 ax[1].tick_params(axis="both", labelsize=ticksize) 206 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 207 208 fig.tight_layout() 209 fig.savefig(os.path.join(DIR, "ampl_plots", f"amplitude{name_str}{pdf_str}"), bbox_inches='tight', dpi=500) 210 211 if show: 212 plt.show() 213 plt.close() 214 215 return 0
Generates plots visualising the outputs produced by pqcprep.training_tools.ampl_train_QNN().
Arguments:
arg_dict : dict
Dictionary containing varialbe information created using
pqcprep.file_tools.compress_args_ampl().show : boolean
If True, display plots. Default is False.
pdf : boolean
If True, save plots in pdf format. Default is False.
DIR : str
Parent directory for output files.
Returns:
Saves plots corresponding to each of the files produced by pqcprep.training_tools.ampl_train_QNN()
in the directory DIR/ampl_plots.
217def waveform_plots(name_str_phase, name_str_ampl, in_dir, out_dir,comp, no_UA=False,show=False, pdf=False): 218 """ 219 Plot the amplitude, phase, and full waveform of a state prepared using QCNNs. 220 221 This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 222 223 Arguments: 224 ---- 225 226 - **name_str_phase** : *str* 227 228 Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 229 230 - **name_str_ampl** : *str* 231 232 Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True. 233 234 - **in_dir** : *str* 235 236 Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 237 238 - **out_dir** : *str* 239 240 Directory for output files. 241 242 - **comp** : *str*, *optional* 243 244 Compare outputs to results from Hayes 2023: if `'GR'` compare to results obtained using the Grover-Rudolph algorithm and if `'QGAN'` compare to results 245 obtained using the QGAN. 246 247 - **no_UA** : *boolean* 248 249 If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False. 250 251 - **show** : *boolean* 252 253 If True, display plots. Default is False. 254 255 - **pdf** : *boolean* 256 257 If True, save plots in pdf format. Default is False. 258 259 260 Returns: 261 ---- 262 263 Plots of the amplitude, phase, and full waveform saved in the directory `out_dir`. 264 """ 265 266 pdf_str = ".pdf" if pdf else ".png" 267 268 # read in given files 269 weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 270 weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 271 272 # extract information from name strings 273 phase_reduce = '(PR)' in name_str_phase 274 real_p = '(r)' in name_str_phase 275 if '(CL)' in name_str_phase: 276 repeat_params="CL" 277 elif '(IL)' in name_str_phase: 278 repeat_params="IL" 279 elif '(both)' in name_str_phase: 280 repeat_params="both" 281 else: 282 repeat_params=None 283 284 psi_mode=None 285 func_A=None 286 for i in ["psi", "linear", "quadratic", "sine"]: 287 if name_str_phase.count(i)==1: 288 psi_mode=i 289 break 290 if psi_mode==None: 291 raise ValueError("Name string could not be interpreted.") 292 for i in ["x76", "linear", "uniform"]: 293 if name_str_ampl.count(i)==1: 294 func_A=i 295 break 296 if func_A==None: 297 raise ValueError("Name string could not be interpreted.") 298 299 n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3] 300 if int(n) < 10: 301 L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 302 m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5] 303 else: 304 L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 305 m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6] 306 307 mint=None 308 for i in np.arange(int(m)+1): 309 if name_str_phase.count(f"({i})")==1: 310 mint=i 311 break 312 s = f"_{n}_{m}({mint})_" 313 if name_str_phase.count(s) != 1 : 314 raise ValueError("Name string could not be interpreted.") 315 else: 316 L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2] 317 s = f"_{n}_{m}({mint})_{L_phase}_" 318 if name_str_phase.count(s) != 1 : 319 raise ValueError("Name string could not be interpreted.") 320 321 n= int(n) 322 m=int(m) 323 L_phase=int(L_phase) 324 L_A=int(L_A) 325 mint=int(mint) 326 327 # generate x array 328 x_arr = x_trans_arr(n) 329 330 # calculate target outputs 331 ampl_target = np.array([A(i, mode=func_A) for i in x_arr]) 332 ampl_target = ampl_target / np.sqrt(np.sum(ampl_target**2)) 333 334 phase_rounded = get_phase_target(m=m, psi_mode=psi_mode, phase_reduce=phase_reduce, mint=mint) 335 phase_target = psi(np.arange(2**n),mode=psi_mode) 336 337 h_target = ampl_target * np.exp(2*1.j*np.pi* phase_target) 338 wave_real_target = np.real(h_target) 339 wave_im_target = np.imag(h_target) 340 341 h_target_rounded = ampl_target * np.exp(2*1.j*np.pi* phase_rounded) 342 wave_real_target_rounded = np.real(h_target_rounded) 343 wave_im_target_rounded = np.imag(h_target_rounded) 344 345 # load Hayes 2023 data for comparison 346 if comp=="GR": 347 ampl_vec_comp=np.abs(load_data_H23("amp_state_GR")) 348 h_comp= load_data_H23("full_state_GR") 349 psi_LPF = load_data_H23("psi_LPF_processed") 350 comp_label="GR" 351 comp=True 352 elif comp=="QGAN": 353 ampl_vec_comp=np.abs(load_data_H23("amp_state_QGAN")) 354 h_comp= load_data_H23("full_state_QGAN") 355 psi_LPF = load_data_H23("psi_LPF_processed") 356 comp_label="QGAN" 357 comp=True 358 else: 359 comp=False 360 361 if comp and no_UA: 362 comp=False 363 print("Comparison to Hayes 2023 not shown due to incompatible amplitude function.") 364 if comp and not (psi_mode=="psi" and func_A=="x76"): 365 comp=False 366 print("Comparison to Hayes 2023 not shown due to incompatible amplitude or phase function.") 367 368 # get PQC state 369 state_vec = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA) 370 phase = phase_from_state(state_vec) 371 372 ampl_vec = np.abs(state_vec) 373 real_wave =np.real(state_vec) 374 im_wave = np.imag(state_vec) 375 376 # plot amplitude 377 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 378 379 ax[0].plot(x_arr,ampl_target, color="black") 380 if comp: 381 ax[0].scatter(x_arr,ampl_vec_comp,label=comp_label, color="blue") 382 ax[0].scatter(x_arr,ampl_vec,label="QCNN", color="red") 383 384 ax[0].set_ylabel(r"$\tilde A(f)$", fontsize=fontsize) 385 if comp: ax[0].legend(fontsize=fontsize, loc='upper right') 386 ax[0].tick_params(axis="both", labelsize=ticksize) 387 ax[0].set_xticks([]) 388 389 if comp: 390 ax[1].scatter(x_arr,ampl_target-ampl_vec_comp,label=comp_label, color="blue") 391 ax[1].scatter(x_arr,ampl_target-ampl_vec,label="QCNN", color="red") 392 393 ax[1].set_ylabel(r"$\Delta \tilde A(f)$", fontsize=fontsize) 394 ax[1].tick_params(axis="both", labelsize=ticksize) 395 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 396 397 fig.tight_layout() 398 fig.savefig(os.path.join(out_dir,f"amplitude{pdf_str}"), bbox_inches='tight', dpi=500) 399 400 if show: 401 plt.show() 402 plt.close() 403 404 # plot phase 405 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [1.5, 1]}) 406 407 ax[0].plot(x_arr,phase_target, color="black") 408 ax[0].plot(x_arr,phase_rounded, color="gray", ls="--") 409 if comp: 410 ax[0].scatter(x_arr,psi_LPF,label="LPF", color="blue") 411 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 412 413 ax[0].set_ylabel(r"$\Psi (f)$", fontsize=fontsize) 414 if comp: ax[0].legend(fontsize=fontsize, loc='upper right') 415 ax[0].tick_params(axis="both", labelsize=ticksize) 416 ax[0].set_xticks([]) 417 418 if comp: 419 ax[1].scatter(x_arr,phase_target-psi_LPF,label="LPF + ", color="blue") 420 ax[1].scatter(x_arr,phase_rounded-phase,label="QCNN", color="red") 421 422 ax[1].set_ylabel(r"$\Delta \Psi(f)$", fontsize=fontsize) 423 ax[1].tick_params(axis="both", labelsize=ticksize) 424 ax[1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 425 426 fig.tight_layout() 427 fig.savefig(os.path.join(out_dir,f"phase{pdf_str}"), bbox_inches='tight', dpi=500) 428 429 if show: 430 plt.show() 431 plt.close() 432 433 # plot full waveform 434 fig, ax = plt.subplots(2, 2, figsize=(2*figsize[0],figsize[1]), gridspec_kw={'height_ratios': [1.5, 1]}) 435 fig.subplots_adjust() 436 437 ax[0,0].plot(x_arr,wave_real_target, color="black") 438 ax[0,0].plot(x_arr,wave_real_target_rounded, color="gray", ls="--") 439 if comp: 440 ax[0,0].scatter(x_arr,np.real(h_comp),label="LPF + "+comp_label, color="blue") 441 ax[0,0].scatter(x_arr,real_wave,label="QCNN", color="red") 442 443 ax[0,0].set_ylabel(r"$\Re[\tilde h(f)]$", fontsize=fontsize) 444 if comp: ax[0,0].legend(fontsize=fontsize, loc='upper right') 445 ax[0,0].tick_params(axis="both", labelsize=ticksize) 446 ax[0,0].set_xticks([]) 447 448 if comp: 449 ax[1,0].scatter(x_arr,wave_real_target -np.real(h_comp),label="LPF"+comp_label, color="blue") 450 ax[1,0].scatter(x_arr,wave_real_target_rounded -real_wave,label="QCNN", color="red") 451 452 ax[1,0].set_ylabel(r"$\Delta \Re[\tilde h(f)]$", fontsize=fontsize) 453 ax[1,0].tick_params(axis="both", labelsize=ticksize) 454 ax[1,0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 455 456 ax[0,1].plot(x_arr,wave_im_target, color="black") 457 ax[0,1].plot(x_arr,wave_im_target_rounded, color="gray", ls="--") 458 if comp: 459 ax[0,1].scatter(x_arr,np.imag(h_comp),label="LPF + "+comp_label, color="blue") 460 ax[0,1].scatter(x_arr,im_wave,label="QCNN", color="red") 461 462 ax[0,1].set_ylabel(r"$\Im[\tilde h(f)]$", fontsize=fontsize) 463 if comp: ax[0,1].legend(fontsize=fontsize, loc='upper right') 464 ax[0,1].tick_params(axis="both", labelsize=ticksize) 465 ax[0,1].set_xticks([]) 466 467 if comp: 468 ax[1,1].scatter(x_arr,wave_im_target -np.imag(h_comp),label="LPF + "+comp_label, color="blue") 469 ax[1,1].scatter(x_arr,wave_im_target_rounded -im_wave,label="QCNN", color="red") 470 471 ax[1,1].set_ylabel(r"$\Delta \Im[\tilde h(f)]$", fontsize=fontsize) 472 ax[1,1].tick_params(axis="both", labelsize=ticksize) 473 ax[1,1].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 474 475 fig.tight_layout() 476 fig.savefig(os.path.join(out_dir,f"waveform{pdf_str}"), bbox_inches='tight', dpi=500) 477 478 if show: 479 plt.show() 480 plt.close() 481 482 return 0
Plot the amplitude, phase, and full waveform of a state prepared using QCNNs.
This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case.
Arguments:
name_str_phase : str
Name string (produced using
pqcprep.file_tools.vars_to_name_str()) corresponding to the weights of the QCNN used for phase encoding.name_str_ampl : str
Name string (produced using
pqcprep.file_tools.vars_to_name_str_ampl()) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored ifno_UAis True.in_dir : str
Directory containing the input files (as specified by
name_str_phaseandname_str_ampl). The directory is expected to contain aweightsfile for both cases.out_dir : str
Directory for output files.
comp : str, optional
Compare outputs to results from Hayes 2023: if
'GR'compare to results obtained using the Grover-Rudolph algorithm and if'QGAN'compare to results obtained using the QGAN.no_UA : boolean
If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.
show : boolean
If True, display plots. Default is False.
pdf : boolean
If True, save plots in pdf format. Default is False.
Returns:
Plots of the amplitude, phase, and full waveform saved in the directory out_dir.
485def two_register_plots(name_str_phase, name_str_ampl,in_dir, out_dir, operators="QRQ",no_UA=False, show=False, pdf=False): 486 """ 487 Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs. 488 489 This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case. 490 491 Arguments: 492 ---- 493 494 - **name_str_phase** : *str* 495 496 Name string (produced using `pqcprep.file_tools.vars_to_name_str()`) corresponding to the weights of the QCNN used for phase encoding. 497 498 - **name_str_ampl** : *str* 499 500 Name string (produced using `pqcprep.file_tools.vars_to_name_str_ampl()`) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored if `no_UA` is True. 501 502 - **in_dir** : *str* 503 504 Directory containing the input files (as specified by `name_str_phase` and `name_str_ampl`). The directory is expected to contain a `weights` file for both cases. 505 506 - **out_dir** : *str* 507 508 Directory for output files. 509 510 - **operators** : *str*, 511 512 Which operators to apply to the registers. Options are `'QRQ'` for full phase extraction, `'RQ'` for partial phase extraction, and `'Q'` for function 513 evaluation only. Default is `'QRQ'`. 514 515 - **no_UA** : *boolean* 516 517 If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False. 518 519 - **show** : *boolean* 520 521 If True, display plots. Default is False. 522 523 - **pdf** : *boolean* 524 525 If True, save plots in pdf format. Default is False. 526 527 528 Returns: 529 ---- 530 531 Plots of the amplitude and phase of the two-register state saved in the directory `out_dir`. 532 """ 533 pdf_str = ".pdf" if pdf else ".png" 534 535 # read in given files 536 weights_phase = os.path.join(in_dir, f"weights{name_str_phase}.npy") 537 weights_ampl = os.path.join(in_dir, f"weights{name_str_ampl}.npy") 538 539 # extract information from name strings 540 real_p = '(r)' in name_str_phase 541 if '(CL)' in name_str_phase: 542 repeat_params="CL" 543 elif '(IL)' in name_str_phase: 544 repeat_params="IL" 545 elif '(both)' in name_str_phase: 546 repeat_params="both" 547 else: 548 repeat_params=None 549 550 psi_mode=None 551 func_A=None 552 for i in ["psi", "linear", "quadratic", "sine"]: 553 if name_str_phase.count(i)==1: 554 psi_mode=i 555 break 556 if psi_mode==None: 557 raise ValueError("Name string could not be interpreted.") 558 for i in ["x76", "linear", "uniform"]: 559 if name_str_ampl.count(i)==1: 560 func_A=i 561 break 562 if func_A==None: 563 raise ValueError("Name string could not be interpreted.") 564 565 n = name_str_ampl[1] if name_str_ampl[2]=="_" else name_str_ampl[1:3] 566 if int(n) < 10: 567 L_A= name_str_ampl[3] if name_str_ampl[4]=="_" else name_str_ampl[3:5] # this assumes nint=None 568 m = name_str_phase[3] if (name_str_phase[4]=="_" or name_str_phase[4]=="(") else name_str_phase[3:5] 569 else: 570 L_A= name_str_ampl[4] if name_str_ampl[5]=="_" else name_str_ampl[4:6] # this assumes nint=None 571 m = name_str_phase[4] if (name_str_phase[5]=="_" or name_str_phase[5]=="(") else name_str_phase[4:6] 572 573 mint=None 574 for i in np.arange(int(m)+1): 575 if name_str_phase.count(f"({i})")==1: 576 mint=i 577 break 578 s = f"_{n}_{m}({mint})_" 579 if name_str_phase.count(s) != 1 : 580 raise ValueError("Name string could not be interpreted.") 581 else: 582 L_phase = name_str_phase[len(s)] if name_str_phase[len(s)+1]=="_" else name_str_phase[len(s)+2] 583 s = f"_{n}_{m}({mint})_{L_phase}_" 584 if name_str_phase.count(s) != 1 : 585 raise ValueError("Name string could not be interpreted.") 586 587 n= int(n) 588 m=int(m) 589 L_phase=int(L_phase) 590 L_A=int(L_A) 591 mint=int(mint) 592 593 # generate x array 594 x_arr = x_trans_arr(n) 595 596 # get PQC state 597 state_vec, state_vec_full = full_encode(n,m, weights_ampl, weights_phase, L_A, L_phase,real_p=real_p,repeat_params=repeat_params,no_UA=no_UA, full_state_vec=True, operators=operators) 598 phase = phase_from_state(state_vec) 599 full_phase = phase_from_state(state_vec_full) 600 601 # plot amplitude 602 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True) 603 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 604 605 ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize) 606 ax[0].tick_params(axis="both", labelsize=ticksize) 607 im=ax[0].pcolormesh(x_arr,np.arange(2**m), np.abs(state_vec_full)) 608 im.set_clim(np.min(np.abs(state_vec_full)),np.max(np.abs(state_vec_full))) 609 locs = ax[0].get_yticks() 610 ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 611 ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 612 ax[1].set_visible(False) 613 614 cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05) 615 cb.ax.tick_params(labelsize=ticksize) 616 617 fig.savefig(os.path.join(out_dir,f"two_reg_ampl{pdf_str}"), bbox_inches='tight', dpi=500) 618 619 if show: 620 plt.show() 621 plt.close() 622 623 # plot phase 624 fig, ax = plt.subplots(2, 1, figsize=figsize, gridspec_kw={'height_ratios': [2**m, 1]},constrained_layout=True) 625 ax[0].scatter(x_arr,phase,label="QCNN", color="red") 626 627 ax[0].set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize) 628 ax[0].tick_params(axis="both", labelsize=ticksize) 629 im=ax[0].pcolormesh(x_arr,np.arange(2**m), full_phase, cmap="hsv") 630 im.set_clim(0, 2*np.pi) 631 locs = ax[0].get_yticks() 632 ax[0].set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 633 ax[0].set_xlabel(r"$f$ (Hz)", fontsize=fontsize) 634 ax[1].set_visible(False) 635 636 cb=fig.colorbar(im,ax=ax.ravel().tolist(), location='top', orientation='horizontal', pad=0.05) 637 cb.ax.tick_params(labelsize=ticksize) 638 639 fig.savefig(os.path.join(out_dir,f"two_reg_phase{pdf_str}"), bbox_inches='tight', dpi=500) 640 641 if show: 642 plt.show() 643 plt.close() 644 645 # plot amplitude and phase 646 X = x_arr 647 Y = np.arange(2**m) 648 X, Y = np.meshgrid(X, Y) 649 Z = np.abs(state_vec_full) 650 P_norm = (full_phase) / (2* np.pi) # normalise to value in [0,1] 651 652 # Plot the surface 653 fig, ax = plt.subplots(figsize=(10,10), subplot_kw={"projection": "3d"}, constrained_layout=False) 654 surf = ax.plot_surface(X, Y, Z,facecolors=cm.hsv(P_norm),linewidth=0, antialiased=True, shade=False) 655 ax.set_xlabel(r"$f$ (Hz)", fontsize=fontsize, labelpad=20) 656 ax.set_ylabel(r"$\tilde{\Psi}(f)$", fontsize=fontsize, labelpad=20) 657 locs = ax.get_yticks() 658 ax.set_yticks(locs[1:-1], [np.round(bin_to_dec(dec_to_bin(i,m),nint=0)*2*np.pi,2) for i in np.arange(len(locs)-2)]) 659 ax.tick_params(axis="both", labelsize=ticksize) 660 661 cb=fig.colorbar(cm.ScalarMappable(cmap=cm.hsv, norm=surf.norm), ax = ax , shrink = 0.6, aspect = 5) 662 cb.ax.tick_params(labelsize=ticksize) 663 locs = cb.ax.get_yticks() 664 cb.ax.set_yticks( [0, 0.25,0.5,0.75,1], [r'$0$',r'$\frac{\pi}{2}$', r'$\pi$',r'$\frac{3\pi}{2}$' , r'$2\pi$']) 665 666 fig.savefig(os.path.join(out_dir,f"two_reg_3D{pdf_str}"), bbox_inches='tight', dpi=500) 667 668 if show: 669 plt.show() 670 plt.close() 671 672 return 0
Plot the two-register statevector resulting from applying amplitude and phase preparation QCNNs.
This function attempts to extract information about QCNN configurations from name strings which might not be successful in the general case.
Arguments:
name_str_phase : str
Name string (produced using
pqcprep.file_tools.vars_to_name_str()) corresponding to the weights of the QCNN used for phase encoding.name_str_ampl : str
Name string (produced using
pqcprep.file_tools.vars_to_name_str_ampl()) corresponding to the weights of the QCNN used for amplitude preparation. This is ignored ifno_UAis True.in_dir : str
Directory containing the input files (as specified by
name_str_phaseandname_str_ampl). The directory is expected to contain aweightsfile for both cases.out_dir : str
Directory for output files.
operators : str,
Which operators to apply to the registers. Options are
'QRQ'for full phase extraction,'RQ'for partial phase extraction, and'Q'for function evaluation only. Default is'QRQ'.no_UA : boolean
If True, a Hadamard transform is applied to the input register instead of preparing the amplitude via a PQC. Default is False.
show : boolean
If True, display plots. Default is False.
pdf : boolean
If True, save plots in pdf format. Default is False.
Returns:
Plots of the amplitude and phase of the two-register state saved in the directory out_dir.