Source code for bspysmg.utils.plots

"""
File containing several methods for plotting purposes.
"""
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from brainspy.utils.waveform import WaveformManager


[docs]def plot_error_hist(targets: np.array, prediction: np.array, error: np.array, mse: np.array, save_dir: str, name: str = "error") -> None: """ Plots and saves error histogram graph for between given target and predicted. Parameters ---------- targets : np.array Reference data used for training/validation. prediction : np.array Predictions made by model. error : np.array Errors correspoding to each target data point. mse : np.array Mean squared error correspoding to each target data point. save_dir : str Name of the path where the graph is to be saved. name : str [Optional] Name of the file for the graph. """ assert targets.size == prediction.size assert targets.size == error.size assert mse >= 0 plt.figure() plt.title('Predicted vs True values') plt.subplot(1, 2, 1) plt.plot(targets, prediction, ".") plt.xlabel("True Output (nA)") plt.ylabel("Predicted Output (nA)") targets_and_prediction_array = np.concatenate((targets, prediction)) min_out = np.min(targets_and_prediction_array) max_out = np.max(targets_and_prediction_array) plt.plot(np.linspace(min_out, max_out), np.linspace(min_out, max_out), "k") plt.title(f"RMSE {np.sqrt(mse)} (nA)") plt.subplot(1, 2, 2) plt.hist(np.reshape(error, error.size), 500) x_lim = 0.25 * np.max([np.abs(error.min()), error.max()]) plt.xlim([-x_lim, x_lim]) plt.title("Error histogram (nA) ") fig_loc = os.path.join(save_dir, name) plt.tight_layout() plt.savefig(fig_loc, dpi=300) plt.close()
[docs]def plot_error_vs_output(targets: np.array, error: np.array, save_dir: str, name: str = "error_vs_output") -> None: """ Plots and saves error vs output graph for given error and their correspoding output. Parameters ---------- targets : np.array Reference data used for training/validation. error : np.array Errors correspoding to each target data point. save_dir : str Name of the path where the graph is to be saved. name : str [Optional] Name of the file for the graph. """ assert targets.size == error.size plt.figure() plt.plot(targets, error, ".") plt.plot( np.linspace( targets.min(), targets.max(), len(error), ), np.zeros_like(error), ) plt.title("Error vs Output") plt.xlabel("Output (nA)") plt.ylabel("Error (nA)") fig_loc = os.path.join(save_dir, name) plt.savefig(fig_loc, dpi=300) plt.close()
[docs]def plot_waves(inputs: np.array, outputs: np.array, input_no: int, output_no: int, batch: int, legend: np.array, save_directory: str) -> None: """ Plots and saves input and output waves for the model. Image is overwritten for each batch and it is used to control what is happening to the device input output relationship for each batch. Parameters ---------- inputs : np.array Input data used for training/validation of model. outputs : np.array Output generated by model corresponding the inputs. input_no : int Input activation electrode number. output_no : int Output electrode number. batch : int The current batch number of the data. legend : np.array List of headers of text file which contains data used for training, validation, and testing. save_directory : str Name of the file for the graph. """ plt.figure() plt.suptitle(f'I/O data for batch {batch}') plt.subplot(211) plt.plot(inputs) plt.ylabel('Inputs (V)') plt.xlabel('Time points (a.u.)') plt.legend(legend[:input_no]) plt.subplot(212) plt.plot(outputs) plt.ylabel('Outputs (nA)') plt.legend(legend[-output_no:]) plt.xlabel('Time points (a.u.)') plt.tight_layout() plt.savefig(os.path.join(save_directory, 'example_batch')) plt.close()
[docs]def output_hist(outputs: np.array, data_dir: str, bins: int = 100, show: bool = False) -> None: """ Saves and optionally plots the histogram of output/predictions of the model. Parameters ---------- outputs : np.array Output generated by model. data_dir : str Name of the path where the graph is to be saved. bins : int [Optional] Number of bins for the histogram. show : bool [Optional] If set to true, it displays the generated histogram. """ plt.figure() plt.title("Output Histogram") plt.hist(outputs, bins=bins) plt.ylabel("Counts") plt.xlabel("Raw output (nA)") if show: plt.show() plt.savefig(data_dir + "/output_distribution") plt.close()
# def plot(self, x, y): # for i in range(np.shape(y)[1]): # plt.figure() # plt.plot(x) # plt.plot(y) # plt.show()
[docs]def iv_plot(result: np.array, inputs: np.array, input_electrode: int, data_dir: str = ".", save_plot: bool = None, show_plot: bool = False) -> None: """ Plots IV characteristics and optinally saves the graph for a given electrode number. Parameters ---------- result : np.array Output current values of an electrode. input_electrode : int Electrode number. save_plot : bool [Optional] If set to true, it saves the generated plot to current directory. show_plot : bool [Optional] If set to true, it displays the generated plot. """ plt.plot(inputs, result, label='IV Curve for electrode ' + str(input_electrode)) plt.xlabel('Voltage (V)') plt.ylabel('Current (nA)') if save_plot is not None: plt.savefig(data_dir + "/iv_plot") if show_plot: plt.show()
[docs]def multi_iv_plot(configs, inputs, output, save_plot=None, show_plot=True): """ Plots the IV curve of several devices in one plot. Devices can be the DNPU device or a surrogate model. Parameters ---------- configs : dict Dictionary containing the configurations for IV measurements with the following keys: 1. devices: list List of devices for which IV response is to be computed. This list contains the names of all the devices (A,B,C,D etc) involved in the experiment. 2. driver: dict It contains the configurations for each device in the experiment which are defined in the devices list. inputs : dict Dictionary containing the list of input signal waves for each device. outputs : dict Dictionary containing the list of output currents for each device. save_plot : str or None If None, the plot will not be saved, if a string provided, the plot will be saved at the specified dir. show_plot: boolean Whether to show the plot or not. """ ylabeldist = -10 electrode_id = 0 cmap = plt.get_cmap("tab10") for k, dev in enumerate(configs['devices']): fig, axs = plt.subplots(2, 4) # plt.grid(True) fig.suptitle('Device ' + dev + ' - Input voltage vs Output current') for i in range(2): for j in range(4): exp_index = j + i * 4 exp = "IV" + str(exp_index + 1) if exp_index < 7: if configs["driver"]['instruments_setup'][dev][ "activation_channel_mask"][exp_index] == 1: masked_idx = sum( configs["driver"]['instruments_setup'][dev] ["activation_channel_mask"][:exp_index + 1]) - 1 # Modifying x-axis temp = inputs[exp][dev][masked_idx] if not configs['driver']['instruments_setup'][ 'average_io_point_difference']: wm = WaveformManager({ 'slope_length': 0, 'plateau_length': int(configs['driver']['instruments_setup'] ['readout_sampling_frequency'] / configs['driver']['instruments_setup'] ['activation_sampling_frequency']) }) temp = wm.points_to_plateaus( torch.tensor( inputs[exp][dev] [masked_idx])).detach().cpu().numpy() axs[i, j].plot(temp, output[exp][dev], color=cmap(exp_index)) axs[i, j].set_ylabel('output (nA)', labelpad=ylabeldist) axs[i, j].set_xlabel('input (V)', labelpad=1) axs[i, j].xaxis.grid(True) axs[i, j].yaxis.grid(True) else: # if self.configs["driver"]['instruments_setup'][ # dev]["activation_channel_mask"][ # exp_index] == 1: # axs[i, # j].plot(input_waveform[exp_index] # [:, electrode_id]) # axs[i, j].set_title( # devlist[dev]["activation_channels"] # [exp_index]) axs[i, j].plot([]) axs[i, j].set_xlabel('Channel Masked') electrode_id += 1 else: for z, key in enumerate(inputs.keys()): m = 0 if configs["driver"]['instruments_setup'][dev][ "activation_channel_mask"][z] == 1: masked_idx = sum( configs["driver"]['instruments_setup'][dev] ["activation_channel_mask"][:z + 1]) - 1 temp = inputs[key][dev][masked_idx] if not configs['driver']['instruments_setup'][ 'average_io_point_difference']: wm = WaveformManager({ 'slope_length': 0, 'plateau_length': int(configs['driver']['instruments_setup'] ['readout_sampling_frequency'] / configs['driver']['instruments_setup'] ['activation_sampling_frequency']) }) temp = wm.points_to_plateaus( torch.tensor(inputs[key][dev][masked_idx]) ).detach().cpu().numpy( )[int(configs['driver']['instruments_setup'] ['readout_sampling_frequency'] / configs['driver']['instruments_setup'] ['activation_sampling_frequency']):] axs[i, j].plot(temp, label="IV" + str(z), color=cmap(z)) m += 1 #axs[i, j].yaxis.tick_right() #axs[i, j].yaxis.set_label_position("right") axs[i, j].set_ylabel('input (V)') axs[i, j].set_xlabel('points', labelpad=1) axs[i, j].set_title("Input input_signal") axs[i, j].xaxis.grid(True) axs[i, j].yaxis.grid(True) axs[i, j].legend() plt.subplots_adjust(hspace=0.3, wspace=0.35) if save_plot is not None: plt.savefig(save_plot) if show_plot: plt.show()