Classifies fungal genomes as Subsurface (cave/deep-earth) or Terrestrial (surface-soil) using Evo-2 genomic language model embeddings and/or functional annotation features. Trained models persist to disk so future runs — including short-read metagenome data — can skip the embedding step.
Genome files live under classify/, organised by niche and data type:
classify/
Subsurface/ (5 cave fungi)
cds/ *.cds-transcripts.fa
dna/ *.scaffolds.fa
annotation_summary/ *.annotation_summary.tsv[.gz]
Terrestrial/ (191 surface fungi)
cds/ *.cds-transcripts.fa
dna/ *.scaffolds.fa
annotation_summary/ *.annotation_summary.tsv[.gz]
Each annotation_summary TSV contains one row per protein with columns:
protein_id, pfam_domains, signalp_start/end/prob, merops_id/pct_id/evalue,
tmhmm_pred_hel/exp_aa/topology, cazy_family/EC/substrate.
Both uncompressed .tsv and gzip-compressed .tsv.gz files are supported.
The pfam_domains column holds pipe-delimited PF_ACC:PF_NAME:EVALUE tokens; a single
protein may carry several domains and the same accession may appear more than once
(multiple hit regions along the sequence).
This project uses pixi for environment management.
curl -fsSL https://pixi.sh/install.sh | bashcd SSF_contrast
pixi installEvo-2 depends on transformer-engine and flash-attn, both of which must be compiled
against the system CUDA installation. After pixi install, run:
pixi run pip3 install --no-build-isolation "transformer-engine[pytorch]"
pixi run pip install flash-attn --no-build-isolation--no-build-isolation is required so the build process links against the CUDA libraries
already present on the host. Verify nvcc is on PATH before building (which nvcc).
These steps are only needed for embedding or hybrid mode; the annotation-only baseline
has no GPU dependency.
All subsequent commands should be prefixed with pixi run to use the managed
environment, or run inside a shell activated with pixi shell.
Builds a classifier from functional annotation features derived from
annotation_summary TSV files:
- CAZyme classes — GH, GT, PL, CE, AA, CBM counts and rates
- Secretome — SignalP probability > 0.5
- Membrane proteins — TMHMM predicted helices > 0
- Proteases — MEROPS family hit
- PFAM scalar features — proteins with ≥1 domain, total domain instances, unique accessions, average domains per annotated protein, multi-domain protein count
- Per-PFAM-family features — for every domain accession present in ≥2 genomes, the count and rate of proteins carrying that domain (5 976 families in the current dataset, giving ~11 982 features total)
pixi run python train.py \
--mode annotation \
--data-dir classify \
--model-dir models/annotationAnnotation loading is parallelised across 8 threads by default. Adjust with
--n-workers:
pixi run python train.py \
--mode annotation \
--data-dir classify \
--model-dir models/annotation \
--n-workers 16Evo-2 model weights (~7 GB for the 1B variant) download automatically on first use.
pixi run python train.py \
--mode embedding \
--seq-type cds \
--data-dir classify \
--model-dir models/evo2_cds \
--embedding-cache models/embeddingsUse --seq-type scaffolds to embed whole-genome scaffold sequences instead of CDS
transcripts (slower, higher GPU memory).
pixi run python train.py \
--mode hybrid \
--seq-type cds \
--data-dir classify \
--model-dir models/hybrid \
--embedding-cache models/embeddingsAfter training an annotation model, generate a full SHAP interpretability report:
pixi run python explain_annotation.py \
--model-dir models/annotation \
--data-dir classify \
--results-dir results \
--top-dependence 5Outputs written to results/:
| File | Description |
|---|---|
shap_beeswarm.png |
Per-sample SHAP values coloured by feature value |
shap_bar.png |
Mean |SHAP| global feature ranking |
shap_waterfall_<name>.png |
Waterfall for each Subsurface genome (5 plots) |
shap_dependence_<feat>.png |
SHAP vs raw value for top-N features |
shap_values.csv |
Full SHAP matrix (genome × feature) |
shap_class_means.csv |
Mean SHAP per class per feature |
logistic_coefficients.csv/.png |
LR weights (logistic model only) |
permutation_importance.csv/.png |
Balanced-accuracy drop per feature |
A formatted summary table is also printed to stdout showing each feature's mean |SHAP|, class-directional sign, and the top-3 driving features for each individual Subsurface genome.
# Single FASTA
pixi run python predict.py \
--input unknown_genome.scaffolds.fa \
--model-dir models/hybrid
# Directory of FASTA files, save results to CSV
pixi run python predict.py \
--input-dir /path/to/new_genomes/ \
--model-dir models/hybrid \
--out results/predictions.csvpixi run python predict.py \
--input metagenome_bin.fa \
--model-dir models/hybrid \
--short-reads| Flag | Default | Description |
|---|---|---|
--mode |
annotation |
annotation / embedding / hybrid |
--seq-type |
cds |
Sequences to embed: cds or scaffolds |
--evo2-model |
evo2_1b_base |
Evo-2 variant (evo2_1b_base, evo2_7b_base, evo2_40b_base) |
--clf-type |
logistic |
Classifier head: logistic or mlp |
--cv-folds |
5 |
Cross-validation folds (auto-clamped to minority class size) |
--embedding-cache |
models/embeddings |
Directory for cached .npy embedding files |
--overwrite-embeddings |
off | Re-compute embeddings even if cache exists |
--results-dir |
results |
Output directory for plots and CSVs |
--n-workers |
8 |
Threads for parallel annotation TSV loading (annotation and hybrid modes) |
After training, models/<name>/ contains:
pipeline.pkl— fitted sklearn pipeline (scaler + classifier)metadata.json— feature names, hyperparameters, CV metrics, label map
results/ contains (where applicable):
logistic_coefficients.csv/top_features_logistic.pngpermutation_importance.csv/top_features_permutation.pngumap_embeddings.png
src/
data_loader.py GenomeRecord dataclass; discovers genomes from classify/
embeddings.py Evo-2 tiling, pooling, and .npy caching
annotation_features.py Scalar annotation features + per-PFAM-family features from
annotation_summary TSVs; two-pass matrix build with
cross-genome PFAM vocabulary; parallel loading via
ThreadPoolExecutor
classifier.py sklearn pipeline; balanced class weights; parallel CV; save/load
features.py Coefficients, permutation importance (parallel), SHAP, UMAP
train.py Training pipeline
predict.py Inference on new FASTA files
explain_annotation.py SHAP-based explanation of the annotation classifier;
beeswarm, bar, waterfall, and dependence plots + CSV exports
| Stage | Mechanism | Controlled by |
|---|---|---|
| Annotation TSV loading | ThreadPoolExecutor (I/O-bound) |
--n-workers (default 8) |
| Cross-validation folds | sklearn n_jobs=-1 (all CPUs) |
automatic |
| Permutation importance | sklearn n_jobs=-1 (all CPUs) |
automatic |
Evo-2 embedding is GPU-bound and runs serially on a single GPU; the per-genome
embedding cache (.npy files) avoids redundant GPU work across training runs.
With 5 Subsurface and 193 Terrestrial genomes (1:38 ratio), all classifiers use
class_weight='balanced'. Cross-validation is stratified and folds are clamped to
the minority class count. Prefer balanced accuracy, F1 (Subsurface class), and
ROC-AUC over raw accuracy when interpreting results.