T-REX: Transformer-RNN Execution for Multi-Token Language Modeling
Because generating one token per forward pass is dumb
This is a very preliminary test for an idea for an improvement of the transformer architecture. This is not supposed to be airtight or scholarly publication, just a minimal demonstration of the idea. Implementation details can be found in this colab notebook.
If you would have asked me in the time of the GPT-1, whether in 2025 we will still have the transformer, scaled up, but mostly unchanged, I would say that it is unlikely. Its ability to generate fluent text, reason about complex topics, and even write code has moved the goalpost on of what can we expect from scaling up deep learning. Yet, beneath this impressive performance lies an architectural design that is unnatural. These models generate output one token at a time. For each new token, the entire multi-billion parameter model performs a full forward pass.
First, there is the matter of computational inefficiency. A full forward pass is a resource-intensive process. The current paradigm allocates this entire computational budget to predict each token, regardless of its complexity. Generating a simple comma or a closing parenthesis requires the same effort as generating a key term in a line of precise reasoning.
Second, this architecture forces a conflation of abstraction levels. The model must simultaneously "think" about the abstract structure of a solution and "write" the next token. Humans typically form a concept or an idea, then execute a sequence of words to express it, then execute a motor program that actually type, write or speak these words. By entangling high-level reasoning with low-level execution at the token level, we may be artificially limiting the model's representational capacity and its ability to reason about more abstract structures.
Third, this entanglement leads directly to opacity. Because the model's internal "algorithm" is interwoven with its token predictions across billions of parameters, understanding how a solution is derived becomes difficult.
Neural Program Synthesis
The core idea of Neural Program Synthesis (NPS) is to decouple levels of abstraction. Here, this is achieved through a two-stage process performed by two specialized components:
The "Compiler" (A Transformer): This module serves as the high-level reasoner. It receives a problem as input (for example, the string "7669+10132="). Its forward pass output is not a token from the answer, but rather a "neural program": the complete set of weights for a small, Recurrent Neural Network (RNN). In essence, it compiles the problem into an executable procedure.
The "Execution Engine" (An RNN): This lightweight network acts as a deterministic processor. It is dynamically programmed with the weights synthesized by the Transformer. Once programmed, it becomes a purpose-built RNN that runs to generate the final output sequence, "17801", one token at a time.
In this model, the Transformer forms a high-level plan. The lightweight RNN then handles the sequential execution of that plan until it produces the <END> token. Then the multiple tokens generated with the RNN are fed back to the transformer, which in turn creates the next RNN weights. Because the cost of running the RNN is negligible relative to the transformer, it effectively converts some output tokens to input tokens, which are cheaper and faster to process.
Testing on a toy problem
To validate this framework, I trained a NPS model on a well-defined algorithmic task: multi-digit addition. This domain is a good testbed because it is narrow enough to be learned by a small model, and the output sequences are completely predicable. Our goal was not merely to see if the model could produce the correct answer, but to determine if the architectural separation would yield a transparent and interpretable computational process.
The generated RNNs can reliably output multiple output tokens
The first thing to verify is whether the transformer can effectively solve the addition problem via generating the weights of an RNN. Figure 1 shows the result. The model perform the task almost perfectly.
These results are with one transformer forward pass. For large problems, this can mean 10x less forward passes to achieve the same output!
In addition to efficiency gains, the execution of the generated RNN provides another lens for the algorithm in action. By tracking the RNN weights or hidden state at each step of generation, we could observe the mechanics of the synthesized program.
One hypothesis is that the generated RNN is basically a lookup table that spits out digits. An easy way to test that is to compare hidden state trajectories or generated weights between different expressions that results in the same output. If weights and representations of “1+99” are identical to 50+50, it is a strong indication that the generated RNN is not sensitive to the type of computation needed to be done, but is just a way to implement multi token output generation with temporal dependency between tokens.
We can compute the pairwise Euclidian distance between RNN weights of all expressions leading to the same number. Here is one example for the 123 number. You can clearly see that for that example, the transformer is creating very different sets of weights for expressions that require carry operation from the ten position to the hundred’s position.
To see the big picture, lets use the variance of a set of weights corresponding to equisum expressions for multiple final sums. As you can see, some sums create RNN weights that are more sensitive to the input than other sums.
To see what are some of the features of the input that cause different RNN weights to be produced, we can look at t-SNE plots where points are colored based on some input attribute.
We can see a rich organization of the RNN weights based on attributes of the input expressions. This suggests that the transformer generates distinct “programs” for expressions that require a different set of arithmetic manipulations.
To verify that the difference in weight is functionally relevant, we can also look at the hidden RNN activations. To do that, we can plot the activations, of a large set of random expressions to serve as the “background points”. Then, we can overlay the different trajectories of different expressions that lead to the same final sum and compare them. For example, the following plot shows two trajectories, one for ‘17000+801’, and one for ‘16000+1801’. Although both sum to ‘17801’, their trajectories seem to diverge after the output digit ‘7’. The hidden RNN representations in the ‘8’ token seem to reside in two different regions of activation space that correspond to the digit ‘8’ at the 100s place.
We can also plot more expressions that lead to the same sum:
You can see that, like the RNN weights, there is more information contained in the trajectories of hidden states than just the identity of the output digits. First, there is a clear clustering based on the digit position and identity. Above that, one can find, similar to the analysis on the weights, a sensitivity to computationally relevant features of the input, like the number of carries, the expression balance, etc.
These findings show that the generated RNN is often not a mere lookup table. It behaves as state machine whose weights and activations are directly correlated with interpretable aspects of the input expression.
Future work
This work is just a proof of concept on a toy problem. While promising, it leaves many questions unanswered, like the effect on the transformer representations relative to the normal token by token transformer. The most important question is whether this approach can be scaled to frontier level systems. I hope this essay is intriguing enough to motivate people with more GPUs to try and scale this up.
Conclusions
The Neural Program Synthesis framework is enforcing a separation between high-level reasoning and low-level execution. This creates a model that is not only more efficient but also inherently more interpretable.
This architectural paradigm aligns desirable interpretability properties with economically beneficial properties like efficiency. This is a more pragmatic way to can stir AI development to be safer without ignoring capitalistic forces.
Check out this colab notebook for a standalone training script for the same network I used to create the results of this post.