RLMedNAS is an autonomous decision-making system designed to engineer optimal Convolutional Neural Network (CNN) topologies for medical image classification. By utilising a Reinforcement Learning (RL) agent, the project replaces labor-intensive manual architecture design with a mathematically optimised, sequential decision-making process.
The manual design of CNNs for medical data is often biased toward human-manageable structures rather than mathematical optimality. RLMedNAS addresses this by employing a Recurrent Neural Network (RNN) Controller trained via the REINFORCE policy gradient algorithm to explore a discrete, block-based search space.
To ensure clinical feasibility on resource-constrained devices, the system integrates Action Masking—a neuro-symbolic approach that prunes domain-inefficient architectures (such as overly wide networks) before training begins.
-
RNN Controller: A single-layer LSTM with 100 hidden units that emits architectural tokens.
-
Neuro-Symbolic Action Masking: Enforces symbolic "validity rules" by setting forbidden action indices (e.g., the 128-filter count) to a large negative constant (
$-1e10$ ) before the SoftMax operation. -
Proxy Evaluation: Uses the OrganAMNIST dataset (28x28 grayscale CT images) as a lightweight proxy to enable rapid policy convergence.
-
Hardware Optimized: Native acceleration for Apple Silicon via the Metal Performance Shaders (MPS) backend in PyTorch.
The Controller navigates the search space by maximising a multi-objective reward signal that balances diagnostic precision with computational parsimony.
The composite utility score
-
Validation Accuracy: The primary performance metric on the OrganAMNIST validation set.
-
Penalty Weight (
$\lambda$ ): A small constant ($10^{-6}$ ) used to penalize overly complex architectures.
The parameters
A moving average baseline
The agent samples three critical hyperparameters for the "Child Network":
-
Filter Count: {16, 32, 64, 128}.
-
Kernel Size: {3x3, 5x5}.
-
Dropout Rate: {0.1, 0.25, 0.5}.
| Search Strategy | Mean Reward | Mean Parameters | Mean Accuracy |
|---|---|---|---|
| Random Search | 0.8215 | 132,503 | 95.40% |
| RL Baseline | 0.8571 | 98,987 | 95.61% |
| RL Masked | 0.9084 | 46,557 | 95.49% |
Key Finding: The RL Masked agent discovered architectures approximately 3x smaller than the random baseline while achieving superior mean rewards.
The agent actively learned specific design rules that align with expert intuition for medical texture classification:
-
High Regularization: RL agents developed a strong consensus for a 0.5 Dropout Rate, identifying it as critical for preventing overfitting.
-
Efficiency over Width: The agent rejected the 128-filter count, discovering that 16 filters provided sufficient representational capacity for the input space.
-
Receptive Field Trade-offs: "Champion" models favored 5x5 kernels combined with lower filter counts to prioritize a wider receptive field for capturing anatomical structures.
git clone https://github.com/Ashwashhere/RLMedNAS.git
cd RLMedNAS
pip install torch torchvision pandas medmnist matplotlib numpy
To execute the comparison between Random Search, RL Baseline, and RL Masked:
python nas_experiment.py
-
Scaling to High-Res: Transferring the verified RL framework to the full CirrMRI600+ dataset for liver cirrhosis staging.
-
Explainability: Integrating Grad-CAM to visualize why specific architectural choices focus on specific liver regions.
-
Hybrid Optimisation: Combining RL for graph structure with Bayesian Optimization for scalar hyperparameter fine-tuning.