Skip to content

Commit caf7923

Browse files
committed
feat(bayes): add keyword argument API for train and untrain
Replace method_missing magic with explicit keyword arguments as the primary training API. This provides better IDE discoverability and a more modern Ruby interface. New API: classifier.train(spam: "text") classifier.train(spam: ["msg1", "msg2"], ham: ["msg3"]) Batch training and multi-category training in a single call are now supported. Legacy positional args and dynamic methods still work for backwards compatibility. Closes #96
1 parent 90c12a1 commit caf7923

File tree

3 files changed

+154
-50
lines changed

3 files changed

+154
-50
lines changed

README.md

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,17 @@ Fast, accurate classification with modest memory requirements. Ideal for spam fi
8181
```ruby
8282
require 'classifier'
8383

84-
classifier = Classifier::Bayes.new('Spam', 'Ham')
84+
classifier = Classifier::Bayes.new(:spam, :ham)
8585

86-
# Train the classifier
87-
classifier.train_spam "Buy cheap viagra now! Limited offer!"
88-
classifier.train_spam "You've won a million dollars! Claim now!"
89-
classifier.train_ham "Meeting scheduled for tomorrow at 10am"
90-
classifier.train_ham "Please review the attached document"
86+
# Train with keyword arguments
87+
classifier.train(spam: "Buy cheap viagra now! Limited offer!")
88+
classifier.train(ham: "Meeting scheduled for tomorrow at 10am")
89+
90+
# Train multiple items at once
91+
classifier.train(
92+
spam: ["You've won a million dollars!", "Free money!!!"],
93+
ham: ["Please review the document", "Lunch tomorrow?"]
94+
)
9195

9296
# Classify new text
9397
classifier.classify "Congratulations! You've won a prize!"
@@ -152,9 +156,9 @@ Save and load trained classifiers with pluggable storage backends. Works with bo
152156
```ruby
153157
require 'classifier'
154158

155-
classifier = Classifier::Bayes.new('Spam', 'Ham')
156-
classifier.train_spam "Buy now! Limited offer!"
157-
classifier.train_ham "Meeting tomorrow at 3pm"
159+
classifier = Classifier::Bayes.new(:spam, :ham)
160+
classifier.train(spam: "Buy now! Limited offer!")
161+
classifier.train(ham: "Meeting tomorrow at 3pm")
158162

159163
# Configure storage and save
160164
classifier.storage = Classifier::Storage::File.new(path: "spam_filter.json")

lib/classifier/bayes.rb

Lines changed: 72 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -41,57 +41,45 @@ def initialize(*categories)
4141

4242
# Provides a general training method for all categories specified in Bayes#new
4343
# For example:
44-
# b = Classifier::Bayes.new 'This', 'That', 'the_other'
45-
# b.train :this, "This text"
46-
# b.train "that", "That text"
47-
# b.train "The other", "The other text"
44+
# b = Classifier::Bayes.new :spam, :ham
4845
#
49-
# @rbs (String | Symbol, String) -> void
50-
def train(category, text)
51-
category = category.prepare_category_name
52-
word_hash = text.word_hash
53-
synchronize do
54-
invalidate_caches
55-
@dirty = true
56-
@category_counts[category] += 1
57-
word_hash.each do |word, count|
58-
@categories[category][word] ||= 0
59-
@categories[category][word] += count
60-
@total_words += count
61-
@category_word_count[category] += count
62-
end
46+
# # Keyword argument API (preferred)
47+
# b.train(spam: "Buy cheap viagra now!!!")
48+
# b.train(spam: ["msg1", "msg2"], ham: ["msg3", "msg4"])
49+
#
50+
# # Positional argument API (legacy)
51+
# b.train :spam, "This text"
52+
# b.train "ham", "That text"
53+
#
54+
# @rbs (?String | Symbol, ?String, **String | Array[String]) -> void
55+
def train(category = nil, text = nil, **categories)
56+
return train_single(category, text) if category
57+
58+
categories.each do |cat, texts|
59+
Array(texts).each { |t| train_single(cat, t) }
6360
end
6461
end
6562

6663
# Provides a untraining method for all categories specified in Bayes#new
6764
# Be very careful with this method.
6865
#
6966
# For example:
70-
# b = Classifier::Bayes.new 'This', 'That', 'the_other'
71-
# b.train :this, "This text"
72-
# b.untrain :this, "This text"
67+
# b = Classifier::Bayes.new :spam, :ham
7368
#
74-
# @rbs (String | Symbol, String) -> void
75-
def untrain(category, text)
76-
category = category.prepare_category_name
77-
word_hash = text.word_hash
78-
synchronize do
79-
invalidate_caches
80-
@dirty = true
81-
@category_counts[category] -= 1
82-
word_hash.each do |word, count|
83-
next unless @total_words >= 0
69+
# # Keyword argument API (preferred)
70+
# b.train(spam: "Buy cheap viagra now!!!")
71+
# b.untrain(spam: "Buy cheap viagra now!!!")
72+
#
73+
# # Positional argument API (legacy)
74+
# b.train :spam, "This text"
75+
# b.untrain :spam, "This text"
76+
#
77+
# @rbs (?String | Symbol, ?String, **String | Array[String]) -> void
78+
def untrain(category = nil, text = nil, **categories)
79+
return untrain_single(category, text) if category
8480

