Practical 4: Regular Group Convolutions

Open notebook: View on Github Open In Collab
Authors: David Knigge

Introduction

In this notebook, we will be implementing regular group convolutional networks from scratch, only making use of pytorch primitives. The goal is to get familiar with the practical considerations to take into account when actually implementing these convolutional networks.

You will be asked to fill in some gaps yourself. The pieces of code you are expected to fill in are surrounded by demarcations, like so:

a = 1
b = 2

# We calculate c = a + b

### YOUR CODE STARTS HERE ###
c = ...
### AND ENDS HERE ###

And you’d be expected to fill in something along the lines of c = a + b for c = .... We’ve included some assertions to test for correctness of resulting tensor shapes.

We’ve also scattered some questions throughout the notebook to test your understanding of the implementation and concepts used. Please include your answers to these questions (and any other observations you deem relevant!) in your report.

If you’d like a refresher of the lecture, here we give a brief overview of the operations we are going to work with / implement. These will be treated more extensively below. For simplicity of notation, here we assume each CNN layer consists of only a single channel. Questions and feedback may be forwarded to David Knigge; d.m.knigge@uva.nl.

Brief recap on CNNs

Conventional CNNs make use of the convolution operator, here defined over \(\mathbb{R}^2\) for a signal \(f:\mathbb{R}^2 \rightarrow \mathbb{R}\) and a kernel \(k: \mathbb{R}^2 \rightarrow \mathbb{R}\) at \(\mathbf{x}\in \mathbb{R}^2\):

\[(f * k) (\mathbf{x}) = \int_{\mathbb{R}^2} f(\tilde{\mathbf{x}})k(\tilde{\mathbf{x}} - \mathbf{x}) \text{d}\tilde{\mathbf{x}},\]

As we can see, the convolution operation comes down to an inner product of the function \(f\) and a shifted kernel \(k\).

Sidenote: In reality CNNs implement a discretised version of this operation;
\[\begin{split}\begin{aligned} (f*k) (\mathbf{x}) &= \sum_{\mathbf{\tilde{x}} \in \mathbb{Z}^2} f(\mathbf{\tilde{x}})k(\mathbf{x}-\mathbf{\tilde{x}})\Delta\mathbf{\tilde{x}}\\ &= \sum_{\mathbf{\tilde{x}} \in \mathbb{Z}^2} f(\mathbf{\tilde{x}})k(\mathbf{x}-\mathbf{\tilde{x}}) \end{aligned}\end{split}\]
Where above, since pixels in an image are generally evenly spaced, we set \(\Delta \mathbf{\tilde{x}}=1\). For this recap we stay in the continuous domain for simplicity.

In convolution layers, like PyTorch’s Conv2D implementation, the above operation is carried out for every \(\mathbf{x} \in \mathbb{Z}^2\) (limited of course to the domain over which the image is defined). Because the same set of weights is used throughout the input, the output of this operation is equivariant to transformations from the translation group \(\mathbb{R}^2\). Furthermore \(f,k\) usually consist of a number of channels, which are all summed over.

In this tutorial, we will use PyTorch’s torch.nn.functional.conv2d() function to perform this integration operation at every position in input feature map. This saves us having to implement the convolution operation ourselves.

Brief recap on GCNNs

In regular group convolutions, the goal is to have a CNN of which are not only equivariant to translations \(\mathbb{R}^2\), but which are also equivariant to another (usually broader) group of interest \(G\). We focus specifically on groups which are combinations of translations \(\mathbb{R}^2\) and some group of interest \(H\). In this tutorial we keep to the group of 90 degree rotations in 2D; the Cyclic group of order 4 \(H=C_4\).

We will operate on 2D images, generally defined on \(\mathbb{R}^2\), as such, the first step in constructing a network which can track under which pose (read: transformation from a group \(G=\mathbb{R}^2\rtimes H\)) a feature in the input occurs, we need to transfer our signal to a domain in which the same feature under a different pose is disentangled. This happens through the lifting convolution, which maps features in our input signal \(f_{in}:\mathbb{R}^2\rightarrow \mathbb{R}\) to a feature map on the group \(f_{out}:G\rightarrow \mathbb{R}\). For a signal and kernel \(f,k\) both defined on \(\mathbb{R}^2\), and a group element \(g=(\boldsymbol{x}, h) \in G=\mathbb{R}^2 \rtimes H\):

\[(f *_{\text{lifting}} k) (g) = \int_{\mathbb{R}^2} f(\tilde{\mathbf{x}})k_h(\tilde{\mathbf{x}} - \mathbf{x}) \,{\rm d}\tilde{\mathbf{x}}.\]

Where \(k_h\) is the kernel \(k:\mathbb{R}^2 \rightarrow \mathbb{R}\) transformed under the regular representation \(\mathcal{L}_h\) of a group element \(h \in H\); \(k_h = \frac{1}{| h|}\mathcal{L}_{h}[k]\).

Sidenote: The factor \(\frac{1}{| h|}\), with \(|h|\) the determinant of the matrix representation of \(h\) in \(\mathbb{R}^2\), accounts for a possible change in volume on \(\mathbb{R}^2\) that \(h\) might have. Working with the cyclic group, we don’t encounter this problem (the determinant of a rotation matrix is 1, volumes are invariant to rotations on \(\mathbb{R}^2\)), but if you’d like to implement equivariance to for example the dilation group, this becomes important.

Next, now that we have a feature map defined on the group; \(f_{out}:G\rightarrow \mathbb{R}\), we apply group convolutions, extending the convolution operation to an integral over the entire group \(G\);

\[\begin{split}\begin{aligned} (f *_{\mathrm{group}} k) (g) &=\int_G f(\tilde{g})k(g^{-1} \cdot \tilde{g}) {\rm d}\tilde{g} \\ &=\int_{\mathbb{R}^2}\int_H f(\tilde{\mathbf{x}}, \tilde{h})\mathcal{L}_{x}\mathcal{L}_{h}k(\tilde{\mathbf{x}}, \tilde{h})\dfrac{1}{|h|} \,{\rm d}\mathbf{\tilde{x}}\,{\rm d}\tilde{h}\\ &=\int_{\mathbb{R}^2}\int_H f(\tilde{\mathbf{x}},\tilde{h})k({h^{-1}}(\tilde{\mathbf{x}}-\mathbf{x}), h^{-1}\cdot \tilde{h})\dfrac{1}{|h|} \,{\rm d}\mathbf{\tilde{x}}\,{\rm d}\tilde{h}. \end{aligned}\end{split}\]

The main difference with the lifting convolution is that the signal and kernel \(f,k\) are both functions on \(G\);\(G\rightarrow \mathbb{R}\), and the integral reflects this by extending over the entire group \(G\). Other than that, there is little difference!

After a number of such group convolutional layers, we will want to ultimately obtain a representation that is invariant to the group action. We can do this by performing a projection which collapses our function defined over \(G\) to a single point, with an operation that is invariant to the group action (summing, averaging, max, min).

After this short refresher, let’s get to coding!

Installing and importing some useful packages

Here we install and import some libraries that we will use throughout this tutorial. We use the pytorch as our deep learning framework of choice. Note that for ease of model training and tracking, we additionally make use of pytorch-lightning.

[1]:
## Standard libraries
import os
import numpy as np
import math
from PIL import Image
from types import SimpleNamespace
from functools import partial

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0

