Can't reproduce alphagenome's benchmarks

Hello,

I’m trying to reproduce the ChromBPNet benchmark results from Extended Data Fig 3g and am encountering a large discrepancy in the DNase track correlation for IMR-90 (ENCSR477RTP) on test fold 0.

  • Reported Track Pearson r: 0.79

  • My Track Pearson r: ~0.35

My other metrics are aligned with the paper’s results (log total count correlation ~0.78, JSD ~0.44), so the issue seems specific to the per-base correlation.

What I’ve done:

  1. Prepared the ground truth data using the exact steps from the ChromBPNet repository.

  2. Used the ChromBPNet fold 0 test peak regions, filtering for overlaps with the AlphaGenome training set.

  3. Used the standard AlphaGenome model for predictions.

I m attaching my prediction script. Any help would be greatly appreciated.

'I m attaching my prediction script. Any help would be greatly appreciated.'
class TestRegulatoryDataset(Dataset):
    def __init__(self, **kwargs):
        data_file = kwargs.pop("data_file")
        cache_dir = kwargs.pop("cache_dir", None)
        subset_size = kwargs.pop("subset_size", None)
        self.split_name = list(data_file.keys())[0]
        self.dataset = load_dataset("json", data_files=data_file, cache_dir=cache_dir)
        self.max_length = 2114

        if subset_size is not None:
            total_size = len(self.dataset[self.split_name])
            if subset_size > total_size:
                logger.warning(
                    f"subset_size ({subset_size}) is larger than the total dataset size ({total_size}). Using the full dataset."
                )
            else:
                indices = list(range(total_size))
                random.shuffle(indices)
                subset_indices = indices[:subset_size]
                self.dataset[self.split_name] = self.dataset[self.split_name].select(
                    subset_indices
                )

        logger.info(
            f"Loaded {len(self.dataset[self.split_name])} sequences from {data_file}"
        )

    def __len__(self) -> int:
        return len(self.dataset[self.split_name])

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        item = self.dataset[self.split_name][idx]
        biosample_type = item["biosample_term"]
        assay_values = []
        assay_sequence = item[f"signal"]
        assay_values.append(torch.tensor(assay_sequence))

        assay_targets = torch.stack(assay_values).unsqueeze(-1)
        assay_targets *= float(10**8) / 11005729934

        return {
            "assay_targets": assay_targets, 
            "biosample_id": "EFO:0001196", 
            "chromosome": item["chr"],
            "start": item["start"],
            "end": item["end"],
            "strand": "+",
            "biosample_name": biosample_type,
        }

