diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index 3c5217f13345..2bf66842a8d8 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -35,6 +35,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj index 4a0ae4032f3e..535cb3e6389a 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Connectors.Google.UnitTests.csproj @@ -25,6 +25,7 @@ + diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs index 844a2341bbc9..fc05f3aaefaf 100644 --- a/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Extensions/GoogleAIServiceCollectionExtensionsTests.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using Google.GenAI; using Microsoft.Extensions.AI; using Microsoft.Extensions.DependencyInjection; using Microsoft.SemanticKernel; @@ -113,4 +114,147 @@ public void GoogleAIEmbeddingGeneratorShouldBeRegisteredInServiceCollection() Assert.NotNull(embeddingsGenerationService); Assert.IsType(embeddingsGenerationService); } + +#if NET + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredInKernelServicesWithApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act + kernelBuilder.AddGoogleGenAIChatClient("modelId", "apiKey"); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredInServiceCollectionWithApiKey() + { + // Arrange + var services = new ServiceCollection(); + + // Act + services.AddGoogleGenAIChatClient("modelId", "apiKey"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredInKernelServices() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act + kernelBuilder.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1"); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + var kernel = kernelBuilder.Build(); + Assert.NotNull(kernel.Services); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredInServiceCollection() + { + // Arrange + var services = new ServiceCollection(); + + // Act + services.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(serviceProvider); + } + + [Fact] + public void GoogleAIChatClientShouldBeRegisteredInKernelServicesWithClient() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + using var googleClient = new Client(apiKey: "apiKey"); + + // Act + kernelBuilder.AddGoogleAIChatClient("modelId", googleClient); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleAIChatClientShouldBeRegisteredInServiceCollectionWithClient() + { + // Arrange + var services = new ServiceCollection(); + using var googleClient = new Client(apiKey: "apiKey"); + + // Act + services.AddGoogleAIChatClient("modelId", googleClient); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleGenAIChatClientShouldBeRegisteredWithServiceId() + { + // Arrange + var services = new ServiceCollection(); + const string serviceId = "test-service-id"; + + // Act + services.AddGoogleGenAIChatClient("modelId", "apiKey", serviceId: serviceId); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetKeyedService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void GoogleVertexAIChatClientShouldBeRegisteredWithServiceId() + { + // Arrange + var services = new ServiceCollection(); + const string serviceId = "test-service-id"; + + // Act + services.AddGoogleVertexAIChatClient("modelId", project: "test-project", location: "us-central1", serviceId: serviceId); + var serviceProvider = services.BuildServiceProvider(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(serviceProvider); + } + + [Fact] + public void GoogleAIChatClientShouldResolveFromServiceProviderWhenClientNotProvided() + { + // Arrange + var services = new ServiceCollection(); + using var googleClient = new Client(apiKey: "apiKey"); + services.AddSingleton(googleClient); + + // Act + services.AddGoogleAIChatClient("modelId"); + var serviceProvider = services.BuildServiceProvider(); + + // Assert + var chatClient = serviceProvider.GetRequiredService(); + Assert.NotNull(chatClient); + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs new file mode 100644 index 000000000000..91bf5435efdc --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Google.UnitTests/Services/GoogleGeminiChatClientTests.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET + +using System; +using Google.GenAI; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using Xunit; + +namespace SemanticKernel.Connectors.Google.UnitTests.Services; + +public sealed class GoogleGeminiChatClientTests +{ + [Fact] + public void GenAIChatClientShouldBeCreatedWithApiKey() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string apiKey = "test-api-key"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleGenAIChatClient(modelId, apiKey); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void VertexAIChatClientShouldBeCreated() + { + // Arrange + string modelId = "gemini-1.5-pro"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleVertexAIChatClient(modelId, project: "test-project", location: "us-central1"); + var kernel = kernelBuilder.Build(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(kernel.Services); + } + + [Fact] + public void ChatClientShouldBeCreatedWithGoogleClient() + { + // Arrange + string modelId = "gemini-1.5-pro"; + using var googleClient = new Client(apiKey: "test-api-key"); + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleAIChatClient(modelId, googleClient); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(); + Assert.NotNull(chatClient); + } + + [Fact] + public void GenAIChatClientShouldBeCreatedWithServiceId() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string apiKey = "test-api-key"; + string serviceId = "test-service"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleGenAIChatClient(modelId, apiKey, serviceId: serviceId); + var kernel = kernelBuilder.Build(); + + // Assert + var chatClient = kernel.GetRequiredService(serviceId); + Assert.NotNull(chatClient); + } + + [Fact] + public void VertexAIChatClientShouldBeCreatedWithServiceId() + { + // Arrange + string modelId = "gemini-1.5-pro"; + string serviceId = "test-service"; + + // Act + var kernelBuilder = Kernel.CreateBuilder(); + kernelBuilder.AddGoogleVertexAIChatClient(modelId, project: "test-project", location: "us-central1", serviceId: serviceId); + var kernel = kernelBuilder.Build(); + + // Assert - just verify no exception during registration + // Resolution requires real credentials, so skip that in unit tests + Assert.NotNull(kernel.Services); + } + + [Fact] + public void GenAIChatClientThrowsForNullModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient(null!, "apiKey")); + } + + [Fact] + public void GenAIChatClientThrowsForEmptyModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("", "apiKey")); + } + + [Fact] + public void GenAIChatClientThrowsForNullApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("modelId", null!)); + } + + [Fact] + public void GenAIChatClientThrowsForEmptyApiKey() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleGenAIChatClient("modelId", "")); + } + + [Fact] + public void VertexAIChatClientThrowsForNullModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleVertexAIChatClient(null!, project: "test-project", location: "us-central1")); + } + + [Fact] + public void VertexAIChatClientThrowsForEmptyModelId() + { + // Arrange + var kernelBuilder = Kernel.CreateBuilder(); + + // Act & Assert + Assert.ThrowsAny(() => kernelBuilder.AddGoogleVertexAIChatClient("", project: "test-project", location: "us-central1")); + } +} + +#endif diff --git a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj index e71d80d17a00..7e104ef8b230 100644 --- a/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj +++ b/dotnet/src/Connectors/Connectors.Google/Connectors.Google.csproj @@ -24,6 +24,11 @@ + + + + + diff --git a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs index d6ab3768d0e0..72518e91aaf8 100644 --- a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs +++ b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIKernelBuilderExtensions.cs @@ -118,4 +118,102 @@ public static IKernelBuilder AddGoogleAIEmbeddingGenerator( dimensions: dimensions); return builder; } + +#if NET + /// + /// Add Google GenAI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The API key for authentication with the Google GenAI API. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleGenAIChatClient( + this IKernelBuilder builder, + string modelId, + string apiKey, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleGenAIChatClient( + modelId, + apiKey, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Add Google Vertex AI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The Google Cloud project ID. If null, will attempt to use the GOOGLE_CLOUD_PROJECT environment variable. + /// The Google Cloud location (e.g., "us-central1"). If null, will attempt to use the GOOGLE_CLOUD_LOCATION environment variable. + /// The optional for authentication. If null, the client will use its internal discovery implementation to get credentials from the environment. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleVertexAIChatClient( + this IKernelBuilder builder, + string modelId, + string? project = null, + string? location = null, + Google.Apis.Auth.OAuth2.ICredential? credential = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleVertexAIChatClient( + modelId, + project, + location, + credential, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } + + /// + /// Add Google AI to the . + /// + /// The kernel builder. + /// The model for chat completion. + /// The to use for the service. If null, one must be available in the service provider when this service is resolved. + /// The optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated kernel builder. + public static IKernelBuilder AddGoogleAIChatClient( + this IKernelBuilder builder, + string modelId, + Google.GenAI.Client? googleClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(builder); + + builder.Services.AddGoogleAIChatClient( + modelId, + googleClient, + serviceId, + openTelemetrySourceName, + openTelemetryConfig); + + return builder; + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs index a45001278e9a..bb094ef08cde 100644 --- a/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs +++ b/dotnet/src/Connectors/Connectors.Google/Extensions/GoogleAIServiceCollectionExtensions.DependencyInjection.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. +using System; using System.Net.Http; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; @@ -47,4 +48,146 @@ public static IServiceCollection AddGoogleAIEmbeddingGenerator( loggerFactory: serviceProvider.GetService(), dimensions: dimensions)); } + +#if NET + /// + /// Add Google GenAI to the specified service collection. + /// + /// The service collection to add the Google GenAI Chat Client to. + /// The model for chat completion. + /// The API key for authentication with the Google GenAI API. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleGenAIChatClient( + this IServiceCollection services, + string modelId, + string apiKey, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + Verify.NotNullOrWhiteSpace(apiKey); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var googleClient = new Google.GenAI.Client(apiKey: apiKey); + + var builder = new GoogleGenAIChatClient(googleClient, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Add Google Vertex AI to the specified service collection. + /// + /// The service collection to add the Google Vertex AI Chat Client to. + /// The model for chat completion. + /// The Google Cloud project ID. If null, will attempt to use the GOOGLE_CLOUD_PROJECT environment variable. + /// The Google Cloud location (e.g., "us-central1"). If null, will attempt to use the GOOGLE_CLOUD_LOCATION environment variable. + /// The optional for authentication. If null, the client will use its internal discovery implementation to get credentials from the environment. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleVertexAIChatClient( + this IServiceCollection services, + string modelId, + string? project = null, + string? location = null, + Google.Apis.Auth.OAuth2.ICredential? credential = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var googleClient = new Google.GenAI.Client(vertexAI: true, credential: credential, project: project, location: location); + + var builder = new GoogleGenAIChatClient(googleClient, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } + + /// + /// Add Google AI to the specified service collection. + /// + /// The service collection to add the Google AI Chat Client to. + /// The model for chat completion. + /// The to use for the service. If null, one must be available in the service provider when this service is resolved. + /// Optional service ID. + /// An optional name for the OpenTelemetry source. + /// An optional callback that can be used to configure the instance. + /// The updated service collection. + public static IServiceCollection AddGoogleAIChatClient( + this IServiceCollection services, + string modelId, + Google.GenAI.Client? googleClient = null, + string? serviceId = null, + string? openTelemetrySourceName = null, + Action? openTelemetryConfig = null) + { + Verify.NotNull(services); + Verify.NotNullOrWhiteSpace(modelId); + + IChatClient Factory(IServiceProvider serviceProvider, object? _) + { + var loggerFactory = serviceProvider.GetService(); + + var client = googleClient ?? serviceProvider.GetRequiredService(); + + var builder = new GoogleGenAIChatClient(client, modelId) + .AsBuilder() + .UseKernelFunctionInvocation(loggerFactory) + .UseOpenTelemetry(loggerFactory, openTelemetrySourceName, openTelemetryConfig); + + if (loggerFactory is not null) + { + builder.UseLogging(loggerFactory); + } + + return builder.Build(); + } + + services.AddKeyedSingleton(serviceId, (Func)Factory); + + return services; + } +#endif } diff --git a/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs b/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs new file mode 100644 index 000000000000..1a6fbceb408b --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Google/Services/GoogleGenAIChatClient.cs @@ -0,0 +1,508 @@ +// Copyright (c) Microsoft. All rights reserved. + +#if NET + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Google.GenAI; +using Google.GenAI.Types; +using Microsoft.Extensions.AI; +using AITextContent = Microsoft.Extensions.AI.TextContent; +using AIDataContent = Microsoft.Extensions.AI.DataContent; +using AIUriContent = Microsoft.Extensions.AI.UriContent; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; +using AIFunctionResultContent = Microsoft.Extensions.AI.FunctionResultContent; + +namespace Microsoft.SemanticKernel.Connectors.Google; + +/// +/// Provides an implementation based on Google.GenAI . +/// +internal sealed class GoogleGenAIChatClient : IChatClient +{ + /// The wrapped instance (optional). + private readonly Client? _client; + + /// The wrapped instance. + private readonly Models _models; + + /// The default model that should be used when no override is specified. + private readonly string? _defaultModelId; + + /// Lazily-initialized metadata describing the implementation. + private ChatClientMetadata? _metadata; + + /// Initializes a new instance. + /// The to wrap. + /// The default model ID to use for chat requests if not specified. + public GoogleGenAIChatClient(Client client, string? defaultModelId) + { + Verify.NotNull(client); + + this._client = client; + this._models = client.Models; + this._defaultModelId = defaultModelId; + } + + /// Initializes a new instance. + /// The client to wrap. + /// The default model ID to use for chat requests if not specified. + public GoogleGenAIChatClient(Models models, string? defaultModelId) + { + Verify.NotNull(models); + + this._models = models; + this._defaultModelId = defaultModelId; + } + + /// + public async Task GetResponseAsync(IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + // Create the request. + (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); + + // Send it. + GenerateContentResponse generateResult = await this._models.GenerateContentAsync(modelId!, contents, config).ConfigureAwait(false); + + // Create the response. + ChatResponse chatResponse = new(new ChatMessage(ChatRole.Assistant, new List())) + { + CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, + ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, + RawRepresentation = generateResult, + ResponseId = generateResult.ResponseId, + }; + + // Populate the response messages. + chatResponse.FinishReason = PopulateResponseContents(generateResult, chatResponse.Messages[0].Contents); + + // Populate usage information if there is any. + if (generateResult.UsageMetadata is { } usageMetadata) + { + chatResponse.Usage = ExtractUsageDetails(usageMetadata); + } + + // Return the response. + return chatResponse; + } + + /// + public async IAsyncEnumerable GetStreamingResponseAsync(IEnumerable messages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Verify.NotNull(messages); + + // Create the request. + (string? modelId, List contents, GenerateContentConfig config) = this.CreateRequest(messages, options); + + // Send it, and process the results. + await foreach (GenerateContentResponse generateResult in this._models.GenerateContentStreamAsync(modelId!, contents, config).WithCancellation(cancellationToken).ConfigureAwait(false)) + { + // Create a response update for each result in the stream. + ChatResponseUpdate responseUpdate = new(ChatRole.Assistant, new List()) + { + CreatedAt = generateResult.CreateTime is { } dt ? new DateTimeOffset(dt) : null, + ModelId = !string.IsNullOrWhiteSpace(generateResult.ModelVersion) ? generateResult.ModelVersion : modelId, + RawRepresentation = generateResult, + ResponseId = generateResult.ResponseId, + }; + + // Populate the response update contents. + responseUpdate.FinishReason = PopulateResponseContents(generateResult, responseUpdate.Contents); + + // Populate usage information if there is any. + if (generateResult.UsageMetadata is { } usageMetadata) + { + responseUpdate.Contents.Add(new UsageContent(ExtractUsageDetails(usageMetadata))); + } + + // Yield the update. + yield return responseUpdate; + } + } + + /// + public object? GetService(System.Type serviceType, object? serviceKey = null) + { + Verify.NotNull(serviceType); + + if (serviceKey is null) + { + // If there's a request for metadata, lazily-initialize it and return it. + if (serviceType == typeof(ChatClientMetadata)) + { + return this._metadata ??= new("google.genai", new Uri("https://generativelanguage.googleapis.com/"), defaultModelId: this._defaultModelId); + } + + // Allow a consumer to access the underlying client if they need it. + if (serviceType.IsInstanceOfType(this._models)) + { + return this._models; + } + + if (this._client is not null && serviceType.IsInstanceOfType(this._client)) + { + return this._client; + } + + if (serviceType.IsInstanceOfType(this)) + { + return this; + } + } + + return null; + } + + /// + void IDisposable.Dispose() { /* nop */ } + + /// Creates the message parameters for from and . + private (string? ModelId, List Contents, GenerateContentConfig Config) CreateRequest(IEnumerable messages, ChatOptions? options) + { + // Create the GenerateContentConfig object. If the options contains a RawRepresentationFactory, try to use it to + // create the request instance, allowing the caller to populate it with GenAI-specific options. Otherwise, create + // a new instance directly. + string? model = this._defaultModelId; + List contents = []; + GenerateContentConfig config = options?.RawRepresentationFactory?.Invoke(this) as GenerateContentConfig ?? new(); + + if (options is not null) + { + if (options.FrequencyPenalty is { } frequencyPenalty) + { + config.FrequencyPenalty ??= frequencyPenalty; + } + + if (options.Instructions is { } instructions) + { + ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instructions }); + } + + if (options.MaxOutputTokens is { } maxOutputTokens) + { + config.MaxOutputTokens ??= maxOutputTokens; + } + + if (!string.IsNullOrWhiteSpace(options.ModelId)) + { + model = options.ModelId; + } + + if (options.PresencePenalty is { } presencePenalty) + { + config.PresencePenalty ??= presencePenalty; + } + + if (options.Seed is { } seed) + { + config.Seed ??= (int)seed; + } + + if (options.StopSequences is { } stopSequences) + { + (config.StopSequences ??= []).AddRange(stopSequences); + } + + if (options.Temperature is { } temperature) + { + config.Temperature ??= temperature; + } + + if (options.TopP is { } topP) + { + config.TopP ??= topP; + } + + if (options.TopK is { } topK) + { + config.TopK ??= topK; + } + + // Populate tools. Each kind of tool is added on its own, except for function declarations, + // which are grouped into a single FunctionDeclaration. + List? functionDeclarations = null; + if (options.Tools is { } tools) + { + foreach (var tool in tools) + { + switch (tool) + { + case AIFunction af: + functionDeclarations ??= []; + functionDeclarations.Add(new() + { + Name = af.Name, + Description = af.Description ?? "", + }); + break; + } + } + } + + if (functionDeclarations is { Count: > 0 }) + { + Tool functionTools = new(); + (functionTools.FunctionDeclarations ??= []).AddRange(functionDeclarations); + (config.Tools ??= []).Add(functionTools); + } + + // Transfer over the tool mode if there are any tools. + if (options.ToolMode is { } toolMode && config.Tools?.Count > 0) + { + switch (toolMode) + { + case NoneChatToolMode: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.NONE } }; + break; + + case AutoChatToolMode: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.AUTO } }; + break; + + case RequiredChatToolMode required: + config.ToolConfig = new() { FunctionCallingConfig = new() { Mode = FunctionCallingConfigMode.ANY } }; + if (required.RequiredFunctionName is not null) + { + ((config.ToolConfig.FunctionCallingConfig ??= new()).AllowedFunctionNames ??= []).Add(required.RequiredFunctionName); + } + break; + } + } + + // Set the response format if specified. + if (options.ResponseFormat is ChatResponseFormatJson responseFormat) + { + config.ResponseMimeType = "application/json"; + if (responseFormat.Schema is { } schema) + { + config.ResponseJsonSchema = schema; + } + } + } + + // Transfer messages to request, handling system messages specially + Dictionary? callIdToFunctionNames = null; + foreach (var message in messages) + { + if (message.Role == ChatRole.System) + { + string instruction = message.Text; + if (!string.IsNullOrWhiteSpace(instruction)) + { + ((config.SystemInstruction ??= new()).Parts ??= []).Add(new() { Text = instruction }); + } + + continue; + } + + Content content = new() { Role = message.Role == ChatRole.Assistant ? "model" : "user" }; + content.Parts ??= []; + AddPartsForAIContents(ref callIdToFunctionNames, message.Contents, content.Parts); + + contents.Add(content); + } + + // Make sure the request contains at least one content part (the request would always fail if empty). + if (!contents.SelectMany(c => c.Parts ?? Enumerable.Empty()).Any()) + { + contents.Add(new() { Role = "user", Parts = [new() { Text = "" }] }); + } + + return (model, contents, config); + } + + /// Creates s for and adds them to . + private static void AddPartsForAIContents(ref Dictionary? callIdToFunctionNames, IList contents, List parts) + { + for (int i = 0; i < contents.Count; i++) + { + var content = contents[i]; + + Part? part = null; + switch (content) + { + case AITextContent textContent: + part = new() { Text = textContent.Text }; + break; + + case AIDataContent dataContent: + part = new() + { + InlineData = new() + { + MimeType = dataContent.MediaType, + Data = dataContent.Data.ToArray(), + } + }; + break; + + case AIUriContent uriContent: + part = new() + { + FileData = new() + { + FileUri = uriContent.Uri.AbsoluteUri, + MimeType = uriContent.MediaType, + } + }; + break; + + case AIFunctionCallContent functionCallContent: + (callIdToFunctionNames ??= [])[functionCallContent.CallId] = functionCallContent.Name; + callIdToFunctionNames[""] = functionCallContent.Name; // track last function name in case calls don't have IDs + + part = new() + { + FunctionCall = new() + { + Id = functionCallContent.CallId, + Name = functionCallContent.Name, + Args = functionCallContent.Arguments is null ? null : functionCallContent.Arguments as Dictionary ?? new(functionCallContent.Arguments!), + } + }; + break; + + case AIFunctionResultContent functionResultContent: + part = new() + { + FunctionResponse = new() + { + Id = functionResultContent.CallId, + Name = callIdToFunctionNames?.TryGetValue(functionResultContent.CallId, out string? functionName) is true || callIdToFunctionNames?.TryGetValue("", out functionName) is true ? + functionName : + null, + Response = functionResultContent.Result is null ? null : new() { ["result"] = functionResultContent.Result }, + } + }; + break; + } + + if (part is not null) + { + parts.Add(part); + } + } + } + + /// Creates s for and adds them to . + private static void AddAIContentsForParts(List parts, IList contents) + { + foreach (var part in parts) + { + AIContent? content = null; + + if (!string.IsNullOrEmpty(part.Text)) + { + content = new AITextContent(part.Text); + } + else if (part.InlineData is { } inlineData) + { + content = new AIDataContent(inlineData.Data, inlineData.MimeType ?? "application/octet-stream"); + } + else if (part.FileData is { FileUri: not null } fileData) + { + content = new AIUriContent(new Uri(fileData.FileUri), fileData.MimeType ?? "application/octet-stream"); + } + else if (part.FunctionCall is { Name: not null } functionCall) + { + content = new AIFunctionCallContent(functionCall.Id ?? "", functionCall.Name, functionCall.Args!); + } + else if (part.FunctionResponse is { } functionResponse) + { + content = new AIFunctionResultContent( + functionResponse.Id ?? "", + functionResponse.Response?.TryGetValue("output", out var output) is true ? output : + functionResponse.Response?.TryGetValue("error", out var error) is true ? error : + null); + } + + if (content is not null) + { + content.RawRepresentation = part; + contents.Add(content); + } + } + } + + private static ChatFinishReason? PopulateResponseContents(GenerateContentResponse generateResult, IList responseContents) + { + ChatFinishReason? finishReason = null; + + // Populate the response messages. There should only be at most one candidate, but if there are more, ignore all but the first. + if (generateResult.Candidates is { Count: > 0 } && + generateResult.Candidates[0] is { Content: { } candidateContent } candidate) + { + // Grab the finish reason if one exists. + finishReason = ConvertFinishReason(candidate.FinishReason); + + // Add all of the response content parts as AIContents. + if (candidateContent.Parts is { } parts) + { + AddAIContentsForParts(parts, responseContents); + } + } + + // Populate error information if there is any. + if (generateResult.PromptFeedback is { } promptFeedback) + { + responseContents.Add(new ErrorContent(promptFeedback.BlockReasonMessage)); + } + + return finishReason; + } + + /// Creates an M.E.AI from a Google . + private static ChatFinishReason? ConvertFinishReason(FinishReason? finishReason) + { + return finishReason switch + { + null => null, + + FinishReason.MAX_TOKENS => + ChatFinishReason.Length, + + FinishReason.MALFORMED_FUNCTION_CALL or + FinishReason.UNEXPECTED_TOOL_CALL => + ChatFinishReason.ToolCalls, + + FinishReason.FINISH_REASON_UNSPECIFIED or + FinishReason.STOP => + ChatFinishReason.Stop, + + _ => ChatFinishReason.ContentFilter, + }; + } + + /// Creates a populated from the supplied . + private static UsageDetails ExtractUsageDetails(GenerateContentResponseUsageMetadata usageMetadata) + { + UsageDetails details = new() + { + InputTokenCount = usageMetadata.PromptTokenCount, + OutputTokenCount = usageMetadata.CandidatesTokenCount, + TotalTokenCount = usageMetadata.TotalTokenCount, + }; + + AddIfPresent(nameof(usageMetadata.CachedContentTokenCount), usageMetadata.CachedContentTokenCount); + AddIfPresent(nameof(usageMetadata.ThoughtsTokenCount), usageMetadata.ThoughtsTokenCount); + AddIfPresent(nameof(usageMetadata.ToolUsePromptTokenCount), usageMetadata.ToolUsePromptTokenCount); + + return details; + + void AddIfPresent(string key, int? value) + { + if (value is int i) + { + (details.AdditionalCounts ??= [])[key] = i; + } + } + } +} + +#endif diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs new file mode 100644 index 000000000000..cf649b24a09f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiChatClientTests.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xRetry; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiGenAIChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipReason = "This test is for manual verification."; + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientGenerationReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("Large Language Model", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Brandon", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and write a long story about my name.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + Assert.False(string.IsNullOrWhiteSpace(message)); + this.Output.WriteLine(message); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("1520", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientReturnsUsageDetailsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Usage); + this.Output.WriteLine($"Input tokens: {response.Usage.InputTokenCount}"); + this.Output.WriteLine($"Output tokens: {response.Usage.OutputTokenCount}"); + this.Output.WriteLine($"Total tokens: {response.Usage.TotalTokenCount}"); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithChatOptionsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Generate a random number between 1 and 100.") + }; + + var chatOptions = new ChatOptions + { + Temperature = 0.0f, + MaxOutputTokens = 100 + }; + + var sut = this.GetGenAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs new file mode 100644 index 000000000000..9173365a60b9 --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiFunctionCallingChatClientTests.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ComponentModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using xRetry; +using Xunit; +using Xunit.Abstractions; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiGenAIFunctionCallingChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipMessage = "This test is for manual verification."; + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + + var functionCallContent = response.Messages + .SelectMany(m => m.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + + var functionCallContent = responses + .SelectMany(r => r.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var responses = await autoInvokingClient.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + var content = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithMultipleFunctionCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetGenAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers first and next return age of Anna customer?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("28", content, StringComparison.OrdinalIgnoreCase); + } + + public sealed class CustomerPlugin + { + [KernelFunction(nameof(GetCustomers))] + [Description("Get list of customers.")] + [return: Description("List of customers.")] + public string[] GetCustomers() + { + return + [ + "John Kowalski", + "Anna Nowak", + "Steve Smith", + ]; + } + + [KernelFunction(nameof(GetCustomerAge))] + [Description("Get age of customer.")] + [return: Description("Age of customer.")] + public int GetCustomerAge([Description("Name of customer")] string customerName) + { + return customerName switch + { + "John Kowalski" => 35, + "Anna Nowak" => 28, + "Steve Smith" => 42, + _ => throw new ArgumentException("Customer not found."), + }; + } + } + + public sealed class MathPlugin + { + [KernelFunction(nameof(Sum))] + [Description("Sum numbers.")] + public int Sum([Description("Numbers to sum")] int[] numbers) + { + return numbers.Sum(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs new file mode 100644 index 000000000000..9ccca355133c --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIChatClientTests.cs @@ -0,0 +1,173 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using xRetry; +using Xunit; +using Xunit.Abstractions; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiVertexAIChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipReason = "This test is for manual verification."; + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientGenerationReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("Large Language Model", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Brandon", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingReturnsValidResponseAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and write a long story about my name.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + Assert.False(string.IsNullOrWhiteSpace(message)); + this.Output.WriteLine(message); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("1520", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientStreamingWithSystemMessagesAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.System, "You are helpful assistant. Your name is Roger."), + new ChatMessage(ChatRole.System, "You know ACDD equals 1520"), + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Tell me your name and the value of ACDD.") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + Assert.True(responses.Count > 1); + var message = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(message); + Assert.Contains("1520", message, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Roger", message, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientReturnsUsageDetailsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, I'm Brandon, how are you?"), + new ChatMessage(ChatRole.Assistant, "I'm doing well, thanks for asking."), + new ChatMessage(ChatRole.User, "Call me by my name and expand this abbreviation: LLM") + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Usage); + this.Output.WriteLine($"Input tokens: {response.Usage.InputTokenCount}"); + this.Output.WriteLine($"Output tokens: {response.Usage.OutputTokenCount}"); + this.Output.WriteLine($"Total tokens: {response.Usage.TotalTokenCount}"); + } + + [RetryFact(Skip = SkipReason)] + public async Task ChatClientWithChatOptionsAsync() + { + // Arrange + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Generate a random number between 1 and 100.") + }; + + var chatOptions = new ChatOptions + { + Temperature = 0.0f, + MaxOutputTokens = 100 + }; + + var sut = this.GetVertexAIChatClient(); + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs new file mode 100644 index 000000000000..964260a69f6f --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Google/Gemini/GeminiVertexAIFunctionCallingChatClientTests.cs @@ -0,0 +1,257 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.ComponentModel; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.Extensions.AI; +using Microsoft.SemanticKernel; +using xRetry; +using Xunit; +using Xunit.Abstractions; +using AIFunctionCallContent = Microsoft.Extensions.AI.FunctionCallContent; + +namespace SemanticKernel.IntegrationTests.Connectors.Google.Gemini; + +public sealed class GeminiVertexAIFunctionCallingChatClientTests(ITestOutputHelper output) : TestsBase(output) +{ + private const string SkipMessage = "This test is for manual verification."; + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var response = await sut.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Messages); + Assert.NotEmpty(response.Messages); + + var functionCallContent = response.Messages + .SelectMany(m => m.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithFunctionCallingReturnsToolCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType(nameof(CustomerPlugin)); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools + }; + + // Act + var responses = await sut.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + + var functionCallContent = responses + .SelectMany(r => r.Contents) + .OfType() + .FirstOrDefault(); + + Assert.NotNull(functionCallContent); + Assert.Contains("GetCustomers", functionCallContent.Name, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientStreamingWithAutoInvokeFunctionsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var responses = await autoInvokingClient.GetStreamingResponseAsync(chatHistory, chatOptions).ToListAsync(); + + // Assert + Assert.NotEmpty(responses); + var content = string.Concat(responses.Select(c => c.Text)); + this.Output.WriteLine(content); + Assert.Contains("John Kowalski", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Anna Nowak", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("Steve Smith", content, StringComparison.OrdinalIgnoreCase); + } + + [RetryFact(Skip = SkipMessage)] + public async Task ChatClientWithMultipleFunctionCallsAsync() + { + // Arrange + var kernel = new Kernel(); + kernel.ImportPluginFromType("CustomerPlugin"); + + var sut = this.GetVertexAIChatClient(); + + var chatHistory = new[] + { + new ChatMessage(ChatRole.User, "Hello, could you show me list of customers first and next return age of Anna customer?") + }; + + var tools = kernel.Plugins + .SelectMany(p => p) + .Cast() + .ToList(); + + var chatOptions = new ChatOptions + { + Tools = tools, + ToolMode = ChatToolMode.Auto + }; + + // Use FunctionInvokingChatClient for auto-invoke + using var autoInvokingClient = new FunctionInvokingChatClient(sut); + + // Act + var response = await autoInvokingClient.GetResponseAsync(chatHistory, chatOptions); + + // Assert + Assert.NotNull(response); + var content = string.Join("", response.Messages.Select(m => m.Text)); + this.Output.WriteLine(content); + Assert.Contains("28", content, StringComparison.OrdinalIgnoreCase); + } + + public sealed class CustomerPlugin + { + [KernelFunction(nameof(GetCustomers))] + [Description("Get list of customers.")] + [return: Description("List of customers.")] + public string[] GetCustomers() + { + return + [ + "John Kowalski", + "Anna Nowak", + "Steve Smith", + ]; + } + + [KernelFunction(nameof(GetCustomerAge))] + [Description("Get age of customer.")] + [return: Description("Age of customer.")] + public int GetCustomerAge([Description("Name of customer")] string customerName) + { + return customerName switch + { + "John Kowalski" => 35, + "Anna Nowak" => 28, + "Steve Smith" => 42, + _ => throw new ArgumentException("Customer not found."), + }; + } + } + + public sealed class MathPlugin + { + [KernelFunction(nameof(Sum))] + [Description("Sum numbers.")] + public int Sum([Description("Numbers to sum")] int[] numbers) + { + return numbers.Sum(); + } + } +} diff --git a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs index 723785497ccd..7e6bb8a45f54 100644 --- a/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs +++ b/dotnet/src/IntegrationTests/Connectors/Google/TestsBase.cs @@ -3,6 +3,8 @@ using System; using Microsoft.Extensions.AI; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.Google; using Microsoft.SemanticKernel.Embeddings; @@ -65,6 +67,52 @@ protected TestsBase(ITestOutputHelper output) _ => throw new ArgumentOutOfRangeException(nameof(serviceType), serviceType, null) }; + protected IChatClient GetGenAIChatClient(string? overrideModelId = null) + { + var modelId = overrideModelId ?? this.GoogleAI.Gemini.ModelId; + var apiKey = this.GoogleAI.ApiKey; + + var kernel = Kernel.CreateBuilder() + .AddGoogleGenAIChatClient(modelId, apiKey) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetVertexAIChatClient(string? overrideModelId = null) + { + var modelId = overrideModelId ?? this.VertexAI.Gemini.ModelId; + + var kernel = Kernel.CreateBuilder() + .AddGoogleVertexAIChatClient(modelId, project: this.VertexAI.ProjectId, location: this.VertexAI.Location) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetGenAIChatClientWithVision() + { + var modelId = this.GoogleAI.Gemini.VisionModelId; + var apiKey = this.GoogleAI.ApiKey; + + var kernel = Kernel.CreateBuilder() + .AddGoogleGenAIChatClient(modelId, apiKey) + .Build(); + + return kernel.GetRequiredService(); + } + + protected IChatClient GetVertexAIChatClientWithVision() + { + var modelId = this.VertexAI.Gemini.VisionModelId; + + var kernel = Kernel.CreateBuilder() + .AddGoogleVertexAIChatClient(modelId, project: this.VertexAI.ProjectId, location: this.VertexAI.Location) + .Build(); + + return kernel.GetRequiredService(); + } + [Obsolete("Temporary test utility for Obsolete ITextEmbeddingGenerationService")] protected ITextEmbeddingGenerationService GetEmbeddingService(ServiceType serviceType) => serviceType switch { diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index d0e45a75f94f..ec65cb12f288 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -41,6 +41,7 @@ +