Hello,
I’m trying to reproduce the ChromBPNet benchmark results from Extended Data Fig 3g and am encountering a large discrepancy in the DNase track correlation for IMR-90 (ENCSR477RTP) on test fold 0.
-
Reported Track Pearson r: 0.79
-
My Track Pearson r: ~0.35
My other metrics are aligned with the paper’s results (log total count correlation ~0.78, JSD ~0.44), so the issue seems specific to the per-base correlation.
What I’ve done:
-
Prepared the ground truth data using the exact steps from the ChromBPNet repository.
-
Used the ChromBPNet fold 0 test peak regions, filtering for overlaps with the AlphaGenome training set.
-
Used the standard AlphaGenome model for predictions.
I m attaching my prediction script. Any help would be greatly appreciated.
'I m attaching my prediction script. Any help would be greatly appreciated.'
class TestRegulatoryDataset(Dataset):
def __init__(self, **kwargs):
data_file = kwargs.pop("data_file")
cache_dir = kwargs.pop("cache_dir", None)
subset_size = kwargs.pop("subset_size", None)
self.split_name = list(data_file.keys())[0]
self.dataset = load_dataset("json", data_files=data_file, cache_dir=cache_dir)
self.max_length = 2114
if subset_size is not None:
total_size = len(self.dataset[self.split_name])
if subset_size > total_size:
logger.warning(
f"subset_size ({subset_size}) is larger than the total dataset size ({total_size}). Using the full dataset."
)
else:
indices = list(range(total_size))
random.shuffle(indices)
subset_indices = indices[:subset_size]
self.dataset[self.split_name] = self.dataset[self.split_name].select(
subset_indices
)
logger.info(
f"Loaded {len(self.dataset[self.split_name])} sequences from {data_file}"
)
def __len__(self) -> int:
return len(self.dataset[self.split_name])
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
item = self.dataset[self.split_name][idx]
biosample_type = item["biosample_term"]
assay_values = []
assay_sequence = item[f"signal"]
assay_values.append(torch.tensor(assay_sequence))
assay_targets = torch.stack(assay_values).unsqueeze(-1)
assay_targets *= float(10**8) / 11005729934
return {
"assay_targets": assay_targets,
"biosample_id": "EFO:0001196",
"chromosome": item["chr"],
"start": item["start"],
"end": item["end"],
"strand": "+",
"biosample_name": biosample_type,
}
def main():
# Model selection matches chrombp benchmark
fold_to_use = dna_client.ModelVersion.ALL_FOLDS
seq_length = dna_client.SEQUENCE_LENGTH_16KB
test_config, modality_config, train_config = _load_configs(
args.test_config, args.architecture_config_path, args.tokenizer_config_path
)
dataset_kwargs = {
"subset_size": 100,
}
data_files = _get_data_files(test_config)
if not data_files:
raise FileNotFoundError("No .jsonl files found under data_dir/chrombpnet_fold0_eval")
# Ensure deterministic dataset subset shuffling
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
dataset = TestRegulatoryDataset(
data_file={"test": data_files},
**dataset_kwargs,
)
generator = torch.Generator()
if "sampler_seed" in test_config:
generator.manual_seed(int(test_config["sampler_seed"]))
else:
generator.manual_seed(int(args.seed))
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, generator=generator)
# API key(s): use first available, no threading
keys_env = os.environ.get("ALPHAGENOME_API_KEY")
if not keys_env:
raise ValueError("Set ALPHAGENOME_API_KEYS (comma-separated) or ALPHAGENOME_API_KEY in environment.")
api_key = keys_env.strip()
model = dna_client.create(api_key, model_version=fold_to_use)
processed = 0
total_valid_positions = 0
total_pred_sum = 0.0
total_target_sum = 0.0
per_sample_correlations = []
per_sample_sum_preds = []
per_sample_sum_targs = []
with tqdm(total=len(dataloader), desc="AlphaGenome (single-thread)") as pbar:
for sample in dataloader:
chromosome = sample["chromosome"][0]
start = sample["start"][0].item()
end = sample["end"][0].item()
strand = sample["strand"][0]
biosample_id = sample["biosample_id"][0]
interval = genome.Interval(chromosome=chromosome, start=start, end=end, strand=strand).resize(seq_length)
output = model.predict_interval(
interval=interval,
requested_outputs=[dna_client.OutputType.DNASE],
ontology_terms=[biosample_id],
)
orig_interval = genome.Interval(chromosome=chromosome, start=start, end=end, strand=strand)
dnase_td = output.dnase.slice_by_interval(orig_interval, match_resolution=True)
dnase_values = np.asarray(dnase_td.values).squeeze()
alphagenome_preds = torch.from_numpy(dnase_values).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
assay_targets = sample["assay_targets"]
pred_seq = alphagenome_preds[0, 0, :, 0]
targ_seq = assay_targets[0, 0, :].to(dtype=pred_seq.dtype)
L_pred = pred_seq.shape[0]
L_targ = targ_seq.shape[0]
win_len = min(1000,L_pred, L_targ)
if win_len > 1:
ps = (L_pred - win_len) // 2
ts = (L_targ - win_len) // 2
pred_win = pred_seq[ps:ps + win_len]
targ_win = targ_seq[ts:ts + win_len].squeeze(-1)
total_valid_positions += int(win_len)
total_pred_sum += float(pred_win.sum().item())
total_target_sum += float(targ_win.sum().item())
per_sample_sum_preds.append(float(pred_win.sum().item()))
per_sample_sum_targs.append(float(targ_win.sum().item()))
corr = torch.corrcoef(torch.stack([pred_win, targ_win]))[0, 1].item()
per_sample_correlations.append(corr)
processed += 1
pbar.update(1)
valid_corrs = [c for c in per_sample_correlations if isinstance(c, float) and not np.isnan(c)]
mean_corr = float(np.mean(valid_corrs)) if len(valid_corrs) > 0 else float("nan")
sum_corr = float("nan")
if len(per_sample_sum_preds) > 1 and len(per_sample_sum_preds) == len(per_sample_sum_targs):
try:
sp = torch.log1p(torch.tensor(per_sample_sum_preds, dtype=torch.float32))
st = torch.log1p(torch.tensor(per_sample_sum_targs, dtype=torch.float32))
sum_corr = float(torch.corrcoef(torch.stack([sp, st]))[0, 1].item())
except Exception:
pass
print(f"Processed {processed} samples")
print(f"Mean correlation (center 1000bp): {mean_corr}")
print(f"Total valid positions: {total_valid_positions}")
print(f"Sum of predictions over valid positions: {total_pred_sum}")
print(f"Sum of targets over valid positions: {total_target_sum}")
print(f"Correlation of per-sample sums: {sum_corr}")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
main()
