Binary Image Denoising

Jiaolong Xu (GitHub ID: Jiaolong)

This notebook is written during GSoC 2014. Thanks Shell Hu and Thoralf Klein for taking time to help me on this project!

This notebook illustrates how to use shogun structured output learning framework for binary images denoising. The task is defined as a pairwise factor graph model with Graph cuts inference, where model parameters are learned by SOSVM using a SGD solver.

Introduction

This notebook illustrates how to use shogun structured output learning framework for binary images denoising. I recommend [1] for a nice introduction of structured learning and prediction in computer vision. One of the founding publications on the topic of learning structured models might be [4]. In the following, I will give an explicit example of structured output prediction for binary image denoising.

Given a noise black/withe image $\textbf{x}$ of size $m \times n$, the task of denoising is to predicted the original binary image $\textbf{y}$. We flatten the image into a long vector $\textbf{y} = [y_0, \dots, y_{m\times n}]$, where $y_i \in \{0, 1\}$ is the value for pixel $i$. In this work, we aim to learn a model from a bunch of noise input images and their ground truth binary images, i.e., supervised learning. One may think about learning a binary classifier for each pixel. It has several drawbacks. First, we may not care about classifying every single pixel completely correct, i.e., if we misclassify a single pixel, this is not as bad as misclassify a whole image. Second, we lose all context, e.g., pairwise pixels (one pixel and its neighbore). The structured predition here is to predict an entire binary image $\textbf{y}$ or a grid graph of $m \times n$. Here, the output space $\mathcal{Y}$ is all possible binary images of size $m \times n$. It can be formulated as following:

$$ \hat{\textbf{y}} = \underset{\textbf{y} \in \mathcal{Y}}{\operatorname{argmax}} f(\textbf{x},\textbf{y}), (1) $$

where $f(\textbf{x},\textbf{y})$ is the compitibility function, measures how well $\textbf{y}$ fits $\textbf{x}$. There are basically three challenges in doing structured learning and prediction:

  • Choosing a parametric form of $f(\textbf{x},\textbf{y})$
  • solving $\underset{\textbf{y} \in \mathcal{Y}}{\operatorname{argmax}} f(\textbf{x},\textbf{y})$
  • learning parameters for $f(\textbf{x},\textbf{y})$ to minimize a loss

In this work, our parameters are pairwise and unary potentials and they can be written as:

$$ f(\textbf{x},\textbf{y}) = \sum_i \textbf{w}_i'\phi_i(\textbf{x}) + \sum_{i,j} \textbf{w}_{ij}'\phi_{ij}(\textbf{x}), (2) $$

where $\textbf{w}_i$ and $\textbf{w}_{ij}$ are unary and pairwise parameters, $\phi_i(\textbf{x})$ and $\phi_{ij}(\textbf{x})$ are unary and pairwise features respectively. Equation (2) is a linear function and can be written as a dot product of a global parameter $\textbf{w}$ and joint feature vector $\Phi(\textbf{x},\textbf{y})$, i.e., $f(\textbf{x},\textbf{y}) = \textbf{w}'\Phi(\textbf{x}, \textbf{y})$. The global parameter $\textbf{w}$ is a collection of unary and pairwise parameters. The joint feature $\Phi(\textbf{x}, \textbf{y})$ maps local features, e.g., pixel values from each location, to the corresponding location of the global feature vector according to $\textbf{y}$. In factor graph model, parameters are associated with a set of factor types.

As said before, the output space $\mathcal{Y}$ is usually finite but very large. In our case, it is all possible binary images of size $m \times n$. Finding ${\operatorname{argmax}}$ in such a large space by exhaustive search is not practical. To do the maximization over $\textbf{y}$ efficiently, the most popular tool is using energy functions or conditional random fields (CRFs). In this work, we implemented Graph cuts [5] for efficient inference. We also implemented max-product LP relaxation inference and tree max-product inference. However, the later is limited to tree-struct graph while for image denosing, we use grid graph.