## PyTorch
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
## Torchvision
import torchvision
from torchvision.datasets import MNIST
from torchvision import transforms
## PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip3 install pytorch-lightning>=1.6 --quiet
    import pytorch_lightning as pl
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
[2]:
# Path to the folder where the datasets are be downloaded (e.g. MNIST)
DATASET_PATH = "../data"
os.makedirs(DATASET_PATH, exist_ok=True)
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/practical4"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

As an example image, we will use the popular image of paprikas. If you are working on GoogleColab, you need to download this image, which we will do below:

[3]:
import urllib.request
from urllib.error import HTTPError
# Github URL where the image is stored
base_url = "https://raw.githubusercontent.com/phlippe/asci_cbl_practicals/main/assets/"
# Files to download
files = ["paprika.tiff"]

# For each file, check whether it already exists. If not, try downloading it.
for file_name in files:
    file_path = os.path.join(DATASET_PATH, file_name)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file manually from the Github, or ask one of your TAs with the following error message:\n", e)

Part 1: Group theory

1.1 What is a group?

To start off, we recap some of the group theoretical preliminaries. Recall that a group is defined by a tuple \((G, \cdot)\), where \(G\) is a set of group elements and \(\cdot\) the binary group action which tells us how elements \(g \in G\) combine. The group action \(\cdot\) needs to satisfy:

  1. Closure. \(G\) is closed under \(\cdot\); for all \(g_1, g_2 \in G\) we have \(g_1 \cdot g_2 \in G\).

  2. Identity. There exists an identity element \(e\) s.t. for each \(g \in G\), we have \(e \cdot g = g \cdot e = g\).

  3. Inverse. For every element \(g \in G\) we have an element \(g^{-1} \in G\), s.t. \(g \cdot g^{-1} = e\).

  4. Associativity. For every set of elements \(g_1, g_2, g_3 \in G\), we have (\(g_1 \cdot g_2) \cdot g_3 = g_1 \cdot (g_2 \cdot g_3)\).

The group can have an action on functions defined on \(\mathbb{R}^2\), which we can instantiate through the regular representation \(\mathcal{L}_g^{\mathbb{G}\rightarrow \mathbb{R}^2}\). For simplicity, we write \(\mathcal{L}_g\). It is given by:

\[\mathcal{L}_g f (\mathbf{x}) = f(g^{-1} \cdot \mathbf{x})\]

Where we write the action of \(g^{-1}\) on \(x\) as \(g^{-1}\cdot \mathbf{x}\). This is where the regular group convolution gets its name; because of its use of the regular representation to transform the kernels \(k\) used throughout the network.

1.2 Implementing a group in python

Let’s start out with a baseclass in which we outline which functions we are going to need when working with groups in our setting. As we’re going to use torch in our implementation of group convolutional neural networks, let us implement the group as a torch module as well.

We first specify a base class GroupBase in which we specify all necessary properties and operations that we need in our treatment of group convolutional neural networks. The idea is that in our implementation of group convolutions, implementing these functions is necessary and sufficient for extending group convolutional neural networks to other groups. In other words; if you’d like to implement group convolutions equivariant to a new group you find interesting, just inherit this baseclass, implement its methods, and you’re good to go. (In practice this only faithfully works for discrete, compact groups).

[ ]:
class GroupBase(torch.nn.Module):

    def __init__(self, dimension, identity):
        """ Implements a group.

        @param dimension: Dimensionality of the group (number of dimensions in the basis of the algebra).
        @param identity: Identity element of the group.
        """
        super().__init__()
        self.dimension = dimension
        self.register_buffer('identity', torch.Tensor(identity))

    def elements(self):
        """ Obtain a tensor containing all group elements in this group.

        """
        raise NotImplementedError()

    def product(self, h, h_prime):
        """ Defines group product on two group elements.

        @param g1: Group element 1
        @param g2: Group element 2
        """
        raise NotImplementedError()

    def inverse(self, h):
        """ Defines inverse for group element.

        @param g: A group element.
        """
        raise NotImplementedError()

    def left_action_on_R2(self, h_batch, x_batch):
        """ Group action of an element from the subgroup H on a vector in R2. For efficiency we
        implement this batchwise.

        @param h_batch: Group elements from H.
        @param x_batch: Vectors in R2.
        """
        raise NotImplementedError()

    def left_action_on_H(self, h_batch, h_prime_batch):
        """ Group action of elements of H on other elements in H itself. Comes down to group product.
        For efficiency we implement this batchwise. Each element in h_batch is applied to each element
        in h_prime_batch.

        @param h_batch: Group elements from H.
        @param h_prime_batch: Other group elements in H.
        """
        raise NotImplementedError()

    def matrix_representation(self, h):
        """ Obtain a matrix representation in R^2 for an element h.

        @param h: Group element
        """
        raise NotImplementedError()

    def determinant(self, h):
        """ Calculate the determinant of the representation of a group element
        h.

        @param g:
        """
        raise NotImplementedError()

    def normalize_group_elements(self, h):
        """ Map the group elements to an interval [-1, 1]. We use this to create
        a standardized input for obtaining weights over the group.

        @param g:
        """
        raise NotImplementedError()

1.3 Implementing the cyclic group \(\rm C_4\)

As an example, let’s discuss a relatively simple group; the group of all \(90°\) rotations of the plane, otherwise known as the cyclic group \(\rm C_4\). Note:

  • The set of group elements of \(C_4\) is given by \(G := \{ e, g, g^2, g^3\}\). We can parameterise these group elements using rotation angles \(\theta\), i.e. \(e=0, g=\frac{1}{2}\pi, g^2 = \pi\), …

  • The group product is then given by \(g \cdot g':= \theta + \theta' \mod 2 \pi\).

  • The inverse is given by: \(g^{-1} = -\theta \mod 2\pi\).

  • The group \(C_4\) has an action on the euclidean plane in 2 dimensions \(\mathbb{R}^2\) given by a rotation matrix;

    \[\begin{split}R_{\theta}: \begin{bmatrix} \cos(\theta) & -\sin(\theta) \\ \sin(\theta) & \cos(\theta) \end{bmatrix}.\end{split}\]

    This gives us the regular representation \(\mathcal{L}_\theta\) on functions \(f\) defined over \(\mathbb{R}^2\):

    \[\mathcal{L}_{\theta} f(\mathbf{x}) = f(R_{-\theta\mod2\pi}\mathbf{x}).\]

Let’s implement this group!

