bspysmg.model.training#

File containing functions for training a surrogate model in pytorch taking into account the error in nano amperes.

Functions

default_train_step(model, dataloader, ...)

Performs the training step of a model within a single epoch and returns the current loss and current trained model.

default_val_step(model, dataloader, criterion)

Performs the validation step of a model within a single epoch and returns the validation loss.

generate_surrogate_model(configs[, ...])

It loads the training and validation datasets from the npz file specified in the field data/dataset_paths of the configs dictionary.

init_seed(configs)

Initializes a random seed for training.

postprocess(dataloader, model, criterion, ...)

Plots error vs output and error histogram for given dataset and saves it to specified directory.

to_device(inputs)

Copies input tensors from CPU to GPU device for processing.

train_loop(model, info_dict, dataloaders, ...)

Performs the training of a model and returns the trained model, training loss validation loss.