Reproducing AlphaGenome becnhmarks

Hello,

I’m trying to reproduce Alpha genome becnhmark from Extended Data Fig 3g, but my metrics are different from the ones reported in the article, especially per-base Pearson r.
For GM12878, I get JSD 0.5937 (instead of reported 0.62), log total Pearson r 0.8447 (instead of 0.78) and per-base Pearson r 0.4817 (instead of 0.65).

Here is relevant code:

import numpy as np
import torch
from tqdm import tqdm
import json
import pyBigWig
from alphagenome.data import genome
from alphagenome.models import dna_client
import pandas as pd
import os
from scipy.spatial.distance import jensenshannon
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns

#inputs
csv_file = ''
output_dir = ''
api_key = ''
batching = True
batch_size = 500 
bin_size = 1

def compute_log_total_count_correlation(pred_sum, truth_sum):
    pred_log = np.log1p(pred_sum)
    truth_log = np.log1p(truth_sum)
    
    if np.std(pred_log) > 0 and np.std(truth_log) > 0:
        corr, _ = pearsonr(pred_log, truth_log)
        return corr
    else:
        return np.nan

def compute_jensen_shannon_distance(pred_values, truth_values, pseudocount=1e-6):
    pred = pred_values + pseudocount
    truth = truth_values + pseudocount
    
    pred_prob = pred / np.sum(pred)
    truth_prob = truth / np.sum(truth)
    
    js_distance = jensenshannon(pred_prob, truth_prob)
    
    return js_distance if not np.isnan(js_distance) else np.nan

def compute_binned_correlation(pred_values, truth_values, bin_size):
    if bin_size == 1:
        if np.std(pred_values) > 0 and np.std(truth_values) > 0:
            return np.corrcoef(pred_values, truth_values)[0, 1]
        else:
            return np.nan

def evaluate_sample(sample_name, biosample, bw_path, bed, data_type, 
                   model, seq_length, batch_size, bin_size, output_dir):
    
    print(f"\n{'='*60}")
    print(f"Evaluating: {sample_name}")
    print(f"{'='*60}")
    
    # Load regions
    regions_df = pd.read_table(bed, names=['chr', 'start', 'end', 'name','value','strand','value1','value2','value3','summit'])
    truth_bw = pyBigWig.open(bw_path)
    
    # Sample if batching
    if batching and len(regions_df) > batch_size:
        batch_df = regions_df.sample(n=batch_size, random_state=42)
        batch_df.reset_index(inplace=True, drop=True)
    else:
        batch_df = regions_df
    
    # Initialize tracking variables
    processed = 0
    per_sample_correlations = []
    per_sample_sum_preds = []
    per_sample_sum_targs = []
    per_sample_js_distances = []
    total_pred_sum = 0.0
    total_target_sum = 0.0
    
    print(f"Processing {len(batch_df)} regions...")
    
    with tqdm(total=len(batch_df), desc=f"  {sample_name}") as pbar:
        for row in range(len(batch_df)):
            try:
                interval = genome.Interval(
                    chromosome=batch_df.at[row, 'chr'],
                    start=batch_df.at[row, 'start'],
                    end=batch_df.at[row, 'end']
                ).resize(seq_length)
                
                orig_interval = genome.Interval(
                    chromosome=batch_df.at[row, 'chr'],
                    start=batch_df.at[row, 'start'],
                    end=batch_df.at[row, 'end']
                )
                
                # Get predictions
                pred = model.predict_interval(
                    interval=interval,
                    requested_outputs=[dna_client.OutputType[str(data_type)]],
                    ontology_terms=[biosample],
                )
                pred_orig_interval = getattr(pred, data_type.lower()).slice_by_interval(
                    orig_interval, match_resolution=True
                )
                pred_values = np.asarray(pred_orig_interval.values).squeeze()
                
                # Get ground truth
                truth_values = np.nan_to_num(truth_bw.values(
                    batch_df.at[row, 'chr'],
                    batch_df.at[row, 'start'],
                    batch_df.at[row, 'end'],
                    numpy=True
                ))
                
                # Per-base binned correlation
                corr = compute_binned_correlation(pred_values, truth_values, bin_size)
                per_sample_correlations.append(corr)
                
                # Log total count sums
                pred_sum = np.sum(pred_values)
                truth_sum = np.sum(truth_values)
                per_sample_sum_preds.append(pred_sum)
                per_sample_sum_targs.append(truth_sum)
                
                # Jensen-Shannon distance
                if np.sum(pred_values) > 0 and np.sum(truth_values) > 0:
                    jsd = compute_jensen_shannon_distance(pred_values, truth_values)
                    per_sample_js_distances.append(jsd)
                else:
                    per_sample_js_distances.append(np.nan)
                
                total_pred_sum += float(pred_sum)
                total_target_sum += float(truth_sum)
                
            except Exception as e:
                print(f"  Error processing region {row}: {e}")
                per_sample_correlations.append(np.nan)
                per_sample_sum_preds.append(0)
                per_sample_sum_targs.append(0)
                per_sample_js_distances.append(np.nan)
            
            processed += 1
            pbar.update(1)
    
    truth_bw.close()
    
    # Compute aggregate metrics
    valid_corrs = [c for c in per_sample_correlations if not np.isnan(c)]
    valid_jsd = [j for j in per_sample_js_distances if not np.isnan(j)]
    
    # Log total count correlation
    if len(per_sample_sum_preds) > 1 and len(per_sample_sum_targs) > 1:
        log_count_corr = compute_log_total_count_correlation(
            np.array(per_sample_sum_preds),
            np.array(per_sample_sum_targs)
        )
    else:
        log_count_corr = np.nan
    
    results = {
        "sample_name": sample_name,
        "biosample": biosample,
        "data_type": data_type,
        "n_regions": processed,
        "n_valid_correlations": len(valid_corrs),
        "n_valid_jsd": len(valid_jsd),
        
        # Primary metrics
        f"pearson_r_{bin_size}bp": float(np.mean(valid_corrs)) if valid_corrs else np.nan,
        f"pearson_r_{bin_size}bp_median": float(np.median(valid_corrs)) if valid_corrs else np.nan,
        "log_total_count_pearson_r": float(log_count_corr) if not np.isnan(log_count_corr) else np.nan,
        "mean_js_distance": float(np.mean(valid_jsd)) if valid_jsd else np.nan,
        "median_js_distance": float(np.median(valid_jsd)) if valid_jsd else np.nan,
        
        # Summary
        "total_pred_sum": float(total_pred_sum),
        "total_truth_sum": float(total_target_sum),
        "ratio_pred_truth": float(total_pred_sum / total_target_sum) if total_target_sum > 0 else np.nan,
        
        # Per-sample data
        "per_sample_correlations": per_sample_correlations,
        "per_sample_js_distances": per_sample_js_distances,
        "per_sample_pred_sums": per_sample_sum_preds,
        "per_sample_truth_sums": per_sample_sum_targs
    }
    
    return results