[ ]:
class CyclicGroup(GroupBase):

    def __init__(self, order):
        super().__init__(
            dimension=1,
            identity=[0.]
        )

        assert order > 1
        self.order = torch.tensor(order)

    def elements(self):
        """ Obtain a tensor containing all group elements in this group.

        """
        return torch.linspace(
            start=0,
            end=2 * np.pi * float(self.order - 1) / float(self.order),
            steps=self.order,
            device=self.identity.device
        )

    def product(self, h1, h2):
        """ Defines group product on two group elements of the cyclic group C4.

        @param h1: Group element 1
        @param h2: Group element 2
        """

        # As we directly parameterize the group by its rotation angles, this
        # will be a simple addition. Don't forget the closure property though!

        ## YOUR CODE STARTS HERE ##
        product = ...
        ## AND ENDS HERE ##

        return product

    def inverse(self, h):
        """ Defines group inverse for an element of the cyclic group C4.

        @param h: Group element
        """

        # Implement the inverse operation. Keep the closure property in mind!

        ## YOUR CODE STARTS HERE ##
        inverse = ...
        ## AND ENDS HERE ##

        return inverse

    def left_action_on_R2(self, batch_h, batch_x):
        """ Group action of an element g on a set of vectors in R2.

        @param batch_h: Tensor of group elements. [num_elements]
        @param batch_x: Tensor of vectors in R2. [2, spatial_x, spatial_y]
        """
        # Create a tensor containing representations of each of the group
        # elements in the input. Creates a tensor of size [batch_size, 2, 2].

        ## YOUR CODE STARTS HERE ##
        batched_rep = ...
        ## AND ENDS HERE ##

        # Transform the r2 input grid with each representation to end up with
        # a transformed grid of dimensionality [num_group_elements, spatial_x,
        # spatial_y, 2]. Note the order of the dimensions!

        # Recall that we are working with a left-regular representation,
        # meaning we transform vectors in R^2 through left-matrix multiplication.

        ## YOUR CODE STARTS HERE ##
        out = torch.einsum(...)
        ## AND ENDS HERE ##

        # Afterwards (because grid_sample assummes our grid is y,x instead of x,y)
        # we swap x and y coordinate values with a roll along final dimension.
        return out.roll(shifts=1, dims=-1)

    def left_action_on_H(self, batch_h, batch_h_prime):
        """ Group action of an element h on a set of group elements in H.
        Nothing more than a batchwise group product.

        @param batch_h: Tensor of group elements.
        @param batch_h_prime: Tensor of group elements to apply group product to.
        """
        # The elements in batch_h work on the elements in batch_h_prime directly,
        # through the group product. Each element in batch_h is applied to each element
        # in batch_h_prime.
        transformed_batch_h = self.product(batch_h.repeat(batch_h_prime.shape[0], 1),
                                           batch_h_prime.unsqueeze(-1))
        return transformed_batch_h

    def matrix_representation(self, h):
        """ Obtain a matrix representation in R^2 for an element h.

        @param h: A group element.
        """
        ## YOUR CODE STARTS HERE ##
        representation = ...
        ## AND ENDS HERE ##

        return representation.to(self.identity.device)

    def normalize_group_elements(self, h):
        """ Normalize values of group elements to range between -1 and 1.
        The group elements range from 0 to 2pi * (self.order - 1) / self.order,
        so we normalize by

        @param h: A group element.
        """
        largest_elem = 2 * np.pi * (self.order - 1) / self.order

        return (2*h / largest_elem) - 1.
[ ]:
# Some tests to verify our implementation.
c4 = CyclicGroup(order=4)
e, g1, g2, g3 = c4.elements()

assert c4.product(e, g1) == g1 and c4.product(g1, g2) == g3
assert c4.product(g1, c4.inverse(g1)) == e

assert torch.allclose(c4.matrix_representation(e), torch.eye(2))
assert torch.allclose(c4.matrix_representation(g2), torch.tensor([[-1, 0], [0, -1]]).float(), atol=1e-6)

assert torch.allclose(c4.left_action_on_R2([g1], torch.tensor([[[0.]], [[1.]]])), torch.tensor([[[0., -1.]]]), atol=1e-7)

1.4 Visualizing the group action

Let’s play around with the group implementation we have just created! To obtain pixel values for the transformed image, we use pytorch’s grid_sample function (see documentation).

[ ]:
img = Image.open(os.path.join(DATASET_PATH, "paprika.tiff"))

img_tensor = transforms.ToTensor()(img)

img
[ ]:
# This creates a grid of the pixel locations in our image
img_grid_R2 = torch.stack(torch.meshgrid(
    torch.linspace(-1, 1, img_tensor.shape[-1]),
    torch.linspace(-1, 1, img_tensor.shape[-2]),
))

# [2, 512, 512] since our image is 2 dimensional and has a width and height of
# 512 pixels
img_grid_R2.shape
[ ]:
# let's create the group of 90 degree clockwise rotations
c4 = CyclicGroup(order=4)
e, g1, g2, _ = c4.elements()
[ ]:
# Create a counterclockwise rotation of 270 degrees using only e, g1 and g2.

## YOUR CODE STARTS HERE ##
g3 = ...
## AND ENDS HERE ##

assert g3 == c4.elements()[-1]

# Transform the image grid we just created with the matrix representation of
# this group element. Note that we implemented this batchwise, so we add a dim.
transformed_grid = c4.left_action_on_R2(c4.inverse(g3).unsqueeze(0), img_grid_R2)

As we’ll be using it extensively throughout this tutorial, let’s take a closer look at the grid_sample. From the pytorch docs:

Currently, only spatial (4-D) and volumetric (5-D) input are supported.

In the spatial (4-D) case, for input with shape \((N,C,H_{in},W_{in})\) and grid with shape \((N,H_\text{out},W_\text{out},2)\), the output will have shape \((N,C,H_\text{out},W_\text{out})\).

Parameters:

input (Tensor) – input of shape \((N, C, H_\text{in}, W_\text{in})\) (4-D case) or \((N, C, D_\text{in}, H_\text{in}, W_\text{in})\) (5-D case).

For now, we’re working in the 4-D setting. We’re going to transform a single image with a single transformation, so in our case \(N=1\), the image has 3 channels so \(C=3\), and the height and width we just saw were both \(H_{\rm in}=W_{\rm in}=512\).

grid (Tensor) – flow-field of shape \((N, H_\text{out}, W_\text{out}, 2)\) (4-D case) or \((N, D_\text{out}, H_\text{out}, W_\text{out}, 3)\) (5-D case)

Remember, this shape is what we matched transformed_grid in C4.left_action_on_R2 to. Currently, we thus expect the transformed grid to be of shape \((1, 512, 512, 2)\).

[ ]:
transformed_grid.shape
[ ]:
# Let's set up the sampling method with some fixed parameters that we'll use throughout.
grid_sample = partial(
    torch.nn.functional.grid_sample,
    padding_mode='zeros',
    align_corners=True,
    mode="bilinear"
)
[ ]:
# This function samples an input tensor based on a grid using interpolation.
# It is implemented for batchwise operations so we add a dimension to our and input image.
transformed_img = grid_sample(img_tensor.unsqueeze(0), transformed_grid)

# if we turn this back into a PIL image we can see the result of our transformation!
transforms.ToPILImage()(transformed_img[0])

Part 2: Group Equivariant Convolutional Networks

As discussed in the lecture, regular group convolutional neural networks consist of three main elements. The lifting convolution, group convolution and projection operation. We treat these in order.

2.1 Lifting convolution

First is the lifting convolution, which disentangles features at any spatial location in the input \(f_{in}\) under transformations of \(H\). You may think of this as registering at all locations, for a given feature \(e\), the occurrences of transformed versions of this feature \(\mathcal{L}_h(e)\), for \(h \in H\). (Instead of \(\mathcal{L}_h(e)\), we sometimes write \(h \cdot e\) to denote the action of \(h\) on \(e\).) The lifting convolution thus maps from \(\mathbb{R}^2\) to \(G= \mathbb{R}^2 \rtimes H\). As a result, our lifted feature map \(f_{out}\) has, besides the usual spatial dimensions, one or more additional group dimensions (dependent on the dimensionality of \(H\)).

