Skip to content

Commit 39607d1

Browse files
feat: add naive-bayes algorithm in machine learning (#997)
1 parent 5a4e21f commit 39607d1

File tree

3 files changed

+291
-0
lines changed

3 files changed

+291
-0
lines changed

DIRECTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@
207207
* [K-Nearest Neighbors](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_nearest_neighbors.rs)
208208
* [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs)
209209
* [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs)
210+
* [Naive Bayes](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/naive_bayes.rs)
210211
* Loss Function
211212
* [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs)
212213
* [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs)

src/machine_learning/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod k_nearest_neighbors;
44
mod linear_regression;
55
mod logistic_regression;
66
mod loss_function;
7+
mod naive_bayes;
78
mod optimization;
89

910
pub use self::cholesky::cholesky;
@@ -18,5 +19,6 @@ pub use self::loss_function::kld_loss;
1819
pub use self::loss_function::mae_loss;
1920
pub use self::loss_function::mse_loss;
2021
pub use self::loss_function::neg_log_likelihood;
22+
pub use self::naive_bayes::naive_bayes;
2123
pub use self::optimization::gradient_descent;
2224
pub use self::optimization::Adam;
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
/// Naive Bayes classifier for classification tasks.
2+
/// This implementation uses Gaussian Naive Bayes, which assumes that
3+
/// features follow a normal (Gaussian) distribution.
4+
/// The algorithm calculates class priors and feature statistics (mean and variance)
5+
/// for each class, then uses Bayes' theorem to predict class probabilities.
6+
7+
pub struct ClassStatistics {
8+
pub class_label: f64,
9+
pub prior: f64,
10+
pub feature_means: Vec<f64>,
11+
pub feature_variances: Vec<f64>,
12+
}
13+
14+
fn calculate_class_statistics(
15+
training_data: &[(Vec<f64>, f64)],
16+
class_label: f64,
17+
num_features: usize,
18+
) -> Option<ClassStatistics> {
19+
let class_samples: Vec<&(Vec<f64>, f64)> = training_data
20+
.iter()
21+
.filter(|(_, label)| (*label - class_label).abs() < 1e-10)
22+
.collect();
23+
24+
if class_samples.is_empty() {
25+
return None;
26+
}
27+
28+
let prior = class_samples.len() as f64 / training_data.len() as f64;
29+
30+
let mut feature_means = vec![0.0; num_features];
31+
let mut feature_variances = vec![0.0; num_features];
32+
33+
// Calculate means
34+
for (features, _) in &class_samples {
35+
for (i, &feature) in features.iter().enumerate() {
36+
if i < num_features {
37+
feature_means[i] += feature;
38+
}
39+
}
40+
}
41+
42+
let n = class_samples.len() as f64;
43+
for mean in &mut feature_means {
44+
*mean /= n;
45+
}
46+
47+
// Calculate variances
48+
for (features, _) in &class_samples {
49+
for (i, &feature) in features.iter().enumerate() {
50+
if i < num_features {
51+
let diff = feature - feature_means[i];
52+
feature_variances[i] += diff * diff;
53+
}
54+
}
55+
}
56+
57+
let epsilon = 1e-9;
58+
for variance in &mut feature_variances {
59+
*variance = (*variance / n).max(epsilon);
60+
}
61+
62+
Some(ClassStatistics {
63+
class_label,
64+
prior,
65+
feature_means,
66+
feature_variances,
67+
})
68+
}
69+
70+
fn gaussian_log_pdf(x: f64, mean: f64, variance: f64) -> f64 {
71+
let diff = x - mean;
72+
let exponent_term = -(diff * diff) / (2.0 * variance);
73+
let log_coefficient = -0.5 * (2.0 * std::f64::consts::PI * variance).ln();
74+
log_coefficient + exponent_term
75+
}
76+
77+
pub fn train_naive_bayes(training_data: Vec<(Vec<f64>, f64)>) -> Option<Vec<ClassStatistics>> {
78+
if training_data.is_empty() {
79+
return None;
80+
}
81+
82+
let num_features = training_data[0].0.len();
83+
if num_features == 0 {
84+
return None;
85+
}
86+
87+
// Verify all samples have the same number of features
88+
if !training_data
89+
.iter()
90+
.all(|(features, _)| features.len() == num_features)
91+
{
92+
return None;
93+
}
94+
95+
// Get unique class labels
96+
let mut unique_classes = Vec::new();
97+
for (_, label) in &training_data {
98+
if !unique_classes
99+
.iter()
100+
.any(|&c: &f64| (c - *label).abs() < 1e-10)
101+
{
102+
unique_classes.push(*label);
103+
}
104+
}
105+
106+
let mut class_stats = Vec::new();
107+
108+
for class_label in unique_classes {
109+
if let Some(mut stats) =
110+
calculate_class_statistics(&training_data, class_label, num_features)
111+
{
112+
stats.class_label = class_label;
113+
class_stats.push(stats);
114+
}
115+
}
116+
117+
if class_stats.is_empty() {
118+
return None;
119+
}
120+
121+
Some(class_stats)
122+
}
123+
124+
pub fn predict_naive_bayes(model: &[ClassStatistics], test_point: &[f64]) -> Option<f64> {
125+
if model.is_empty() || test_point.is_empty() {
126+
return None;
127+
}
128+
129+
// Get number of features from the first class statistics
130+
let num_features = model[0].feature_means.len();
131+
if test_point.len() != num_features {
132+
return None;
133+
}
134+
135+
let mut best_class = None;
136+
let mut best_log_prob = f64::NEG_INFINITY;
137+
138+
for stats in model {
139+
// Calculate log probability to avoid underflow
140+
let mut log_prob = stats.prior.ln();
141+
142+
for (i, &feature) in test_point.iter().enumerate() {
143+
if i < stats.feature_means.len() && i < stats.feature_variances.len() {
144+
// Use log PDF directly to avoid numerical underflow
145+
log_prob +=
146+
gaussian_log_pdf(feature, stats.feature_means[i], stats.feature_variances[i]);
147+
}
148+
}
149+
150+
if log_prob > best_log_prob {
151+
best_log_prob = log_prob;
152+
best_class = Some(stats.class_label);
153+
}
154+
}
155+
156+
best_class
157+
}
158+
159+
pub fn naive_bayes(training_data: Vec<(Vec<f64>, f64)>, test_point: Vec<f64>) -> Option<f64> {
160+
let model = train_naive_bayes(training_data)?;
161+
predict_naive_bayes(&model, &test_point)
162+
}
163+
164+
#[cfg(test)]
165+
mod tests {
166+
use super::*;
167+
168+
#[test]
169+
fn test_naive_bayes_simple_classification() {
170+
let training_data = vec![
171+
(vec![1.0, 1.0], 0.0),
172+
(vec![1.1, 1.0], 0.0),
173+
(vec![1.0, 1.1], 0.0),
174+
(vec![5.0, 5.0], 1.0),
175+
(vec![5.1, 5.0], 1.0),
176+
(vec![5.0, 5.1], 1.0),
177+
];
178+
179+
// Test point closer to class 0
180+
let test_point = vec![1.05, 1.05];
181+
let result = naive_bayes(training_data.clone(), test_point);
182+
assert_eq!(result, Some(0.0));
183+
184+
// Test point closer to class 1
185+
let test_point = vec![5.05, 5.05];
186+
let result = naive_bayes(training_data, test_point);
187+
assert_eq!(result, Some(1.0));
188+
}
189+
190+
#[test]
191+
fn test_naive_bayes_one_dimensional() {
192+
let training_data = vec![
193+
(vec![1.0], 0.0),
194+
(vec![1.1], 0.0),
195+
(vec![1.2], 0.0),
196+
(vec![5.0], 1.0),
197+
(vec![5.1], 1.0),
198+
(vec![5.2], 1.0),
199+
];
200+
201+
let test_point = vec![1.15];
202+
let result = naive_bayes(training_data.clone(), test_point);
203+
assert_eq!(result, Some(0.0));
204+
205+
let test_point = vec![5.15];
206+
let result = naive_bayes(training_data, test_point);
207+
assert_eq!(result, Some(1.0));
208+
}
209+
210+
#[test]
211+
fn test_naive_bayes_empty_training_data() {
212+
let training_data = vec![];
213+
let test_point = vec![1.0, 2.0];
214+
let result = naive_bayes(training_data, test_point);
215+
assert_eq!(result, None);
216+
}
217+
218+
#[test]
219+
fn test_naive_bayes_empty_test_point() {
220+
let training_data = vec![(vec![1.0, 2.0], 0.0)];
221+
let test_point = vec![];
222+
let result = naive_bayes(training_data, test_point);
223+
assert_eq!(result, None);
224+
}
225+
226+
#[test]
227+
fn test_naive_bayes_dimension_mismatch() {
228+
let training_data = vec![(vec![1.0, 2.0], 0.0), (vec![3.0, 4.0], 1.0)];
229+
let test_point = vec![1.0]; // Wrong dimension
230+
let result = naive_bayes(training_data, test_point);
231+
assert_eq!(result, None);
232+
}
233+
234+
#[test]
235+
fn test_naive_bayes_inconsistent_feature_dimensions() {
236+
let training_data = vec![
237+
(vec![1.0, 2.0], 0.0),
238+
(vec![3.0], 1.0), // Different dimension
239+
];
240+
let test_point = vec![1.0, 2.0];
241+
let result = naive_bayes(training_data, test_point);
242+
assert_eq!(result, None);
243+
}
244+
245+
#[test]
246+
fn test_naive_bayes_multiple_classes() {
247+
let training_data = vec![
248+
(vec![1.0, 1.0], 0.0),
249+
(vec![1.1, 1.0], 0.0),
250+
(vec![5.0, 5.0], 1.0),
251+
(vec![5.1, 5.0], 1.0),
252+
(vec![9.0, 9.0], 2.0),
253+
(vec![9.1, 9.0], 2.0),
254+
];
255+
256+
let test_point = vec![1.05, 1.05];
257+
let result = naive_bayes(training_data.clone(), test_point);
258+
assert_eq!(result, Some(0.0));
259+
260+
let test_point = vec![5.05, 5.05];
261+
let result = naive_bayes(training_data.clone(), test_point);
262+
assert_eq!(result, Some(1.0));
263+
264+
let test_point = vec![9.05, 9.05];
265+
let result = naive_bayes(training_data, test_point);
266+
assert_eq!(result, Some(2.0));
267+
}
268+
269+
#[test]
270+
fn test_train_and_predict_separately() {
271+
let training_data = vec![
272+
(vec![1.0, 1.0], 0.0),
273+
(vec![1.1, 1.0], 0.0),
274+
(vec![5.0, 5.0], 1.0),
275+
(vec![5.1, 5.0], 1.0),
276+
];
277+
278+
let model = train_naive_bayes(training_data);
279+
assert!(model.is_some());
280+
281+
let model = model.unwrap();
282+
assert_eq!(model.len(), 2);
283+
284+
let test_point = vec![1.05, 1.05];
285+
let result = predict_naive_bayes(&model, &test_point);
286+
assert_eq!(result, Some(0.0));
287+
}
288+
}

0 commit comments

Comments
 (0)