def main():

    # Model selection matches chrombp benchmark
    fold_to_use = dna_client.ModelVersion.ALL_FOLDS
    seq_length = dna_client.SEQUENCE_LENGTH_16KB

    test_config, modality_config, train_config = _load_configs(
        args.test_config, args.architecture_config_path, args.tokenizer_config_path
    )
    
    dataset_kwargs = {
        "subset_size": 100,
    }

    data_files = _get_data_files(test_config)
    if not data_files:
        raise FileNotFoundError("No .jsonl files found under data_dir/chrombpnet_fold0_eval")

    # Ensure deterministic dataset subset shuffling
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dataset = TestRegulatoryDataset(
        data_file={"test": data_files},
        **dataset_kwargs,
    )

    generator = torch.Generator()
    if "sampler_seed" in test_config:
        generator.manual_seed(int(test_config["sampler_seed"]))
    else:
        generator.manual_seed(int(args.seed))
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, generator=generator)

    # API key(s): use first available, no threading
    keys_env = os.environ.get("ALPHAGENOME_API_KEY")
    if not keys_env:
        raise ValueError("Set ALPHAGENOME_API_KEYS (comma-separated) or ALPHAGENOME_API_KEY in environment.")
    api_key = keys_env.strip()
    model = dna_client.create(api_key, model_version=fold_to_use)

    processed = 0
    total_valid_positions = 0
    total_pred_sum = 0.0
    total_target_sum = 0.0
    per_sample_correlations = []
    per_sample_sum_preds = []
    per_sample_sum_targs = []
    with tqdm(total=len(dataloader), desc="AlphaGenome (single-thread)") as pbar:
        for sample in dataloader:
            chromosome = sample["chromosome"][0]
            start = sample["start"][0].item()
            end = sample["end"][0].item()
            strand = sample["strand"][0]
            biosample_id = sample["biosample_id"][0]

            interval = genome.Interval(chromosome=chromosome, start=start, end=end, strand=strand).resize(seq_length)

            output = model.predict_interval(
                interval=interval,
                requested_outputs=[dna_client.OutputType.DNASE],
                ontology_terms=[biosample_id],
            )

            orig_interval = genome.Interval(chromosome=chromosome, start=start, end=end, strand=strand)
            dnase_td = output.dnase.slice_by_interval(orig_interval, match_resolution=True)

            dnase_values = np.asarray(dnase_td.values).squeeze()
            alphagenome_preds = torch.from_numpy(dnase_values).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
            assay_targets = sample["assay_targets"]
            pred_seq = alphagenome_preds[0, 0, :, 0]
            targ_seq = assay_targets[0, 0, :].to(dtype=pred_seq.dtype)

            L_pred = pred_seq.shape[0]
            L_targ = targ_seq.shape[0]
            win_len = min(1000,L_pred, L_targ)

            if win_len > 1:
                ps = (L_pred - win_len) // 2
                ts = (L_targ - win_len) // 2
                pred_win = pred_seq[ps:ps + win_len]
                targ_win = targ_seq[ts:ts + win_len].squeeze(-1)

                total_valid_positions += int(win_len)
                total_pred_sum += float(pred_win.sum().item())
                total_target_sum += float(targ_win.sum().item())
                per_sample_sum_preds.append(float(pred_win.sum().item()))
                per_sample_sum_targs.append(float(targ_win.sum().item()))
                corr = torch.corrcoef(torch.stack([pred_win, targ_win]))[0, 1].item()
                per_sample_correlations.append(corr)

            processed += 1
            pbar.update(1)

    valid_corrs = [c for c in per_sample_correlations if isinstance(c, float) and not np.isnan(c)]
    mean_corr = float(np.mean(valid_corrs)) if len(valid_corrs) > 0 else float("nan")
    sum_corr = float("nan")
    if len(per_sample_sum_preds) > 1 and len(per_sample_sum_preds) == len(per_sample_sum_targs):
        try:
            sp = torch.log1p(torch.tensor(per_sample_sum_preds, dtype=torch.float32))
            st = torch.log1p(torch.tensor(per_sample_sum_targs, dtype=torch.float32))
            sum_corr = float(torch.corrcoef(torch.stack([sp, st]))[0, 1].item())
        except Exception:
            pass

    print(f"Processed {processed} samples")
    print(f"Mean correlation (center 1000bp): {mean_corr}")
    print(f"Total valid positions: {total_valid_positions}")
    print(f"Sum of predictions over valid positions: {total_pred_sum}")
    print(f"Sum of targets over valid positions: {total_target_sum}")
    print(f"Correlation of per-sample sums: {sum_corr}")

if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
    main()



1 Like

Hey! I would suggest trying out the following things:

  • Use 1Mb sequence length as input as this was the sequence length the model was trained and evaluated on.

  • dna_client.ModelVersion.ALL_FOLDS is a distilled model that was trained across all intervals in the genome through the teachers. You need to use one of the other model folds that has the right held-out data.

  • visualize the output along with the experimental data for one example. it could be that there is a 1-off error somewhere causing a large drop in the correlation.

1 Like

Hey, thanks for getting back!

  1. I have now used 1Mb sequence length and

  2. I have tested this on all the folds available on the API
    → After both the changes I am still getting low per base correlation, similar as before.

  3. Further, I tried shifting the prediction by a range of offsets(+/-10), and I still get low correlation, with the highest score still being ~0.3 with my current alignment (no offset). Then I visualized the predictions vs targets without any manual offset. I have attached the picture for the center 1000 bp for IMR-90 (DNase).

Do you think there is any other source for an offset/misalignment between the nucleotide coordinates (chr, start, end) being input into the AG API and the targets that I have constructed? If there was, presumably the manual shifting I did (+/- 10 range), should have given a boost in per BP correlation. I also can’t seem to identify an obvious shift in the images I’ve attached.

