Writing a custom Repurposer

Xfer implements and supports two kinds of Repurposers:

  • Meta-model Repurposer - this uses the source model to extract features and then fits a meta-model to the features
  • Neural network Repurposer - this modifies the source model to create a target model

Below are examples of creating custom Repurposers for both classes

Setup

First import relevant modules, define data iterators and load a source model

In [1]:
import warnings
warnings.filterwarnings("ignore")

import logging
logging.disable(logging.WARNING)

import xfer

import os
import glob
import mxnet as mx
import random
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import classification_report

random.seed(1)
In [2]:
def get_iterators_from_folder(data_dir, train_size=0.6, batchsize=10, label_name='softmax_label', data_name='data', random_state=1):
    """
    Method to create iterators from data stored in a folder with the following structure:
    /data_dir
        /class1
            class1_img1
            class1_img2
            ...
            class1_imgN
        /class2
            class2_img1
            class2_img2
            ...
            class2_imgN
        ...
        /classN
    """
    # assert dir exists
    if not os.path.isdir(data_dir):
        raise ValueError('Directory not found: {}'.format(data_dir))
    # get class names
    classes = [x.split('/')[-1] for x in glob.glob(data_dir+'/*')]
    classes.sort()
    fnames = []
    labels = []
    for c in classes:
            # get all the image filenames and labels
            images = glob.glob(data_dir+'/'+c+'/*')
            images.sort()
            fnames += images
            labels += [c]*len(images)
    # create label2id mapping
    id2label = dict(enumerate(set(labels)))
    label2id = dict((v,k) for k, v in id2label.items())

    # get indices of train and test
    sss = StratifiedShuffleSplit(n_splits=2, test_size=None, train_size=train_size, random_state=random_state)
    train_indices, test_indices = next(sss.split(labels, labels))

    train_img_list = []
    test_img_list = []
    train_labels = []
    test_labels = []
    # create imglist for training and test
    for idx in train_indices:
        train_img_list.append([label2id[labels[idx]], fnames[idx]])
        train_labels.append(label2id[labels[idx]])
    for idx in test_indices:
        test_img_list.append([label2id[labels[idx]], fnames[idx]])
        test_labels.append(label2id[labels[idx]])

    # make iterators
    train_iterator = mx.image.ImageIter(batchsize, (3,224,224), imglist=train_img_list, label_name=label_name, data_name=data_name,
                                        path_root='')
    test_iterator = mx.image.ImageIter(batchsize, (3,224,224), imglist=test_img_list, label_name=label_name, data_name=data_name,
                                      path_root='')

    return train_iterator, test_iterator, train_labels, test_labels, id2label, label2id
In [3]:
dataset = 'test_images' # options are: 'test_sketches', 'test_images_sketch', 'mnist-50', 'test_images' or your own data.
num_classes = 4

train_iterator, test_iterator, train_labels, test_labels, id2label, label2id = get_iterators_from_folder(dataset, 0.6, 4, label_name='prob_label', random_state=1)
In [4]:
# Download vgg19 (trained on imagenet)
path = 'http://data.mxnet.io/models/imagenet/'
[mx.test_utils.download(path+'vgg/vgg19-0000.params'),
mx.test_utils.download(path+'vgg/vgg19-symbol.json')]
Out[4]:
['vgg19-0000.params', 'vgg19-symbol.json']
In [5]:
# This will be the source model we use for repurposing later
source_model = mx.module.Module.load('vgg19', 0, label_names=['prob_label'])

Custom Meta-model Repurposer

We will create a new Repurposer that uses the KNN algorithm as a meta-model. The resulting Meta-model Repurposer will classify the features extracted by the neural network source model.

In [6]:
from sklearn.neighbors import KNeighborsClassifier

Definition

