Hello,
I’m trying to reproduce Alpha genome becnhmark from Extended Data Fig 3g, but my metrics are different from the ones reported in the article, especially per-base Pearson r.
For GM12878, I get JSD 0.5937 (instead of reported 0.62), log total Pearson r 0.8447 (instead of 0.78) and per-base Pearson r 0.4817 (instead of 0.65).
Here is relevant code:
import numpy as np
import torch
from tqdm import tqdm
import json
import pyBigWig
from alphagenome.data import genome
from alphagenome.models import dna_client
import pandas as pd
import os
from scipy.spatial.distance import jensenshannon
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import seaborn as sns
#inputs
csv_file = ''
output_dir = ''
api_key = ''
batching = True
batch_size = 500
bin_size = 1
def compute_log_total_count_correlation(pred_sum, truth_sum):
pred_log = np.log1p(pred_sum)
truth_log = np.log1p(truth_sum)
if np.std(pred_log) > 0 and np.std(truth_log) > 0:
corr, _ = pearsonr(pred_log, truth_log)
return corr
else:
return np.nan
def compute_jensen_shannon_distance(pred_values, truth_values, pseudocount=1e-6):
pred = pred_values + pseudocount
truth = truth_values + pseudocount
pred_prob = pred / np.sum(pred)
truth_prob = truth / np.sum(truth)
js_distance = jensenshannon(pred_prob, truth_prob)
return js_distance if not np.isnan(js_distance) else np.nan
def compute_binned_correlation(pred_values, truth_values, bin_size):
if bin_size == 1:
if np.std(pred_values) > 0 and np.std(truth_values) > 0:
return np.corrcoef(pred_values, truth_values)[0, 1]
else:
return np.nan
def evaluate_sample(sample_name, biosample, bw_path, bed, data_type,
model, seq_length, batch_size, bin_size, output_dir):
print(f"\n{'='*60}")
print(f"Evaluating: {sample_name}")
print(f"{'='*60}")
# Load regions
regions_df = pd.read_table(bed, names=['chr', 'start', 'end', 'name','value','strand','value1','value2','value3','summit'])
truth_bw = pyBigWig.open(bw_path)
# Sample if batching
if batching and len(regions_df) > batch_size:
batch_df = regions_df.sample(n=batch_size, random_state=42)
batch_df.reset_index(inplace=True, drop=True)
else:
batch_df = regions_df
# Initialize tracking variables
processed = 0
per_sample_correlations = []
per_sample_sum_preds = []
per_sample_sum_targs = []
per_sample_js_distances = []
total_pred_sum = 0.0
total_target_sum = 0.0
print(f"Processing {len(batch_df)} regions...")
with tqdm(total=len(batch_df), desc=f" {sample_name}") as pbar:
for row in range(len(batch_df)):
try:
interval = genome.Interval(
chromosome=batch_df.at[row, 'chr'],
start=batch_df.at[row, 'start'],
end=batch_df.at[row, 'end']
).resize(seq_length)
orig_interval = genome.Interval(
chromosome=batch_df.at[row, 'chr'],
start=batch_df.at[row, 'start'],
end=batch_df.at[row, 'end']
)
# Get predictions
pred = model.predict_interval(
interval=interval,
requested_outputs=[dna_client.OutputType[str(data_type)]],
ontology_terms=[biosample],
)
pred_orig_interval = getattr(pred, data_type.lower()).slice_by_interval(
orig_interval, match_resolution=True
)
pred_values = np.asarray(pred_orig_interval.values).squeeze()
# Get ground truth
truth_values = np.nan_to_num(truth_bw.values(
batch_df.at[row, 'chr'],
batch_df.at[row, 'start'],
batch_df.at[row, 'end'],
numpy=True
))
# Per-base binned correlation
corr = compute_binned_correlation(pred_values, truth_values, bin_size)
per_sample_correlations.append(corr)
# Log total count sums
pred_sum = np.sum(pred_values)
truth_sum = np.sum(truth_values)
per_sample_sum_preds.append(pred_sum)
per_sample_sum_targs.append(truth_sum)
# Jensen-Shannon distance
if np.sum(pred_values) > 0 and np.sum(truth_values) > 0:
jsd = compute_jensen_shannon_distance(pred_values, truth_values)
per_sample_js_distances.append(jsd)
else:
per_sample_js_distances.append(np.nan)
total_pred_sum += float(pred_sum)
total_target_sum += float(truth_sum)
except Exception as e:
print(f" Error processing region {row}: {e}")
per_sample_correlations.append(np.nan)
per_sample_sum_preds.append(0)
per_sample_sum_targs.append(0)
per_sample_js_distances.append(np.nan)
processed += 1
pbar.update(1)
truth_bw.close()
# Compute aggregate metrics
valid_corrs = [c for c in per_sample_correlations if not np.isnan(c)]
valid_jsd = [j for j in per_sample_js_distances if not np.isnan(j)]
# Log total count correlation
if len(per_sample_sum_preds) > 1 and len(per_sample_sum_targs) > 1:
log_count_corr = compute_log_total_count_correlation(
np.array(per_sample_sum_preds),
np.array(per_sample_sum_targs)
)
else:
log_count_corr = np.nan
results = {
"sample_name": sample_name,
"biosample": biosample,
"data_type": data_type,
"n_regions": processed,
"n_valid_correlations": len(valid_corrs),
"n_valid_jsd": len(valid_jsd),
# Primary metrics
f"pearson_r_{bin_size}bp": float(np.mean(valid_corrs)) if valid_corrs else np.nan,
f"pearson_r_{bin_size}bp_median": float(np.median(valid_corrs)) if valid_corrs else np.nan,
"log_total_count_pearson_r": float(log_count_corr) if not np.isnan(log_count_corr) else np.nan,
"mean_js_distance": float(np.mean(valid_jsd)) if valid_jsd else np.nan,
"median_js_distance": float(np.median(valid_jsd)) if valid_jsd else np.nan,
# Summary
"total_pred_sum": float(total_pred_sum),
"total_truth_sum": float(total_target_sum),
"ratio_pred_truth": float(total_pred_sum / total_target_sum) if total_target_sum > 0 else np.nan,
# Per-sample data
"per_sample_correlations": per_sample_correlations,
"per_sample_js_distances": per_sample_js_distances,
"per_sample_pred_sums": per_sample_sum_preds,
"per_sample_truth_sums": per_sample_sum_targs
}
return results
def main():
# Load sample information from CSV
print(f"Loading sample information from: {csv_file}")
samples_df = pd.read_csv(csv_file)
print(f"Found {len(samples_df)} samples to process:")
for _, row in samples_df.iterrows():
print(f" - {row['sample_name']} ({row['data_type']})")
# Initialize model (once for all samples)
print("\nInitializing AlphaGenome model...")
fold_to_use = dna_client.ModelVersion.FOLD_0
seq_length = dna_client.SEQUENCE_LENGTH_1MB
model = dna_client.create(api_key, model_version=fold_to_use)
# Process each sample
all_results = []
for idx, row in samples_df.iterrows():
sample_name = row['sample_name']
biosample = row['biosample']
bw_path = row['bw_path']
bed = row['bed']
data_type = row['data_type']
# Validate files exist
if not os.path.exists(bw_path):
print(f" WARNING: BigWig file not found: {bw_path}")
continue
if not os.path.exists(bed):
print(f" WARNING: BED file not found: {bed}")
continue
# Evaluate sample
results = evaluate_sample(
sample_name=sample_name,
biosample=biosample,
bw_path=bw_path,
bed=bed,
data_type=data_type,
model=model,
seq_length=seq_length,
batch_size=batch_size,
bin_size=bin_size,
output_dir=output_dir
)
all_results.append(results)
if __name__ == "__main__":
main()
I’m using ground truth bigwig downloaded from ENCSR003WJE – ENCODE and regions bed file is made according to this collab - Google Colab
I would be grateful for any insight on what I’m doing wrong here.