Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ Text classification in Ruby. Five algorithms, native performance, streaming supp

| | This Gem | Other Forks |
|:--|:--|:--|
| **Algorithms** | 5 classifiers | 2 only |
| **Incremental LSI** | Brand's algorithm (no rebuild) | Full SVD rebuild on every add |
| **LSI Performance** | Native C extension (5-50x faster) | Pure Ruby or requires GSL |
| **Streaming** | Train on multi-GB datasets | Must load all data in memory |
| **Persistence** | Pluggable (file, Redis, S3) | Marshal only |
| **Algorithms** | 5 classifiers | 2 only |
| **Incremental LSI** | Brand's algorithm (no rebuild) | Full SVD rebuild on every add |
| **LSI Performance** | Native C extension (5-50x faster) | Pure Ruby or requires GSL |
| **Streaming** | Train on multi-GB datasets | Must load all data in memory |
| **Persistence** | Pluggable (file, Redis, S3) | Marshal only |

## Installation

Expand All @@ -42,7 +42,7 @@ classifier = Classifier::LogisticRegression.new(:positive, :negative)
classifier.train(positive: "Great product!", negative: "Terrible experience")
classifier.classify "Loved it!" # => "Positive"
```
[Logistic Regression Guide →](https://rubyclassifier.com/docs/guides/logistic-regression/basics)
[Logistic Regression Guide →](https://rubyclassifier.com/docs/guides/logisticregression/basics)

### LSI (Latent Semantic Indexing)

Expand Down
5 changes: 3 additions & 2 deletions lib/classifier/bayes.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class Bayes # rubocop:disable Metrics/ClassLength
# The class can be created with one or more categories, each of which will be
# initialized and given a training method. E.g.,
# b = Classifier::Bayes.new 'Interesting', 'Uninteresting', 'Spam'
# @rbs (*String | Symbol) -> void
# b = Classifier::Bayes.new ['Interesting', 'Uninteresting', 'Spam']
# @rbs (*String | Symbol | Array[String | Symbol]) -> void
def initialize(*categories)
super()
@categories = {}
categories.each { |category| @categories[category.prepare_category_name] = {} }
categories.flatten.each { |category| @categories[category.prepare_category_name] = {} }
@total_words = 0
@category_counts = Hash.new(0)
@category_word_count = Hash.new(0)
Expand Down
4 changes: 3 additions & 1 deletion lib/classifier/logistic_regression.rb
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ class LogisticRegression # rubocop:disable Metrics/ClassLength
#
# classifier = Classifier::LogisticRegression.new(:spam, :ham)
# classifier = Classifier::LogisticRegression.new('Positive', 'Negative', 'Neutral')
# classifier = Classifier::LogisticRegression.new(['Positive', 'Negative', 'Neutral'])
#
# Options:
# - learning_rate: Step size for gradient descent (default: 0.1)
# - regularization: L2 regularization strength (default: 0.01)
# - max_iterations: Maximum training iterations (default: 100)
# - tolerance: Convergence threshold (default: 1e-4)
#
# @rbs (*String | Symbol, ?learning_rate: Float, ?regularization: Float,
# @rbs (*String | Symbol | Array[String | Symbol], ?learning_rate: Float, ?regularization: Float,
# ?max_iterations: Integer, ?tolerance: Float) -> void
def initialize(*categories, learning_rate: DEFAULT_LEARNING_RATE,
regularization: DEFAULT_REGULARIZATION,
max_iterations: DEFAULT_MAX_ITERATIONS,
tolerance: DEFAULT_TOLERANCE)
super()
categories = categories.flatten
raise ArgumentError, 'At least two categories required' if categories.size < 2

@categories = categories.map { |c| c.to_s.prepare_category_name }
Expand Down
11 changes: 11 additions & 0 deletions test/bayes/bayesian_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ def test_categories
assert_equal %w[Interesting Uninteresting].sort, @classifier.categories.sort
end

def test_array_initialization
classifier = Classifier::Bayes.new(%w[Spam Ham])

assert_equal %w[Ham Spam], classifier.categories.sort

classifier.train_spam 'bad nasty spam email'
classifier.train_ham 'good legitimate email'

assert_equal 'Spam', classifier.classify('this is spam')
end

def test_add_category
@classifier.add_category 'Test'

Expand Down
10 changes: 10 additions & 0 deletions test/logistic_regression/logistic_regression_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def test_accepts_symbols_and_strings
assert_equal %w[Spam Ham].sort, classifier2.categories.sort
end

def test_accepts_array_of_categories
classifier = Classifier::LogisticRegression.new(%w[Spam Ham])

assert_equal %w[Ham Spam], classifier.categories.sort
end

def test_array_initialization_requires_at_least_two
assert_raises(ArgumentError) { Classifier::LogisticRegression.new(['Only']) }
end

def test_custom_hyperparameters
classifier = Classifier::LogisticRegression.new(
:spam, :ham,
Expand Down