xfer.NeuralNetworkRepurposer¶
-
class
xfer.
NeuralNetworkRepurposer
(source_model: mxnet.module.module.Module, context_function=<function cpu>, num_devices=1, batch_size=64, num_epochs=5, optimizer='sgd', optimizer_params=None)¶ Bases:
xfer.repurposer.Repurposer
Base class for repurposers that create a target neural network from a source neural network through Transfer Learning.
- Transfer layer architecture and weights from a source neural network (Transfer) and
- Train a target neural network adapting the transferred network to new data set (Learn)
Parameters: - source_model (
mxnet.mod.Module
) – Source neural network to do transfer leaning from - context_function (function) – MXNet context function that provides device type context
- num_devices (int) – Number of devices to use to train target neural network
- batch_size (int) – Size of data batches to be used for training the target neural network
- num_epochs (int) – Number of epochs to be used for training the target neural network
- optimizer (str) – Optimizer required by MXNet to train target neural network. Default: ‘sgd’
- optimizer_params (dict(str, float)) – Optimizer params required by MXNet to train target neural network. Default: {‘learning_rate’: 1e-3}
Methods
__init__
Initialize self. deserialize
Uses dictionary to set attributes of repurposer. get_params
Get parameters of repurposer that are in the constructor predict_label
Perform predictions on test data using the target_model (repurposed neural network). predict_probability
Perform predictions on test data using the target_model (repurposed neural network). repurpose
Train a neural network by transferring layers/weights from source_model. save_repurposer
Serialize the repurposed model (source_model, target_model and supporting info) and save it to given file_path. serialize
Serialize repurposer to dictionary. -
repurpose
(train_iterator: mxnet.io.io.DataIter)¶ Train a neural network by transferring layers/weights from source_model. Set self.target_model to the repurposed neural network.
Parameters: train_iterator ( mxnet.io.DataIter
) – Training data iterator to use to extract features from source_model
-
predict_probability
(test_iterator: mxnet.io.io.DataIter)¶ Perform predictions on test data using the target_model (repurposed neural network).
Parameters: test_iterator ( mxnet.io.DataIter
) – Test data iterator to return predictions forReturns: Predicted probabilities Return type: numpy.ndarray
-
predict_label
(test_iterator: mxnet.io.io.DataIter)¶ Perform predictions on test data using the target_model (repurposed neural network).
Parameters: test_iterator (mxnet.io.DataIter) – Test data iterator to return predictions for Returns: Predicted labels Return type: numpy.ndarray
-
serialize
(file_prefix)¶ Serialize repurposer to dictionary.
Returns: Dictionary describing repurposer Return type: dict
-
deserialize
(input_dict)¶ Uses dictionary to set attributes of repurposer.
Parameters: input_dict (dict) – Dictionary containing values for attributes to be set to
-
save_repurposer
(model_name, model_directory='', save_source_model=None)¶ Serialize the repurposed model (source_model, target_model and supporting info) and save it to given file_path.
Parameters: - model_name (str) – Name to save repurposer to.
- model_directory (str) – File directory to save repurposer in.
- save_source_model (boolean) – Flag to choose whether to save repurposer source model. Will use default if set to None. (MetaModelRepurposer default: True, NeuralNetworkRepurposer default: False)