To give you more context, I’m attaching my entire data processing pipeline. I also tried to recreate ChromBPNet benchmarks for K562 ATAC-seq using an already merged BAM file with precomputed peaks, available at (https://zenodo.org/records/7445373). Even in this setup, for the exact same code as copied with changes from points 1-3, I still get significantly lower per base correlation from what was reported in the paper.

Also for context, I’m evaluating the predictions against 1000 bp interval target windows around the center peak summit.

Would it be possible for you to provide the ChromBPNet sequences and targets you used.

Here is my data pipeline process:

ChromBPNet Fold-0 Test Intervals Generation

This is the full preprocessing workflow followed to create fold-0 test intervals for evaluating ChromBPNet, exactly following the ChromBPNet Preprocessing wiki. I’m posting this here so the process is clear and reproducible.


Requirements

  • Experiment accessions (from the AlphaGenome paper, used in the ChromBPNet comparison):

  • Reference files:

    • hg38.chrom.sizes
    • ENCODE hg38 blacklist BED
  • Tools:
    samtools, macs2, bedtools, bedClip, bedGraphToBigWig


Step 0: Download BAM files

Example for K562 DNase (ENCSR000EOT):

# From the ENCODE portal, copy the BAM file links. Example:
wget https://www.encodeproject.org/files/ENCFF123ABC/@@download/ENCFF123ABC.bam -O rep1.bam
wget https://www.encodeproject.org/files/ENCFF456DEF/@@download/ENCFF456DEF.bam -O rep2.bam

Same for the accessions above (ATAC IMR-90 and DNase for K562, HepG2, GM12878, IMR-90).


Step A: Prepare BAMs

ChromBPNet requires filtered, sorted, indexed BAMs.

  • Single-end (unfiltered BAMs) → filter manually:

    samtools view -b -@50 -F 780 -q 30 rep1.unfiltered.bam > rep1.filtered.bam
    

    Wiki reference

  • Paired-end (filtered BAMs provided by ENCODE) → skip filtering.

  • If one replicate only:

    samtools sort -@50 -o sample.sorted.bam rep1.filtered.bam
    samtools index sample.sorted.bam
    
  • If multiple replicates:

    # Filter each replicate if single-end + unfiltered
    samtools view -b -@50 -F 780 -q 30 rep1.unfiltered.bam > rep1.filtered.bam
    samtools view -b -@50 -F 780 -q 30 rep2.unfiltered.bam > rep2.filtered.bam
    
    # Merge
    samtools merge -f merged_unsorted.bam rep1.filtered.bam rep2.filtered.bam ...
    
    # Sort + index once after merge
    samtools sort -@50 -o sample.sorted.bam merged_unsorted.bam
    samtools index sample.sorted.bam
    

Step B: Call Relaxed Peaks

ChromBPNet uses relaxed peaks (p=0.01) to mimic ENCODE “overlap peaks.”
Wiki reference

macs2 callpeak -t sample.sorted.bam -f BAM -g hs -n sample_relaxed --call-summits -p 1e-2

Output: sample_relaxed_peaks.narrowPeak


Step C: Blacklist Filtering

ChromBPNet removes peaks overlapping blacklisted regions ±1057 bp (because evaluation windows are 2,114 bp wide).
Wiki reference

bedtools slop -i hg38.blacklist.bed -g hg38.chrom.sizes -b 1057 > blacklist.pad1057.bed
bedtools intersect -v -a sample_relaxed_peaks.narrowPeak -b blacklist.pad1057.bed   > sample.peaks.noBL.bed

Step D: Make 2,114 bp Summit Windows

ChromBPNet evaluates predictions on 2,114 bp windows centered at the peak summit.
Column 10 of narrowPeak = summit offset.
Wiki reference

awk 'BEGIN{OFS="	"}{
  center=$2+$10; start=center-1057; end=center+1057;
  if(start<0) start=0;
  print $1,start,end,$4
}' sample.peaks.noBL.bed | sort -k1,1 -k2,2n > sample.summit2114.raw.bed

bedClip sample.summit2114.raw.bed hg38.chrom.sizes sample.summit2114.bed

Step E: Fold-0 Split

