Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 99 additions & 33 deletions frontend/src-tauri/src/whisper_engine/whisper_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl WhisperEngine {
.join("models")
}
};

log::info!("WhisperEngine using models directory: {}", models_dir.display());
log::info!("Debug mode: {}", cfg!(debug_assertions));

Expand All @@ -146,7 +146,7 @@ impl WhisperEngine {

#[cfg(feature = "openmp")]
log::info!("OpenMP parallel processing: enabled");

let engine = Self {
models_dir,
current_context: Arc::new(RwLock::new(None)),
Expand All @@ -162,17 +162,20 @@ impl WhisperEngine {
// Initialize active downloads tracking
active_downloads: Arc::new(RwLock::new(HashSet::new())),
};

Ok(engine)
}

pub async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
let models_dir = &self.models_dir;
let mut models = Vec::new();
// Use centralized model catalog from config.rs
let model_configs = WHISPER_MODEL_CATALOG;

let mut known_filenames = std::collections::HashSet::new();

for &(name, filename, size_mb, accuracy, speed, description) in model_configs {
known_filenames.insert(filename.to_string());
let model_path = models_dir.join(filename);
let status = if model_path.exists() {
// Check if file size is reasonable (at least 1MB for a valid model)
Expand Down Expand Up @@ -232,7 +235,7 @@ impl WhisperEngine {
} else {
ModelStatus::Missing
};

let model_info = ModelInfo {
name: name.to_string(),
path: model_path,
Expand All @@ -242,20 +245,83 @@ impl WhisperEngine {
status,
description: description.to_string(),
};

models.push(model_info);
}


// Scan for custom models in the directory
if models_dir.exists() {
if let Ok(mut entries) = tokio::fs::read_dir(models_dir).await {
while let Ok(Some(entry)) = entries.next_entry().await {
let path = entry.path();
if path.is_file() {
if let Some(extension) = path.extension() {
if extension == "bin" {
if let Some(file_name_os) = path.file_name() {
let filename = file_name_os.to_string_lossy().to_string();

// Skip known models
if !known_filenames.contains(&filename) {
if let Ok(metadata) = std::fs::metadata(&path) {
let file_size_bytes = metadata.len();
let file_size_mb = (file_size_bytes / (1024 * 1024)) as u32;

// Extract a clean name (remove ggml- prefix if present, and .bin suffix)
let mut clean_name = filename.strip_suffix(".bin").unwrap_or(&filename).to_string();
if clean_name.starts_with("ggml-") {
clean_name = clean_name.strip_prefix("ggml-").unwrap().to_string();
}

let status = if file_size_mb > 1 {
match self.validate_model_file(&path).await {
Ok(_) => ModelStatus::Available,
Err(_) => {
log::warn!("Custom model file {} appears corrupted (failed validation)", filename);
ModelStatus::Corrupted {
file_size: file_size_bytes,
expected_min_size: 1024 * 1024
}
}
}
} else {
ModelStatus::Corrupted {
file_size: file_size_bytes,
expected_min_size: 1024 * 1024
}
};

let model_info = ModelInfo {
name: clean_name.clone(),
path: path.clone(),
size_mb: file_size_mb,
accuracy: "Unknown".to_string(), // Custom models
speed: "Unknown".to_string(),
status,
description: format!("Custom model: {}", filename),
};

log::info!("Discovered custom model: {}", clean_name);
models.push(model_info);
}
}
}
}
}
}
}
}
}

// Update internal cache
let mut available_models = self.available_models.write().await;
available_models.clear();
for model in &models {
available_models.insert(model.name.clone(), model.clone());
}

Ok(models)
}

pub async fn load_model(&self, model_name: &str) -> Result<()> {
let models = self.available_models.read().await;
let model_info = models.get(model_name)
Expand Down Expand Up @@ -358,11 +424,11 @@ impl WhisperEngine {
pub async fn get_current_model(&self) -> Option<String> {
self.current_model.read().await.clone()
}

pub async fn is_model_loaded(&self) -> bool {
self.current_context.read().await.is_some()
}

// Enhanced function to clean repetitive text patterns and meaningless outputs
fn clean_repetitive_text(text: &str) -> String {
if text.is_empty() {
Expand Down Expand Up @@ -511,7 +577,7 @@ impl WhisperEngine {

repeated_words as f32 / total_words
}

/// Transcribe audio with streaming support for partial results and adaptive quality
pub async fn transcribe_audio_with_confidence(&self, audio_data: Vec<f32>, language: Option<String>) -> Result<(String, f32, bool)> {
let ctx_lock = self.current_context.read().await;
Expand Down Expand Up @@ -802,7 +868,7 @@ impl WhisperEngine {

Ok(cleaned_result)
}

pub async fn get_models_directory(&self) -> PathBuf {
self.models_dir.clone()
}
Expand Down Expand Up @@ -893,7 +959,7 @@ impl WhisperEngine {
}
}
}

pub async fn download_model(&self, model_name: &str, progress_callback: Option<Box<dyn Fn(u8) + Send>>) -> Result<()> {
log::info!("Starting download for model: {}", model_name);

Expand Down Expand Up @@ -940,56 +1006,56 @@ impl WhisperEngine {

_ => return Err(anyhow!("Unsupported model: {}", model_name))
};

log::info!("Model URL for {}: {}", model_name, model_url);

// Generate correct filename - all models follow ggml-{model_name}.bin pattern
let filename = format!("ggml-{}.bin", model_name);
let file_path = self.models_dir.join(&filename);

log::info!("Downloading to file path: {}", file_path.display());

// Create models directory if it doesn't exist
if !self.models_dir.exists() {
fs::create_dir_all(&self.models_dir).await
.map_err(|e| anyhow!("Failed to create models directory: {}", e))?;
}

// Update model status to downloading
{
let mut models = self.available_models.write().await;
if let Some(model_info) = models.get_mut(model_name) {
model_info.status = ModelStatus::Downloading { progress: 0 };
}
}

log::info!("Creating HTTP client and starting request...");
let client = Client::new();

log::info!("Sending GET request to: {}", model_url);
let response = client.get(model_url).send().await
.map_err(|e| anyhow!("Failed to start download: {}", e))?;

log::info!("Received response with status: {}", response.status());
if !response.status().is_success() {
// Remove from active downloads on error
let mut active = self.active_downloads.write().await;
active.remove(model_name);
return Err(anyhow!("Download failed with status: {}", response.status()));
}

let total_size = response.content_length().unwrap_or(0);
log::info!("Response successful, content length: {} bytes ({:.1} MB)", total_size, total_size as f64 / (1024.0 * 1024.0));

if total_size == 0 {
log::warn!("Content length is 0 or unknown - download may not show accurate progress");
}

let mut file = fs::File::create(&file_path).await
.map_err(|e| anyhow!("Failed to create file: {}", e))?;

log::info!("File created successfully at: {}", file_path.display());

// Stream download with real progress reporting
log::info!("Starting streaming download...");
log::info!("Expected size: {:.1} MB", total_size as f64 / (1024.0 * 1024.0));
Expand Down Expand Up @@ -1060,24 +1126,24 @@ impl WhisperEngine {
}

log::info!("Streaming download completed: {} bytes", downloaded);

// Ensure 100% progress is always reported
{
let mut models = self.available_models.write().await;
if let Some(model_info) = models.get_mut(model_name) {
model_info.status = ModelStatus::Downloading { progress: 100 };
}
}

if let Some(ref callback) = progress_callback {
callback(100);
}

file.flush().await
.map_err(|e| anyhow!("Failed to flush file: {}", e))?;

log::info!("Download completed for model: {}", model_name);

// Update model status to available
{
let mut models = self.available_models.write().await;
Expand All @@ -1095,7 +1161,7 @@ impl WhisperEngine {

Ok(())
}

pub async fn cancel_download(&self, model_name: &str) -> Result<()> {
log::info!("Cancelling download for model: {}", model_name);

Expand Down
Loading