LoRA Finetuning on Custom CUT&Tag Dataset

Hi,

Thank you for open-sourcing the AlphaGenome utility and for building this community.

I am attempting to fine-tune the released AlphaGenome weights using LoRA on a custom CUT&Tag dataset. My setup is as follows:

  • I preprocess the CUT&Tag data similarly to ATAC-seq. Since CUT&Tag also relies on Tn5 tagmentation to profile chromatin accessibility, I follow the same preprocessing strategy.

  • I use the ['atac'] prediction head.

I currently have two questions:


1. Gradient norm becomes NaN during training

During train_step, the first few batches run normally, but after ~20 batches the gradient norm becomes NaN. Below is an excerpt from the logs:

Epoch 1/1

Compiling JAX functions (this may take 2–5 minutes on first batch)…
✓ Compilation complete! Took 181.1s

Array Sharding Visualization:
Sequences shape: (8, 131072, 4)
Sequences sharding: NamedSharding(mesh=Mesh('batch': 8, axis_types=(Auto,)), spec=PartitionSpec('batch', None, None), memory_kind=device)
Sequences[:, 0, 0] (batch dimension):
GPU 0    GPU 1    GPU 2    GPU 3    GPU 4    GPU 5    GPU 6    GPU 7

Loss value: 0.38322877883911133
Gradient norm: 5.1783525123028085e-05

Training on 4412 batches…
Batch 10/4412: Loss = 0.758550, Grad Norm = 0.000071
Batch 20/4412: Loss = 0.244102, Grad Norm = nan
Batch 30/4412: Loss = 0.046306, Grad Norm = nan
Batch 40/4412: Loss = 0.042764, Grad Norm = nan
Batch 50/4412: Loss = 0.051070, Grad Norm = nan

The loss does not appear to explode, but the gradient norm becomes NaN and remains so afterward. Do you have suggestions on what might cause this instability? Could it be related to the CUT&Tag signal distribution, LoRA scaling, mixed precision, or optimizer settings?


2. Converting base-pair Tn5 profiles back to BigWig

After training, the model outputs base-pair–resolution Tn5 tagmentation profiles (preprocessed with chrombpnet-style bias correction). For visualization, should I convert these predictions back to a standard BigWig track?

If so, what utility do you recommend for converting the predicted arrays into BigWig format for genome browser visualization?


Thank you in advance for your help.

Hi,

1. Gradient norm becomes NaN during training
We’ve just released a tutorial on how to finetune AlphaGenome on a new dataset: https://github.com/google-deepmind/alphagenome_research/blob/main/colabs/finetune.ipynb. Please have a look to see if you can spot any differences with your training setup.

It’s surprising that your gradient norm is NaN while the loss isn’t. Perhaps there’s an underflow in the gradient norm computation itself. I’d suggest casting the gradients to float32 before the norm computation.

Finally, can you check if this behavior is consistent across both fold_0 and all_folds versions?

2. Converting base-pair Tn5 profiles back to BigWig
To visualize your results you can use the AlphaGenome plotting library directly: alphagenome/src/alphagenome/visualization at main · google-deepmind/alphagenome · GitHub . For example, check out the last cell in our finetuning Colab.

Hi Vincent,

Much appreciated! I think you’re correct. The different behaviors between the gradient norm and the loss indicate an issue with the grad_norm computation.

Comparing the finetune.ipynb notebook with my implementation, there are two major differences:

  1. My previous code used a mixed-precision policy with bfloat16 and float32. The finetune.ipynb notebook enforced float32.
  2. The finetune.ipynb notebook used head replacement as opposed to lora.

Switching to a precision policy of params=float32,compute=float32,output=float32 resolved the NaNs issue.

I checked both fold_0 and all_folds and confirmed that this behavior is consistent across versions.

Best regards,

Ray