xfer.model_handler.ModelHandler¶
-
class
xfer.model_handler.
ModelHandler
(module, context_function=<function cpu>, num_devices=1, data_name='data')¶ Bases:
object
Class for model manipulation and feature extraction.
Parameters: Methods
__init__
Initialize self. add_layer_bottom
Add layer to input of model. add_layer_top
Add layer to output of model. drop_layer_bottom
Remove layers from input of model. drop_layer_top
Remove layers from output of model. get_layer_names_matching_type
Return names of layers of specified type. get_layer_output
Function to extract features from data iterator with model. get_layer_parameters
Get list of layer parameters associated with the the layer names given. get_layer_type
Return type of named layer. get_module
Return MXNet Module using the model symbol and parameters. save_symbol
Serialise model symbol graph. update_sym
Update symbol attribute, layer names, and layer types dict and clean parameters. visualize_net
Display computational graph of model. Attributes
layer_names
Get list of names of model layers. -
drop_layer_top
(num_layers_to_drop=1)¶ Remove layers from output of model.
Parameters: n (int) – Number of layers to remove from model output.
-
drop_layer_bottom
(num_layers_to_drop=1)¶ Remove layers from input of model.
Parameters: n (int) – Number of layers to remove from model input.
-
add_layer_top
(layer_list)¶ Add layer to output of model. model layers = (layer1, layer2, layer3), layer_list = [layerA, layerB] -> model layers = (layer1, layer2, layer3, layerA, layerB)
Parameters: layer_list (list( mx.symbol
)) – List of MxNet symbol layers to be added to model output.
-
add_layer_bottom
(layer_list)¶ Add layer to input of model. model layers = (layer1, layer2, layer3), layer_list = [layerA, layerB] -> model layers = (layerA, layerB, layer1, layer2, layer3)
Parameters: layer_list (list( mx.symbol
)) – List of MxNet symbol layers to be added to model input.
-
get_module
(iterator, fixed_layer_parameters=None, random_layer_parameters=None)¶ Return MXNet Module using the model symbol and parameters.
Parameters: Returns: MXNet module
Return type: mx.module.Module
-
get_layer_type
(layer_name)¶ Return type of named layer.
Parameters: name (str) – Name of layer being inspected. Returns: Layer type Return type: str
-
get_layer_names_matching_type
(layer_type)¶ Return names of layers of specified type.
Parameters: layer_type (str) – Return list of layers of this type. Returns: Names of layers with specified type Return type: list(str)
-
get_layer_output
(data_iterator, layer_names)¶ Function to extract features from data iterator with model. Returns a dictionary of layer_name -> numpy array of features extracted (flattened).
Parameters: Returns: Ordered Dictionary of features ({layer_name: features}), list of labels Layer names in the ordered dictionary follow the same order as input list of layer_names
Return type: OrderedDict[str,
numpy.array
], list(int)
-
get_layer_parameters
(layer_names)¶ Get list of layer parameters associated with the the layer names given.
Parameters: layer_names (list(str)) – List of layer names. Returns: List of layer parameters Return type: list(str)
-
visualize_net
()¶ Display computational graph of model.
-
save_symbol
(model_name)¶ Serialise model symbol graph.
Parameters: model_name (str) – Prefix to file name (model_name-symbol.json).
-
update_sym
(new_symbol)¶ Update symbol attribute, layer names, and layer types dict and clean parameters.
Parameters: new_symbol ( mx.symbol.Symbol
) – Symbol with which to update ModelHandler.
-