diff --git a/Cargo.toml b/Cargo.toml index 5de98e0c..f2b8dae0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,7 @@ unused_qualifications = "warn" [features] default = [] bedrock = ["polaris_internal/bedrock"] +openai = ["polaris_internal/openai"] [dependencies] polaris_internal = { path = "crates/polaris_internal", version = "0.0.1" } diff --git a/crates/polaris_internal/Cargo.toml b/crates/polaris_internal/Cargo.toml index 3e9cf67c..5be6b4b1 100644 --- a/crates/polaris_internal/Cargo.toml +++ b/crates/polaris_internal/Cargo.toml @@ -7,6 +7,7 @@ keywords = [] [features] default = [] +openai = ["polaris_model_providers/openai"] bedrock = ["polaris_model_providers/bedrock"] [dependencies] diff --git a/crates/polaris_model_providers/Cargo.toml b/crates/polaris_model_providers/Cargo.toml index e4b5fd15..a3497a0f 100644 --- a/crates/polaris_model_providers/Cargo.toml +++ b/crates/polaris_model_providers/Cargo.toml @@ -10,6 +10,7 @@ workspace = true [features] default = ["anthropic"] anthropic = [] +openai = ["dep:async-openai"] bedrock = [ "dep:aws-sdk-bedrockruntime", "dep:aws-config", @@ -27,6 +28,13 @@ serde_json = "1.0" reqwest = { version = "0.13.1", features = ["json"] } tracing = "0.1" +# OpenAI dependencies +async-openai = { version = "0.33", optional = true, default-features = false, features = [ + "rustls", + "responses", + "response-types", +] } + # AWS Bedrock dependencies aws-sdk-bedrockruntime = { version = "1.124.0", optional = true } aws-config = { version = "1.5", features = [ diff --git a/crates/polaris_model_providers/src/lib.rs b/crates/polaris_model_providers/src/lib.rs index 9f5b4cfb..b596fdd5 100644 --- a/crates/polaris_model_providers/src/lib.rs +++ b/crates/polaris_model_providers/src/lib.rs @@ -7,6 +7,7 @@ //! | Provider | Feature Flag | Description | //! |----------|--------------|-------------| //! | Anthropic | `anthropic` (default) | Direct Anthropic API access | +//! | `OpenAI` | `openai` | `OpenAI` Responses API | //! | AWS Bedrock | `bedrock` | AWS Bedrock Converse API | //! //! # Feature Flags @@ -20,8 +21,11 @@ //! # Enable only Bedrock //! polaris_model_providers = { path = "../polaris_model_providers", default-features = false, features = ["bedrock"] } //! -//! # Enable both providers -//! polaris_model_providers = { path = "../polaris_model_providers", features = ["bedrock"] } +//! # Enable OpenAI +//! polaris_model_providers = { path = "../polaris_model_providers", default-features = false, features = ["openai"] } +//! +//! # Enable multiple providers +//! polaris_model_providers = { path = "../polaris_model_providers", features = ["openai", "bedrock"] } //! ``` //! //! # Usage @@ -49,6 +53,18 @@ //! server.add_plugins(ModelsPlugin); //! server.add_plugins(BedrockPlugin::from_env()); //! ``` +//! +//! For `OpenAI`, provide an API key via environment variable: +//! +//! ```ignore +//! use polaris_model_providers::OpenAiPlugin; +//! use polaris_models::ModelsPlugin; +//! use polaris_system::server::Server; +//! +//! let mut server = Server::new(); +//! server.add_plugins(ModelsPlugin); +//! server.add_plugins(OpenAiPlugin::from_env("OPENAI_API_KEY")); +//! ``` mod schema; @@ -58,6 +74,12 @@ pub mod anthropic; #[cfg(feature = "anthropic")] pub use anthropic::AnthropicPlugin; +#[cfg(feature = "openai")] +pub mod openai; + +#[cfg(feature = "openai")] +pub use openai::OpenAiPlugin; + #[cfg(feature = "bedrock")] pub mod bedrock; diff --git a/crates/polaris_model_providers/src/openai/mod.rs b/crates/polaris_model_providers/src/openai/mod.rs new file mode 100644 index 00000000..235474ca --- /dev/null +++ b/crates/polaris_model_providers/src/openai/mod.rs @@ -0,0 +1,9 @@ +//! `OpenAI` provider backend. +//! +//! Uses the `OpenAI` Responses API. + +mod plugin; +mod provider; + +pub use plugin::OpenAiPlugin; +pub use provider::OpenAiProvider; diff --git a/crates/polaris_model_providers/src/openai/plugin.rs b/crates/polaris_model_providers/src/openai/plugin.rs new file mode 100644 index 00000000..f2f9a63e --- /dev/null +++ b/crates/polaris_model_providers/src/openai/plugin.rs @@ -0,0 +1,50 @@ +//! `OpenAI` provider plugin. + +use super::provider::OpenAiProvider; +use polaris_models::{ModelRegistry, ModelsPlugin}; +use polaris_system::plugin::{Plugin, PluginId, Version}; +use polaris_system::server::Server; +use std::sync::Arc; + +/// Plugin providing support for `OpenAI` models via the Responses API. +/// +/// ```ignore +/// server.add_plugins(OpenAiPlugin::from_env("OPENAI_API_KEY")); +/// ``` +pub struct OpenAiPlugin { + api_key: String, +} + +impl OpenAiPlugin { + /// Creates a plugin that reads the API key from the specified environment variable. + /// + /// # Panics + /// + /// Panics if the environment variable is not set. + #[must_use] + pub fn from_env(env_var: &str) -> Self { + let api_key = std::env::var(env_var).unwrap_or_else(|_| { + panic!("Environment variable {env_var} for OpenAiPlugin not set. Please set it to your OpenAI API key."); + }); + Self { api_key } + } +} + +impl Plugin for OpenAiPlugin { + const ID: &'static str = "polaris::provider::openai"; + const VERSION: Version = Version::new(0, 0, 1); + + fn dependencies(&self) -> Vec { + vec![PluginId::of::()] + } + + fn build(&self, server: &mut Server) { + let provider = OpenAiProvider::new(self.api_key.clone()); + + let Some(mut registry) = server.get_resource_mut::() else { + panic!("ModelRegistry not found. Make sure to add ModelsPlugin before OpenAiPlugin."); + }; + + registry.register_llm_provider("openai", Arc::new(provider)); + } +} diff --git a/crates/polaris_model_providers/src/openai/provider.rs b/crates/polaris_model_providers/src/openai/provider.rs new file mode 100644 index 00000000..a59fba60 --- /dev/null +++ b/crates/polaris_model_providers/src/openai/provider.rs @@ -0,0 +1,480 @@ +//! `OpenAI` [`LlmProvider`] implementation using the Responses API. + +use crate::schema::normalize_schema_for_strict_mode; +use async_openai::config::OpenAIConfig; +use async_openai::error::OpenAIError; +use async_openai::types::responses::{ + CreateResponseArgs, EasyInputContent, EasyInputMessage, FunctionCallOutput, + FunctionCallOutputItemParam, FunctionTool, FunctionToolCall, InputContent, InputImageContent, + InputItem, InputParam, InputTextContent, Item, OutputItem, OutputMessageContent, ReasoningItem, + Response, ResponseFormatJsonSchema, ResponseTextParam, ResponseUsage, Role, SummaryPart, + SummaryTextContent, TextResponseFormatConfiguration, Tool, ToolChoiceFunction, + ToolChoiceOptions, ToolChoiceParam, +}; +use async_trait::async_trait; +use polaris_models::llm::{ + AssistantBlock, GenerationError, GenerationRequest, GenerationResponse, ImageMediaType, + LlmProvider, Message, ReasoningBlock, TextBlock, ToolCall, ToolChoice, ToolFunction, + ToolResultContent as PolarisToolResult, ToolResultStatus, Usage, UserBlock, +}; + +/// `OpenAI` [`LlmProvider`] implementation using the Responses API. +pub struct OpenAiProvider { + client: async_openai::Client, +} + +impl OpenAiProvider { + /// Creates a new provider with the given API key. + #[must_use] + pub fn new(api_key: impl Into) -> Self { + let config = OpenAIConfig::new().with_api_key(api_key); + Self { + client: async_openai::Client::with_config(config), + } + } +} + +#[async_trait] +impl LlmProvider for OpenAiProvider { + async fn generate( + &self, + model: &str, + request: GenerationRequest, + ) -> Result { + let create_response = convert_request(model, &request)?; + let response = self + .client + .responses() + .create(create_response) + .await + .map_err(convert_error)?; + convert_response(response) + } +} + +// --------------------------------------------------------------------------- +// Request conversion (Polaris -> OpenAI) +// --------------------------------------------------------------------------- + +fn convert_request( + model: &str, + request: &GenerationRequest, +) -> Result { + let input_items = convert_messages(&request.messages)?; + + let tools: Option> = request.tools.as_ref().map(|tools| { + tools + .iter() + .map(|tool| { + let normalized_parameters = + normalize_schema_for_strict_mode(tool.parameters.clone()); + Tool::Function(FunctionTool { + name: tool.name.clone(), + description: Some(tool.description.clone()), + parameters: Some(normalized_parameters), + strict: Some(true), + }) + }) + .collect() + }); + + let tool_choice = request.tool_choice.as_ref().map(convert_tool_choice); + + let text = request.output_schema.as_ref().map(|schema| { + let normalized = normalize_schema_for_strict_mode(schema.clone()); + ResponseTextParam { + format: TextResponseFormatConfiguration::JsonSchema(ResponseFormatJsonSchema { + name: "structured_output".to_string(), + description: None, + schema: Some(normalized), + strict: Some(true), + }), + verbosity: None, + } + }); + + let mut builder = CreateResponseArgs::default(); + builder.model(model).input(InputParam::Items(input_items)); + + if let Some(system) = &request.system { + builder.instructions(system.clone()); + } + if let Some(tools) = tools { + builder.tools(tools); + } + if let Some(tool_choice) = tool_choice { + builder.tool_choice(tool_choice); + } + if let Some(text) = text { + builder.text(text); + } + + builder.build().map_err(|build_err| { + GenerationError::InvalidRequest(format!("Failed to build CreateResponse: {build_err}")) + }) +} + +fn convert_messages(messages: &[Message]) -> Result, GenerationError> { + let mut items = Vec::new(); + + for message in messages { + match message { + Message::User { content } => { + convert_user_message(content, &mut items)?; + } + Message::Assistant { content, .. } => { + convert_assistant_message(content, &mut items)?; + } + } + } + + Ok(items) +} + +fn convert_user_message( + blocks: &[UserBlock], + items: &mut Vec, +) -> Result<(), GenerationError> { + // Separate tool results from regular content blocks. + // Tool results become top-level InputItem entries, while text/image + // blocks get grouped into a single EasyInputMessage. + let mut content_parts: Vec = Vec::new(); + + for block in blocks { + match block { + UserBlock::Text(block) => { + content_parts.push(InputContent::InputText(InputTextContent { + text: block.text.clone(), + })); + } + UserBlock::Image(image) => { + let data_url = build_image_data_url(image)?; + content_parts.push(InputContent::InputImage(InputImageContent { + image_url: Some(data_url), + file_id: None, + detail: Default::default(), + })); + } + UserBlock::ToolResult(result) => { + // Each tool result is a separate top-level item. + // Flush any accumulated content first. + flush_content_parts(&mut content_parts, Role::User, items); + + let output_text = match &result.content { + PolarisToolResult::Text(text) => text.clone(), + PolarisToolResult::Image(_) => { + return Err(GenerationError::UnsupportedContent( + "Image tool results are not supported by OpenAI".to_string(), + )); + } + }; + + let output_text = match result.status { + ToolResultStatus::Success => output_text, + ToolResultStatus::Error => format!("Error: {output_text}"), + }; + + // OpenAI uses call_id to link function outputs back to function calls. + let call_id = result.call_id.clone().ok_or_else(|| { + GenerationError::InvalidRequest( + "Tool result is missing a call_id, which is required by OpenAI to link function outputs back to function calls".to_string(), + ) + })?; + + items.push(InputItem::Item(Item::FunctionCallOutput( + FunctionCallOutputItemParam { + call_id, + output: FunctionCallOutput::Text(output_text), + id: None, + status: None, + }, + ))); + } + UserBlock::Audio(_) => { + return Err(GenerationError::UnsupportedContent( + "Audio content is not yet supported by the OpenAI Responses provider" + .to_string(), + )); + } + UserBlock::Document(_) => { + return Err(GenerationError::UnsupportedContent( + "Document content is not yet supported by the OpenAI Responses provider" + .to_string(), + )); + } + } + } + + // Flush any remaining content. + flush_content_parts(&mut content_parts, Role::User, items); + + Ok(()) +} + +fn convert_assistant_message( + blocks: &[AssistantBlock], + items: &mut Vec, +) -> Result<(), GenerationError> { + // Text blocks get grouped into a single EasyInputMessage with role assistant. + // Tool calls and reasoning blocks become individual top-level Item entries. + let mut text_parts: Vec = Vec::new(); + + for block in blocks { + match block { + AssistantBlock::Text(block) => { + text_parts.push(InputContent::InputText(InputTextContent { + text: block.text.clone(), + })); + } + AssistantBlock::ToolCall(call) => { + flush_content_parts(&mut text_parts, Role::Assistant, items); + + let arguments = + serde_json::to_string(&call.function.arguments).map_err(|json_err| { + GenerationError::InvalidRequest(format!( + "Failed to serialize tool call arguments: {json_err}" + )) + })?; + + items.push(InputItem::Item(Item::FunctionCall(FunctionToolCall { + call_id: call.call_id.clone().ok_or_else(|| { + GenerationError::InvalidRequest( + "Tool call is missing a call_id, which is required by OpenAI to link function calls to their outputs".to_string(), + ) + })?, + name: call.function.name.clone(), + arguments, + id: Some(call.id.clone()), + status: None, + }))); + } + AssistantBlock::Reasoning(reasoning) => { + flush_content_parts(&mut text_parts, Role::Assistant, items); + + let summary = reasoning + .reasoning + .iter() + .map(|text| SummaryPart::SummaryText(SummaryTextContent { text: text.clone() })) + .collect(); + + if reasoning.id.is_none() { + tracing::warn!( + "Reasoning block is missing an ID; using empty string as fallback" + ); + } + + items.push(InputItem::Item(Item::Reasoning(ReasoningItem { + id: reasoning.id.clone().unwrap_or_default(), + summary, + content: None, + encrypted_content: None, + status: None, + }))); + } + } + } + + flush_content_parts(&mut text_parts, Role::Assistant, items); + + Ok(()) +} + +/// Flushes accumulated content parts into an [`EasyInputMessage`] and appends +/// it to the items list. Does nothing if `parts` is empty. +fn flush_content_parts(parts: &mut Vec, role: Role, items: &mut Vec) { + if parts.is_empty() { + return; + } + + let content = if parts.len() == 1 { + // Single text block can use the simpler Text variant. + if let InputContent::InputText(ref text_content) = parts[0] { + EasyInputContent::Text(text_content.text.clone()) + } else { + EasyInputContent::ContentList(core::mem::take(parts)) + } + } else { + EasyInputContent::ContentList(core::mem::take(parts)) + }; + + items.push(InputItem::EasyMessage(EasyInputMessage { + content, + role, + r#type: Default::default(), + })); + + parts.clear(); +} + +fn build_image_data_url( + image: &polaris_models::llm::ImageBlock, +) -> Result { + let mime = match image.media_type { + ImageMediaType::JPEG => "image/jpeg", + ImageMediaType::PNG => "image/png", + ImageMediaType::GIF => "image/gif", + ImageMediaType::WEBP => "image/webp", + ref other => { + return Err(GenerationError::UnsupportedContent(format!( + "Unsupported image media type for OpenAI: {other:?}" + ))); + } + }; + + let polaris_models::llm::DocumentSource::Base64(data) = &image.data; + Ok(format!("data:{mime};base64,{data}")) +} + +fn convert_tool_choice(choice: &ToolChoice) -> ToolChoiceParam { + match choice { + ToolChoice::Auto => ToolChoiceParam::Mode(ToolChoiceOptions::Auto), + ToolChoice::Required => ToolChoiceParam::Mode(ToolChoiceOptions::Required), + ToolChoice::None => ToolChoiceParam::Mode(ToolChoiceOptions::None), + ToolChoice::Specific(name) => { + ToolChoiceParam::Function(ToolChoiceFunction { name: name.clone() }) + } + } +} + +// --------------------------------------------------------------------------- +// Response conversion (OpenAI -> Polaris) +// --------------------------------------------------------------------------- + +fn convert_response(response: Response) -> Result { + let content = response + .output + .into_iter() + .map(convert_output_item) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + let usage = response.usage.map(convert_usage).unwrap_or_default(); + + Ok(GenerationResponse { content, usage }) +} + +fn convert_output_item(item: OutputItem) -> Result, GenerationError> { + match item { + OutputItem::Message(msg) => msg + .content + .into_iter() + .map(convert_output_message_content) + .collect::, _>>(), + OutputItem::FunctionCall(call) => { + let arguments: serde_json::Value = serde_json::from_str(&call.arguments) + .unwrap_or_else(|err| { + tracing::warn!( + error = %err, + raw_arguments = call.arguments, + "Failed to parse tool call arguments as JSON, falling back to Null" + ); + serde_json::Value::Null + }); + + if call.id.is_none() { + tracing::warn!( + call_id = call.call_id, + function = call.name, + "OpenAI function call is missing an item ID" + ); + } + + Ok(vec![AssistantBlock::ToolCall(ToolCall { + id: call.id.unwrap_or_default(), + call_id: Some(call.call_id), + function: ToolFunction { + name: call.name, + arguments, + }, + signature: None, + additional_params: None, + })]) + } + OutputItem::Reasoning(reasoning) => { + let texts: Vec = reasoning + .summary + .into_iter() + .map(|part| { + let SummaryPart::SummaryText(text_content) = part; + text_content.text + }) + .collect(); + + if texts.is_empty() { + Ok(vec![]) + } else { + Ok(vec![AssistantBlock::Reasoning(ReasoningBlock { + id: Some(reasoning.id), + reasoning: texts, + signature: None, + })]) + } + } + // Other output item types (file search, web search, computer use, etc.) + // are not mapped to Polaris types yet. + other => { + tracing::warn!( + item = ?other, + "Dropping unsupported OpenAI output item type during response conversion" + ); + Ok(vec![]) + } + } +} + +fn convert_output_message_content( + content: OutputMessageContent, +) -> Result { + match content { + OutputMessageContent::OutputText(text) => { + Ok(AssistantBlock::Text(TextBlock { text: text.text })) + } + OutputMessageContent::Refusal(refusal) => Err(GenerationError::Refusal(refusal.refusal)), + } +} + +fn convert_usage(usage: ResponseUsage) -> Usage { + Usage { + input_tokens: Some(u64::from(usage.input_tokens)), + output_tokens: Some(u64::from(usage.output_tokens)), + total_tokens: Some(u64::from(usage.total_tokens)), + } +} + +// --------------------------------------------------------------------------- +// Error conversion +// --------------------------------------------------------------------------- + +fn convert_error(err: OpenAIError) -> GenerationError { + match err { + OpenAIError::ApiError(api_err) => GenerationError::Provider { + status: None, + message: api_err.message.clone(), + source: Some(Box::new(OpenAIError::ApiError(api_err))), + }, + OpenAIError::Reqwest(ref reqwest_err) => { + if reqwest_err + .status() + .is_some_and(|s| s == reqwest::StatusCode::UNAUTHORIZED) + { + GenerationError::Auth(err.to_string()) + } else if reqwest_err + .status() + .is_some_and(|s| s == reqwest::StatusCode::TOO_MANY_REQUESTS) + { + GenerationError::RateLimited { retry_after: None } + } else { + GenerationError::Http(err.to_string()) + } + } + OpenAIError::JSONDeserialize(serde_err, ref _body) => GenerationError::Json(serde_err), + OpenAIError::InvalidArgument(msg) => GenerationError::InvalidRequest(msg), + _ => GenerationError::Provider { + status: None, + message: err.to_string(), + source: Some(Box::new(err)), + }, + } +} diff --git a/crates/polaris_model_providers/src/schema.rs b/crates/polaris_model_providers/src/schema.rs index bb56d1a7..430e96e4 100644 --- a/crates/polaris_model_providers/src/schema.rs +++ b/crates/polaris_model_providers/src/schema.rs @@ -20,6 +20,7 @@ const SUPPORTED_FORMATS: &[&str] = &[ /// /// This function: /// - Sets `additionalProperties: false` on all object types. +/// - Ensures all property keys appear in `required` /// - Removes unsupported properties like `minimum`, `maximum`, `multipleOf`, etc. /// - Filters string formats to only supported values. /// - Removes `minItems` values greater than 1. @@ -100,6 +101,12 @@ pub fn normalize_schema_for_strict_mode(mut schema: Value) -> Value { if is_object { obj.insert("additionalProperties".to_string(), Value::Bool(false)); + + // Ensure every property key in the `required` array. + if let Some(Value::Object(props)) = obj.get("properties") { + let all_keys: Vec = props.keys().map(|k| Value::String(k.clone())).collect(); + obj.insert("required".to_string(), Value::Array(all_keys)); + } } // Recursively process nested schemas diff --git a/crates/polaris_model_providers/tests/openai_integration.rs b/crates/polaris_model_providers/tests/openai_integration.rs new file mode 100644 index 00000000..692eb7a1 --- /dev/null +++ b/crates/polaris_model_providers/tests/openai_integration.rs @@ -0,0 +1,75 @@ +//! Integration tests for the `OpenAI` provider. +//! +//! These tests are ignored by default because they require: +//! - `OPENAI_API_KEY` environment variable (or in `.env` file) +//! - Network access to the `OpenAI` API +//! - May incur API costs +//! +//! To run these tests: +//! ```sh +//! cargo test -p polaris_model_providers --features openai --test openai_integration -- --ignored +//! ``` + +#![cfg(feature = "openai")] + +mod common; + +use common::{LlmTestExt, init_env}; +use polaris_model_providers::openai::OpenAiPlugin; +use polaris_models::llm::Llm; +use polaris_models::{ModelRegistry, ModelsPlugin}; +use polaris_system::server::Server; + +const MODEL: &str = "openai/gpt-4o"; + +fn get_llm(model_id: &str) -> Llm { + init_env(); + + let mut server = Server::new(); + server.add_plugins(ModelsPlugin); + server.add_plugins(OpenAiPlugin::from_env("OPENAI_API_KEY")); + server.finish(); + + let registry = server + .get_global::() + .expect("ModelRegistry should be available"); + registry.llm(model_id).expect("model should be valid") +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_basic_generation() { + get_llm(MODEL).test_basic_generation().await; +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_system_prompt() { + get_llm(MODEL).test_system_prompt().await; +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_tool_calling() { + get_llm(MODEL).test_tool_calling().await; +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_structured_output() { + get_llm(MODEL).test_structured_output().await; +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_invalid_model_error() { + get_llm("openai/not-a-real-model") + .test_invalid_model_error() + .await; +} + +#[tokio::test] +#[ignore = "requires OPENAI_API_KEY"] +async fn test_image_input() { + get_llm(MODEL).test_image_input().await; +} diff --git a/crates/polaris_models/src/llm/error.rs b/crates/polaris_models/src/llm/error.rs index 843374db..306e8b30 100644 --- a/crates/polaris_models/src/llm/error.rs +++ b/crates/polaris_models/src/llm/error.rs @@ -56,6 +56,10 @@ pub enum GenerationError { #[error("unsupported content: {0}")] UnsupportedContent(String), + /// The model refused to fulfill the request (e.g. content policy). + #[error("model refused the request: {0}")] + Refusal(String), + /// Error returned by the model provider. #[error("provider error: {message}")] Provider {