Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 0 additions & 123 deletions rust/cocoindex/src/llm/azureopenai.rs

This file was deleted.

5 changes: 2 additions & 3 deletions rust/cocoindex/src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ pub trait LlmEmbeddingClient: Send + Sync {
}

mod anthropic;
mod azureopenai;
mod bedrock;
mod gemini;
mod litellm;
Expand Down Expand Up @@ -157,7 +156,7 @@ pub async fn new_llm_generation_client(
as Box<dyn LlmGenerationClient>
}
LlmApiType::AzureOpenAi => {
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
Box::new(openai::Client::new_azure(address, api_key, api_config).await?)
as Box<dyn LlmGenerationClient>
}
LlmApiType::Voyage => {
Expand Down Expand Up @@ -196,7 +195,7 @@ pub async fn new_llm_embedding_client(
as Box<dyn LlmEmbeddingClient>
}
LlmApiType::AzureOpenAi => {
Box::new(azureopenai::Client::new_azure_openai(address, api_key, api_config).await?)
Box::new(openai::Client::new_azure(address, api_key, api_config).await?)
as Box<dyn LlmEmbeddingClient>
}
LlmApiType::LiteLlm | LlmApiType::Vllm | LlmApiType::Anthropic | LlmApiType::Bedrock => {
Expand Down
60 changes: 53 additions & 7 deletions rust/cocoindex/src/llm/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use base64::prelude::*;
use super::{LlmEmbeddingClient, LlmGenerationClient, detect_image_mime_type};
use async_openai::{
Client as OpenAIClient,
config::OpenAIConfig,
config::{AzureConfig, OpenAIConfig},
types::{
ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage,
Expand All @@ -22,13 +22,15 @@ static DEFAULT_EMBEDDING_DIMENSIONS: phf::Map<&str, u32> = phf_map! {
"text-embedding-ada-002" => 1536,
};

pub struct Client {
client: async_openai::Client<OpenAIConfig>,
pub struct Client<C: async_openai::config::Config = OpenAIConfig> {
client: async_openai::Client<C>,
}

impl Client {
pub(crate) fn from_parts(client: async_openai::Client<OpenAIConfig>) -> Self {
Self { client }
pub(crate) fn from_parts<C: async_openai::config::Config>(
client: async_openai::Client<C>,
) -> Client<C> {
Client { client }
}

pub fn new(
Expand Down Expand Up @@ -67,6 +69,44 @@ impl Client {
}
}

impl Client<AzureConfig> {
pub async fn new_azure(
address: Option<String>,
api_key: Option<String>,
api_config: Option<super::LlmApiConfig>,
) -> Result<Self> {
let config = match api_config {
Some(super::LlmApiConfig::AzureOpenAi(config)) => config,
Some(_) => api_bail!("unexpected config type, expected AzureOpenAiConfig"),
None => api_bail!("AzureOpenAiConfig is required for Azure OpenAI"),
};

let api_base =
address.ok_or_else(|| anyhow::anyhow!("address is required for Azure OpenAI"))?;

// Default to API version that supports structured outputs (json_schema).
let api_version = config
.api_version
.unwrap_or_else(|| "2024-08-01-preview".to_string());

let api_key = api_key
.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
.ok_or_else(|| anyhow::anyhow!(
"AZURE_OPENAI_API_KEY must be set either via api_key parameter or environment variable"
))?;

let azure_config = AzureConfig::new()
.with_api_base(api_base)
.with_api_version(api_version)
.with_deployment_id(config.deployment_id)
.with_api_key(api_key);

Ok(Self {
client: OpenAIClient::with_config(azure_config),
})
}
}

pub(super) fn create_llm_generation_request(
request: &super::LlmGenerateRequest,
) -> Result<CreateChatCompletionRequest> {
Expand Down Expand Up @@ -136,7 +176,10 @@ pub(super) fn create_llm_generation_request(
}

#[async_trait]
impl LlmGenerationClient for Client {
impl<C> LlmGenerationClient for Client<C>
where
C: async_openai::config::Config + Send + Sync,
{
async fn generate<'req>(
&self,
request: super::LlmGenerateRequest<'req>,
Expand Down Expand Up @@ -175,7 +218,10 @@ impl LlmGenerationClient for Client {
}

#[async_trait]
impl LlmEmbeddingClient for Client {
impl<C> LlmEmbeddingClient for Client<C>
where
C: async_openai::config::Config + Send + Sync,
{
async fn embed_text<'req>(
&self,
request: super::LlmEmbeddingRequest<'req>,
Expand Down
Loading