The parameters are learned by regularized risk minimization, where the risk defined by user provided loss function $\Delta(\mathbf{y},\mathbf{\hat{y}})$. We use the Hamming loss in this experiment. The empirical risk is defined in terms of the surrogate hinge loss $\mathcal{L}_i(\mathbf{w}) = \max_{\mathbf{y} \in \mathcal{Y}} \Delta(\mathbf{y}_i,\mathbf{y}) - \mathbf{w}' [\Phi(\mathbf{x}_i,\mathbf{y}_i) - \Phi(\mathbf{x}_i,\mathbf{y})]$. The training objective is given by

$$ \min_{\mathbf{w}} \frac{\lambda}{2} ||\mathbf{w}||^2 + \frac{1}{N} \sum_{i=1}^N \mathcal{L}_i(\mathbf{w}). (3) $$

Create binary denoising dataset

In [1]:
import os
SHOGUN_DATA_DIR=os.getenv('SHOGUN_DATA_DIR', '../../../data')
import numpy as np
import numpy.random

We dfine an Example class for the training and testing examples.

In [2]:
class Example:
    """ Example class.
    
        Member variables:
        id: id of the example
        im_bw: original binary image
        im_noise: original image with noise
        feats: feature for each pixel
        labels: binary labels of each pixel
    """
    
    def __init__(self, id, im_bw, im_noise, feats, labels):
        self.id = id
        self.im_bw = im_bw
        self.im_noise = im_noise
        self.feature = feats
        self.labels = labels

In the following, we create a toy dataset. Similar to [2], we make random noisy images, then smooth them to make the true (discrete) output values.

In [3]:
import scipy.ndimage
# generate random binary images
im_size = np.array([50, 50], np.int32)
num_images = 30

ims_bw = []
for i in range(num_images):
        im_rand = np.random.random_sample(im_size)
        im_bw = np.round(scipy.ndimage.gaussian_filter(im_rand, sigma=3))
        ims_bw.append(im_bw)

Next, noises are added to the binary images. We apply the same strategy as in [3], the noisy images are generated as $z_i = x_i(1-t_i^n) + (1-x_i)t_i^n$, where $x_i$ is the true binary label, and $t_i \in [0,1]$ is a random value. Here, $n \in (1, \infty)$ is the noise level, where lower values correspond to more noise.

In this experiment, we use only two features as unary features: a constant of $1$ and the noisy input value at the pixel, i.e., $\textbf{u}(i) = [z_i, 1]$.

In [4]:
# define noisy level
noise_level = 2

# initialize an empty list
example_list = []

for i in range(len(ims_bw)):
    im_bw = ims_bw[i]
    # add noise to the binary image
    t = np.random.random_sample(im_bw.shape)
    im_noise = im_bw*(1-t**noise_level) + (1-im_bw)*(t**noise_level)
    
    # create 2-d unary features
    c1 = np.ravel(im_noise)
    c2 = np.ones(im_noise.size, np.int32)
    feats = np.column_stack([c1, c2])
    # we use pixel-level labels
    # so just flatten the original binary image into a vector
    labels = np.ravel(im_bw)
    example = Example(i, im_bw, im_noise, feats, labels)
    example_list.append(example)

Now we creat a function to visualize our examples.

In [5]:
import matplotlib.pyplot as plt
%matplotlib inline
In [6]:
def plot_example(example):
    """ Plot example."""
    
    fig, plots = plt.subplots(1, 2, figsize=(12, 4))
    plots[0].matshow(example.im_bw, cmap=plt.get_cmap('Greys'))
    plots[0].set_title('Binary image')
    plots[1].matshow(example.im_noise, cmap=plt.get_cmap('Greys'))
    plots[1].set_title('Noise image')
    
    for p in plots:
        p.set_xticks(())
        p.set_yticks(())

    plt.show()
In [7]:
# plot an example
plot_example(example_list[9])

