ISM implementation/efficiency

Hi. I have been using alphagenome’s score_ism_variants function and just realized that it does quite a lot of unnecessary forward passes:

Forward passes in score_ism_variants
ism_variants generates 3 × L variants (3 non-reference substitutions at each of L positions in the ism_interval).

Each variant scored via score_variant_predict_variant makes 4 forward passes:

apply_fn(reference_sequence, …) — main model, ref
apply_fn(alternate_sequence, …) — main model, alt
junctions_apply_fn(ref_embeddings, …) — splice junction head, ref
junctions_apply_fn(alt_embeddings, …) — splice junction head, alt

All forward passes use batch size = 1 (sequence is encoded with [np.newaxis] giving shape [1, S, 4]).

Each of the 3L variants gets its own thread, and within that thread _predict_variant runs 4 forward passes serially: main-ref → main-alt → junction-ref → junction-alt. The parallelism is only across variants (threads), not within a single variant’s scoring.

This seems quite wasteful as it does 3L as many forward passes for the reference sequence, and also uses a batch size of 1 when in theory more could fit (at least for smaller context sizes).

I was wondering if you have considered making the ISM workflow more efficient. Is there a way to increase batch size at least, or do one forward pass for the ref allele? I am being quite bottlenecked by this at this time so any help or suggestion (or code changes) would be much appreciated :slight_smile: Thank you!

Hi @valeh,

Thanks for the post! So for the open-source model, we prioritized readability and correctness over speed. For this reason, we decided to opt for a simpler single-device setup, as well as not sharing the reference prediction between the 3 alleles for ISM. We’ll look into adding this, but we’re keen to keep the code somewhat simple to make it easier for folks to follow. Increasing batch size can be difficult as you can quickly run out of GPU memory, though this very much depends on the diversity of variant scoring you’re doing.

Note that there’s only 2 full forward passes per variant prediction: the junction predictions re-uses the trunk embeddings from the first forward pass :smiley:

Hi @tward Thank you for the response (and clarification regarding the junction embeddings). So I kind of need the speed since I am scoring a very large number of variants, but I also have made some code changes to the open-source model because I need to do the aggregation a bit differently than the CenterMaskScorer currently does, as well as expose (return) the raw unaggregated predictions from the ISM forward passes. (I suppose if I can at least have the latter change I won’t need the former as I can do the aggregation on the raw tracks post-hoc).

Is there any way I can keep my code changes but also use the faster implementation?

Thanks!