In [7]:
class KNNRepurposer(xfer.MetaModelRepurposer):
    def __init__(self, source_model: mx.mod.Module, feature_layer_names, context_function=mx.context.cpu, num_devices=1,
                 n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=-1):
        # Call init() of parent
        super(KNNRepurposer, self).__init__(source_model, feature_layer_names, context_function, num_devices)

        # Initialise parameters specific to the KNN algorithm
        self.n_neighbors = n_neighbors
        self.weights = weights
        self.algorithm = algorithm
        self.leaf_size = leaf_size
        self.p = p
        self.metric = metric
        self.metric_params = metric_params
        self.n_jobs = n_jobs

    # Define function that takes a set of features and labels and returns a trained model.
    # feature_indices_per_layer is a dictionary which gives the feature indices which correspond
    # to each layer's features.
    def _train_model_from_features(self, features, labels, feature_indices_per_layer=None):
        lin_model = KNeighborsClassifier(n_neighbors=self.n_neighbors,
                                        weights=self.weights,
                                        algorithm=self.algorithm,
                                        leaf_size=self.leaf_size,
                                        p=self.p,
                                        metric=self.metric,
                                        metric_params=self.metric_params)
        lin_model.fit(features, labels)
        return lin_model

    # Define a function that predicts the class probability given features
    def _predict_probability_from_features(self, features):
        return self.target_model.predict_proba(features)

    # Define a function that predicts the class label given features
    def _predict_label_from_features(self, features):
        return self.target_model.predict(features)

    # In order to make your repurposer serialisable, you will need to implement functions
    # which convert your model's parameters to a dictionary.
    def get_params(self):
        """
        This function should return a dictionary of all the parameters of the repurposer that
        are in the repurposer constructor arguments.
        """
        param_dict = super().get_params()
        param_dict['n_neighbors'] = self.n_neighbors
        param_dict['weights'] = self.weights
        param_dict['algorithm'] = self.algorithm
        param_dict['leaf_size'] = self.leaf_size
        param_dict['p'] = self.p
        param_dict['metric'] = self.metric
        param_dict['metric_params'] = self.metric_params
        param_dict['n_jobs'] = self.n_jobs
        return param_dict

    # Some repurposers will need a get_attributes() and set_attributes() to get and set the parameters
    # of the repurposer that are not in the constructor argument. An example is shown below:

    # def get_attributes(self):
    #     """
    #     This function should return a dictionary of all the parameters of the repurposer that
    #     are NOT in the constructor arguments.
    #
    #     This function does not need to be defined if the repurposer has no specific attributes.
    #     """
    #     param_dict = super().get_attributes()
    #     param_dict['example_attribute'] = self.example_attribute
    #     return param_dict

    # def set_attributes(self, input_dict):
    #     super().set_attributes(input_dict)
    #     self.example_attribute  = input_dict['example_attribute']

    def serialize(self, file_prefix):
        """
        Saves repurposer (excluding source model) to file_prefix.json.
        This method converts the repurposer to dictionary and saves as a json.


        :param str file_prefix: Prefix to save file with
        """
        output_dict = {}
        output_dict[repurposer_keys.PARAMS] = self.get_params()
        output_dict[repurposer_keys.TARGET_MODEL] = target_model_to_dict()  # This should be some serialised representation of the target model
        output_dict.update(self.get_attributes())

        utils.save_json(file_prefix, output_dict)

    def deserialize(self, input_dict):
        """
        Uses dictionary to set attributes of repurposer

        :param dict input_dict: Dictionary containing values for attributes to be set to
        """
        self.set_attributes(input_dict)  # Set attributes of the repurposer from input_dict
        self.target_model = target_model_from_dict()  # Unpack dictionary representation of target model

Use

In [8]:
repurposerKNN = KNNRepurposer(source_model, ['fc8'])
In [9]:
repurposerKNN.repurpose(train_iterator)
In [10]:
results = repurposerKNN.predict_label(test_iterator)
In [11]:
print(classification_report(y_pred=results, y_true=test_labels))
             precision    recall  f1-score   support

          0       1.00      0.50      0.67         2
          1       0.67      1.00      0.80         2
          2       1.00      1.00      1.00         2
          3       1.00      1.00      1.00         2

