Skip to content

Ashwashhere/RLMedNAS

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 

Repository files navigation

RLMedNAS: Reinforcement Learning-driven Neural Architecture Search

RLMedNAS is an autonomous decision-making system designed to engineer optimal Convolutional Neural Network (CNN) topologies for medical image classification. By utilising a Reinforcement Learning (RL) agent, the project replaces labor-intensive manual architecture design with a mathematically optimised, sequential decision-making process.

🔬 Project Overview

The manual design of CNNs for medical data is often biased toward human-manageable structures rather than mathematical optimality. RLMedNAS addresses this by employing a Recurrent Neural Network (RNN) Controller trained via the REINFORCE policy gradient algorithm to explore a discrete, block-based search space.

To ensure clinical feasibility on resource-constrained devices, the system integrates Action Masking—a neuro-symbolic approach that prunes domain-inefficient architectures (such as overly wide networks) before training begins.


🛠 Key Features

  • RNN Controller: A single-layer LSTM with 100 hidden units that emits architectural tokens.

  • Neuro-Symbolic Action Masking: Enforces symbolic "validity rules" by setting forbidden action indices (e.g., the 128-filter count) to a large negative constant ($-1e10$) before the SoftMax operation.

  • Proxy Evaluation: Uses the OrganAMNIST dataset (28x28 grayscale CT images) as a lightweight proxy to enable rapid policy convergence.

  • Hardware Optimized: Native acceleration for Apple Silicon via the Metal Performance Shaders (MPS) backend in PyTorch.


🔢 Mathematical Framework

The Controller navigates the search space by maximising a multi-objective reward signal that balances diagnostic precision with computational parsimony.

The Reward Function

The composite utility score $R$ is defined as:

$$R = \text{ValAccuracy} - \lambda \times \text{TotalParameters}$$

  • Validation Accuracy: The primary performance metric on the OrganAMNIST validation set.

  • Penalty Weight ($\lambda$): A small constant ($10^{-6}$) used to penalize overly complex architectures.

Policy Optimization

The parameters $\theta_c$ are refined using the REINFORCE algorithm to maximize the expected reward $J(\theta_c)$:

$$\nabla_{\theta_c}J(\theta_c)=E_{\tau\sim\pi_{\theta_c}}\left[\sum_{t=1}^T\nabla_{\theta_c}\log P(a_t|a_{t-1:1};\theta_c)(R-b)\right]$$

A moving average baseline $b$ is subtracted to reduce variance and ensure stability.


📊 Search Space & Performance

The agent samples three critical hyperparameters for the "Child Network":

  • Filter Count: {16, 32, 64, 128}.

  • Kernel Size: {3x3, 5x5}.

  • Dropout Rate: {0.1, 0.25, 0.5}.

Quantitative Results

Search Strategy Mean Reward Mean Parameters Mean Accuracy
Random Search 0.8215 132,503 95.40%
RL Baseline 0.8571 98,987 95.61%
RL Masked 0.9084 46,557 95.49%

Key Finding: The RL Masked agent discovered architectures approximately 3x smaller than the random baseline while achieving superior mean rewards.


💡 Knowledge Discovery

The agent actively learned specific design rules that align with expert intuition for medical texture classification:

  • High Regularization: RL agents developed a strong consensus for a 0.5 Dropout Rate, identifying it as critical for preventing overfitting.

  • Efficiency over Width: The agent rejected the 128-filter count, discovering that 16 filters provided sufficient representational capacity for the input space.

  • Receptive Field Trade-offs: "Champion" models favored 5x5 kernels combined with lower filter counts to prioritize a wider receptive field for capturing anatomical structures.


💻 Getting Started

Installation

git clone https://github.com/Ashwashhere/RLMedNAS.git
cd RLMedNAS
pip install torch torchvision pandas medmnist matplotlib numpy

Running the Search

To execute the comparison between Random Search, RL Baseline, and RL Masked:

python nas_experiment.py

🗺️ Clinical Roadmap

  • Scaling to High-Res: Transferring the verified RL framework to the full CirrMRI600+ dataset for liver cirrhosis staging.

  • Explainability: Integrating Grad-CAM to visualize why specific architectural choices focus on specific liver regions.

  • Hybrid Optimisation: Combining RL for graph structure with Bayesian Optimization for scalar hyperparameter fine-tuning.

About

RLMedNAS addresses the systemic inefficiency of manual neural network design in medical imaging. Traditional "human-in-the-loop" design is labor-intensive and often biased toward conceptually simple rather than mathematically optimal structures.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages