A PyTorch-based deep learning library for multi-task brain MRI analysis. Train models for dementia classification, brain tumor detection, and 3D volumetric NIfTI analysis with just a few lines of code.
- 2D & 3D Support: RGB image analysis and NIfTI (.nii/.nii.gz) volumetric processing
- Flexible Task Selection: Train for dementia only, tumor only, or both simultaneously
- Easy to Use: Keras-like API for quick training
- Task-Specific Attention: Separate attention mechanisms for each task
- MRI Artifact Simulation: Bias field, ghosting, spike noise, Rician noise
- Elastic Deformation: Anatomical variation simulation for 2D and 3D
- MixUp / CutMix: Modern batch-level augmentation techniques
- AutoAugment: Automatic MRI-specific augmentation policy
- HuggingFace Hub: Share and download models via HuggingFace
- ONNX Export: Production deployment for both 2D and 3D models
- Model Zoo: Registry of pretrained model configurations
- Visualization: Built-in attention heatmap visualization
- Configurable: YAML/JSON configuration support
2D - Dementia (6 classes):
- AD Alzheimer's Disease
- AD Mild Demented
- AD Moderate Demented
- AD Very Mild Demented
- CN Non-Demented (Cognitively Normal)
- PD Parkinson's Disease
2D - Brain Tumor (4 classes):
- Glioma
- Meningioma
- No Tumor
- Pituitary
3D - Alzheimer's (NIfTI):
- CN (Cognitively Normal)
- MCI (Mild Cognitive Impairment)
- AD (Alzheimer's Disease)
# Basic installation
pip install vbai
# With NIfTI (3D) support
pip install vbai[nifti]
# With HuggingFace Hub integration
pip install vbai[hub]
# With ONNX export
pip install vbai[onnx]
# With all optional dependencies
pip install vbai[full]
# Development installation
git clone https://github.com/Neurazum-AI-Department/vbai.git
cd vbai
pip install -e .[dev]import vbai
# Create model for both tasks (default)
model = vbai.MultiTaskBrainModel(variant='q') # 'q' for quality, 'f' for fast
# Prepare dataset
dataset = vbai.UnifiedMRIDataset(
dementia_path='./data/dementia/train',
tumor_path='./data/tumor/train',
is_training=True
)
# Create trainer and train
trainer = vbai.Trainer(model=model, lr=0.0005, device='cuda')
history = trainer.fit(train_data=dataset, epochs=10, batch_size=32)
trainer.save('brain_model.pt')import vbai
# Create 3D model
model = vbai.MultiTask3DBrainModel(
variant='q',
tasks={'alzheimer': 3},
input_shape=(96, 96, 96)
)
# Create NIfTI dataset
dataset = vbai.NIfTIDataset(
root='./data/alzheimer_3d', # Subfolders: CN/, MCI/, AD/
target_shape=(96, 96, 96),
is_training=True
)
# Dataloaders
train_loader, val_loader = vbai.create_3d_dataloaders(
root='./data/alzheimer_3d',
batch_size=4, val_split=0.2
)
# Train
trainer = vbai.Trainer3D(model=model, lr=1e-4, device='cuda')
history = trainer.fit(train_loader, val_loader, epochs=25)
trainer.save('alzheimer_3d.pt')import vbai
# Dementia only
model = vbai.MultiTaskBrainModel(variant='q', tasks=['dementia'])
# Tumor only
model = vbai.MultiTaskBrainModel(variant='q', tasks=['tumor'])import vbai
# 2D prediction
model = vbai.load('brain_model.pt', device='cuda')
result = model.predict('brain_scan.jpg')
print(f"Dementia: {result.dementia_class} ({result.dementia_confidence:.1%})")
print(f"Tumor: {result.tumor_class} ({result.tumor_confidence:.1%})")
# 3D NIfTI prediction
model_3d = vbai.load_3d('alzheimer_3d.pt', device='cuda')
result = model_3d.predict('scan.nii.gz', task='alzheimer',
class_names=['CN', 'MCI', 'AD'])
print(f"{result.predicted_class}: {result.confidence:.1%}")import vbai
# List available models
models = vbai.list_models() # All models
models_3d = vbai.list_models('3d') # Only 3D models
# Download and load from Hub
model = vbai.from_hub('Neurazum/vbai-3d-q', device='cuda')
# Push your trained model to Hub
model.push_to_hub('username/my-brain-model')
# Or use the functional API
vbai.push_to_hub(model, 'username/my-brain-model', private=True)import vbai
# Export 2D model
model_2d = vbai.MultiTaskBrainModel(variant='q')
model_2d.export_onnx('model_2d.onnx')
# Export 3D model
model_3d = vbai.MultiTask3DBrainModel(variant='q', tasks={'alzheimer': 3})
model_3d.export_onnx('model_3d.onnx')
# Or use the functional API
vbai.export_onnx(model_2d, 'model_2d.onnx')
# PyTorch-free inference with ONNX
onnx_model = vbai.ONNXModel('model_3d.onnx')
output = onnx_model.predict_nifti('brain_scan.nii.gz')
probs = onnx_model.softmax(output)import vbai
import numpy as np
# ── MRI Artifact Simulation ──
volume = np.random.rand(96, 96, 96).astype(np.float32)
# Simulate individual artifacts
volume = vbai.simulate_bias_field(volume, intensity=0.3)
volume = vbai.simulate_ghosting(volume, num_ghosts=3, intensity=0.15)
volume = vbai.simulate_rician_noise(volume, std=0.03)
volume = vbai.simulate_spike_noise(volume, num_spikes=1, intensity=0.5)
# Or apply random artifacts in one call
volume = vbai.simulate_mri_artifacts(volume, p=0.5)
# ── Elastic Deformation ──
deformed_2d = vbai.elastic_deformation_2d(image_2d, alpha=50, sigma=5)
deformed_3d = vbai.elastic_deformation_3d(volume, alpha=30, sigma=4)
# ── MixUp / CutMix (batch-level, works with both 2D and 3D) ──
mixed, labels_a, labels_b, lam = vbai.mixup(images, labels, alpha=0.2)
loss = lam * criterion(model(mixed), labels_a) + (1-lam) * criterion(model(mixed), labels_b)
mixed, labels_a, labels_b, lam = vbai.cutmix(images, labels, alpha=1.0)
# ── AutoAugment (automatic MRI-specific policy) ──
augmenter = vbai.MRIAutoAugment(mode='3d', num_policies=10)
augmented_volume = augmenter(volume)
augmenter_2d = vbai.MRIAutoAugment(mode='2d', num_policies=10)
augmented_image = augmenter_2d(image_2d)import vbai
model = vbai.MultiTaskBrainModel(variant='q')
callbacks = [
vbai.EarlyStopping(monitor='val_loss', patience=5),
vbai.ModelCheckpoint(
filepath='checkpoints/model_{epoch:02d}.pt',
monitor='val_loss',
save_best_only=True
)
]
trainer = vbai.Trainer(model=model, callbacks=callbacks)
trainer.fit(train_data, val_data, epochs=50)import vbai
# Use preset configurations
config = vbai.get_default_config('quality') # 'default', 'fast', 'quality', 'debug'
# 3D configuration
config_3d = vbai.get_default_3d_config('quality')
# Custom config
model_config = vbai.ModelConfig(
variant='q',
tasks=['dementia', 'tumor'],
dropout=0.3,
use_edge_branch=True
)# Train 2D model
vbai-train --dementia_path ./data/dementia --tumor_path ./data/tumor --epochs 10
# Single-task training
vbai-train --dementia_path ./data/dementia --tasks dementia --epochs 10
# Prediction
vbai-predict --model brain_model.pt --image brain_scan.jpg
# With visualization
vbai-predict --model brain_model.pt --image brain_scan.jpg --visualize --output result.png| Variant | Layers | Channels | Speed | Accuracy |
|---|---|---|---|---|
f (fast) |
3 | 32-64-128 | Fast | Good |
q (quality) |
4 | 64-128-256-512 | Slower | Better |
| Variant | Stages | Channels | Speed | Accuracy |
|---|---|---|---|---|
f (fast) |
3x1 blocks | 32-64-128 | Fast | Good |
q (quality) |
3x2 blocks | 64-128-256 | Slower | Better |
data/
├── dementia/
│ ├── train/
│ │ ├── AD_Alzheimer/
│ │ ├── AD_Mild_Demented/
│ │ ├── AD_Moderate_Demented/
│ │ ├── AD_Very_Mild_Demented/
│ │ ├── CN_Non_Demented/
│ │ └── PD_Parkinson/
│ └── val/
└── tumor/
├── train/
│ ├── Glioma/
│ ├── Meningioma/
│ ├── No_Tumor/
│ └── Pituitary/
└── val/
data/alzheimer_3d/
├── CN/
│ ├── subject_001.nii.gz
│ └── subject_002.nii.gz
├── MCI/
│ └── subject_003.nii.gz
└── AD/
└── subject_004.nii.gz
MultiTaskBrainModel- 2D multi-task model (dementia + tumor)MultiTask3DBrainModel- 3D volumetric model (NIfTI)Trainer- 2D training loop managerTrainer3D- 3D training loop manager
UnifiedMRIDataset- 2D dataset (RGB images)NIfTIDataset- 3D dataset (NIfTI volumes)UnifiedNIfTIDataset- 3D multi-task dataset
simulate_bias_field()/simulate_ghosting()/simulate_spike_noise()/simulate_rician_noise()- MRI artifact simulationsimulate_mri_artifacts()- Combined random artifact applicationelastic_deformation_2d()/elastic_deformation_3d()- Elastic deformationmixup()/cutmix()- Batch-level augmentation (2D & 3D)MRIAutoAugment- Automatic augmentation policy
list_models()/get_model_info()- Model zoo registryfrom_hub()/push_to_hub()- HuggingFace Hub integrationexport_onnx()- ONNX export (2D & 3D)ONNXModel- PyTorch-free ONNX inference
ModelConfig/Model3DConfig- Architecture settingsTrainingConfig/Training3DConfig- Training hyperparametersget_default_config()/get_default_3d_config()- Presets
EarlyStopping- Stop when no improvementModelCheckpoint- Save best/all checkpointsTensorBoardLogger- Log to TensorBoard
See the examples/ directory:
train_basic.py- Basic 2D trainingtrain_advanced.py- Advanced training with callbackstrain_3d.py- 3D NIfTI traininginference.py- Model inference
@software{vbai,
title = {Vbai: Visual Brain AI Library},
author = {Neurazum},
year = {2025},
url = {https://github.com/Neurazum-AI-Department/vbai}
}MIT License - see LICENSE for details.
Is being planned...
- Website: Neurazum - HealFuture
- Email: contact@neurazum.com
Neurazum AI Department