diff --git a/rust/cocoindex/src/llm/azureopenai.rs b/rust/cocoindex/src/llm/azureopenai.rs deleted file mode 100644 index e0824485e..000000000 --- a/rust/cocoindex/src/llm/azureopenai.rs +++ /dev/null @@ -1,123 +0,0 @@ -use crate::prelude::*; - -use super::LlmEmbeddingClient; -use super::LlmGenerationClient; -use async_openai::{Client as OpenAIClient, config::AzureConfig}; -use phf::phf_map; - -static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { - "text-embedding-3-small" => 1536, - "text-embedding-3-large" => 3072, - "text-embedding-ada-002" => 1536, -}; - -pub struct Client { - client: async_openai::Client, -} - -impl Client { - pub async fn new_azure_openai( - address: Option, - api_key: Option, - api_config: Option, - ) -> anyhow::Result { - let config = match api_config { - Some(super::LlmApiConfig::AzureOpenAi(config)) => config, - Some(_) => anyhow::bail!("unexpected config type, expected AzureOpenAiConfig"), - None => anyhow::bail!("AzureOpenAiConfig is required for Azure OpenAI"), - }; - - let api_base = - address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?; - - // Default to API version that supports structured outputs (json_schema). - // See: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/api-version-lifecycle - let api_version = config - .api_version - .unwrap_or_else(|| "2024-08-01-preview".to_string()); - - let api_key = api_key.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok()) - .ok_or_else(|| anyhow::anyhow!("AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"))?; - - let azure_config = AzureConfig::new() - .with_api_base(api_base) - .with_api_version(api_version) - .with_deployment_id(config.deployment_id) - .with_api_key(api_key); - - Ok(Self { - client: OpenAIClient::with_config(azure_config), - }) - } -} - -#[async_trait] -impl LlmGenerationClient for Client { - async fn generate<'req>( - &self, - request: super::LlmGenerateRequest<'req>, - ) -> Result { - let request = &request; - let response = retryable::run( - || async { - let req = super::openai::create_llm_generation_request(request)?; - let response = self.client.chat().create(req).await?; - retryable::Ok(response) - }, - &retryable::RetryOptions::default(), - ) - .await?; - - // Extract the response text from the first choice - let text = response - .choices - .into_iter() - .next() - .and_then(|choice| choice.message.content) - .ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?; - - Ok(super::LlmGenerateResponse { text }) - } - - fn json_schema_options(&self) -> super::ToJsonSchemaOptions { - super::ToJsonSchemaOptions { - fields_always_required: true, - supports_format: false, - extract_descriptions: false, - top_level_must_be_object: true, - supports_additional_properties: true, - } - } -} - -#[async_trait] -impl LlmEmbeddingClient for Client { - async fn embed_text<'req>( - &self, - request: super::LlmEmbeddingRequest<'req>, - ) -> Result { - let response = retryable::run( - || async { - let texts: Vec = request.texts.iter().map(|t| t.to_string()).collect(); - self.client - .embeddings() - .create(async_openai::types::CreateEmbeddingRequest { - model: request.model.to_string(), - input: async_openai::types::EmbeddingInput::StringArray(texts), - dimensions: request.output_dimension, - ..Default::default() - }) - .await - }, - &retryable::RetryOptions::default(), - ) - .await?; - Ok(super::LlmEmbeddingResponse { - embeddings: response.data.into_iter().map(|e| e.embedding).collect(), - }) - } - - fn get_default_embedding_dimension(&self, model: &str) -> Option { - DEFAULT_EMBEDDING_DIMENSIONS.get(model).copied() - } -} diff --git a/rust/cocoindex/src/llm/mod.rs b/rust/cocoindex/src/llm/mod.rs index e63d71fca..a7a1cb5b9 100644 --- a/rust/cocoindex/src/llm/mod.rs +++ b/rust/cocoindex/src/llm/mod.rs @@ -116,7 +116,6 @@ pub trait LlmEmbeddingClient: Send + Sync { } mod anthropic; -mod azureopenai; mod bedrock; mod gemini; mod litellm; @@ -157,7 +156,7 @@ pub async fn new_llm_generation_client( as Box } LlmApiType::AzureOpenAi => { - Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?) + Box::new(openai::Client::new_azure(address, api_key, api_config).await?) as Box } LlmApiType::Voyage => { @@ -196,7 +195,7 @@ pub async fn new_llm_embedding_client( as Box } LlmApiType::AzureOpenAi => { - Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?) + Box::new(openai::Client::new_azure(address, api_key, api_config).await?) as Box } LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic | LlmApiType::Bedrock => { diff --git a/rust/cocoindex/src/llm/openai.rs b/rust/cocoindex/src/llm/openai.rs index e102bfa8b..20faf5312 100644 --- a/rust/cocoindex/src/llm/openai.rs +++ b/rust/cocoindex/src/llm/openai.rs @@ -4,7 +4,7 @@ use base64::prelude::*; use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type}; use async_openai::{ Client as OpenAIClient, - config::OpenAIConfig, + config::{AzureConfig, OpenAIConfig}, types::{ ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage, @@ -22,13 +22,15 @@ static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! { "text-embedding-ada-002" => 1536, }; -pub struct Client { - client: async_openai::Client, +pub struct Client { + client: async_openai::Client, } impl Client { - pub(crate) fn from_parts(client: async_openai::Client) -> Self { - Self { client } + pub(crate) fn from_parts( + client: async_openai::Client, + ) -> Client { + Client { client } } pub fn new( @@ -67,6 +69,44 @@ impl Client { } } +impl Client { + pub async fn new_azure( + address: Option, + api_key: Option, + api_config: Option, + ) -> Result { + let config = match api_config { + Some(super::LlmApiConfig::AzureOpenAi(config)) => config, + Some(_) => api_bail!("unexpected config type, expected AzureOpenAiConfig"), + None => api_bail!("AzureOpenAiConfig is required for Azure OpenAI"), + }; + + let api_base = + address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?; + + // Default to API version that supports structured outputs (json_schema). + let api_version = config + .api_version + .unwrap_or_else(|| "2024-08-01-preview".to_string()); + + let api_key = api_key + .or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok()) + .ok_or_else(|| anyhow::anyhow!( + "AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable" + ))?; + + let azure_config = AzureConfig::new() + .with_api_base(api_base) + .with_api_version(api_version) + .with_deployment_id(config.deployment_id) + .with_api_key(api_key); + + Ok(Self { + client: OpenAIClient::with_config(azure_config), + }) + } +} + pub(super) fn create_llm_generation_request( request: &super::LlmGenerateRequest, ) -> Result { @@ -136,7 +176,10 @@ pub(super) fn create_llm_generation_request( } #[async_trait] -impl LlmGenerationClient for Client { +impl LlmGenerationClient for Client +where + C: async_openai::config::Config + Send + Sync, +{ async fn generate<'req>( &self, request: super::LlmGenerateRequest<'req>, @@ -175,7 +218,10 @@ impl LlmGenerationClient for Client { } #[async_trait] -impl LlmEmbeddingClient for Client { +impl LlmEmbeddingClient for Client +where + C: async_openai::config::Config + Send + Sync, +{ async fn embed_text<'req>( &self, request: super::LlmEmbeddingRequest<'req>,