Skip to content
4 changes: 2 additions & 2 deletions flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3"
),
"private": False,
},
Expand All @@ -336,7 +336,7 @@
"config": {
"supported_file_types": (
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
".pptx, .csv, .html, .mhtml, .txt, .md, .zip"
".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3"
),
"private": False,
},
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/indices/ingests/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
HtmlReader,
MathpixPDFReader,
MhtmlReader,
MP3Reader,
OCRReader,
PandasExcelReader,
PDFThumbnailReader,
Expand Down Expand Up @@ -53,6 +54,7 @@
".tiff": unstructured,
".tif": unstructured,
".pdf": PDFThumbnailReader(),
".mp3": MP3Reader(),
".txt": TxtReader(),
".md": TxtReader(),
}
Expand Down
2 changes: 2 additions & 0 deletions libs/kotaemon/kotaemon/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .excel_loader import ExcelReader, PandasExcelReader
from .html_loader import HtmlReader, MhtmlReader
from .mathpix_loader import MathpixPDFReader
from .mp3_loader import MP3Reader
from .ocr_loader import ImageReader, OCRReader
from .pdf_loader import PDFThumbnailReader
from .txt_loader import TxtReader
Expand All @@ -30,6 +31,7 @@
"AdobeReader",
"TxtReader",
"PDFThumbnailReader",
"MP3Reader",
"WebReader",
"DoclingReader",
]
101 changes: 101 additions & 0 deletions libs/kotaemon/kotaemon/loaders/mp3_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional

from loguru import logger

from kotaemon.base import Document, Param

from .base import BaseReader

if TYPE_CHECKING:
from transformers import pipeline


class MP3Reader(BaseReader):
model_name_or_path: str = Param(
help="The model name or path to use for speech recognition.",
default="distil-whisper/distil-large-v3",
)
cache_dir: str = Param(
help="The cache directory to use for the model.",
default="models",
)

@Param.auto()
def asr_pipeline(self) -> "pipeline":
"""Setup the ASR pipeline for speech recognition"""
try:
import accelerate # noqa: F401
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
except ImportError:
raise ImportError(
"Please install the required packages to use the MP3Reader: "
"'pip install accelerate torch transformers'"
)

try:
# Device and model configuration
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Model and processor initialization
model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_name_or_path,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=self.cache_dir,
).to(device)

processor = AutoProcessor.from_pretrained(
self.model_name_or_path,
)

# ASR pipeline setup
asr_pipeline = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
return_timestamps=True,
)
logger.info("ASR pipeline setup successful.")
except Exception as e:
logger.error(f"Error occurred during ASR pipeline setup: {e}")
raise

return asr_pipeline

def speech_to_text(self, audio_path: str) -> str:
try:
import librosa

# Performing speech recognition
audio_array, _ = librosa.load(audio_path, sr=16000) # 16kHz sampling rate
result = self.asr_pipeline(audio_array)

text = result.get("text", "").strip()
if text == "":
logger.warning("No text found in the audio file.")
return text
except Exception as e:
logger.error(f"Error occurred during speech recognition: {e}")
return ""

def run(
self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs
) -> list[Document]:
return self.load_data(str(file_path), extra_info=extra_info, **kwargs)

def load_data(
self, audio_file: str, extra_info: Optional[dict] = None, **kwargs
) -> List[Document]:
# Get text from the audio file
text = self.speech_to_text(audio_file)
metadata = extra_info or {}

return [Document(text=text, metadata=metadata)]
13 changes: 13 additions & 0 deletions libs/kotaemon/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def if_llama_cpp_not_installed():
return False


def if_librosa_not_installed():
try:
import librosa # noqa: F401
except ImportError:
return True
else:
return False


skip_when_haystack_not_installed = pytest.mark.skipif(
if_haystack_not_installed(), reason="Haystack is not installed"
)
Expand Down Expand Up @@ -97,3 +106,7 @@ def if_llama_cpp_not_installed():
skip_llama_cpp_not_installed = pytest.mark.skipif(
if_llama_cpp_not_installed(), reason="llama_cpp is not installed"
)

skip_when_librosa_not_installed = pytest.mark.skipif(
if_librosa_not_installed(), reason="librosa is not installed"
)
Binary file added libs/kotaemon/tests/resources/dummy.mp3
Binary file not shown.
21 changes: 20 additions & 1 deletion libs/kotaemon/tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
DocxReader,
HtmlReader,
MhtmlReader,
MP3Reader,
UnstructuredReader,
)

from .conftest import skip_when_unstructured_pdf_not_installed
from .conftest import (
skip_when_librosa_not_installed,
skip_when_unstructured_pdf_not_installed,
)


def test_docx_reader():
Expand Down Expand Up @@ -93,3 +97,18 @@ def test_azureai_document_intelligence_reader(mock_client):

assert len(docs) == 1
mock_client.assert_called_once()


@skip_when_librosa_not_installed
@patch("kotaemon.loaders.MP3Reader.asr_pipeline")
def test_mp3_reader(mock_pipeline):
# Mock the return value
mock_pipeline.return_value = "This is the transcript"

reader = MP3Reader()
docs = reader.load_data(str(Path(__file__).parent / "resources" / "dummy.mp3"))

assert len(docs) == 1

# Assert that the ASR pipeline was called
mock_pipeline.assert_called_once()