Build Factor Graph Model

In [8]:
from modshogun import Factor, TableFactorType, FactorGraph
from modshogun import FactorGraphObservation, FactorGraphLabels, FactorGraphFeatures
from modshogun import FactorGraphModel, GRAPH_CUT, LP_RELAXATION
from modshogun import MAPInference

We define a 'make_grid_edges' function to compute the indeces of the pairwise pixels. we use grid graph with neighborhood size of $4$ in our experiment.

In [9]:
def make_grid_edges(grid_w, grid_h, neighborhood=4):
    """ Create grid edge lists.
    
        Args:
            grid_w: width of the grid
            grid_h: height of the grid
            neigborhood: neigborhood of the node (4 or 8)
        
        Returns:
            edge list of the grid graph
    """
    if neighborhood not in [4, 8]:
        raise ValueError("neighborhood can only be '4' or '8', got %s" % repr(neighborhood))
    inds = np.arange(grid_w * grid_h).reshape([grid_w, grid_h])
    inds = inds.astype(np.int64)
    right = np.c_[inds[:, :-1].ravel(), inds[:, 1:].ravel()]
    down = np.c_[inds[:-1, :].ravel(), inds[1:, :].ravel()]
    edges = [right, down]
    if neighborhood == 8:
        upright = np.c_[inds[1:, :-1].ravel(), inds[:-1, 1:].ravel()]
        downright = np.c_[inds[:-1, :-1].ravel(), inds[1:, 1:].ravel()]
        edges.extend([upright, downright])
        
    return np.vstack(edges)
In [10]:
# in this experiment, we use fixed image size
im_w = example_list[0].im_bw.shape[1]
im_h = example_list[0].im_bw.shape[0]
# we compute the indeces of the pairwise nodes
edge_list = make_grid_edges(im_w, im_h)

For binary denosing, we define two types of factors:

  • unary factor: the unary factor type is used to define unary potentials that captures the the appearance likelyhood of each pixel. We use very simple unary feature in this experiment, the pixel value and a constant value $1$. As we use binary label, thus the size of the unary parameter is $4$.
  • pairwise factor: the pairwise factor type is used to define pairwise potentials between each pair of pixels. There features of the pairwise factors are constant $1$ and there are no additional edge features. For the pairwise factors, there are $2 \times 2$ parameters.

Putting all parameters together, the global parameter vector $\mathbf{w}$ has length $8$.

In [11]:
def define_factor_type(num_status, dim_feat):
    """ Define factor type.

        Args:
            num_status: number of status.
            dim_feat: dimention of the unary node feature

        Returns:
            ftype_unary: unary factor type
            ftype_pair: pairwise factor type
    """

    # unary, type id = 0
    cards_u = np.array([num_status], np.int32) # cardinalities
    w_u = np.zeros(num_status*dim_feat, np.float64)
    ftype_unary = TableFactorType(0, cards_u, w_u)

    # pairwise, type id = 1
    cards_p = np.array([num_status, num_status], np.int32)
    w_p = np.zeros(num_status*num_status, np.float64)
    ftype_pair = TableFactorType(1, cards_p, w_p)

    return ftype_unary, ftype_pair