ChromBPNet defines data splits by chromosome. For fold-0:

  • Test: chr1, chr3, chr6
  • Valid: chr8, chr20
  • Train: everything else
    Wiki reference

Extract test intervals:

awk '($1=="chr1"||$1=="chr3"||$1=="chr6")' sample.summit2114.bed > sample.fold0_test.bed

Optional dedup (to ensure no repeated intervals):

awk 'BEGIN{OFS="	"} !seen[$1 FS $2 FS $3]++' sample.fold0_test.bed > sample.fold0_test.uniq.bed

Step F: Build Coverage BigWig (for ground truth signal)

ChromBPNet evaluates against coverage profiles from the BAM.
Generate a BigWig:

bedtools genomecov -ibam sample.sorted.bam -bg -g hg38.chrom.sizes > sample.coverage.bedGraph
LC_COLLATE=C sort -k1,1 -k2,2n sample.coverage.bedGraph > sample.coverage.sorted.bedGraph
bedGraphToBigWig sample.coverage.sorted.bedGraph hg38.chrom.sizes sample.coverage.bw

Final Outputs for Evaluation

  • sample.fold0_test.bed = fold-0 test intervals (2,114 bp windows on chr1/3/6)
  • sample.coverage.bw = observed DNase/ATAC signal

These are the inputs for evaluating predictions (per-base Pearson r, log-total counts Pearson r, JSD). The exact targets are extracted from the sample.coverage.bw and sample.fold0_test.bed using pyBigWig library. Bed format files and pyBigWig both assume 0-based half-open intervals.

We also ensured that the tracks were scaled such that they add up to 100 million.

1 Like

Hi,

You also need to apply the following preprocessing step (reads_to_bigwig.py) to shift the reads accordingly: chrombpnet/chrombpnet/helpers/preprocessing at master · kundajelab/chrombpnet · GitHub

You can see that from the plots that the peaks are not as sharp as for example shown in the ChromBPNet or AlphaGenome paper.

The above script shifts the reads appropriately. The chrombpnet tutorial doesn’t list this step because this is done inside the full pipeline

From Preprocessing · kundajelab/chrombpnet Wiki · GitHub “When using the ChromBPNet pipelines to train models, you do not need to worry about pre-shifting the fragments, bams or tagAligns. The training tools are designed to automatically detect and correct any shifts in the files, specifically for ATAC and DNase, so you can run training without any additional shifting.”

Ziga

1 Like

You can also download the already processed observed bigwigs from ENCSR389HIH – ENCODE.

The accessions for other chrombpnet models are listed here: ChromBPNet_release

1 Like

Hey! I was able to reproduce the Chrombpnet numbers with this. Thanks for your help!

I am still however getting low profile correlation (0.3) for dnaseseq on encode CCre sequences (<500 bp) that have more variable shape patterns (not just peak centered regions). I formed these intervals using bigwigs derived exactly as mentioned in the alphagenome paper. Is this expected?

1 Like

Great to hear that you were able to reproduce the numbers.

0.3 sounds too low. You’d expect a similar number to what one seen for DNase peaks since CCREs are mostly accessible peaks as well. Are you using DNase-seq outputs or ATAC-seq outputs?

1 Like

I realized that there was a BAM file filtering bug in my data-processing pipeline. I have since fixed the issue and now all my file accessions match the files mentioned in the Alpha-genome metadata. Additionally I have verified that the non-zero avg mean for the different tracks matches the ones reported in the Alpha-genome metadata upto 2 decimal places. I am assuming that this means that my pipeline and processing is fine now.

With this changes, the scores on the cCREs with Alpha-genome jump to 0.7-0.8 using 1MB context. However, when I use only the context of the cCREs (<500bp), The scores are in the range of 0.4-0.5, for Dnase per bp.

Does this make sense?

Thanks for your help and patience!

1 Like

Glad to hear that you managed to reproduce the results.

The drop in performance makes sense. We only trained the model with 1Mb input sequences so feeding in shorter sequences can cause a distribution shift inside the model making the predictions less accurate than if we would have trained the model for that sequence length.

The Fig 7 in the AlphaGenome preprint shows how the performance drops as you lower the sequence length either only at inference time and/or also during the training time.

1 Like

This makes sense! Thanks for all your help and patience along with way!

1 Like