avg / total       0.92      0.88      0.87         8

Custom Neural Network Repurposer

Now we will define a custom Neural Network Repurposer which performs transfer learning by:

  1. taking the original source neural network and keeping all layers up to transfer_layer_name
  2. adding two fully connected layers on the top
  3. fine-tuning with any conv layers frozen

Definition

In [12]:
class Add2FullyConnectedRepurposer(xfer.NeuralNetworkRepurposer):
    def __init__(self, source_model: mx.mod.Module, transfer_layer_name, num_nodes, target_class_count,
                 context_function=mx.context.cpu, num_devices=1, batch_size=64, num_epochs=5):
        super().__init__(source_model, context_function, num_devices, batch_size, num_epochs)

        # initialse parameters
        self.transfer_layer_name = transfer_layer_name
        self.num_nodes = num_nodes
        self.target_class_count = target_class_count

    def _get_target_symbol(self, source_model_layer_names):
        # Check if 'transfer_layer_name' is present in source model
        if self.transfer_layer_name not in source_model_layer_names:
            raise ValueError('transfer_layer_name: {} not found in source model'.format(self.transfer_layer_name))

        # Create target symbol by transferring layers from source model up to 'transfer_layer_name'
        transfer_layer_key = self.transfer_layer_name + '_output'  # layer key with output suffix to lookup mxnet symbol group
        source_symbol = self.source_model.symbol.get_internals()
        target_symbol = source_symbol[transfer_layer_key]
        return target_symbol

    # All Neural Network Repurposers must implement this function which takes a training iterator and returns an MXNet Module
    def _create_target_module(self, train_iterator: mx.io.DataIter):
        # Create model handler to manipulate the source model
        model_handler = xfer.model_handler.ModelHandler(self.source_model, self.context_function, self.num_devices)

        # Create target symbol by transferring layers from source model up to and including 'transfer_layer_name'
        target_symbol = self._get_target_symbol(model_handler.layer_names)

        # Update model handler by replacing source symbol with target symbol
        # and cleaning up weights of layers that were not transferred
        model_handler.update_sym(target_symbol)

        # Add a fully connected layer (with nodes equal to number of target classes) and a softmax output layer on top
        fully_connected_layer1 = mx.sym.FullyConnected(num_hidden=self.num_nodes, name='fc_rep')
        fully_connected_layer2 = mx.sym.FullyConnected(num_hidden=self.target_class_count, name='fc_from_fine_tune_repurposer')
        softmax_output_layer = mx.sym.SoftmaxOutput(name=train_iterator.provide_label[0][0].replace('_label', ''))
        model_handler.add_layer_top([fully_connected_layer1, fully_connected_layer2,  softmax_output_layer])

        # Get fixed layers
        conv_layer_names = model_handler.get_layer_names_matching_type('Convolution')
        conv_layer_params = model_handler.get_layer_parameters(conv_layer_names)

        # Create and return target mxnet module using the new symbol and params
        return model_handler.get_module(train_iterator, fixed_layer_parameters=conv_layer_params)

    # To be serialisable, Neural Network Repurposers require get_params, get_attributes, set_attributes as shown above

Use

In [13]:
# instantiate repurposer
repurposer2Fc = Add2FullyConnectedRepurposer(source_model, transfer_layer_name='fc7', num_nodes=64, target_class_count=num_classes)
In [14]:
train_iterator.reset()
repurposer2Fc.repurpose(train_iterator)
In [15]:
results = repurposer2Fc.predict_label(test_iterator)
In [16]:
print(classification_report(y_pred=results, y_true=test_labels))
             precision    recall  f1-score   support

          0       1.00      0.50      0.67         2
          1       1.00      1.00      1.00         2
          2       1.00      1.00      1.00         2
          3       0.67      1.00      0.80         2

avg / total       0.92      0.88      0.87         8