In [12]:
# define factor types
ftype_unary, ftype_pair = define_factor_type(num_status=2, dim_feat=2)
In [13]:
def prepare_factor_graph_model(example_list, ftype_unary, ftype_pair, edge_list, num_status = 2, dim_feat = 2):
    """ Prepare factor graph model data.

        Args:
            example_list: the examples
            num_status: number of status
            dim_feat: dimention of the unary features
    """

    num_samples = len(example_list)
    # Initialize factor graph features and labels
    feats_fg = FactorGraphFeatures(num_samples)
    labels_fg = FactorGraphLabels(num_samples)

    # Interate over all the examples
    for i in range(num_samples):
        example = example_list[i]
        feats = example.feature
        num_var = feats.shape[0]
        dim_feat = feats.shape[1]
        # Initialize factor graph
        cards = np.array([num_status]*num_var, np.int32) # cardinalities
        fg = FactorGraph(cards)

        # add unary
        for u in range(num_var):
            data_u = np.array(feats[u,:], np.float64)
            inds_u = np.array([u], np.int32)
            factor_u = Factor(ftype_unary, inds_u, data_u)
            fg.add_factor(factor_u)
        
        # add pairwise
        for p in range(edge_list.shape[0]):
            data_p = np.array([1.0], np.float64)
            inds_p = np.array(edge_list[p,:], np.int32)
            factor_p = Factor(ftype_pair, inds_p, data_p)
            fg.add_factor(factor_p)
        
        # add factor graph feature
        feats_fg.add_sample(fg)
        # add factor graph label
        labels = example.labels.astype(np.int32)
        assert(labels.shape[0] == num_var)
        loss_weight = np.array([1.0/num_var]*num_var)
        f_obs = FactorGraphObservation(labels, loss_weight)
        labels_fg.add_label(f_obs)

    return feats_fg, labels_fg

We split the samples into training and testing sets. The features and labels are converted for factor graph model.

In [14]:
num_train_samples = 10

examples_train = example_list[:num_train_samples]
examples_test = example_list[num_train_samples:]

# create features and labels for factor graph mode
(feats_train, labels_train) = prepare_factor_graph_model(examples_train, ftype_unary, ftype_pair, edge_list)
(feats_test, labels_test) = prepare_factor_graph_model(examples_test, ftype_unary, ftype_pair, edge_list)

In this experiment, we use Graph cuts as approximate inference algorithm, i.e., solve Eq. (1). Please refer to [4] for a comprehensive understanding.

In [15]:
# inference algorithm
infer_alg = GRAPH_CUT
#infer_alg = LP_RELAXATION
In [16]:
# create model and register factor types
model = FactorGraphModel(feats_train, labels_train, infer_alg)
model.add_factor_type(ftype_unary)
model.add_factor_type(ftype_pair)

Learning parameter with structured output SVM

We apply (StochasticSOSVM) to learn the parameter $\textbf{w}$.

In [17]:
from modshogun import StochasticSOSVM
import time

# Training with Stocastic Gradient Descent
sgd = StochasticSOSVM(model, labels_train, True, True)
sgd.set_num_iter(300)
sgd.set_lambda(0.0001)

# train
t0 = time.time()
sgd.train()
t1 = time.time()
w_sgd = sgd.get_w()
print "SGD took", t1 - t0, "seconds."
SGD took 178.910168171 seconds.
In [18]:
def evaluation(labels_pr, labels_gt, model):
    """ Evaluation

        Args:
            labels_pr: predicted label
            labels_gt: ground truth label
            model: factor graph model

        Returns:
            ave_loss: average loss
    """

    num_train_samples = labels_pr.get_num_labels()
    acc_loss = 0.0
    ave_loss = 0.0
    for i in range(num_train_samples):
        y_pred = labels_pr.get_label(i)
        y_truth = labels_gt.get_label(i)
        acc_loss = acc_loss + model.delta_loss(y_truth, y_pred)

    ave_loss = acc_loss / num_train_samples

    return ave_loss
In [19]:
# training error
labels_train_pr = sgd.apply()
ave_loss = evaluation(labels_train_pr, labels_train, model)
print('SGD: Average training error is %.4f' % ave_loss)
SGD: Average training error is 0.0756
In [20]:
def plot_primal_trainError(sosvm, name = 'SGD'):
    """ Plot primal objective values and training errors."""
    
    primal_val = sosvm.get_helper().get_primal_values()
    train_err = sosvm.get_helper().get_train_errors()
    
    fig, plots = plt.subplots(1, 2, figsize=(12,4))
    
    # primal vs passes
    plots[0].plot(range(primal_val.size), primal_val, label=name)
    plots[0].set_xlabel('effecitve passes')
    plots[0].set_ylabel('primal objective')
    plots[0].set_title('whole training progress')
    plots[0].legend(loc=1)
    plots[0].grid(True)
    
    # training error vs passes
    plots[1].plot(range(train_err.size), train_err, label=name)
    plots[1].set_xlabel('effecitve passes')
    plots[1].set_ylabel('training error')
    plots[1].set_title('effective passes')
    plots[1].legend(loc=1)
    plots[1].grid(True) 