intuitionsepgconvs.png

For example, take the group of 90 deg rotations as \(H\), and let’s say \(e\) is a squiggle of sorts, of which we have three occurences in our input feature map \(f_{in}\). Two are under a 90 deg rotation; \(\theta_{90}\cdot e\), and one is under its canonical orientation; \(\theta_0 \cdot e\). A lifting convolution with a kernel \(k\) which exactly matches the feature \(e\) would land responses at different offsets along the group dimension of the feature map; namely one at the spatial feature map corresponding to the group elements \(\theta_0\) and two at the spatial feature map corresponding to \(\theta_{90}\). See the above figure for an intuition.

2.1.1 Overview

How do we get our convolution operation to pick up features under different transformations of \(H\)? Intuitively, it is not that different from the convolution operation as we are familiar with in CNNs. There, we enable the extraction of features at any location by sharing the same kernel over all spatial positions.

gconvs-convolution.drawio (1).png

From our current group theoretic perspective, we interpret this as applying all possible translations \(\boldsymbol{x} \in \mathbb{R}^2\) to our kernel \(k\) and recording the response we get when we take the inner product of the input \(f_{in}\) with this transformed kernel \(\mathcal{L}_{\mathbf{x}} (k)\). This operation starts with a feature map defined on \(\mathbb{R}^2\) and also yields a feature map defined over \(\mathbb{R}^2\). See above.

liftingconv.png

Now that we additionally want to register features under different group actions \(\mathcal{L}_h\) for \(h \in H\), we can do so by simply also transforming \(k\) with all of their group actions and recording the results. For example, in case of the rotation group \({\rm C_4}\), we not only translate, but additionally rotate the kernel \(k\) by all possible 90 deg rotations, and record the responses for the resulting transformed kernels!

2.1.2 Implementing the lifting convolution kernel

Let’s get to programming. First, we need to define a kernel \(k\) which we can transform under arbitrary group actions with \(\mathcal{L}_h\). When working with images, a convolution kernel is generally defined as a set of independently sampled weights \(W\) defined over an equidistant discretisation of \(\mathbb{R}^2\) (pixels are evenly spaced).

Recall that we can express the group action of a group \(H\) on functions (such as kernels \(k\)) defined over \(\mathbb{R}^2\) through the regular representation \(\mathcal{L}_h\). The regular representation transforms the function \(k\) through a transformation of the domain of the function \(k\). In other words, the regular representation transforms the grid over which the kernel \(k\) is defined, to obtain the values for the transformed function \(\mathcal{L}_h(k)\).

As such, to define a kernel \(k\) which we can transform with the regular representation of a group \(H\), we need to construct a grid over which the kernel values are defined. We can then transform this grid by the action of each group element \(h \in H\) to obtain a set of grids corresponding to transformed kernels for each of the group elements of \(H\). Let’s get to work!

Notes:

  • In implementing the actual lifting and group convolution operations, we will make use of PyTorch’s Conv2D class. This simplifies our life a lot, since Conv2D takes cares of translating the kernels \(k\) over all input locations. Hence we do not need to implement the action of the translation group ourselves (\(\mathcal{L}_{\mathbf{x}}\)), but will still remain translation equivariant! Making our operations compatible with Conv2D requires a small amount of trickery, but we will get to that later.

