Inference speed on NVIDIA H100 GPU

Thanks for making the model public. Very excited to apply this to some real-life genetic data.

In doing some testing with the quickstart code on GitHub - google-deepmind/alphagenome_research: Research code accompanying AlphaGenome, the model seemed to take more than a few minutes to produce six predictions (2 input sequences on 3 tracks) using an NVIDIA H100 GPU. I read on the paper that the student model should take less than a second with this hardware. I wonder if this running time is expected, or if this may indicate something is wrong with my hardware? Perhaps there is more inference going on than I realize?

On a related note, if there are any performance benchmarks released or planning to be released, it would be helpful to be able to assess the quality of researchers’ setups. Thanks!

Hi @apt,

Thanks for the post! Yes this is somewhat expected, there are a few key reasons:

  1. The first time you make a prediction, JAX will compile the model for that sequence length which can take some time (on the order of 1min). Subsequent calls should be fast though as JAX will cache the compilation graph.
  2. The DNA model wrapper does some extra quality of life improvements to make it easier to interact with the model. Some examples include fetching the DNA sequence from assembly, extracting splice junctions from dense matrix, removing padding, reverse complementing, upcasting from bfloat16 and transferring to CPU. These add an overhead from the raw model prediction.
  3. Making XLA deterministic removes some performance improvements (e.g. triton GEMMs). Enabling these can improve inference speed, but also increases non-determinism. You might also want to consider TPUs which have less non-deterministic behavior.

Good idea on having some notebooks with expected timings, we’re working on a few new new notebooks, will keep this in mind.

In the meantime, you can try this example to measure raw inference speed:

import time

from alphagenome_research.model import dna_model
from alphagenome_research.model import model as model_lib
from alphagenome_research.model.metadata import metadata as metadata_lib
import haiku as hk
import huggingface_hub
import jax
import jax.numpy as jnp
import jmp
import orbax.checkpoint as ocp

checkpoint_path = huggingface_hub.snapshot_download(
    repo_id='google/alphagenome-all-folds'
)
checkpointer = ocp.StandardCheckpointer()
params, state = checkpointer.restore(checkpoint_path)
metadata = {o: metadata_lib.load(o) for o in dna_model.Organism}

jmp_policy = jmp.get_policy('params=float32,compute=bfloat16,output=bfloat16')


@jax.jit
def forward(params, state, dna_sequence, organism_index):
  @hk.transform_with_state
  def _forward(dna_sequence, organism_index):
    with hk.mixed_precision.push_policy(model_lib.AlphaGenome, jmp_policy):
      return model_lib.AlphaGenome(metadata)(dna_sequence, organism_index)

  (predictions, _), _ = _forward.apply(
      params, state, None, dna_sequence, organism_index
  )
  return dna_model.extract_predictions(predictions)


dna_sequence = jnp.zeros((1, 2**20, 4), dtype=jnp.bfloat16)
organism_idx = jnp.zeros((1,), dtype=jnp.int32)
params, state = jax.device_put((params, state))

# Warm-start
preds = forward(params, state, dna_sequence, organism_idx)
jax.block_until_ready(preds)

NUM_REPEATS = 10

start_time = time.time()
for _ in range(10):
  preds = forward(params, state, dna_sequence, organism_idx)
  jax.block_until_ready(preds)

avg_time = (time.time() - start_time) / NUM_REPEATS
print(f'Inference time: {avg_time:.2f} seconds')

Should yield around ~300ms on a H100.

1 Like