From 0e4eab6bf3a101f8d55e5b661fef742e5ac47698 Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Mon, 29 Dec 2025 08:43:57 -0800 Subject: [PATCH 1/2] fix: accept array of categories in classifier initialization Bayes.new and LogisticRegression.new now accept either: - Splat arguments: Bayes.new('Spam', 'Ham') - Array argument: Bayes.new(['Spam', 'Ham']) Both forms are now equivalent, fixing the issue where array arguments were treated as a single category name. Fixes #110 --- lib/classifier/bayes.rb | 5 +++-- lib/classifier/logistic_regression.rb | 4 +++- test/bayes/bayesian_test.rb | 11 +++++++++++ test/logistic_regression/logistic_regression_test.rb | 10 ++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/lib/classifier/bayes.rb b/lib/classifier/bayes.rb index 750aeea..394880c 100644 --- a/lib/classifier/bayes.rb +++ b/lib/classifier/bayes.rb @@ -26,11 +26,12 @@ class Bayes # rubocop:disable Metrics/ClassLength # The class can be created with one or more categories, each of which will be # initialized and given a training method. E.g., # b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam' - # @rbs (*String | Symbol) -> void + # b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam'] + # @rbs (*String | Symbol | Array[String | Symbol]) -> void def initialize(*categories) super() @categories = {} - categories.each { |category| @categories[category.prepare_category_name] = {} } + categories.flatten.each { |category| @categories[category.prepare_category_name] = {} } @total_words = 0 @category_counts = Hash.new(0) @category_word_count = Hash.new(0) diff --git a/lib/classifier/logistic_regression.rb b/lib/classifier/logistic_regression.rb index 04e3f03..314904f 100644 --- a/lib/classifier/logistic_regression.rb +++ b/lib/classifier/logistic_regression.rb @@ -46,6 +46,7 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength # # classifier = Classifier::LogisticRegression.new(:spam, :ham) # classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral') + # classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral']) # # Options: # - learning_rate: Step size for gradient descent (default: 0.1) @@ -53,13 +54,14 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength # - max_iterations: Maximum training iterations (default: 100) # - tolerance: Convergence threshold (default: 1e-4) # - # @rbs (*String | Symbol, ?learning_rate: Float, ?regularization: Float, + # @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float, # ?max_iterations: Integer, ?tolerance: Float) -> void def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE, regularization: DEFAULT_REGULARIZATION, max_iterations: DEFAULT_MAX_ITERATIONS, tolerance: DEFAULT_TOLERANCE) super() + categories = categories.flatten raise ArgumentError, 'At least two categories required' if categories.size < 2 @categories = categories.map { |c| c.to_s.prepare_category_name } diff --git a/test/bayes/bayesian_test.rb b/test/bayes/bayesian_test.rb index 411fa62..43838a7 100644 --- a/test/bayes/bayesian_test.rb +++ b/test/bayes/bayesian_test.rb @@ -17,6 +17,17 @@ def test_categories assert_equal %w[Interesting Uninteresting].sort, @classifier.categories.sort end + def test_array_initialization + classifier = Classifier::Bayes.new(%w[Spam Ham]) + + assert_equal %w[Ham Spam], classifier.categories.sort + + classifier.train_spam 'bad nasty spam email' + classifier.train_ham 'good legitimate email' + + assert_equal 'Spam', classifier.classify('this is spam') + end + def test_add_category @classifier.add_category 'Test' diff --git a/test/logistic_regression/logistic_regression_test.rb b/test/logistic_regression/logistic_regression_test.rb index 28deaf4..419121f 100644 --- a/test/logistic_regression/logistic_regression_test.rb +++ b/test/logistic_regression/logistic_regression_test.rb @@ -19,6 +19,16 @@ def test_accepts_symbols_and_strings assert_equal %w[Spam Ham].sort, classifier2.categories.sort end + def test_accepts_array_of_categories + classifier = Classifier::LogisticRegression.new(%w[Spam Ham]) + + assert_equal %w[Ham Spam], classifier.categories.sort + end + + def test_array_initialization_requires_at_least_two + assert_raises(ArgumentError) { Classifier::LogisticRegression.new(['Only']) } + end + def test_custom_hyperparameters classifier = Classifier::LogisticRegression.new( :spam, :ham, From 34c606e72a942969c10a48856e69f3608d7bd9c3 Mon Sep 17 00:00:00 2001 From: Lucas Carlson Date: Mon, 29 Dec 2025 08:49:13 -0800 Subject: [PATCH 2/2] docs: add visual indicators to comparison table and fix URL Add checkmark and X emoji to the Why This Library comparison table for better visual scanning. Fix Logistic Regression guide URL to use correct path (logisticregression without hyphen). --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index bbca62a..0cfb9a8 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@ Text classification in Ruby. Five algorithms, native performance, streaming supp | | This Gem | Other Forks | |:--|:--|:--| -| **Algorithms** | 5 classifiers | 2 only | -| **Incremental LSI** | Brand's algorithm (no rebuild) | Full SVD rebuild on every add | -| **LSI Performance** | Native C extension (5-50x faster) | Pure Ruby or requires GSL | -| **Streaming** | Train on multi-GB datasets | Must load all data in memory | -| **Persistence** | Pluggable (file, Redis, S3) | Marshal only | +| **Algorithms** | ✅ 5 classifiers | ❌ 2 only | +| **Incremental LSI** | ✅ Brand's algorithm (no rebuild) | ❌ Full SVD rebuild on every add | +| **LSI Performance** | ✅ Native C extension (5-50x faster) | ❌ Pure Ruby or requires GSL | +| **Streaming** | ✅ Train on multi-GB datasets | ❌ Must load all data in memory | +| **Persistence** | ✅ Pluggable (file, Redis, S3) | ❌ Marshal only | ## Installation @@ -42,7 +42,7 @@ classifier = Classifier::LogisticRegression.new(:positive, :negative) classifier.train(positive: "Great product!", negative: "Terrible experience") classifier.classify "Loved it!" # => "Positive" ``` -[Logistic Regression Guide →](https://rubyclassifier.com/docs/guides/logistic-regression/basics) +[Logistic Regression Guide →](https://rubyclassifier.com/docs/guides/logisticregression/basics) ### LSI (Latent Semantic Indexing)