[ ]:
class LiftingKernelBase(torch.nn.Module):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        """ Implements a base class for the lifting kernel. Stores the R^2 grid
        over which the lifting kernel is defined and it's transformed copies
        under the action of a group H.

        """
        super().__init__()
        self.group = group

        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Create spatial kernel grid. These are the coordinates on which our
        # kernel weights are defined.
        self.register_buffer("grid_R2", torch.stack(torch.meshgrid(
            torch.linspace(-1, 1, self.kernel_size),
            torch.linspace(-1, 1, self.kernel_size),
        )).to(self.group.identity.device))

        # Transform the grid by the elements in this group.
        self.register_buffer("transformed_grid_R2", self.create_transformed_grid_R2())

    def create_transformed_grid_R2(self):
        """Transform the created grid by the group action of each group element.
        This yields a grid (over H) of spatial grids (over R2). In other words,
        a list of grids, each index of which is the original spatial grid transformed by
        a corresponding group element in H.

        """
        # Obtain all group elements.

        ## YOUR CODE STARTS HERE ##
        group_elements = ...
        ## AND ENDS HERE ##

        # Transform the grid defined over R2 with the sampled group elements.
        # Recall how the left-regular representation acts on the domain of a
        # function on R2! (Hint: look closely at the equation given under 1.3)

        ## YOUR CODE STARTS HERE ##
        transformed_grid = ...
        ## AND ENDS HERE ##

        return transformed_grid


    def sample(self, sampled_group_elements):
        """ Sample convolution kernels for a given number of group elements

        arguments should include:
        :param sampled_group_elements: the group elements over which to sample
            the convolution kernels

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        raise NotImplementedError()

[ ]:
# Let's check whether our implementation works correctly. First we inspect the
# shape of our transformed grids to assess whether this is correct.
lifting_kernel_base = LiftingKernelBase(
    group=CyclicGroup(order=4),
    kernel_size=2,
    in_channels=1,
    out_channels=1
)

assert lifting_kernel_base.transformed_grid_R2.shape == torch.Size([4, 2, 2, 2])
[ ]:
# Let's visualize the transformed kernel grids!
lifting_kernel_base = LiftingKernelBase(
    group=CyclicGroup(order=4),
    kernel_size=7,
    in_channels=1,
    out_channels=1
)

transformed_grid_R2 = lifting_kernel_base.transformed_grid_R2

# The grid has a shape of [num_group_elements, kernel_size, kernel_size, dim_fmap_domain(R^2)]
transformed_grid_R2.shape
[ ]:
plt.rcParams['figure.figsize'] = [12, 3]

# Create [group_elements] figures
fig, ax = plt.subplots(1, transformed_grid_R2.shape[0])

# Fold both spatial dimensions into a single dimension
transformed_grid_R2 = transformed_grid_R2.reshape(transformed_grid_R2.shape[0],
                                                  transformed_grid_R2.shape[1]*transformed_grid_R2.shape[2],
                                                  2).numpy()

# Visualize the transformed kernel grids. We mark the same cornerpoint by a blue 'x' in all grids as reference point.
for group_elem in range(transformed_grid_R2.shape[0]):
    ax[group_elem].scatter(transformed_grid_R2[group_elem, 0, 0], transformed_grid_R2[group_elem, 0, 1], marker='x', c='b')
    ax[group_elem].scatter(transformed_grid_R2[group_elem, 1:, 0], transformed_grid_R2[group_elem, 1:, 1], c='r')

fig.text(0.5, 0.04, 'Group elements', ha='center')
plt.show()

If your code is correctly implemented, you should see the counter-clockwise rotation action happening!

At this point we have a set of grids transformed under the operation of the group \(H\). We now need to decide how we are going to sample kernel values at the grid points in each of these grids. This is where the first major hurdle in applying GCNNs occurs.

Whereas conventional CNNs get away with sharing the same set of weights over all spatial positions (which is due to the fact that we translate the kernel only with steps that match whole pixel-distances), for arbitrary groups \(H\) we may require kernel values for grid points that lie off the pixel grid of the kernel under its canonical transformation.

Of course, we could (and in fact will) use interpolation to obtain kernel values for grid locations between pixel locations, but this may be limiting in expressivity and will introduce interpolation artefacts!

Notes:

  • When we are implementing equivariance for the group of 90 deg rotations, \(H=C_4\), we of course could get away without using interpolation, since all transformed grids lie share the same locations. We could implement the group action of this particular group through a permutation of the weights. But we’d like to be more general than that in this tutorial, so we’ll go with interpolation instead! Luckily PyTorch has a function that allows us to sample an input on a grid; we will use PyTorch’s grid_sample function for interpolation!

[ ]:
class InterpolativeLiftingKernel(LiftingKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        # Create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels.
        self.weight = torch.nn.Parameter(torch.zeros((
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        ), device=self.group.identity.device))

        # Initialize weights using kaiming uniform intialisation.
        torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))

    def sample(self):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # torch grid_sample function.

        # Next, we want a transformed set of weights for each group element so
        # we repeat the set of spatial weights along the output group axis.

        ## YOUR CODE STARTS HERE ##
        weight = ...
        ## AND ENDS HERE ##

        # Check whether the weight has the expected shape.
        assert weight.shape == torch.Size((
            self.group.elements().numel(),
            self.out_channels * self.in_channels,
            self.kernel_size,
            self.kernel_size
        ))

        # Sample the transformed kernels.
        transformed_weight = grid_sample(
            weight,
            self.transformed_grid_R2
        )

        # Separate input and output channels.
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(),
            self.out_channels,
            self.in_channels,
            self.kernel_size,
            self.kernel_size
        )

        # Put the output channel dimension before the output group dimension.
        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight
[ ]:
ik = InterpolativeLiftingKernel(
    group=CyclicGroup(order=4),
    kernel_size=5,
    in_channels=2,
    out_channels=1
)

weights = ik.sample()

Let’s visualize the weights we sampled from our lifting convolution kernel!

[ ]:
plt.rcParams['figure.figsize'] = [10, 5]

# Pick an output channel to visualize
out_channel_idx = 0

# Create [in_channels, group_elements] figures
fig, ax = plt.subplots(weights.shape[2], weights.shape[1])

for in_channel in range(weights.shape[2]):
    for group_elem in range(weights.shape[1]):
        ax[in_channel, group_elem].imshow(
            weights[out_channel_idx, group_elem, in_channel, :, :].detach().numpy()
        )

fig.text(0.5, 0.04, 'Group elements', ha='center')
fig.text(0.04, 0.5, 'Input channels', va='center', rotation='vertical')

plt.show()

As we can see, spatial kernel rotates under the action of the rotation group elements!

2.1.3 Implementing the lifting convolution

Finally, we can implement the lifting convolution operation! This class should take as input a feature map defined over \(\mathbb{R}^2\), and spit out a feature map over \(\mathbb{R}^2\rtimes H\), where features under different transformations \(h \in H\) are disentangled along the \(H\) axis!

Notes:

  • To prevent having to implement our own implementation of the convolution operation, and to leverage the highly optimized pytorch Conv2D class, we use some neat tricks in our lifting (and group) convolution classes. Normally, a convolution layer applies a set of \(n\) spatial kernels throughout the input, where \(n\) is the number of output channels of the convolution operation. Because we now also have num_group_elem transformed versions of these kernels, which we want to apply everywhere in the input, we can trick PyTorch by having it treat every transformation of the same spatial kernel as a separate output channel. To do this, we simply reshape our set of [out_channels, num_group_elem, in_channels, kernel_size, kernel_size] kernels into a set of [out_channels x num_group_elem, in_channels, kernel_size, kernel_size] kernels. See below!

  • As mentioned, a great additional benefit of using the PyTorch Conv2D class is that we are not required to obtain translated kernels \(\mathcal{L}_{\mathbf{x}}(k)\) ourselves, PyTorch takes care of that!

[ ]:
class LiftingConvolution(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size):
        super().__init__()

        self.kernel = InterpolativeLiftingKernel(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )

    def forward(self, x):
        """ Perform lifting convolution

        @param x: Input sample [batch_dim, in_channels, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """

        # Obtain convolution kernels transformed under the group.

        ## YOUR CODE STARTS HERE ##
        conv_kernels = ...
        ## AND ENDS HERE ##

        # Apply lifting convolution. Note that using a reshape we can fold the
        # group dimension of the kernel into the output channel dimension. We
        # treat every transformed kernel as an additional output channel. This
        # way we can use pytorch's conv2d function!

        # Question: Do you see why we (can) do this?

        ## YOUR CODE STARTS HERE ##
        x = torch.nn.functional.conv2d(
            input=x,
            weight=conv_kernels.reshape(
                ...
            ),
        )
        ## AND ENDS HERE ##

        # Reshape [batch_dim, in_channels * num_group_elements, spatial_dim_1,
        # spatial_dim_2] into [batch_dim, in_channels, num_group_elements,
        # spatial_dim_1, spatial_dim_2], separating channel and group
        # dimensions.
        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2]
        )

        return x

[ ]:
lifting_conv = LiftingConvolution(
    group=CyclicGroup(order=4),
    kernel_size=5,
    in_channels=3,
    out_channels=8
)

2.2 Group convolution

Now that we have a way to obtain feature maps defined over the group \(G\) from input functions defined over \(\mathbb{R}^2\), let’s move on to implementing a convolutional layer that fully operates on the group \(G: \mathbb{R}^2 \rtimes H\). Note that the input feature map at this stage, \(f_{in}\), has, besides the usual spatial dimensions defined over \(\mathbb{R}^2\), one or more additional group dimensions defined over \(H\). As such, the group convolution operation, \(*_{group}\), maps from a function on the group \(f_{in}\) to another function on the group \(f_{out}\).

2.2.1 Overview

Since the input to the group convolution layer now contains additional group dimensions; \(f_{in}: \mathbb{R}^2 \rtimes H \rightarrow \mathbb{R}\), we need to convolve it with a kernel \(k_{group}\) that is also defined over the entire group; \(k_{group}: \mathbb{R}^2 \rtimes H \rightarrow \mathbb{R}\). This is in contrast to the lifting convolution, where \(k_{lifting}\) was only defined over the spatial domain \(\mathbb{R}^2\). You could think of this kernel \(k_{group}\) as a stack of spatial kernels, a separate one for each group element \(h \in H\). Importantly, since \(k\) is now also defined on a grid over \(H\), the group \(H\) now also has an action on this \(H\)-axis of \(k\). For example, in our case of \(H=C_4\), elements \(\theta \in C_4\) now not only have a rotating action on the spatial domain of the kernel, but also a translating action along the group axis. Hence, applying a group element \(\theta \in C_4\) leads to a twist-shift of the \(k\) along the group axis. See below! (Remember \(H=C_4\) is periodic)

gconvs-Page-9.drawio (2).png

Apart from this difference, the group convolution operator works in much the same way as the liting operator. We again transform the kernel \(k_{group}\) with the actions of the group \(H\) and \(\mathbb{R}^2\) and obtain responses for the inner product of this kernel with the input (although again, we let PyTorch do all the work for the translation group). See below for an intuition.

groupconv.png

2.2.2 Implementing the group convolution kernel

Again, let’s define a kernel \(k\) which we can transform with the group action. Now, our kernel grid is not only defined over \(\mathbb{R}^2\), but additionally over \(H\).

Notes:

  • Since the grid over \(H\) is made up of elements \(h'\in H\), transforming the grid over \(H\) with (another) group element \(h \in H\) comes out to applying the group product of \(h\) with each grid element \(h'\).

  • Because we are working with semidirect product groups \(\mathbb{R}^2 \rtimes H\), we can transform the \(\mathbb{R}^2\) and \(H\) dimensions of our grids separately before combining them into a shared grid over \(\mathbb{R}^2 \rtimes H\)!

[ ]:
class GroupKernelBase(torch.nn.Module):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        """ Implements base class for the group convolution kernel. Stores grid
        defined over the group R^2 \rtimes H and it's transformed copies under
        all elements of the group H.

        """
        super().__init__()
        self.group = group

        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Create a spatial kernel grid
        self.register_buffer("grid_R2", torch.stack(torch.meshgrid(
            torch.linspace(-1, 1, self.kernel_size),
            torch.linspace(-1, 1, self.kernel_size),
        )).to(self.group.identity.device))

        # The kernel grid now also extends over the group H, as our input
        # feature maps contain an additional group dimension
        self.register_buffer("grid_H", self.group.elements())

        self.register_buffer("transformed_grid_R2xH", self.create_transformed_grid_R2xH())

    def create_transformed_grid_R2xH(self):
        """Transform the created grid over R^2 \rtimes H by the group action of
        each group element in H.

        This yields a set of grids over the group. In other words, a list of
        grids, each index of which is the original grid over G transformed by
        a corresponding group element in H.
        """
        # Sample the group H.

        ## YOUR CODE STARTS HERE ##
        group_elements = ...
        ## AND ENDS HERE ##

        # Transform the grid defined over R2 with the sampled group elements.

        ## YOUR CODE STARTS HERE ##
        transformed_grid_R2 = ...
        ## AND ENDS HERE ##

        # Transform the grid defined over H with the sampled group elements.

        ## YOUR CODE STARTS HERE ##
        transformed_grid_H = ...
        ## AND ENDS HERE ##

        # Rescale values to between -1 and 1, we do this to please the torch
        # grid_sample function.
        transformed_grid_H = self.group.normalize_group_elements(transformed_grid_H)

        # Create a combined grid as the product of the grids over R2 and H
        # repeat R2 along the group dimension, and repeat H along the spatial dimension
        # to create a [output_group_elem, num_group_elements, kernel_size, kernel_size, 3] grid
        transformed_grid = torch.cat(
            (
                transformed_grid_R2.view(
                    group_elements.numel(),
                    1,
                    self.kernel_size,
                    self.kernel_size,
                    2,
                ).repeat(1, group_elements.numel(), 1, 1, 1),
                transformed_grid_H.view(
                    group_elements.numel(),
                    group_elements.numel(),
                    1,
                    1,
                    1,
                ).repeat(1, 1, self.kernel_size, self.kernel_size, 1, )
            ),
            dim=-1
        )
        return transformed_grid


    def sample(self, sampled_group_elements):
        """ Sample convolution kernels for a given number of group elements

        arguments should include:
        :param sampled_group_elements: the group elements over which to sample
            the convolution kernels

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        raise NotImplementedError()

Let’s get some intuition for what is happening with our grid when we apply the group action of \(H\) to it. First we will inspect the action on \(\mathbb{R}^2\).

[ ]:
group_kernel_base = GroupKernelBase(
    group=CyclicGroup(order=4),
    kernel_size=7,
    in_channels=1,
    out_channels=1
)

# Sample the group.
group_elements = group_kernel_base.group.elements()

# Transform the grid defined over R2 with the sampled group elements.
transformed_grid_R2 = group_kernel_base.group.left_action_on_R2(
    group_kernel_base.group.inverse(group_elements),
    group_kernel_base.grid_R2
)
[ ]:
plt.rcParams['figure.figsize'] = [10, 3]

# Create [group_elements] figures.
fig, ax = plt.subplots(1, transformed_grid_R2.shape[0])

# Fold both spatial dimensions into a single dimension.
transformed_grid_R2 = transformed_grid_R2.reshape(transformed_grid_R2.shape[0],
                                                  transformed_grid_R2.shape[1]*transformed_grid_R2.shape[2],
                                                  2).numpy()

# Visualize the transformed kernel grids. We mark the same cornerpoint by a blue 'x' in all grids as reference point.
for group_elem in range(transformed_grid_R2.shape[0]):
    ax[group_elem].scatter(transformed_grid_R2[group_elem, 0, 0], transformed_grid_R2[group_elem, 0, 1], marker='x', c='b')
    ax[group_elem].scatter(transformed_grid_R2[group_elem, 1:, 0], transformed_grid_R2[group_elem, 1:, 1], c='r')

fig.text(0.5, 0.04, 'Group elements', ha='center')

plt.show()

As we can see, this part of the grid, and what happens to it, is identical to the grid we saw in the lifting convolution. Note however, that this is only the spatial dimensions of the grid over which a group convolution kernel is defined. Let’s move on to the grid over \(H\).

[ ]:
plt.rcParams['figure.figsize'] = [5.5, 3]
# Transform the grid defined over H with the sampled group elements.
transformed_grid_H = group_kernel_base.group.left_action_on_H(
    group_kernel_base.group.inverse(group_elements), group_kernel_base.grid_H
)

# Create [group_elements] figures.
fig, ax = plt.subplots(1, transformed_grid_H.shape[0])

# Visualize the transformed kernel grids. We mark the same cornerpoint by a blue 'x' in all grids as reference point.
for group_elem in range(transformed_grid_H.shape[0]):
    ax[group_elem].scatter(torch.zeros_like(transformed_grid_H[group_elem, 0]), transformed_grid_H[group_elem, 0], marker='x', c='b')
    ax[group_elem].scatter(torch.zeros_like(transformed_grid_H[group_elem, 1:]), transformed_grid_H[group_elem, 1:], c='r')
    ax[group_elem].set_xticks([])

fig.text(0.5, 0.04, 'Group elements', ha='center')

# Remove the xticks, our group is 1D!
plt.show()

In our current setting, \(H\) is one-dimensional, and as we can see, transforming the grid over \(H\) with all group elements of \(H\) leads to a translation over the group. Next, let’s see what happens when we combine these grids.

[ ]:
plt.rcParams['figure.figsize'] = [10, 3]
transformed_grid_R2xH = group_kernel_base.transformed_grid_R2xH

# Create [group_elements] figures.
fig, ax = plt.subplots(1, transformed_grid_H.shape[0], subplot_kw=dict(projection='3d'))

# Flatten spatial and group grid dimensions.
transformed_grid_R2xH = transformed_grid_R2xH.reshape(transformed_grid_R2xH.shape[0],
                                                      transformed_grid_R2xH.shape[1] * transformed_grid_R2xH.shape[2] * transformed_grid_R2xH.shape[3],
                                                      transformed_grid_R2xH.shape[4])

# Visualize the transformed kernel grids. We mark the same row by a blue 'x' in all grids as reference point.
for group_elem in range(transformed_grid_R2xH.shape[0]):
    ax[group_elem].scatter(transformed_grid_R2xH[group_elem, 0:7, 0],
                           transformed_grid_R2xH[group_elem, 0:7, 1],
                           transformed_grid_R2xH[group_elem, 0:7, 2],
                           marker='x',
                           c='b')
    ax[group_elem].scatter(transformed_grid_R2xH[group_elem, 7:, 0],
                           transformed_grid_R2xH[group_elem, 7:, 1],
                           transformed_grid_R2xH[group_elem, 7:, 2],
                           c='r')

fig.text(0.5, 0.04, 'Group elements', ha='center')

plt.show()

As we can see, under the application of group elements \(h' \in H\) the grid defined over \(\mathbb{R}^2 \rtimes H\) rotates over the spatial dimensions, and shifts along the group dimension!

Let’s now implement the group kernel using interpolation as well.

Notes:

  • Luckily, grid_sample also supports 3D inputs, so we can continue using it.

  • For multidimensional groups \(H\) the following implementation won’t work, as this would require kernels defined on grids with dimensionality > 3, which grid_sample does not support. To resolve this, one could implement the sampling of the weights along the \(H\) dimension using a translation of the weight matrix along the \(H\) dimensions, and only interpolate over the spatial dimensions. This is possible because we don’t end up between grid points along the group dimension \(H\) (remember the closure constraint of the group product?).

  • The \(C_4\) group exhibits periodicity along the group axis, so our kernels should too. Although we correctly implemented the group product to reflect this, grid_sample doesn’t know about the periodicity of the weights in its interpolation. This shouldn’t be a problem since, again because of the closure constraint, we should always end up exactly on grid points along the group axis, meaning no interpolation is necessary in that direction. In practice, because of the way grid_sample is implemented, we may encouter some small interpolation artefacts because of this.

[ ]:
class InterpolativeGroupKernel(GroupKernelBase):

    def __init__(self, group, kernel_size, in_channels, out_channels):
        super().__init__(group, kernel_size, in_channels, out_channels)

        # Create and initialise a set of weights, we will interpolate these
        # to create our transformed spatial kernels. Note that our weight
        # now also extends over the group H.

        ## YOUR CODE STARTS HERE ##
        self.weight = torch.nn.Parameter(torch.zeros((
            ...
        ), device=self.group.identity.device))
        ## AND ENDS HERE ##

        # initialize weights using kaiming uniform intialisation.
        torch.nn.init.kaiming_uniform_(self.weight.data, a=math.sqrt(5))

    def sample(self):
        """ Sample convolution kernels for a given number of group elements

        should return:
        :return kernels: filter bank extending over all input channels,
            containing kernels transformed for all output group elements.
        """
        # First, we fold the output channel dim into the input channel dim;
        # this allows us to transform the entire filter bank in one go using the
        # torch grid_sample function.

        # Next, we want a transformed set of weights for each group element so
        # we repeat the set of spatial weights along the output group axis.

        ## YOUR CODE STARTS HERE ##
        weight = ...
        ## AND ENDS HERE ##

        assert weight.shape == torch.Size((
            self.group.elements().numel(),
            self.out_channels * self.in_channels,
            self.group.elements().numel(),
            self.kernel_size,
            self.kernel_size
        ))

        # Sample the transformed kernels using the grid_sample function.
        transformed_weight = grid_sample(
            weight,
            self.transformed_grid_R2xH,
        )

        # Separate input and output channels. Note we now have a notion of
        # input and output group dimensions in our weight matrix!
        transformed_weight = transformed_weight.view(
            self.group.elements().numel(), # Output group elements (like in the lifting convolution)
            self.out_channels,
            self.in_channels,
            self.group.elements().numel(), # Input group elements (due to the additional dimension of our feature map)
            self.kernel_size,
            self.kernel_size
        )

        # Put the output channel dimension before the output group dimension.
        transformed_weight = transformed_weight.transpose(0, 1)

        return transformed_weight
[ ]:
ik = InterpolativeGroupKernel(
    group=CyclicGroup(order=4),
    kernel_size=5,
    in_channels=2,
    out_channels=8
)
[ ]:
weights = ik.sample()
weights.shape

Let’s visualize the sampled group convolution kernels! We visualize our 3D kernels in 2D by folding the input group dimension into the first spatial dimension. In doing so, we create a 2D flattened version of the 3D group convolution kernel, where spatial kernels corresponding to the different group elements lie along the spatial dimension. Each channel goes from [num_group_elem, kernel_size, kernel_size] to [num_group_elem x kernel_size, kernel_size].

To clearly see what happens to the group convolution kernel under transformation of the group \(H\), we outline the spatial kernel corresponding to the first input group element in red. For subsequent transformations we can see this spatial kernel. See below!

[ ]:
plt.rcParams['figure.figsize'] = [10, 10]

# For ease of viewing, we fold the input group dimension into the spatial x dimension
weights_t = weights.view(
    weights.shape[0],
    weights.shape[1],
    weights.shape[2],
    weights.shape[3] * weights.shape[4],
    weights.shape[5]
)

# pick an output channel to visualize
out_channel_idx = 0

# create [in_channels, group_elements] figures
fig, ax = plt.subplots(weights.shape[2], weights.shape[1])

for in_channel in range(weights.shape[2]):
    for group_elem in range(weights.shape[1]):
        ax[in_channel, group_elem].imshow(
            weights_t[out_channel_idx, group_elem, in_channel, :, :].detach()
        )

        # Outline the spatial kernel corresponding to the first group element under canonical transformation
        rect = matplotlib.patches.Rectangle(
            (-0.5, group_elem * weights_t.shape[-1] - 0.5), weights_t.shape[-1], weights_t.shape[-1], linewidth=5, edgecolor='r', facecolor='none')
        ax[in_channel, group_elem].add_patch(rect)

fig.text(0.5, 0.04, 'Group elements', ha='center')
fig.text(0.04, 0.5, 'Input channels / input group elements', va='center', rotation='vertical')

plt.show()

We see the same twist shift motion as we saw with the kernel grids!

2.2.3 Implementing the group convolution

The next step is implementing the group convolution operation.

Notes:

  • We would still like to use PyTorch’s Conv2D implementation, but we’re now faced with an additional problem; the group dimension in the input feature map. Luckily we can resolve this problem in much the same way; normally a 2D convolution layer integrates over a local neighbourhood of all input channels. We would now additionally like to integrate over the entire group. Thus, we can simply treat the group dimensions in the input feature map as additional channel dimensions! We achieve this by folding our input group dimension into the input channel dimension; \(f_{in}\) is reshaped from [batch, in_channels, num_group_elem, spatial_1, spatial_2] into [batch, in_channels x num_group_elem, spatial_1, spatial_2].

  • To match this, and to apply the same trick as we did in the lifting convolution to get results for each separate group element in the output, we also reshape our kernel from [out_channels, num_group_elem, in_channels, num_group_elem, kernel_size, kernel_size] to [out_channels x num_group_elem, in_channels x num_group_elem, kernel_size, kernel_size]. See below!

[ ]:
class GroupConvolution(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size):
        super().__init__()

        self.kernel = InterpolativeGroupKernel(
            group=group,
            kernel_size=kernel_size,
            in_channels=in_channels,
            out_channels=out_channels
        )

    def forward(self, x):
        """ Perform group convolution

        @param x: Input sample [batch_dim, in_channels, group_dim, spatial_dim_1,
            spatial_dim_2]
        @return: Function on a homogeneous space of the group
            [batch_dim, out_channels, num_group_elements, spatial_dim_1,
            spatial_dim_2]
        """

        # We now fold the group dimensions of our input into the input channel
        # dimension.

        ## YOUR CODE STARTS HERE ##
        x = x.reshape(
            ...
        )
        ## AND ENDS HERE ##

        # We obtain convolution kernels transformed under the group.

        ## YOUR CODE STARTS HERE ##
        conv_kernels = ...
        ## AND ENDS HERE ##

        # Apply group convolution, note that the reshape folds the 'output' group
        # dimension of the kernel into the output channel dimension, and the
        # 'input' group dimension into the input channel dimension.

        # Question: Do you see why we (can) do this?

        ## YOUR CODE STARTS HERE ##
        x = ...
        ## AND ENDS HERE ##

        # Reshape [batch_dim, in_channels * num_group_elements, spatial_dim_1,
        # spatial_dim_2] into [batch_dim, in_channels, num_group_elements,
        # spatial_dim_1, spatial_dim_2], separating channel and group
        # dimensions.
        x = x.view(
            -1,
            self.kernel.out_channels,
            self.kernel.group.elements().numel(),
            x.shape[-1],
            x.shape[-2],
        )

        return x

2.3 Projection to obtain invariance and tying everything together

Up until now, our feature maps equivary with the group action of \(\mathbb{R}^2 \rtimes H\); our feature maps are defined over \(\mathbb{R}^2 \rtimes H\). Usually in a CNN, a series of convolutional layers build a representation, which is followed by a (number of) linear layer(s). To create a GCNN using our lifting and group convolution operations that is fully invariant to the action of the group, we must apply a projection operation invariant to the action of the group to our feature map, to reduce its dimensionality from [batch, channels, num_group_elem, spatial_1, spatial_2] to [batch, channels] or even . The representation that we obtain then is fully invariant to the group. This representation is pushed through a final linear layer to yield a classification.

Below, we build a small GCNN from our implemented PyTorch modules.

Notes:

  • We use a mean-pooling operation to pool over group and spatial dimensions, but we could also use max or min pooling, or any other operation invariant to the group.

[ ]:
from torch.nn import AdaptiveAvgPool3d


class GroupEquivariantCNN(torch.nn.Module):

    def __init__(self, group, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()

        # Create the lifing convolution.

        ## YOUR CODE STARTS HERE ##
        self.lifting_conv = ...
        ## AND ENDS HERE ##

        # Create a set of group convolutions.
        self.gconvs = torch.nn.ModuleList()

        ## YOUR CODE STARTS HERE ##
        for i in range(num_hidden):
            self.gconvs.append(
                ...
            )
        ## AND ENDS HERE ##

        # Create the projection layer. Hint: check the import at the top of
        # this cell.

        ## YOUR CODE STARTS HERE ##
        self.projection_layer = ...
        ## AND ENDS HERE ##

        # And a final linear layer for classification.
        self.final_linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):

        # Lift and disentangle features in the input.
        x = self.lifting_conv(x)
        x = torch.nn.functional.layer_norm(x, x.shape[-4:])
        x = torch.nn.functional.relu(x)

        # Apply group convolutions.
        for gconv in self.gconvs:
            x = gconv(x)
            x = torch.nn.functional.layer_norm(x, x.shape[-4:])
            x = torch.nn.functional.relu(x)

        # to ensure equivariance, apply max pooling over group and spatial dims.
        x = self.projection_layer(x).squeeze()

        x = self.final_linear(x)
        return x

To compare, let’s create a more or less identical CNN. The only difference here is that this network consists of regular convolution operations.

[ ]:
class CNN(torch.nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, num_hidden, hidden_channels):
        super().__init__()

        self.first_conv = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size
        )

        self.convs = torch.nn.ModuleList()
        for i in range(num_hidden):
            self.convs.append(
                torch.nn.Conv2d(
                    in_channels=hidden_channels,
                    out_channels=hidden_channels,
                    kernel_size=kernel_size
                )
            )

        self.final_linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x):

        x = self.first_conv(x)
        x = torch.nn.functional.layer_norm(x, x.shape[-3:])
        x = torch.nn.functional.relu(x)

        for conv in self.convs:
            x = conv(x)
            x = torch.nn.functional.layer_norm(x, x.shape[-3:])
            x = torch.nn.functional.relu(x)

        # Apply average pooling over remaining spatial dimensions.
        x = torch.nn.functional.adaptive_avg_pool2d(x, 1).squeeze()

        x = self.final_linear(x)
        return x

