Isolating the contribution of fast-weight target modules in frozen LLMs, on Gemma3-1B.
Original In-Place TTT paper (Feng et al.)
The In-Place Test-Time Training (In-Place TTT) paper trains the base model and TTT modules jointly during continual pretraining (~20B tokens at 32K context on H800s). This leaves an open question: how much of the long-context gain comes from the TTT adapter modules (Conv1D, W_target) learning a useful next-token-prediction target, versus the base model co-adapting to tolerate the dynamic weight updates?
We isolate the first contribution by freezing the base model and training only the TTT adapter modules, then comparing against vanilla Gemma3 on RULER-style long-context tasks.
Use google/gemma-3-1b-it (26 layers, hidden 1152, GeGLU MLPs with intermediate 6912, 32K context) as the base. We implement In-Place TTT as a drop-in enhancement: a Conv1D + W_target adapter is added to the MLP of the global-attention layers [0, 6, 12, 18, 24], gated on config.use_ttt.
We skip continual pretraining entirely and pretrain the adapter only. The base model is fully frozen except for the down_proj (W_down) of the TTT layers β that surface is the one the per-chunk ΞW updates, so it is trained jointly with Conv1D + W_target. Every other parameter (embeddings, attention, gate/up_proj, norms, lm_head, and the down_proj of all non-TTT layers) is requires_grad=False.
Primary training corpus: 500k samples of roneneldan/TinyStories (short narratives, ~2M total). Yukang/LongAlpaca-12k is also supported as a long-context variant β note it caps at 12k samples, so we run it for 2 epochs by default to compensate.
Three modes, each run on the same example set:
in_contextβ vanilla Gemma3; prompt is[doc, q], single forward.ttt_paperβ same[doc, q]prompt, but TTT layers update their fast weights chunk-by-chunk during prefill (matches the original paper's eval).ttt_strictβ two-phase: ingest doc only and snapshot the per-layer ΞW, then answerqalone with that snapshot patched in. Doc is absent at answer time, so the fast weights must substitute for context.
Tasks come from RULER (vt, cwe, fwe) and HELMET ICL/RAG (helmet_trec_coarse, helmet_banking77, helmet_nq, helmet_hotpotqa), tested at 1K / 4K / 8K / 16K / 32K context lengths. Metrics: accuracy, peak GPU memory, latency. See benchmark/README.md.
NOTE: StrictTTTPredictor does not work with the updated HELMET-based evaluation.
In-Place-Test-Time-Training/
βββ models/hf_gemma3/
β βββ config_gemma3.py # Gemma3TTTConfig: subclasses upstream Gemma3TextConfig, adds TTT fields
β βββ model_gemma3.py # Gemma3MLPTTT, Gemma3DecoderLayerTTT, Gemma3TextModelTTT, Gemma3ForCausalLMTTT
β βββ test_gemma3.py # pytest suite: instantiation, forward, generate, save/load round-trip, freeze
βββ train/
β βββ main.py # training entry point (frozen base + TTT-adapter pretraining)
β βββ test_main.py # pytest suite: tokenize, freeze, save, wandb, CLI plumbing
β βββ README.md # training details, default hyperparameters, Colab walkthrough
βββ benchmark/ # eval harness (configs, data_gen, eval, scripts, results/plots)
βββ scripts/ # misc utilities (prompt dumping, eval shell helpers)
βββ third_party/
β βββ RULER/ # NVIDIA RULER (submodule)
β βββ HELMET/ # Princeton HELMET (submodule)
βββ Makefile # convenience commands (see `make help`)
βββ pyproject.toml # deps managed by uv
βββ LICENSE # Apache 2.0
βββ NOTICE # attribution to HuggingFace, Google (Gemma), Bytedance (TTT reference)
model_gemma3.py mirrors upstream transformers.models.gemma3.modeling_gemma3 and adds:
TTTLinear,TTTConv1dβ marker subclasses ofnn.Linear/nn.Conv1dso_init_weightscan identify TTT modules unambiguously (avoids shape collisions withq_proj/o_proj).Gemma3MLPTTTβ Gemma3 MLP with optionalttt_proj(W_target) +ttt_convmodules, chunked TTT update inforward(x, t=...).Gemma3DecoderLayerTTTβ Gemma3 decoder layer, near-mirror of upstream; only delta is atarget_stateskwarg threaded intomlp(...).Gemma3PreTrainedModelTTTβ inherits from upstreamGemma3PreTrainedModel. Custom_init_weightsdoes diagonal init forTTTLinear(near-identity) and zero init forTTTConv1d(no-op start), and defers everything else tosuper()so_is_hf_initializedskip-flags are honored and loaded checkpoints aren't trampled.Gemma3TextModelTTT,Gemma3ForCausalLMTTTβ backbone + LM head.freeze_base_model()on the LM keepsttt_proj(W_target),ttt_conv, and thedown_proj(W_down) of the TTT layers only trainable; every other parameter β includingdown_projon non-TTT layers β isrequires_grad=False.
When config.use_ttt=False, the TTT branches are skipped entirely and the model behaves identically to upstream Gemma3.
make install # uv sync --all-groups + RULER submodule + nltk data + PG-essay haystack
make test # fast tests over models/ and train/ (skips @slow)
make test-slow # downloads google/gemma-3-1b-it; needs HF auth + Gemma TOU acceptancefrom models.hf_gemma3 import Gemma3ForCausalLMTTT, Gemma3TTTConfig
config = Gemma3TTTConfig.from_pretrained(
"google/gemma-3-1b-it",
use_ttt=True,
ttt_layers=[0, 6, 12, 18, 24], # global layers in Gemma3-1B
ttt_chunk=2048,
ttt_lr=0.3,
)
model = Gemma3ForCausalLMTTT.from_pretrained("google/gemma-3-1b-it", config=config)
model.freeze_base_model() # ttt_proj, ttt_conv, and down_proj on TTT layers get gradientsmodel = Gemma3ForCausalLMTTT.from_pretrained("./checkpoints/gemma3-1b-ttt")from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"yourname/gemma3-1b-ttt",
trust_remote_code=True, # custom modeling code lives in the Hub repo
)trust_remote_code=True is required because Gemma3ForCausalLMTTT is not part of upstream transformers.
- https://huggingface.co/hungngo04/gemma-3-1b-it-ttt-tinystories-500k
- https://huggingface.co/changminbark/gemma-3-1b-it-ttt-longalpaca-full
train/main.py pretrains the TTT adapters (ttt_conv, ttt_proj/W_target) and the TTT-layer down_proj (W_down) on a single dataset selected via --dataset, then pushes the result to the Hub (bundled with the modeling code + auto_map).
# Primary run: 500k TinyStories samples
make train-tinystories HF_USER=<you>
# Long-context variant
make train-longalpaca HF_USER=<you>
# or directly:
uv run python -m train.main --dataset tinystories --hf-user <you> --max-samples 500000
uv run python -m train.main --dataset longalpaca --hf-user <you>Supported datasets: tinystories (roneneldan/TinyStories) and longalpaca (Yukang/LongAlpaca-12k). See train/README.md for the full table of default hyperparameters, every CLI flag, wandb setup, and a Colab walkthrough.
train/main.py handles this automatically: it sets auto_map, copies config_gemma3.py + model_gemma3.py next to the weights, and pushes to <hf-user>/<base>-ttt-<dataset> (override with --repo-id). Authenticate once with make login-hf. Use --no-push to skip the upload.
For manually-built checkpoints, make push-hub HF_REPO_ID=... CKPT_DIR=... is still available; the repo must contain config.json with an auto_map block, the two .py modeling files, weights, and ideally a model card noting the Gemma base license. See HuggingFace's custom code documentation for the standard layout.
uv run python -m benchmark.scripts.generate --profile dev
uv run python -m benchmark.scripts.evaluate --profile dev --predictor benchmark.eval.factories:gemma3_in_context_factory
uv run python -m benchmark.scripts.evaluate --profile dev --predictor benchmark.eval.factories:gemma3_ttt_paper_factory
uv run python -m benchmark.scripts.evaluate --profile dev --predictor benchmark.eval.factories:gemma3_ttt_strict_factory
uv run python -m benchmark.scripts.aggregate
uv run python -m benchmark.scripts.report
uv run python -m benchmark.scripts.plotRuns the three modes described above and reports accuracy, peak GPU memory, and latency as a function of context length. The harness lives under benchmark/ (configs, data_gen, eval, scripts) and pulls tasks from NVIDIA RULER and Princeton HELMET (third_party/). See benchmark/README.md for setup, profiles (dev/full), and how to register new predictors.
Run make help for the full list. Highlights:
| Target | Description |
|---|---|
make install |
uv sync --all-groups |
make test |
fast pytest suite (skips slow) |
make test-slow |
downloads real Gemma3-1B and exercises the load path |
make train DATASET=... |
trains on tinystories/longalpaca and pushes (HF_USER=...) |
make train-tinystories / make train-longalpaca |
dataset-specific shortcuts |
make eval |
runs the benchmark pipeline (see benchmark/README.md) |
make login-hf |
huggingface-cli login |
make push-hub |
upload $(CKPT_DIR) to $(HF_REPO_ID) |
make clean |
nuke __pycache__, .pytest_cache, etc. |
PyTorch, HuggingFace Transformers, NVIDIA RULER, HuggingFace Datasets, Weights & Biases.
Can the TTT adapter modules (Conv1D, W_target) be trained while keeping the base model frozen, and still recover some fraction of the long-context gains reported in the paper?
- Adapter-enhanced model matches the baseline at short contexts (no damage to base capability).
- Improves over baseline at longer contexts, but by a smaller margin than the paper's fully-trained variant.
- The size of that gap quantifies how much of the paper's reported gains require base-model co-adaptation.
- A finding of no improvement (or degradation) is itself a meaningful negative result: it would say the base model's adaptation is load-bearing, not just the adapter's learned target.
Full report: Adapter-Only In-Place Test-Time Training: Isolating the Contribution of Fast-Weight Target Modules in Frozen LLMs. Plot sources are under benchmark/results/plots/.
After pretraining 2 In-Place Test-Time Training MLP layers on TinyStories (500K samples) and LongAlpaca (2Γ12K samples), we ran evaluations against the HELMET+RULER inspired benchmark, which excluded LLM-as-a-Judge metrics.
| Common Words Extraction (CWE) | Frequent Words Extraction (FWE) |
|---|---|
![]() |
![]() |
| Variable Tracking (VT) | HELMET β Banking77 |
|---|---|
![]() |
![]() |
| HELMET β TREC Coarse | |
|---|---|
![]() |
As shown in the plots, on all evaluation tasks the "vanilla" Gemma3 (in-context) outperforms the In-Place TTT Gemma3 models. The one discrepancy to note is that the LongAlpaca-trained model performs better than the vanilla model on the Common Words Extraction (CWE) task at the input context of 8192 tokens. It may be possible that the LongAlpaca dataset is constructed in a way that aligns the In-Place TTT MLP layers to be better at this specific task and context length. This should be investigated further.
| Peak GPU memory | Mean latency |
|---|---|
![]() |
![]() |
In terms of peak GPU memory, the TinyStories model had significantly higher usage compared to the other two models. This is most likely due to the fact that the TinyStories model had a much smaller TTT chunk size (128) compared to the LongAlpaca model's TTT chunk size (2048). The smaller TTT chunk size means the TinyStories model has to create more chunk deltas and updates than the LongAlpaca one, leading to higher GPU memory usage.
Mean latency among the In-Place TTT models is similar and higher than the vanilla model, which makes sense as the new architecture adds more calculations like computing update deltas, aggregating them through cumsum, and then applying them to get the MLP output (see Algorithm 1 in the appendix of the original paper).
As it stands, pretraining In-Place TTT MLP layers while freezing the base model does not lead to improvement in performance. These findings, however, are limited to just the Gemma3 model, where the In-Place TTT MLP layers were implemented in the global attention layers. Furthermore, pretraining of the In-Place TTT MLP layers was limited (500K samples of TinyStories and 2Γ12K samples of LongAlpaca) due to resource constraints. To retain or improve performance, it may be necessary to pretrain the In-Place TTT MLP layers on bigger and better datasets, and possibly to train the base model jointly (rather than freezing it).
Modeling code is Apache 2.0. See LICENSE and NOTICE for full attribution to HuggingFace Transformers (Apache 2.0), Google (Gemma 3 architecture and weights, subject to the Gemma Terms of Use), and the Bytedance In-Place TTT reference implementation (Apache 2.0).
Chang Min Bark and Hung Ngo
CSCI357 (Spring 2026) β AI with Neural Nets
Professor Brian King
May 6, 2026
AI tools like Claude Code were used to write documentation and parts of the code.






