-
Notifications
You must be signed in to change notification settings - Fork 3
Training Guide
This guide details how to fine-tune or train the Kiri OCR recognition model from scratch.
Kiri OCR supports two data formats:
Store your dataset on the Hugging Face Hub. The dataset should have:
- An image column (containing image objects or paths).
- A text column (containing the ground truth string).
Example dataset: mrrtmob/km_en_image_line
Organize your local files:
data/
├── train/
│ ├── images/
│ │ ├── 001.jpg
│ │ └── ...
│ └── labels.txt
└── val/
├── images/
│ ├── 001.jpg
│ └── ...
└── labels.txt
Format of labels.txt:
Each line contains the image path and the text label, separated by a TAB character.
data/train/images/001.jpg Hello World
data/train/images/002.jpg Khmer Text Here
Use the kiri-ocr train command.
kiri-ocr train \
--hf-dataset mrrtmob/km_en_image_line \
--output-dir output_model \
--epochs 100 \
--batch-size 32 \
--device cudaYou can combine multiple datasets by passing them as a space-separated list. This is useful for mixing different languages or data sources. The trainer will automatically concatenate them.
kiri-ocr train \
--hf-dataset mrrtmob/khmer_text_v1 mrrtmob/english_text_v2 \
--output-dir output_model \
--epochs 100Note: All datasets must share the same schema (column names for image and text).
kiri-ocr train \
--train-labels data/train/labels.txt \
--val-labels data/val/labels.txt \
--output-dir output_model \
--device cudaYou can customize training parameters via command line or a YAML config file. The config file allows you to store and reuse complex configurations.
kiri-ocr init-config -o training_config.yamlHere is a complete reference of the available options in config.yaml:
# ==========================================
# Kiri OCR Training Configuration
# ==========================================
# --- Image Dimensions ---
# Must match what your model expects. Default is 48x640.
height: 48
width: 640
# --- Training Hyperparameters ---
batch_size: 32 # Reduce if running out of VRAM
epochs: 100 # Total training passes
lr: 0.0003 # Learning rate (AdamW optimizer)
weight_decay: 0.01 # Regularization to prevent overfitting
# --- Loss Balancing ---
# Weights for the two loss functions. Sum typically 1.0.
ctc_weight: 0.5 # Alignment-free loss (fast convergence)
dec_weight: 0.5 # Attention decoder loss (better semantics)
# --- Sequence Handling ---
# Truncate sequences longer than this to prevent OOM
max_seq_len: 512
# --- Paths & IO ---
output_dir: models # Where to save checkpoints
save_steps: 1000 # Save freq (steps). 0 = save every epoch only.
resume: false # Resume from 'latest.safetensors' in output_dir
# from_model: path/to/model.pt # Initialize from specific checkpoint
# --- Hardware ---
device: cuda # 'cuda' for GPU, 'cpu' for CPU
# --- Dataset Configuration ---
# OPTION 1: Hugging Face Dataset (Recommended)
hf_dataset: mrrtmob/khmer_english_ocr_image_line
hf_train_split: train
hf_val_split: validation # Optional, auto-detected if omitted
hf_image_col: image
hf_text_col: text
# OPTION 2: Local Files (Comment out HF options if using this)
# train_labels: data/train/labels.txt
# val_labels: data/val/labels.txt
# vocab: vocab.json # Optional, auto-generated if missing
# --- Model Architecture Customization ---
# Leave commented to use defaults (Small: 256 dim, 4/3 layers)
# Encoder (Transformer)
# encoder_dim: 256
# encoder_heads: 8
# encoder_layers: 4
# encoder_ffn_dim: 1024
# Decoder (Transformer)
# decoder_dim: 256
# decoder_heads: 8
# decoder_layers: 3
# decoder_ffn_dim: 1024
# Regularization
# dropout: 0.15Once you have edited your config file, pass it to the training command:
kiri-ocr train --config training_config.yamlNote: You can override config file values by passing command line arguments. CLI arguments take precedence.
# Uses config.yaml but overrides epochs to 50
kiri-ocr train --config config.yaml --epochs 50You can customize the Transformer encoder and decoder architecture to balance between speed and accuracy.
| Argument | Default | Description |
|---|---|---|
--encoder-dim |
256 | Encoder hidden dimension |
--encoder-heads |
8 | Encoder attention heads |
--encoder-layers |
4 | Number of encoder layers |
--encoder-ffn-dim |
1024 | Encoder feedforward dimension |
--decoder-dim |
256 | Decoder hidden dimension |
--decoder-heads |
8 | Decoder attention heads |
--decoder-layers |
3 | Number of decoder layers |
--decoder-ffn-dim |
1024 | Decoder feedforward dimension |
--dropout |
0.15 | Dropout rate |
For resource-constrained environments or when speed is critical:
kiri-ocr train \
--hf-dataset mrrtmob/khmer_english_ocr_image_line \
--encoder-dim 128 \
--encoder-layers 3 \
--encoder-ffn-dim 512 \
--decoder-dim 128 \
--decoder-layers 2 \
--decoder-ffn-dim 512 \
--batch-size 128 \
--device cudaFor maximum accuracy when you have ample GPU memory:
kiri-ocr train \
--hf-dataset mrrtmob/khmer_english_ocr_image_line \
--encoder-dim 512 \
--encoder-layers 8 \
--encoder-ffn-dim 2048 \
--decoder-dim 512 \
--decoder-layers 6 \
--decoder-ffn-dim 2048 \
--batch-size 8 \
--device cuda| Config | Params | VRAM | Speed | Use Case |
|---|---|---|---|---|
| Tiny (dim=128) | ~2M | ~500MB | Fast | Mobile, embedded |
| Small (dim=256) | ~8M | ~2GB | Medium | Default, good balance |
| Medium (dim=384) | ~18M | ~4GB | Slower | Higher accuracy |
| Large (dim=512) | ~32M | ~8GB | Slow | Best accuracy |
| XL (dim=768) | ~72M | ~16GB | Very Slow | Complex scripts |
Tips:
-
encoder_dimmust be divisible byencoder_heads - FFN dimension is typically 4x the hidden dimension
- Match encoder/decoder dims for better results
- Reduce
--batch-sizewhen using larger models
During training, the system logs:
- Loss: Total loss (CTC + Decoder).
- Val Acc: Validation accuracy (character-level).
Epoch 50/100 | Loss: 0.60 (CTC: 0.7, Dec: 0.5) | Val Acc: 78%
The --output-dir will contain:
-
model.safetensors: The best model weights. -
vocab.json: The character vocabulary (crucial for inference!). -
model_meta.json: Metadata about the model architecture. -
history.json: Training metrics log.
To use your trained model with the Python API:
from kiri_ocr import OCR
ocr = OCR(model_path="output_model/model.safetensors")
text, _ = ocr.recognize_single_line_image("test_image.jpg")When loading your model, choose the decode method based on your needs:
# Fast CTC for batch processing
ocr = OCR(model_path="output_model/model.safetensors", decode_method="fast")
# Accurate decoder for balanced quality (default)
ocr = OCR(model_path="output_model/model.safetensors", decode_method="accurate")
# Beam search for highest quality
ocr = OCR(model_path="output_model/model.safetensors", decode_method="beam")# Cell 1: Setup
!pip install -q kiri-ocr datasets
from google.colab import drive
drive.mount('/content/drive')
# Cell 2: Train
!kiri-ocr train \
--hf-dataset mrrtmob/khmer_english_ocr_image_line \
--output-dir /content/drive/MyDrive/kiri_models/v1 \
--height 48 \
--width 640 \
--batch-size 32 \
--epochs 100 \
--lr 0.0003 \
--ctc-weight 0.5 \
--dec-weight 0.5 \
--save-steps 5000 \
--device cuda
# Cell 3: Test
from kiri_ocr import OCR
ocr = OCR(
model_path="/content/drive/MyDrive/kiri_models/v1/model.safetensors",
device="cuda"
)
text, confidence = ocr.recognize_single_line_image("test.png")
print(f"'{text}' ({confidence:.1%})")| Dataset Size | Epochs | Time |
|---|---|---|
| 10K samples | 100 | ~10 hours |
| 50K samples | 100 | ~24 hours |
| 100K samples | 100 | ~48 hours |
Cause: Model not trained enough (need 50-100 epochs minimum)
# Check your model metadata
python -c "
import json
with open('model_meta.json') as f:
meta = json.load(f)
print(f\"Epoch: {meta.get('epoch', 'unknown')}\")
print(f\"Step: {meta.get('step', 'unknown')}\")
"Cause: vocab.json must be in the same directory as the model
ls output/
# Should show: vocab.json, model.safetensors, model_meta.json, etc.Cause 1: Batch size too large
kiri-ocr train --batch-size 16 ...Cause 2: Very long text sequences in dataset
# Limit maximum sequence length
kiri-ocr train --max-seq-len 256 ...Try different decode methods:
# CTC tends to give more reliable confidence scores
ocr = OCR(model_path="model.safetensors", decode_method="fast")
# Or use decoder for better accuracy but slower
ocr = OCR(model_path="model.safetensors", decode_method="accurate")kiri-ocr train \
--hf-dataset mrrtmob/khmer_english_ocr_image_line \
--output-dir output \
--resume \
--device cudaKiri OCR Home | GitHub Repository | Report Issue
© 2026 Kiri OCR. Released under the Apache 2.0 License.