Part 3: Experimenting with our implementation

Note that for ease of model training and tracking, we additionally make use of pytorch-lightning as in the previous practicals.

3.1 Generalization to the group action

To show the generalization capabilities of regular group convolutional networks, we will train this model on the MNIST training dataset, but evaluate it on an augmented version of the MNIST test set in which each image is randomly rotated by a continuous rotation between \([0, 2\pi]\).

[ ]:
# We normalize the training data.
train_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                                  ])

# To demonstrate the generalization capabilities our rotation equivariant layers bring, we apply a random
# rotation between 0 and 360 deg to the test set.
test_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.RandomRotation(
                                                     [0, 360],
                                                     torchvision.transforms.InterpolationMode.BILINEAR,
                                                     fill=0),
                                                 torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                                 ])

# We demonstrate our models on the MNIST dataset.
train_ds = torchvision.datasets.MNIST(root=DATASET_PATH, train=True, transform=train_transform, download=True)
test_ds = torchvision.datasets.MNIST(root=DATASET_PATH, train=False, transform=test_transform)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64, shuffle=False)

# Set the random seed for reproducibility.
pl.seed_everything(12)

Let’s visualize some of the training and test images. As we can see the test images are randomly rotated.

[ ]:
NUM_IMAGES = 4
images = [train_ds[idx][0] for idx in range(NUM_IMAGES)]
orig_images = [Image.fromarray(train_ds.data[idx].numpy()) for idx in range(NUM_IMAGES)]
orig_images = [test_transform(img) for img in orig_images]

