xfer.model_handler.ModelHandler

class xfer.model_handler.ModelHandler(module, context_function=<function cpu>, num_devices=1, data_name='data')

Bases: object

Class for model manipulation and feature extraction.

Parameters:
  • module (mx.module.Module) – MXNet module to be manipulated.
  • context_function (function) – MXNet context function.
  • num_devices (int) – Number of devices to run process on.
  • data_name (str) – Name of input layer of model.

Methods

__init__ Initialize self.
add_layer_bottom Add layer to input of model.
add_layer_top Add layer to output of model.
drop_layer_bottom Remove layers from input of model.
drop_layer_top Remove layers from output of model.
get_layer_names_matching_type Return names of layers of specified type.
get_layer_output Function to extract features from data iterator with model.
get_layer_parameters Get list of layer parameters associated with the the layer names given.
get_layer_type Return type of named layer.
get_module Return MXNet Module using the model symbol and parameters.
save_symbol Serialise model symbol graph.
update_sym Update symbol attribute, layer names, and layer types dict and clean parameters.
visualize_net Display computational graph of model.

Attributes

layer_names Get list of names of model layers.
drop_layer_top(num_layers_to_drop=1)

Remove layers from output of model.

Parameters:n (int) – Number of layers to remove from model output.
drop_layer_bottom(num_layers_to_drop=1)

Remove layers from input of model.

Parameters:n (int) – Number of layers to remove from model input.
add_layer_top(layer_list)

Add layer to output of model. model layers = (layer1, layer2, layer3), layer_list = [layerA, layerB] -> model layers = (layer1, layer2, layer3, layerA, layerB)

Parameters:layer_list (list(mx.symbol)) – List of MxNet symbol layers to be added to model output.
add_layer_bottom(layer_list)

Add layer to input of model. model layers = (layer1, layer2, layer3), layer_list = [layerA, layerB] -> model layers = (layerA, layerB, layer1, layer2, layer3)

Parameters:layer_list (list(mx.symbol)) – List of MxNet symbol layers to be added to model input.
get_module(iterator, fixed_layer_parameters=None, random_layer_parameters=None)

Return MXNet Module using the model symbol and parameters.

Parameters:
  • iterator (mxnet.io.DataIter) – MXNet iterator to be used with model.
  • fixed_layer_parameters (list(str)) – List of layer parameters to keep fixed.
  • random_layer_parameters (list(str)) – List of layer parameters to randomise.
Returns:

MXNet module

Return type:

mx.module.Module

get_layer_type(layer_name)

Return type of named layer.

Parameters:name (str) – Name of layer being inspected.
Returns:Layer type
Return type:str
get_layer_names_matching_type(layer_type)

Return names of layers of specified type.

Parameters:layer_type (str) – Return list of layers of this type.
Returns:Names of layers with specified type
Return type:list(str)
get_layer_output(data_iterator, layer_names)

Function to extract features from data iterator with model. Returns a dictionary of layer_name -> numpy array of features extracted (flattened).

Parameters:
  • data_iterator (mxnet.io.DataIter) – Iterator containing input data.
  • layer_names (list(str)) – List of names of layers to extract features from.
Returns:

Ordered Dictionary of features ({layer_name: features}), list of labels Layer names in the ordered dictionary follow the same order as input list of layer_names

Return type:

OrderedDict[str, numpy.array], list(int)

get_layer_parameters(layer_names)

Get list of layer parameters associated with the the layer names given.

Parameters:layer_names (list(str)) – List of layer names.
Returns:List of layer parameters
Return type:list(str)
visualize_net()

Display computational graph of model.

save_symbol(model_name)

Serialise model symbol graph.

Parameters:model_name (str) – Prefix to file name (model_name-symbol.json).
layer_names

Get list of names of model layers.

Returns:List of layer names
Return type:list[str]
update_sym(new_symbol)

Update symbol attribute, layer names, and layer types dict and clean parameters.

Parameters:new_symbol (mx.symbol.Symbol) – Symbol with which to update ModelHandler.