def main():
    
    # Load sample information from CSV
    print(f"Loading sample information from: {csv_file}")
    samples_df = pd.read_csv(csv_file)
    
    print(f"Found {len(samples_df)} samples to process:")
    for _, row in samples_df.iterrows():
        print(f"  - {row['sample_name']} ({row['data_type']})")
    
    # Initialize model (once for all samples)
    print("\nInitializing AlphaGenome model...")
    fold_to_use = dna_client.ModelVersion.FOLD_0
    seq_length = dna_client.SEQUENCE_LENGTH_1MB
    model = dna_client.create(api_key, model_version=fold_to_use)
    
    # Process each sample
    all_results = []
    
    for idx, row in samples_df.iterrows():
        sample_name = row['sample_name']
        biosample = row['biosample']
        bw_path = row['bw_path']
        bed = row['bed']
        data_type = row['data_type']
        
        # Validate files exist
        if not os.path.exists(bw_path):
            print(f"  WARNING: BigWig file not found: {bw_path}")
            continue
        if not os.path.exists(bed):
            print(f"  WARNING: BED file not found: {bed}")
            continue
        
        # Evaluate sample
        results = evaluate_sample(
            sample_name=sample_name,
            biosample=biosample,
            bw_path=bw_path,
            bed=bed,
            data_type=data_type,
            model=model,
            seq_length=seq_length,
            batch_size=batch_size,
            bin_size=bin_size,
            output_dir=output_dir
        )
        
        all_results.append(results)

if __name__ == "__main__":
    main()

I’m using ground truth bigwig downloaded from ENCSR003WJE – ENCODE and regions bed file is made according to this collab - Google Colab

I would be grateful for any insight on what I’m doing wrong here.

For what it is worth, from my experience, you cannot just grab the “raw” (in the sense of un-processed) bigwig files from Encode as the ground truth for alphagenome (AG), b/c it was trained on a processed version of these (see “DNase and ATAC Data” section in Methods that describes the data processing). So you would either need to repeat that processing as they describe in the paper, or download the processed data from HuggingFace in form of TFRecord files (see google/alphagenome-all-folds · Hugging Face).

Thank you for the suggestion!
I tried normalizing my bigwig files to 100000000 (using normalize_bigwig.py from RSeQC), but I don’t think it helped, as all my metrics are still the same.
Thank you for the link to the processed data from HuggingFace, although I don’t think I understand how to load the data for custom intervals and not the predefined train/valid/test subsets. If you could help me with that, I’d really appreciate the assistance