img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8,8))
plt.title("Images sampled from the MNIST train set, augmented with test transforms.")
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()
[ ]:
class DataModule(pl.LightningModule):

    def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
        """
        Inputs:
            model_name - Name of the model/CNN to run. Used for creating the model (see function below)
            model_hparams - Hyperparameters for the model, as dictionary.
            optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD
            optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc.
        """
        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        # Create model
        self.model = create_model(model_name, model_hparams)
        # Create loss module
        self.loss_module = nn.CrossEntropyLoss()

    def forward(self, imgs):
        return self.model(imgs)

    def configure_optimizers(self):
        # AdamW is Adam with a correct implementation of weight decay (see here for details: https://arxiv.org/pdf/1711.05101.pdf)
        optimizer = optim.AdamW(
            self.parameters(), **self.hparams.optimizer_hparams)
        return [optimizer], []

    def training_step(self, batch, batch_idx):
        # "batch" is the output of the training data loader.
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss_module(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()

        # Logs the accuracy per epoch to tensorboard (weighted average over batches)
        self.log('train_acc', acc, on_step=False, on_epoch=True)
        self.log('train_loss', loss)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs).argmax(dim=-1)
        acc = (labels == preds).float().mean()
        # By default logs it per epoch (weighted average over batches)
        self.log<