Data dtype using for Alphagenome

Hi All, thanks for your wonderful work!

I notice that in Alphagenome paper, your team using brain floating point (bfloat16) to store track data, I wondering when you training the model, the dtype for the model is what? Float32 or bfloat16 or float16? Waiting for your answer, thx!

Hello @Wildest, welcome to the forum!

Yes we use bfloat16 targets during training, typically normalized using the track’s mean of non-zero values.

Thx for your reply!

I also wonder do you use mixed-precision or not while training .

If the original target is 2 (2 different experiments) by 3 (signals at 3 base pair) matrix,
6, 3, 9 (in this track, the mean of non-zero values is 6)
2, 0, 2 (in this track, the mean of non-zero values is 2)
So I need normalized the target into
1, 0.5, 1.5
1, 0, 1
Right?

Hi!

We do use Jax mixed precision. I’m not aware if it differs from the implementation of other frameworks. For example, in Jax if x is a BF16 array and we do jax.numpy.sum(x) then the accumulation is carried out in FP32, but the result is cast back as BF16. This affects, for example, the normalization statistics.

We explicitly perform the losses in FP32 and, as is common practice, we cast the attention logits before the Softmax to FP32 and after the Softmax to BF16.

Your math on the non-zero mean scaling is correct :slight_smile:
Note that the AlphaGenome API returns unscaled predictions.