Skip to content

Unsupervised decoding of hidden markov models from transformer residual streams (for transformers also trained on HMM data and acting as optimal predictors of the HMM)

Notifications You must be signed in to change notification settings

lena-lenkeit/markov-transformers

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

57 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Unsupervised decoding of HMMs from transformer residual streams

Motivation

This was a weekend project that turned out to work surprisingly well. I'm training small transformers on data generated by Hidden Markov Models (HMMs), and then try to extract the full HMM (the transition matrix and emission matrix), as well as the belief states (which state the HMM is in given past observations) directly from the transformer residual stream, without knowing the HMM the transformer was trained on.

The core idea is to assume transformers trained on HMMs also act as perfect predictors of the HMM (assuming they have enough capacity and training data), in which case they must represent optimal belief states over the HMM states somehow, which can then be used to optimize for the HMM parameters via SGD.

That HMM-trained transformers indeed implement and represent optimal belief states, and the existence of this optimal belief state structure came from this amazing LessWrong post (I'm not affiliated with the authors in any way), without which I wouldn't have gotten the idea for how to approach this problem.

Some thoughts

Currently, I can infer the HMM transition and emission matrices and belief states for the mess3 process from a 1-layer decoder-only transformer trained on that process just from the residual stream activations of the transformer for a large batch of process sequences, the process observations / emissions (the tokens used as input to get the residual stream activations) and the number of HMM hidden states.

Since I assume the transformer to be a perfect predictor of the process anyways, the training dataset might not even be needed, since by definition you can just sample a new dataset from the trained transformer.

Also, the number of HMM hidden states seems possible to infer from a PCA/SVD of the residual stream activations, or maybe even by just allocating a larger HMM than necessary, and having the optimization procedure find the smallest possible HMM?

What next?

I was surprised at how well this seems to work for the mess3 process, so I might try to iterate on this quickly in the next few days to get a writeup out. This might be really exciting.

  • Try this on more HMMs
  • Automatic inference of number of hidden states
  • Automatic dataset sampling, to make it truly not require any data from the training process
  • Clean up the repo
  • Write a blogpost?

About

Unsupervised decoding of hidden markov models from transformer residual streams (for transformers also trained on HMM data and acting as optimal predictors of the HMM)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages