6. Probability Readout¶
The predictions $\hat{\mathbf{x}}_{u+T}$ in
$$
\hat{\mathbf{x}}_{u+T} = \Big[ \, \Phi (\mathbf{X}_u, \mathcal{A}) \, \Big] \mathbf{1}
$$
may have a small MSE when compared to the observed words $\mathbf{x}_{u+T}$ but they are not a good strategy for estimating the next word. This is because $\hat{\mathbf{x}}_{T}$ need not be a valid word. Indeed, it most likely will not be a valid word.
Word $\mathbf{e}_i$ is represented by the eigenvector encoding $\mathbf{x}_i = \mathbf{V}_n^T \mathbf{e}_i$ as stated in
$$
\mathbf{x}_i = \mathbf{V}_n^T \mathbf{e}_i.
$$
Since there are a total of $c$ words in our corpus, there are a total of $c$ vectors $\mathbf{x}_i$ that represent valid words. The vectors at the output of the transformer are most unlikely to be one of these vectors, and the estimate $\hat{\mathbf{x}}_{T}$ is just as unlikely unless we manage to drive the train and test MSEs to zero.
To solve this problem, we must force the readout to be a valid word. We do that with a readout layer whose output is a vector of $\tilde{d}_n$ probabilities for each of the $\tilde{d}_n$ words in the corpus. This readout layer is a softmax applied to the output of a fully connected layer that acts on the output of the transformer,
$$
\boldsymbol{\pi} (\mathbf{X}) = \text{sm} \Big[\, \mathbf{A} \, \text{vec} \big( \Phi (\mathbf{X}, \mathcal{A})\big) \, \Big] .
$$
The matrix $\mathbf{A}$ is a trainable parameter with $nT$ columns and $c$ rows. After applying the softmax normalization, the entries of the output $\boldsymbol{\pi}(\mathbf{X})$ add up to one and can be interpreted as a set of probabilities that dictate the likelihood of the next word in the sequence. The $i$th entry $\boldsymbol{\pi}_i(\mathbf{X})$ is the predicted probability that the next word is $\mathbf{e}_i$.
We refer to the probabilities in
$$
\boldsymbol{\pi} (\mathbf{X})
$$
as a policy. To train this policy, we minimize the cross-entropy loss between the true word at time $u+T$ and the probabilities $\boldsymbol{\pi}(\mathbf{X})$,
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \big(\mathbf{e}_{u+T}\big)^T \big( \log \boldsymbol{\pi}(\mathbf{X}_u) \big) .
$$
Notice that in
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \big(\mathbf{e}_{u+T}\big)^T \big( \log \boldsymbol{\pi}(\mathbf{X}_u) \big) .
$$
the vector $\mathbf{e}_{u+T}$ is the index encoding of the word at time $u+T$. This is a vector with all zeros except that it has a 1 at the entry that corresponds to the index of the word that is observed at time $u+T$. It is therefore a valid probability index that we can incorporate into a cross-entropy comparison.
Further notice that the optimization is joint over the trainable parameters $\mathcal{A}$ of the transformer and the readout matrix $\mathbf{A}$. These two parameters are implicit in
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \big(\mathbf{e}_{u+T}\big)^T \big( \log \boldsymbol{\pi}(\mathbf{X}_u) \big) .
$$
They appear because $\boldsymbol{\pi} (\mathbf{X}_u)$ depends on $\mathbf{A}$ and $\mathcal{A}$. In the hope that it is revealing to make this dependence explicit, we instantiate $\mathbf{X} = \mathbf{X}_u$ in
$$
\boldsymbol{\pi} (\mathbf{X}) = \text{sm} \Big[\, \mathbf{A} \, \text{vec} \big( \Phi (\mathbf{X}, \mathcal{A})\big) \, \Big]
$$
and substitute the result in
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \big(\mathbf{e}_{u+T}\big)^T \big( \log \boldsymbol{\pi}(\mathbf{X}_u) \big) .
$$
to write
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \Big[\mathbf{e}_{u+T}\Big]^T
\bigg[ \log \text{sm}
\Big[\, \mathbf{A} \,
\text{vec} \big(\,
\Phi (\mathbf{X}_u, \mathcal{A}) \, \big) \, \Big]\, \bigg] .
$$
We solve this empirical risk minimization (ERM) problem to predict the next word in a sequence of text. This prediction is based on observing a history of length $T$ that is processed by a transformer
$$
\mathbf{X}_\ell = \mathbf{X}_{\ell-1} + \sigma\bigg( \sum_{h=1}^H \mathbf{Y}_\ell^h \,\bigg) .
$$
with a probability readout layer
$$
\boldsymbol{\pi} (\mathbf{X}) = \text{sm} \Big[\, \mathbf{A} \, \text{vec} \big( \Phi (\mathbf{X}, \mathcal{A})\big) \, \Big] .
$$
Different from the readout strategy in
$$
\hat{\mathbf{x}}_{u+T} = \Big[ \, \Phi (\mathbf{X}_u, \mathcal{A}) \, \Big] \mathbf{1}
$$
and the training procedure in
$$
\mathcal{A}^* = \arg\min_{\mathcal{A}} \frac{1}{C}~\sum_{u=0}^{C-1} ~ \Big\| \, \Phi \big(\, \mathbf{X}_{u}, \, \mathcal{A} \, \big) \mathbf{1} – \mathbf{x}_{u+T} \,\Big \|^2 \, .
$$
the ERM problem in
$$
\mathcal{A}^*, \mathbf{A}^* = \arg\min_{\mathcal{A},\, \mathbf{A}} ~ \frac{1}{C}~\sum_{u=0}^{C-1} ~ \big(\mathbf{e}_{u+T}\big)^T \big( \log \boldsymbol{\pi}(\mathbf{X}_u) \big) .
$$
produces parameters $\mathcal{A}^*$ and $\mathbf{A}^*$ that map directly to predictions of actual words.