In [21]:
# plot primal objective values and training errors at each pass
plot_primal_trainError(sgd)

Testing results

In [22]:
# Testing error
sgd.set_features(feats_test)
sgd.set_labels(labels_test)

labels_test_pr = sgd.apply()
ave_loss = evaluation(labels_test_pr, labels_test, model)
print('SGD: Average testing error is %.4f' % ave_loss)
SGD: Average testing error is 0.0774
In [23]:
def plot_results(example, y_pred):
    """ Plot example."""
    
    im_pred = y_pred.reshape(example.im_bw.shape)
    
    fig, plots = plt.subplots(1, 3, figsize=(12, 4))
    
    plots[0].matshow(example.im_noise, cmap=plt.get_cmap('Greys'))
    plots[0].set_title('noise input')
    plots[1].matshow(example.im_bw, cmap=plt.get_cmap('Greys'))
    plots[1].set_title('ground truth labels')
    plots[2].matshow(im_pred, cmap=plt.get_cmap('Greys'))
    plots[2].set_title('predicted labels')
    
    for p in plots:
        p.set_xticks(())
        p.set_yticks(())

    plt.show()
In [24]:
import matplotlib.pyplot as plt
%matplotlib inline

# plot one example
i = 8

# get predicted output
y_pred = FactorGraphObservation.obtain_from_generic(labels_test_pr.get_label(i)).get_data()

# plot results
plot_results(examples_test[i], y_pred)
In [25]:
def plot_results_more(examples, labels_pred, num_samples=10):
    """ Plot example."""
    
    fig, plots = plt.subplots(num_samples, 3, figsize=(12, 4*num_samples))
    
    for i in range(num_samples):
        example = examples[i]
        
        # get predicted output
        y_pred = FactorGraphObservation.obtain_from_generic(labels_pred.get_label(i)).get_data()
        im_pred = y_pred.reshape(example.im_bw.shape)
    
        plots[i][0].matshow(example.im_noise, cmap=plt.get_cmap('Greys'))
        plots[i][0].set_title('noise input')
        plots[i][0].set_xticks(())
        plots[i][0].set_yticks(())
        
        plots[i][1].matshow(example.im_bw, cmap=plt.get_cmap('Greys'))
        plots[i][1].set_title('ground truth labels')
        plots[i][1].set_xticks(())
        plots[i][1].set_yticks(())
        
        plots[i][2].matshow(im_pred, cmap=plt.get_cmap('Greys'))
        plots[i][2].set_title('predicted labels')
        plots[i][2].set_xticks(())
        plots[i][2].set_yticks(())

    plt.show()
In [26]:
plot_results_more(examples_test, labels_test_pr, num_samples=5)

Reference

[1] Nowozin, S., & Lampert, C. H. Structured learning and prediction in computer vision. Foundations and Trends® in Computer Graphics and Vision, 6(3–4), 185-365, 2011.

[3] Justin Domke, Learning Graphical Model Parameters with Approximate Marginal Inference, IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 35, no. 10, pp. 2454-2467, 2013.

[4] Tsochantaridis, I., Hofmann, T., Joachims, T., Altun, Y., Support Vector Machine Learning for Interdependent and Structured Ouput Spaces, ICML 2004.

[5] Boykov, Y., Veksler, O., & Zabih, R. Fast approximate energy minimization via graph cuts. Pattern Analysis and Machine Intelligence, IEEE Transactions on, 23(11), 1222-1239, 2001.