85-
orig = @categories[category][word] || 0
86-
@categories[category][word] ||= 0
87-
@categories[category][word] -= count
88-
if @categories[category][word] <= 0
89-
@categories[category].delete(word)
90-
count = orig
91-
end
92-
@category_word_count[category] -= count if @category_word_count[category] >= count
93-
@total_words -= count
94-
end
81+
categories.each do |cat, texts|
82+
Array(texts).each { |t| untrain_single(cat, t) }
9583
end
9684
end
9785

@@ -340,6 +328,49 @@ def remove_category(category)
340328

341329
private
342330

331+
# Core training logic for a single category and text.
332+
# @rbs (String | Symbol, String) -> void
333+
def train_single(category, text)
334+
category = category.prepare_category_name
335+
word_hash = text.word_hash
336+
synchronize do
337+
invalidate_caches
338+
@dirty = true
339+
@category_counts[category] += 1
340+
word_hash.each do |word, count|
341+
@categories[category][word] ||= 0
342+
@categories[category][word] += count
343+
@total_words += count
344+
@category_word_count[category] += count
345+
end
346+
end
347+
end
348+
349+
# Core untraining logic for a single category and text.
350+
# @rbs (String | Symbol, String) -> void
351+
def untrain_single(category, text)
352+
category = category.prepare_category_name
353+
word_hash = text.word_hash
354+
synchronize do
355+
invalidate_caches
356+
@dirty = true
357+
@category_counts[category] -= 1
358+
word_hash.each do |word, count|
359+
next unless @total_words >= 0
360+
361+
orig = @categories[category][word] || 0
362+
@categories[category][word] ||= 0
363+
@categories[category][word] -= count
364+
if @categories[category][word] <= 0
365+
@categories[category].delete(word)
366+
count = orig
367+
end
368+
@category_word_count[category] -= count if @category_word_count[category] >= count
369+
@total_words -= count
370+
end
371+
end
372+
end
373+
343374
# Restores classifier state from a JSON string (used by reload)
344375
# @rbs (String) -> void
345376
def restore_from_json(json)

test/bayes/bayesian_test.rb

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,75 @@ def test_untrain_decrements_category_word_count
246246
assert_operator new_word_count, :<, initial_word_count, 'Category word count should decrease'
247247
end
248248

249+
# Keyword argument API tests
250+
251+
def test_train_with_keyword_argument
252+
@classifier.train(interesting: 'good words here')
253+
@classifier.train(uninteresting: 'bad words there')
254+
255+
assert_equal 'Interesting', @classifier.classify('good words')
256+
assert_equal 'Uninteresting', @classifier.classify('bad words')
257+
end
258+
259+
def test_train_with_array_value
260+
@classifier.train(interesting: ['good words', 'great content', 'love this'])
261+
@classifier.train(uninteresting: 'bad stuff')
262+
263+
assert_equal 'Interesting', @classifier.classify('good great love')
264+
end
265+
266+
def test_train_with_multiple_categories
267+
@classifier.train(
268+
interesting: ['good words', 'great content'],
269+
uninteresting: ['bad words', 'boring stuff']
270+
)
271+
272+
assert_equal 'Interesting', @classifier.classify('good great')
273+
assert_equal 'Uninteresting', @classifier.classify('bad boring')
274+
end
275+
276+
def test_untrain_with_keyword_argument
277+
@classifier.train(interesting: 'hello world')
278+
@classifier.untrain(interesting: 'hello world')
279+
280+
category_words = @classifier.instance_variable_get(:@categories)[:Interesting]
281+
282+
assert_empty category_words
283+
end
284+
285+
def test_untrain_with_array_value
286+
@classifier.train(interesting: ['apple', 'banana', 'cherry'])
287+
@classifier.untrain(interesting: ['apple', 'banana'])
288+
289+
result = @classifier.classify('cherry')
290+
291+
assert_equal 'Interesting', result
292+
end
293+
294+
def test_keyword_and_positional_produce_same_result
295+
classifier1 = Classifier::Bayes.new 'Spam', 'Ham'
296+
classifier1.train :spam, 'buy now'
297+
classifier1.train :ham, 'hello friend'
298+
299+
classifier2 = Classifier::Bayes.new 'Spam', 'Ham'
300+
classifier2.train(spam: 'buy now')
301+
classifier2.train(ham: 'hello friend')
302+
303+
assert_equal classifier1.classify('buy'), classifier2.classify('buy')
304+
assert_equal classifier1.classify('hello'), classifier2.classify('hello')
305+
end
306+
307+
def test_keyword_and_dynamic_method_produce_same_result
308+
classifier1 = Classifier::Bayes.new 'Spam', 'Ham'
309+
classifier1.train_spam 'buy now'
310+
classifier1.train_ham 'hello friend'
311+
312+
classifier2 = Classifier::Bayes.new 'Spam', 'Ham'
313+
classifier2.train(spam: 'buy now', ham: 'hello friend')
314+
315+
assert_equal classifier1.classifications('buy'), classifier2.classifications('buy')
316+
end
317+
249318
# Edge case tests
250319

251320
def test_empty_string_training

0 commit comments

Comments
 (0)