Download the files for Lab 5A from the following links:
We recommend that you use Google Colab, as training will be faster on the GPU.
To enable the GPU on Colab, go to Edit / Notebook settings / Hardware accelerator / select T4 GPU
Instructions on how to download and use Jupyter Notebooks can be found here. You can find a static version of the notebook below.
Lab 5: Time Series and Transformers¶
0. Environmental Setup¶
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from datetime import datetime
from torch.utils.data import DataLoader, Dataset
import math
torch.set_default_dtype(torch.float64)
plt.style.use('default')
plt.rcParams['font.size'] = '14'
if torch.cuda.is_available():
device = torch.device("cuda:0")
elif torch.backends.mps.is_available():
device = torch.device("mps:0")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
Using device: cuda:0
1 Time Series¶
We define a time series $X$ as a collection of $T +1$ vectors $x_t \in \reals^n$ indexed by a time index $t=0,1,\ldots T$. There are several tasks that we may want to perform in a time series, but the prototypical example is the prediction of the entry $x_T$ at time $T$ when given the history of the series between times $0$ and $T-1$, $$ X_T = x_{0:T-1} = \big[\, x_0, x_1, \ldots, x_{T-1} \,\big].\tag{1} $$ This task is illustrated in Figure 1 for $T=10$. The time series is observed between times $t=0$ and $t=T-1 = 9$. The value at time $T=10$ is unobserved. Our goal is to predict it.
This is a goal that we can formulate as a machine learning task. Given the history of the time series between times $0$ and $T-1$, we introduce a learning parameterization $H$ to produce estimates of the time series at time $T$, $$ \hat{x_T} = \Phi (\, X_T, \, H \, ).\tag{2} $$ These estimates can be compared to the true value of the time series $x_T$ to formulate a training cost that we then optimize to find the optimal set of parameters. That is, we go through the usual steps of: (i) Acquiring data for several time series. This yields a set of $U$ histories $X_u$ and corresponding time $T$ values $x_{uT}$. (ii) Introducing a loss function $l(\hat{x_T},x_T)$ measuring the fit between the time series value $x_T$ and its prediction $\hat{x_T}$. (iii) Formulating the empirical risk minimization (ERM) problem, $$ H^* ~=~ \argmin_H \frac{1}{U}\sum_{u=0}^{U-1} l(\, \Phi (\, X_{uT}, \, H),\, x_{uT}). \tag{3} $$ In (3), the index $u$ denotes several different time series. This is not quite how time series work. In reality, we are given a single time series that extends for $T+U$ units of time and the “different” time series are actually different windows of the same time series, $$ X_{uT} = x_{u:u+T-1} = \big[\, x_{u}, x_{u+1}, \ldots, x_{u+T-1} \,\big], \qquad x_{uT} = x_{u+T} .\tag{4} $$ Thus, out of a single time series we extract a number of training samples that consider time $u$ as the starting point of a new sequence of length $T$ out of which we want to predict the value of the sequence at time $u+T$. Our first task is to construct the dataset in (4) when given a time series.
Task 1: Load data¶
In this lab we work with weather data. We are given a time series with $T+U = 52,696$ entries each of which has various descriptors of the weather at different times of different days. The entries in the time series are twelve weather indicators such as humidity, atmospheric pressure, and temperature.
Load the data from the lab’s page and plot component “T (degC)” of the time series as a function of time. This is the average temperature during each time interval.
Separate this time series into two parts. The first part contains 70% of the values and the second part contains the remaining 30%. Use these two time series to extract $U=36,787$ samples of the form in (4) for a training set and to extract $U=15,709$ samples of the form in (4) for a test set. In both cases, use $T=100$.
# Load the weather data
raw_data = np.genfromtxt(
'weather.csv',
delimiter=",",
skip_header=1,
dtype=str,
)
# Split the date strings and numerical data
date_strs = raw_data[:, 0].astype(str) # First column: date strings
data = torch.from_numpy(raw_data[:, 1:].astype(float)) # Rest: numerical data
# Convert date strings to datetime objects
timesteps = [datetime.strptime(date,
# Print the shape information
print(f"\nNumber of samples: {data.shape[0]}")
print(f"Number of variables: {data.shape[1]}\n")
# x_t includes 12 features
feature_dim = 12
properties = [1, 3, 4, 5, 6, 7, 8, 9, 12, 13, 16, 19] # use the properties in the selected columns
data = data[:, [properties[i] for i in range(feature_dim)]].float().transpose(0,1).to(device)
n,T = data.shape
Number of samples: 52696 Number of variables: 21
# Visualize average temperature
plt.plot(timesteps, data[0,:].cpu().numpy()) # T(degC) is in the first column of data with selected properties
plt.title('Temperature vs. time')
plt.xlabel('Date')
plt.ylabel('degC')
plt.xticks(rotation=45)
plt.show()
# Split into testing and training datasets
train_len = int(0.7*T) # use the first 70% of data to train
train_data = data[:,0:train_len] # use the remaining 30% of data to test
# normalize
train_data = (train_data - train_data.mean(dim=1, keepdim=True)) / train_data.std(dim=1, keepdim=True)
test_data = data[:,train_len:]
test_data = (test_data - test_data.mean(dim=1, keepdim=True)) / test_data.std(dim=1, keepdim=True)
print(f"train_data shape: {train_data.shape}")
print(f"test_data shape: {test_data.shape}")
train_data shape: torch.Size([12, 36887]) test_data shape: torch.Size([12, 15809])
T = 100
batch_size = 1024
# Split x into windows
class WeatherTSDataset(Dataset):
def __init__(self, data, T):
self.data = data
self.T = T
def __len__(self):
"""this is equivalent to U - 1"""
u_plus_t = self.data.shape[1]
return (u_plus_t - self.T - 1)
def __getitem__(self, u):
"""
This function recturns a window of size T and the target value at the next time step.
"""
x_window = self.data[:,u:(u + self.T)]
y = self.data[:,u + self.T]
return x_window, y
train_dataset = WeatherTSDataset(train_data, T)
test_dataset = WeatherTSDataset(test_data, T)
print(f"train_dataset length: {len(train_dataset)}")
print(f"test_dataset length: {len(test_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
train_dataset length: 36786 test_dataset length: 15708
# Define a plotting function for visualization after training
properties_label = ['T (degC)', 'Tdew (degC)', 'rh
'sh (g/kg)', 'H2OC (mmol/mol)', 'max. wv (m/s)', 'wd (deg)', 'SWDR (W/m^2)', 'Tlog (degC)']
def plot_compare_prediction(predictions, actual, title):
"""This function will be used during evaluation to plot the predictions and actual values of the time series."""
fig, axes = plt.subplots(3, 4, figsize=(24, 16))
axes = axes.flatten()
time = timesteps[:predictions.shape[0]]
for i in range(feature_dim):
axes[i].plot(time, predictions[:, i], alpha=0.5)
axes[i].plot(time, actual[:, i], alpha=0.5)
axes[i].set_title(properties_label[i])
axes[i].set_xlabel('Date')
axes[i].set_ylabel('Values')
for tick in axes[i].get_xticklabels():
tick.set_rotation(45)
handles, labels = axes[0].get_legend_handles_labels()
plt.subplots_adjust(hspace=0.6, wspace=0.3)
plt.show()
1.1 A More Precise Definition of Time Series¶
We begun Section 1 describing time series as a collection of $T +1$ vectors $x_t \in \reals^n$. A more common definition of a time series is that of a set of vectors $x_t \in \reals^n$ that extends from time $t=0$ to infinity. At any point in time $t$ our goal is to predict the value of $x_t$ given the whole process’s history $x_{0:t-1}$. In practice, values in the distant past are considered irrelevant for the estimation of $x_t$. We therefore introduce a window of length $T$ and consider the history of the time series starting at time $t-T$. Formally, we define the windowed history $$ X_{t} = x_{t-T:t-1} = \big[\, x_{t-T}, x_{t-T+1}, \ldots, x_{t-1} \,\big],\tag{5} $$ and consider a learning parameterization that maps $X_{t}$ to predictions $$ \hat{x_t} = \Phi \big(\, X_{t}, \, H \, \big). \tag{6} $$ This is an equivalent description of the history and parameterization in (1) and (2). It is just that instead of starting at time $t=0$ to predict at time $T$ as in (1) and (2) we start at arbitrary time $t$ to predict at time $t+T$.
This more accurate description of a time series is important during execution. The trained model $\Phi(X, H^*)$ is executed on a rolling basis. At any time $t$ we make predictions by executing the model $\Phi(X_t, H^*)$ with the history window $X_t$ as defined in (5). After observing $x_t$ — at which point the problem of predicting $x_t$ becomes moot — we advance time to $t+1$, update the history window and execute the model $\Phi(X_{t+1}, H^*)$ to make a prediction of the value of the time series at time $t+1$.
Henceforth, we work with the definition of a time series as a sequence of $T$ vectors $X=x_{0:T-1}$ with the goal of predicting $x_T$. There are less indexes involved and the notation is less cumbersome. But we keep in mind that out trained models are to be executed on a rolling basis on an indefinite time series.
2 Attention Layers¶
Attention layers create representations of the entries $x_t$ of a time series $X = [x_0, \ldots,x_T]$ that depend on context. This is done by constructing vectors $y_t$ that are linear combinations of all of the entries of the time series weighted by importance coefficients. I.e., for a certain matrix $M$ and similarity function $\text{d}(\cdot,\cdot)$, we compute the vector $$ y_t ~=~ \sum_{u=0}^{T} \text{d}(x_t,x_u) M x_u .\tag{7} $$ The collection of vectors $y_t$ forms another time series $Y = [y_0, \ldots,y_T]$. The construction of the time series $Y$ is such that its entries $y_t$ depend on all other entries of the time series. For this reason we call it a contextual representation. The purpose of the importance coefficients $\text{d}(x_t,x_u)$ is for the representation $y_t$ to be most affected by the time series vectors $x_u$ that are deemed most relevant to $x_t$.
The importance coefficients $\text{d}(x_t,x_u)$ are called attention coefficients.
2.1 Attention Coefficients¶
To accomplish this we rely on the attention coefficients
$$
B_{tu} ~=~ \langle Qx_t, Kx_u \rangle
~=~ (Qx_t)^T (Kx_u) . \tag{8}
$$
Attention is just a way of measuring the similarity between the components $x_t$ and $x_u$ of the time series. Indeed, if we make $Q = K = I$ the attention coefficient reduces to the inner product between the time series’s components, $B_{tu}=\langle x_t, x_u \rangle $. This inner product is a standard measure of similarity between vectors.
The incorporation of $Q$ and $K$ in (8) introduces learnable coefficients that may yield more relevant measures of similarity. The matrices $Q$ and $K$ have $n$ columns — which is the number of entries of each of the time series vectors $x_t$ — and $m$ rows. In general, $m \ll n$ because we know that inner products are more meaningful in low dimensional spaces.
The coefficients $B_{tu}$ can be arranged into row vectors $b_t$ that include all of the attention coefficients associated with time $t$. It follows from (8) and the definition of the time series matrix $X = [x_0, \ldots,x_T]$ that this vector of attention coefficients can be computed as $$ b_{t} ~=~ (Qx_t)^T (KX) .\tag{9} $$ The computation of the attention vector $b_{t}$ is represented in Figure fig_attention_vector. We begin with the time series represented in its matrix form $X$ and isolate a specific time index $t$. This is the vector $x_t$ in Figure 4. To compute attention coefficients we multiply $x_t$ by the query matrix $Q$. This multiplication yields the query vectors $Qx_t$ for this particular component of the time series. In parallel, each of the vectors $x_u$ of the time series is multiplied by the key matrix $K$. This results in the calculation of the key vectors $Kx_u$ which are the columns of the key matrix $KX$. Although not required, the number of rows of the query and key vectors are (much) smaller than the number of rows of the time series. The attention coefficients in (8) are the result of computing the inner product between the query vector and the key matrix.
The attention coefficients in (8) can be further grouped into an attention matrix $B$. Operating from the definition of the attention vectors in (9) we can see that this matrix is given by $$ B ~=~ (QX)^T (KX) .\tag{10} $$ This computation is illustrated in Figure (5). The time series matrix $X$ is multiplied by the query matrix $Q$ and the key matrix $K$. These multiplications result in the computation of the queries $QX$ and the keys $KX$. The attention matrix $B$ is the outer product $(QX)^T (KX)$ between queries and keys.
Notice that there are a large number of attention coefficients but they are generated by a relatively small number of parameters. Indeed, there are at total of $(T+1)^2$ attention coefficients when we operate with a time series with $T+1$ vectors. However, the query and key matrices have $m\times n$ coefficients.
As is the case with convolutions in time, graphs, and images, the matrix expression in (10) is the one that we use for implementations. The scalar and vector expressions in (8) and (9) are valuable to understand attention but not used in implementations.
Task 2: Attention¶
Implement a Pytorch module for the similarity operation in (10). The query and key matrices are attributes of this class. The forward method should compute the attention matrix $B$.
# We define a linear attention class that computes the attention matrix B,
# which is equivalent to the outer product (QX).T(KX) between queties and keys
# Input: m (the dimension of the Q and K matrices)
# n (the number of features, n=12 in our case)
# Output: B (the attention matrix)
class LinearAttention(nn.Module):
"""
We define a linear attention class that computes the attention matrix B,
which is equivalent to the outer product (QX).T(KX) between queties and keys.
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
"""
def __init__(self, m, n):
super(LinearAttention, self).__init__()
self.Q = nn.Parameter(torch.randn(m, n))
self.K = nn.Parameter(torch.randn(m, n))
self.initialize_parameters()
def initialize_parameters(self):
"""Initialize the values of the learnable parameter matrices.
Kaiming uniform is just a type of random initialization, you don't need to
worry about it. It is a good default initialization for linear layers.
"""
nn.init.kaiming_uniform_(self.Q, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.K, a=math.sqrt(5))
def forward(self, X):
"""
Compute the linear attention matrix B.
Args:
X (torch.Tensor): The input sequence.
Returns:
B (torch.Tensor): The linear attention matrix.
"""
QX = torch.matmul(self.Q, X)
KX = torch.matmul(self.K, X)
# Transposes with more than two dimensions need specification on the dimensions to transpose.
QX_t = QX.transpose(1,2)
B = QX_t@KX
return B
# Example sequence of a batch of 2 examples with 12 features and 5 time steps
X_dummy = torch.randn(2,12,5).to(device)
# Testing the module
linear_att = LinearAttention(m=3,n=12).to(device)
B = linear_att(X_dummy)
# B is the linear attention matrix, of size (5,5), for each example.
display(B.shape)
B
torch.Size([2, 5, 5])
tensor([[[-0.0504, 0.4186, -0.3042, 0.0318, -0.1312], [ 0.2227, 1.0347, -0.9597, -0.1263, -0.1839], [-0.3236, 0.1610, 0.2798, -0.0793, 0.0967], [-0.3085, 1.8316, -1.3670, 0.3024, -0.7205], [-0.0983, 0.0552, 0.1935, -0.1672, 0.1704]], [[ 0.1915, 0.1118, 0.1169, -0.4661, 0.3297], [-0.1832, 0.0207, 0.0268, 0.6762, -0.3522], [-0.5165, 0.0935, 0.0256, -0.1771, 0.1697], [-0.1208, 0.0812, 0.0794, 0.2828, -0.0961], [-1.0151, 0.0670, -0.0760, -0.5436, 0.3589]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)
2.2 Nonlinear Attention¶
The similarity coefficients in (8) are what we call a linear attention mechanism. Nonlinear attention mechanisms post process linear attention coefficients with a nonlinear function.
The most common choice of nonlinearity is a function we call a softmax. For a given vector $b\in\reals^{T+1}$ the softmax is the vector, $a = \text{sm}(b)$ with components, $$ a_{u} = \frac{\exp(b_u)}{ \sum_{u’=1}^{T+1}\exp(b_{u’})} \quad\Leftrightarrow\quad \text{sm}(a) = \frac{\exp(b)}{ \mathbb{1}^T \exp(b)} ,\tag{11} $$ where in the second equality we define the vector of exponentials $\exp(b) := [\exp(b_0); \ldots; \exp(b_T)]$ and the vector of all ones $\mathbb{1} := [1; \ldots; 1]$. As per (11), the softmax entry $a_u$ is the ratio between the exponential $\exp(b_u)$ of Component $u$ of the vector $b$ normalized to the sum of the exponentials $\exp(b_{u’})$ of all components of $b$. We point out that the definition is similar but not identical to the definition of the softmax function we used to introduce the cross entropy loss In Lab 2C.
With this definition we can now define softmax similarity coefficients as the application of the softmax function in (11) to the linear similarity vector in (9). $$ a_t ~=~ \text{sm} \Big(\, b_t \,\Big) ~=~ \text{sm} \Big(\, (Qx_t)^T (Kx_u) \,\Big) ~=~ \frac{\exp\Big(\, (Qx_t)^T (KX) \,\Big)} {\mathbb{1}^T \exp\Big(\, (Qx_t)^T (KX) \,\Big)} ~.\tag{12} $$ Observe that the definition of the softmax in (11) is such that the sum of the entries of the softmax vector is normalized to, $\mathbb{1}^T\text{sm}(a) = 1$. As a particular case, the sum of the similarity coefficients $a_t$ in (12) is $\mathbb{1}^Ta_t = 1$. We can then think of the softmax similarity coefficients in (12) as a nonlinear normalization of the attention coefficients in (9).
This observation is important because it makes it plain that (12) is similar to a pointwise nonlinearity. Indeed, if we use $A_{tu}$ to denote the entries of $a_t$, it follows from the definitions in (8), (9) and (12) that $$ A_{tu} ~=~ \frac{\exp(B_{tu})}{ \sum_{u’=1}^{T+1}\exp(B_{tu})}\tag{13} $$ Thus, the similarity coefficient $A_{tu}$ is obtained by applying an exponential pointwise nonlinearity to the linear similarity coefficient $B_{tu}$ followed by a normalization. The purpose of the exponential nonlinearity is to magnify the difference between different similarity coefficients.
Similarly to (10), we can group all attention coefficients into a matrix $$ A = \text{sm} \Big(\, (QX)^T (KX) \, \Big), \tag{14} $$ where the softmax function implements normalizing along the rows of $(QX)^T (KX)$. I.e., the rows of the similarity matrix $A$ are the vectors $a_t$ defined in (12).
Task 3: Softmax Attention¶
Implement a Pytorch module for the similarity operation in (14). The query and key matrices are attributes of this class. The forward method should compute the attention matrix $A$.
# We define a softmax attention module to perform softmax on the attention matrix B.
# Input: m (the dimension of the Q and K matrices)
# n (the number of features, n=12 in our case)
# Output: A (softmax of the attention matrix B)
class SoftmaxAttention(nn.Module):
"""
We define a softmax attention module to perform softmax on the attention matrix B.
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
"""
def __init__(self, m, n):
super(SoftmaxAttention, self).__init__()
self.Q = nn.Parameter(torch.randn(m, n))
self.K = nn.Parameter(torch.randn(m, n))
self.initialize_parameters()
def initialize_parameters(self):
"""
Initialize the values of the learnable parameter matrices.
Kaiming uniform is just a type of random initialization, you don't need to
worry about it. It is a good default initialization for linear layers.
"""
nn.init.kaiming_uniform_(self.Q, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.K, a=math.sqrt(5))
def forward(self, X):
"""
Compute the attention matrix A.
Args:
X (torch.Tensor): The input sequence.
Returns:
A (torch.Tensor): The attention matrix with softmax applied.
"""
QX = torch.matmul(self.Q, X)
KX = torch.matmul(self.K, X)
# Transposes with more than two dimensions need specification on the dimensions to transpose.
QX_t = QX.transpose(1,2)
B = QX_t@KX
# Normalizing on the column dimension of each example, each a_t (row) sums to 1
A = F.softmax(B, dim=-1)
return A
# Example sequence of a batch of 2 examples with 12 features and 5 time steps
X_dummy = torch.randn(2,12,5).to(device)
# Testing the module
softmax_att = SoftmaxAttention(m=3,n=12).to(device)
A = softmax_att(X_dummy)
# Verification: The rows of A sum to 1.
A.sum(dim=2)
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000], [1.0000, 1.0000, 1.0000, 1.0000, 1.0000]], device='cuda:0', grad_fn=<SumBackward1>)
2.3 Contextual Representations¶
$$ \def\bm{\mathbf} $$The similarity coefficients in (12)-(14) are used to create a contextual representation of $x_t$, $$ \bm{z}_t = \sum_{u=0}^{T} (\bm{V} \bm{x}_u) A_{tu}\tag{15} $$ This contextual representation is a linear combination of all the vectors in the time series multiplied by a matrix $\bm{V}$ and scaled by the similarity coefficients $A_{tu}$. The matrix $\bm{V} \in \mathbb{R}^{m \times n}$ is a projection in a lower dimensional space. The dimensions of $\bm{V}$ are the same as the dimensions of $\bm{Q}$ and $\bm{K}$.
It is easy to see that the expression in (15) is equivalent to $$ \bm{z}_t = \bm{V} \bm{X} \bm{a}_{t}^T.\tag{16} $$ This operation is explained in Figure 6. The time series $\bm{X}$ is multiplied by the value matrix $\bm{V}$. This produces an alternative representation of the time series given by the product $\bm{V}\bm{X}$. This representation is of lower dimensionality. Instead of having vectors $\bm{x}_u \in \mathbb{R}^n$ associated with each point in time $u$, we have vectors $\bm{V}\bm{x}_u \in \mathbb{R}^m$. We choose $m \ll n$. The low dimensional contextual representation $\bm{z}_t$ is obtained by weighting each vector $\bm{V} \bm{x}_u$ by the attention coefficient $A_{tu}$ and summing over all times $u$. Equivalently, we obtain $\bm{z}_t$ as the product $\bm{V}\bm{X} \bm{a}_{t}^T$ shown in (16). We say that this representation is contextual because $\bm{z}_t$ depends on vectors $\bm{x}_u$ that have been deemed similar to $\bm{x}_t$ by the attention coefficient $A_{tu}$.
The operation in (16) can also be represented in matrix form. The matrix $\bm{Z}$ with columns $\bm{z}_t$ is given by $$ \bm{Z} = \bm{V}\bm{X}\bm{A}^T \tag{17} $$ This operation is represented in Figure 7. The top part of Figure 7 is the same as the top part of Figure 6. We are constructing a lower dimensional representation $\bm{V}\bm{X}$ of the time series. We then produce the contextual representation $\bm{Z}$ by multiplying $\bm{V}\bm{X}$ with the attention matrix transpose $\bm{A}^T$.
The representations in (17) are of dimension $m \ll n$. We complete an attention layer with a dimensional recovery step. This is done by multiplication with the transpose of a matrix $\bm{W} \in \mathbb{R}^{m \times n}$. We can write this operation in terms of individual contextual vectors $\bm{z}_t$, $$ \bm{y}_t ~=~ \bm{W}^T \bm{z}_t ~=~ \bm{W}^T \bm{V} \bm{X} \bm{a}_{t}^T.\tag{18} $$ or in terms of the matrix $\bm{Z}$ with all of the contextual vectors $$ \bm{Y} ~=~ \bm{W}^T \bm{Z} ~=~ \bm{W}^T \bm{V} \bm{X} \bm{A}^T \tag{19} $$ The vectors $\bm{y}_t$ are contextual representations of the time series that have the same dimensionality as the components of the time series. The operations in (18) and (19) are illustrated in Figures 8 and 9.
The representation $\bm{Y}$ is the output of an attention layer.
2.4 Softmax Attention Layers¶
As it follows from the discussions in Sections 2.2 and 2.3 a softmax attention layer has two distinct operations. The first operation is the computation of the softmax attention coefficients, $$ A ~=~ \text{sm}\left( (QX)^T (KX) \right)\tag{20} $$ This is Equation (14) repeated here for reference. The second operation is the computation of the contextual representation, $$ Y ~=~ W^T \, V \, X \, A^T ~=~ W^T \, V \, X \, \left[\, \text{sm} \Big(\, (Q\, X)^T (K\, X) \, \Big) \,\right]^T .\tag{21} $$ This is Equation (19) repeated here for reference.
The parameters of the attention layer are the matrices $Q$, $K$, $V$ and $W$. All of these matrices have $m$ rows and $n$ columns with $m \ll n$. Having intermediate representations of smaller dimension is important.
The expressions in (14) and (21) are what you should use to implement and analyze attention layers. However, it is sometimes instructive to keep in mind the definition of the attention vectors $$ a_t ~=~ \text{sm} \Big(\, (Qx_t)^T (Kx_u) \,\Big) ,\tag{22} $$ and the expanded expression for the computation of the contextual representation $$ y_t ~=~ W^T \sum_{u=0}^{T} A_{tu} V x_u.\tag{23} $$ Equation (22) is a repetition of (12) and equation (23) is a combination of (18) and (15). Notice that the expression in (23) has the same form of the conceptual expression in (7) with $\text{d}(x_t,x_u) = A_{tu}$ and $M =W^TV$.
It is common to postprocess the contextual representation further, but to do so without further mixing of different time components. The simplest we can do is add a pointwise nonlinearity to (21) so that the output is $$ Y_\sigma ~=~ \sigma\Big(\, Y \,\Big) ~=~ \sigma\Big(\,W^T V X A^T \,\Big) .\tag{24} $$ It is also not uncommon to postprocess each $y_t$ with a fully connected neural network (FCNN). This is not unwise because the dimensionality of $y_t$ is not too large and we will use the same FCNN for all times $t$. We will not do this here.
Task 4&5: Attention Layer and Nonlinearity¶
Code a Pytorch module that implements an attention layer. The matrices $Q$, $K$, $V$ and $W$ are parameters of this module. The forward method of this module takes a time series $X$ as an input and produces the time series $Y$ as an output. The module receives $m$ and $n$ as initialization parameters.
Modify the module of Task 4 to include a pointwise nonlinear operation. You can choose your favorite nonlinearity here, but we suggest you implement a relu.
class AttentionLayer(nn.Module):
"""
We define a softmax attention layer module to
compute the attention matrix A.
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
"""
def __init__(self, m, n):
super(AttentionLayer, self).__init__()
self.Q = nn.Parameter(torch.empty(m, n) )
self.K = nn.Parameter(torch.empty(m, n))
self.V = nn.Parameter(torch.empty(m, n))
self.W = nn.Parameter(torch.empty(n, m))
self.nonlinearity = nn.ReLU()
self.initialize_parameters()
def initialize_parameters(self):
"""
Initialize the values of the learnable parameter matrices.
Kaiming uniform is just a type of random initialization, you don't need to
worry about it. It is a good default initialization for linear layers.
"""
nn.init.kaiming_uniform_(self.Q, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.K, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
def forward(self, X):
"""
Computes the Attention layer computation, which involves
computing the matrix A, then reweighting XV by the attention coefficients.
Finally, it applies a nonlinearity to the result.
"""
B, n, T = X.shape
QX = torch.matmul(self.Q, X) #(B, m, T)
KX = torch.matmul(self.K, X) #(B, m, T)
VX = torch.matmul(self.V, X) #(B, m, T)
QX_t = QX.transpose(1,2)
# Compute linear attention
B = QX_t @ (KX) # (B, T, T)
# Softmax attention.
# Normalizing on the column dimension of each example, each a_t (row) sums to 1
A = F.softmax(B, dim=2) # (B, T, T)
Z = torch.matmul(VX, A) #(B, m, T)
# Feedforward block.
Y_l = self.nonlinearity(Z)
X_l = self.nonlinearity(torch.matmul(self.W, Y_l)) #(B, n, T)
return X_l
# Example sequence of a batch of 1 example with 12 features and 100 time steps
X_dummy = torch.randn(1, n, 100).to(device)
# Testing the module
attention_layer = AttentionLayer(m=3, n=12).to(device)
output = attention_layer(X_dummy)
print(output.shape)
print(output[0,0,:])
torch.Size([1, 12, 100]) tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0109, 0.0220, 0.0000, 0.0000, 0.0000, 0.0010, 0.0000, 0.0000, 0.0000, 0.0477, 0.0000, 0.0000, 0.0000, 0.0358, 0.0000, 0.0000, 0.0000, 0.0041, 0.0000, 0.0000, 0.0012, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0495, 0.0000, 0.0000, 0.0000, 0.0000, 0.0389, 0.0000, 0.0000, 0.0136, 0.0137, 0.0000, 0.0132, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0329, 0.0024, 0.0000, 0.0000, 0.0000, 0.0170, 0.0117, 0.0258, 0.0000, 0.0014, 0.0230, 0.0110, 0.0410, 0.0000, 0.0096, 0.0392, 0.0000, 0.0168, 0.0000, 0.0457, 0.0000, 0.0000, 0.0000, 0.0438, 0.0000, 0.0000, 0.0000, 0.0033, 0.0069, 0.0000, 0.0083, 0.0298, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0027, 0.0000, 0.0000, 0.0013, 0.0000, 0.0000, 0.0304, 0.0000, 0.0000, 0.0000], device='cuda:0', grad_fn=<SliceBackward0>)
3 Transformers¶
A transformer is a layered architecture where each layer is an attention layer. Formally, this is a composition of operations defined by the recursion $$ A_l ~=~ \text{sm} \Big(\, (Q_l X_{l-1})^T (K_l X_{l-1}) \, \Big),\tag{25} $$ $$ Y_l ~=~ W_l^T \, V_l \, X_{l-1} \, A_l^T,\tag{26} $$ $$ X_l ~=~ \sigma\Big( Y_l \,\Big) .\tag{27} $$ This recursion is initialed with $X_0 = X$ and is repeated $L$ times, where $L$ is the number of layers of the transformer. This is analogous to the composition of layers in convolutional and graph neural networks.
For future reference, define the tensor $A = [Q_l, K_l, V_l, W_l]$ grouping the query, key, value, and dimension recovery matrices of all layers. With this definition we write the output of a transformer as $$ \Phi (X, A) ~=~ Y_L \tag{28} $$ In (28), $X$ is the time series that we input to the transformer and $A$ is the trainable parameter. The output $\Phi (X, A)$ is another time series with the same number of time components $T+1$ and vectors with the same dimension $n$. The vectors $y_t$ are representations of time $t$ that depend on the context of the whole time series.
4 Time Series Prediction¶
We use a transformer to predict the next entry of a time series; Section 1. To do so observe that the output of the transformer $Y$ is an $n \times (T+1)$ matrix representing the time series $X$ which is also an $n \times (T+1)$ matrix. This is mismatched to a problem in which the input is a time series with $T$ components $[x_0,\ldots,x_{T-1}]$ and the output is a prediction $\hat{x}_T$ of the value at time $T$. To sort out this mismatch consider the average $\bar{x} = X \mathbf{1} / T$ of the time series values and define the input of the transformer as the time series $$ \tilde{X} = \Big[ \, X, \, \bar{x}\, \Big].\tag{29} $$ In the time series $\tilde{X}_t$ we append the mean $\bar{x}$ to the given time series $X_t$. The idea is that $\bar{x}$ is a naive prediction of the time series entry for time $T+1$.
We can now use a transformer to refine this estimate. We do that by reading the transformer output at time $T+1$ and declaring it to be our estimate of the weather data, $$ \hat{x}_{T} = \Big[ \, \Phi (\tilde{X}, A) \, \Big]_T.\tag{30} $$ An alternative approach is to process the time series $X$ without appending the naive estimate $\bar{x}$. This gives as an output a time series with $T$ components. In this case we declare that the estimate $\hat{x}_{T}$ is the average of the outputs of the transformer for all times, $$ \hat{x}_{T} = \Big[ \, \Phi (X, A) \, \Big] \mathbf{1}.\tag{31} $$ Notice that (29)-(30) and (31) are similar approaches. In (29)-(30) we compute an average before running the transformer and in (31) we compute an average after running the transformer.
Task 6&7: Transformer¶
Code a Pytorch module to implement a Transformer as specified by (25)-(27). This implementation can leverage the implementation of the attention layer in Task 5.
Use the data in Task 1 to train a transformer for weather prediction. You can choose either of the approaches in (29)-(30) or (31). The parameters of the transformer are your choice. We suggest that you use $L=3$ layers and $m=3$ for your intermediate representations. Use a mean squared loss and evaluate train and test performance.
class Transformer(nn.Module):
"""
An implementation of the transformer, using the mean of the time series as the prediction for the next time step.
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
L (int): The number of layers.
"""
def __init__(self, m, n, L):
super(Transformer, self).__init__()
self.attention_layers = nn.ModuleList([AttentionLayer(m, n) for _ in range(L)])
def forward(self, X):
"""
The forward pass of the transformer, stacks L attention layers
(the input is the time series X, and the output is the last time step of the transformer)
Args:
X (torch.Tensor): The input sequence.
Returns:
X_l (torch.Tensor): The output of the transformer.
"""
B, n, T = X.shape
# Compute the mean token to append to the sequence.
x_tilde = X.mean(dim=2, keepdim=True) # mean over the time dimension
# Concatenate the mean token to the sequence along the time dimension.
X_tilde = torch.cat((X, x_tilde), dim=-1)
# X_l has shape (B, n, T+1)
X_l = X_tilde
for attention_layer in self.attention_layers:
X_l = attention_layer(X_l)
# Output the last vector.
return X_l[:,:,-1]
# Example sequence of a batch of 1 example with 12 features and 100 time steps
X_dummy = torch.randn(1, n, 100).to(device)
# Testing the module
transformer = Transformer(m=3, n=12, L=6).to(device)
X_l = transformer(X_dummy)
print(X_dummy.shape)
print(X_l.shape)
torch.Size([1, 12, 100]) torch.Size([1, 12])
# Training
n_epochs = 50
m = 6
n = feature_dim
L = 3
T = T
H = 3
estimator = Transformer(m, n, L).float().to(device)
optimizer = torch.optim.SGD(estimator.parameters(), lr=1e-3)
loss = nn.MSELoss()
estimator.train()
train_loss = []
for epoch in range(n_epochs): # Iterate over n_epochs epochs
for x_batch, y_batch in train_loader: # Iterate over all batches in the dataset
# (Step i) Load the data. These commands send the data to the GPU memory.
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
# (Step ii) Compute the gradients. We use automated differentiation.
optimizer.zero_grad() # Gradient reset to indicate where the backward computation stops.
yHat = estimator(x_batch) # Call the neural network.
mse = loss(yHat,y_batch) # Call the loss functions.
mse.backward() # Compute gradients moving backwards untit the gradient reset.
# (Step iii) Update parameters by taking an SGD (or other optimizer) step.
optimizer.step()
train_loss.append(mse.item())
if epoch print(f"Epoch {epoch}/{n_epochs} Loss: {train_loss[-1]}")
# End of batch loop.
print(train_loss[-1]) # Print training loss.
Epoch 0/50 Loss: 0.9979864358901978 Epoch 5/50 Loss: 1.0151433944702148 Epoch 10/50 Loss: 1.0232248306274414 Epoch 15/50 Loss: 0.9938992857933044 Epoch 20/50 Loss: 1.0198849439620972 Epoch 25/50 Loss: 1.034855604171753 Epoch 30/50 Loss: 0.9846301078796387 Epoch 35/50 Loss: 1.0373992919921875 Epoch 40/50 Loss: 1.025139331817627 Epoch 45/50 Loss: 1.034329891204834 0.9852505326271057
def evaluate(test_loader, estimator):
"""
Computes the test loss and plots the predictions and true values of the time series.
Args:
test_loader (DataLoader): The test data loader.
estimator (nn.Module): The trained transformer model.
Returns:
avg_test_loss (float): The average test loss.
"""
estimator.eval()
predictions_t = []
actual_t = []
loss = nn.MSELoss()
with torch.no_grad():
test_loss_acc = 0.
num_batches = 0
for inputs, targets in test_loader:
inputs, targets = inputs.to(device), targets.to(device)
X_l = estimator(inputs)
test_loss = loss(X_l, targets).item()
test_loss_acc += test_loss
predictions_t.extend(X_l[:, :].tolist())
actual_t.extend(targets[:, :].tolist())
num_batches += 1
# Calculate average losses
avg_test_loss = test_loss_acc / num_batches
print(f'Average Test Loss: {avg_test_loss:.8f}')
# Visualize output
predictions_t = np.array(predictions_t)
actual_t = np.array(actual_t)
plot_compare_prediction(predictions_t, actual_t, 'Transformer')
return avg_test_loss
evaluate(test_loader, estimator)
Average Test Loss: 0.96636721
0.9663672093302011
5 Multihead Attention¶
As we did with CNNs and GNNs we also incorporate multiple features. Features in transformers are called heads $$ A_l^h ~=~ \text{sm} \Big(\, (Q_lX_{l-1})^T (K_lX_{l-1}) \, \Big),\tag{32} $$ $$ Y_l^h ~=~ W_l^T \, V_l \, X_{l-1} \, {A_l^h}^T ,\tag{33} $$ $$ X_l ~=~ \sigma\bigg( \sum_{h=1}^H Y_l^h \,\bigg) .\tag{34} $$ A (minor) difference between multihead transformers and neural networks with multiple features is that the outputs of attention layers always have a single feature. The multiple features $Y_l^h$ generated by different heads are added at the output of each layer to produce the layer’s output $X_l$.
Task 8: Multi-head Layer¶
Code a Pytorch module to implement a multihead transformer as specified by (32)-(34). This implementation can leverage the implementation of the attention layer in Task 5.
class MultiHeadLayer(nn.Module):
"""
An implementation of the multihead attention layer.
The difference between AttentionLayer and this class is,
now Q,K,V are matrices of shape (H, m, n), and the attention matrix B is of shape (H, T, T)
(one attention feature per head)
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
k (int): The dimension of the W matrix.
H (int): The number of heads.
"""
def __init__(self, m, n, H):
super(MultiHeadLayer, self).__init__()
self.m = m
self.H = H
self.Q = nn.Parameter(torch.empty(H, m, n))
self.K = nn.Parameter(torch.empty(H, m, n))
self.V = nn.Parameter(torch.empty(H, m, n))
self.W = nn.Parameter(torch.empty(n, m))
self.nonlinearity = nn.ReLU()
self.initialize_parameters()
def initialize_parameters(self):
"""
Initialize the values of the learnable parameter matrices.
Kaiming uniform is just a type of random initialization, you don't need to
worry about it. It is a good default initialization for linear layers.
"""
nn.init.kaiming_uniform_(self.Q, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.K, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))
def forward(self, X):
"""
The forward pass of the multihead attention layer, analogous to the one in the
AttentionLayer class. The main difference is that we need to make sure that the
matrix multiplications account for the new head dimenison.
Args:
X (torch.Tensor): The input sequence.
Returns:
X_l (torch.Tensor): The output of the multihead attention layer.
"""
B, n, T = X.shape # X: (B, n, T)
# Expand X to include the head dimension
X_expanded = X.unsqueeze(1) # (B, 1, n, T)
# Compute QX, KX, VX for each head
# The unsqueeze is used to add the head dimension to the matrices,
# because they are of shape (H, m, n), and we need to multiply them
# with X_expanded of shape (B, 1, n, T)
QX = torch.matmul(self.Q.unsqueeze(0), X_expanded) # (B, H, m, T)
KX = torch.matmul(self.K.unsqueeze(0), X_expanded) # (B, H, m, T)
VX = torch.matmul(self.V.unsqueeze(0), X_expanded) # (B, H, m, T)
# Transpose QX for multiplication
QX_t = QX.transpose(-2, -1) # (B, H, T, m)
# Compute attention scores B per head
B_matrix = torch.matmul(QX_t, KX) # (B, H, T, T)
# Compute attention weights A per head
A = F.softmax(B_matrix, dim=-1) # (B, H, T, T)
# Compute Z per head
Z = torch.matmul(VX, A) # (B, H, m, T)
# Average over the heads
Z = Z.sum(dim=1)
# Continue with feed-forward network
Y_l = torch.matmul(self.W, Z) # (B, n, T)
X_l = X + self.nonlinearity(Y_l) # (B, n, T)
return X_l
# Test code
layer = MultiHeadLayer(m=3, n=12, H=4).to(device)
X_dummy = torch.randn(1, n, 100).to(device)
output = layer(X_dummy)
print(output.shape)
print(output[0,0,:])
torch.Size([1, 12, 100]) tensor([-0.4146, 1.2016, -0.2843, -0.1861, 1.7256, -1.3594, -0.8331, -1.4487, 0.0303, 0.4614, -0.9792, 0.8990, 2.2749, -1.7884, -1.8909, -0.9293, -0.3365, -0.1339, 1.7545, 0.3676, 0.6645, -0.7899, 0.8880, 0.8931, 0.4194, -0.5105, -1.1982, 0.9511, 0.0121, 1.0617, 2.0614, 1.0274, -2.2922, 0.7783, 0.7714, 1.5638, 1.7037, -0.1530, 0.1866, -0.2623, -1.1237, -0.3733, -0.1903, -1.4724, -2.1897, -0.5852, -1.1204, 0.5195, 0.4742, -1.5120, -0.2626, 0.2744, -1.2994, 1.6559, -0.9548, 1.0801, -0.6603, -0.5715, -0.4242, -0.2085, -0.5027, 1.5344, -0.0070, -0.8301, 0.4626, 0.0340, -1.0499, 0.2514, -0.8558, 0.4166, 0.2071, -0.1422, -0.9235, -0.1930, 0.7839, 0.3962, -0.4783, 0.8936, 0.1352, 0.4671, -0.5731, 1.9169, -0.1338, -0.0838, 0.7217, 0.8771, 0.4322, 1.0234, -0.3090, 0.3151, 0.1281, -0.7404, -0.6522, 1.4577, 0.3399, -0.4781, 0.7331, -1.1684, -0.6576, -0.9867], device='cuda:0', grad_fn=<SliceBackward0>)
class MultiHeadTransformer(nn.Module):
"""
Mutlihead Transformer, analogous to the Transformer class, in the single head case.
Args:
m (int): The dimension of the Q and K matrices.
n (int): The number of features, n=12 in our case.
k (int): The dimension of the W matrix.
L (int): The number of layers.
H (int): The number of heads.
"""
def __init__(self, m, n, L, H):
super(MultiHeadTransformer, self).__init__()
self.layers = nn.ModuleList([
MultiHeadLayer(m, n, H) for _ in range(L)
])
def forward(self, X):
"""
The forward pass of the multihead transformer, stacks L multihead layers.
This class is essentially the same as the Transformer class, but using the
MultiHeadLayer class instead of the AttentionLayer class.
Args:
X (torch.Tensor): The input sequence.
Returns:
X_l (torch.Tensor): The output of the transformer.
"""
B, n, T = X.shape
# Compute the mean token to append to the sequence.
x_tilde = X.mean(dim=2, keepdim=True) # mean over the time dimension
X_tilde = torch.cat((X, x_tilde), dim=-1)
# X_l has shape (B, n, T+1)
X_l = X_tilde
for layer in self.layers:
X_l = layer(X_l)
# Output the last vector.
return X_l[:,:,-1]
X = torch.randn(2, n,100)
X_l = MultiHeadTransformer(m, n, L=6,H=4)(X)
print(X_l.shape)
torch.Size([2, 12])
Task 9: Train a Multi-head Transformer¶
Use the data in Task 1 to train a multihead transformer for weather prediction. You can choose either of the approaches in (29)-(30) or (31). The parameters of the transformer are your choice. We suggest that you use $L=3$ layers, $m=3$ for your intermediate representations, and $H=4$ for the number of heads. Use a mean squared loss and evaluate train and test performance.
Train Multihead¶
# Training
n_epochs = 50
m = 6
n = feature_dim
L = 3
T = T
H = 3
estimator = MultiHeadTransformer(m, n, L, H).float().to(device)
optimizer = torch.optim.SGD(estimator.parameters(), lr=1e-5)
loss = nn.MSELoss()
estimator.train()
train_loss = []
for epoch in range(n_epochs): # Iterate over n_epochs epochs
for x_batch, y_batch in train_loader: # Iterate over all batches in the dataset
# (Step i) Load the data. These commands send the data to the GPU memory.
x_batch = x_batch.to(device)
y_batch = y_batch.to(device)
# (Step ii) Compute the gradients. We use automated differentiation.
optimizer.zero_grad() # Gradient reset to indicate where the backward computation stops.
yHat = estimator(x_batch) # Call the neural network.
mse = loss(yHat,y_batch) # Call the loss functions.
mse.backward() # Compute gradients moving backwards untit the gradient reset.
# (Step iii) Update parameters by taking an SGD (or other optimizer) step.
optimizer.step()
train_loss.append(mse.item())
if epoch print(f"Epoch {epoch}/{n_epochs} Loss: {train_loss[-1]}")
# End of batch loop.
print(train_loss[-1]) # Print training loss.
Epoch 0/50 Loss: 0.5980456471443176 Epoch 5/50 Loss: 0.5956921577453613 Epoch 10/50 Loss: 0.5751610398292542 Epoch 15/50 Loss: 0.6154860258102417 Epoch 20/50 Loss: 0.5880224704742432 Epoch 25/50 Loss: 0.5862999558448792 Epoch 30/50 Loss: 0.5892231464385986 Epoch 35/50 Loss: 0.5942856073379517 Epoch 40/50 Loss: 0.6131274700164795 Epoch 45/50 Loss: 0.585210382938385 0.5798981189727783
plt.plot(train_loss)
[<matplotlib.lines.Line2D at 0x742dec95af90>]
evaluate(test_loader, estimator)
Average Test Loss: 0.56911424
0.5691142380237579