Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package ai.koog.integration.tests

import ai.koog.prompt.dsl.prompt
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
import ai.koog.prompt.executor.clients.google.GoogleModels
import ai.koog.prompt.executor.clients.google.GoogleParams
import ai.koog.prompt.executor.clients.google.GoogleSearchConfig
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
import ai.koog.prompt.message.Message
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.string.shouldContain
import io.kotest.matchers.string.shouldNotBeBlank
import io.kotest.matchers.types.shouldBeInstanceOf
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assumptions.assumeTrue
import org.junit.jupiter.api.BeforeAll
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import kotlin.time.Duration.Companion.seconds

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class GoogleGroundingLiveTest {

private lateinit var executor: MultiLLMPromptExecutor

@BeforeAll
fun setup() {
val apiKey = System.getenv("GEMINI_API_TEST_KEY")
assumeTrue(apiKey != null, "GEMINI_API_TEST_KEY not set — skipping live grounding tests")
executor = MultiLLMPromptExecutor(GoogleLLMClient(apiKey!!))
}

@Test
fun `grounding enabled returns correct answer for 2026 ICC Cricket World Cup winner`() = runTest(timeout = 60.seconds) {
val p = prompt(
"grounding-on-test",
params = GoogleParams(groundingSearchConfig = GoogleSearchConfig(groundingEnabled = true))
) {
user("Who won the ICC Cricket World Cup 2026? Answer in one word.")
}
val response = executor.execute(p, GoogleModels.Gemini2_5Flash)
response.shouldNotBeEmpty()
val content = response.first().shouldBeInstanceOf<Message.Assistant>().content
content.shouldNotBeBlank()
content.lowercase() shouldContain "india"
}

@Test
fun `grounding disabled answers from training data`() = runTest(timeout = 60.seconds) {
val p = prompt("grounding-off-test", params = GoogleParams()) {
user("What is the capital of France? Answer in one word.")
}
val response = executor.execute(p, GoogleModels.Gemini2_5Flash)
response.shouldNotBeEmpty()
val content = response.first().shouldBeInstanceOf<Message.Assistant>().content
content.shouldNotBeBlank()
content.lowercase() shouldContain "paris"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ import ai.koog.prompt.executor.clients.google.models.GoogleModelsResponse
import ai.koog.prompt.executor.clients.google.models.GooglePart
import ai.koog.prompt.executor.clients.google.models.GoogleRequest
import ai.koog.prompt.executor.clients.google.models.GoogleResponse
import ai.koog.prompt.executor.clients.google.models.GoogleSearch
import ai.koog.prompt.executor.clients.google.models.GoogleTool
import ai.koog.prompt.executor.clients.google.models.GoogleToolConfig
import ai.koog.prompt.executor.clients.google.models.ImageSearch
import ai.koog.prompt.executor.clients.google.models.Interval
import ai.koog.prompt.executor.clients.google.models.SearchTypes
import ai.koog.prompt.executor.clients.google.models.WebSearch
import ai.koog.prompt.executor.clients.google.structure.GoogleBasicJsonSchemaGenerator
import ai.koog.prompt.executor.clients.google.structure.GoogleResponseFormat
import ai.koog.prompt.executor.clients.google.structure.GoogleStandardJsonSchemaGenerator
Expand Down Expand Up @@ -402,10 +407,36 @@ public open class GoogleLLMClient @JvmOverloads constructor(
null -> null
}

val groundingConfig = googleParams.groundingSearchConfig
val groundingTool: GoogleTool? = if (groundingConfig?.groundingEnabled == true) {
val interval = if (groundingConfig.groundingStartTime != null && groundingConfig.groundingEndTime != null) {
Interval(startTime = groundingConfig.groundingStartTime, endTime = groundingConfig.groundingEndTime)
} else {
null
}
val searchTypes = if (groundingConfig.webSearch || groundingConfig.imageSearch) {
SearchTypes(
webSearch = if (groundingConfig.webSearch) WebSearch() else null,
imageSearch = if (groundingConfig.imageSearch) ImageSearch() else null,
)
} else {
null
}
GoogleTool(googleSearch = GoogleSearch(timeRangeFilter = interval, searchTypes = searchTypes))
} else {
null
}

val allTools = when {
groundingTool != null && googleTools != null -> googleTools + groundingTool
groundingTool != null -> listOf(groundingTool)
else -> googleTools
}

return GoogleRequest(
contents = contents,
systemInstruction = googleSystemInstruction,
tools = googleTools,
tools = allTools,
generationConfig = generationConfig,
toolConfig = GoogleToolConfig(functionCallingConfig),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,59 @@ package ai.koog.prompt.executor.clients.google
import ai.koog.prompt.executor.clients.google.models.GoogleThinkingConfig
import ai.koog.prompt.params.LLMParams
import kotlinx.serialization.json.JsonElement
import kotlin.time.Instant

/**
* Configuration for Google Search grounding, which allows Gemini models to ground
* responses with real-time web search results.
*
* @property groundingEnabled When `true`, enables grounding with Google Search.
* All other properties require this to be `true`.
* @property groundingStartTime Optional RFC3339 timestamp (e.g. `"2025-01-01T00:00:00Z"`)
* restricting search results to those published after this time.
* Must be set together with [groundingEndTime].
* @property groundingEndTime Optional RFC3339 timestamp restricting search results to those
* published before this time. Must be set together with [groundingStartTime].
* @property webSearch When `true`, explicitly requests web search results in addition to
* the default grounding behavior.
* @property imageSearch When `true`, explicitly requests image search results.
*
* API reference: https://ai.google.dev/api/caching#GoogleSearch
*/
public data class GoogleSearchConfig(
val groundingEnabled: Boolean = false,
val groundingStartTime: String? = null,
val groundingEndTime: String? = null,
val webSearch: Boolean = false,
val imageSearch: Boolean = false,
) {
init {
require(
groundingEnabled ||
(groundingStartTime == null && groundingEndTime == null && !webSearch && !imageSearch)
) {
"groundingEnabled must be true when groundingStartTime/groundingEndTime/searchTypes are configured"
}
require((groundingStartTime == null) == (groundingEndTime == null)) {
"Both groundingStartTime and groundingEndTime must be set together, or both must be null"
}
if (groundingStartTime != null && groundingEndTime != null) {
val start = parseRfc3339("groundingStartTime", groundingStartTime)
val end = parseRfc3339("groundingEndTime", groundingEndTime)
require(start <= end) {
"groundingStartTime must be <= groundingEndTime, but was $groundingStartTime > $groundingEndTime"
}
}
}

private companion object {
private fun parseRfc3339(fieldName: String, value: String): Instant = try {
Instant.parse(value)
} catch (_: IllegalArgumentException) {
throw IllegalArgumentException("$fieldName must be a valid RFC3339 timestamp, but was $value")
}
}
}

internal fun LLMParams.toGoogleParams(): GoogleParams {
if (this is GoogleParams) return this
Expand Down Expand Up @@ -35,6 +88,8 @@ internal fun LLMParams.toGoogleParams(): GoogleParams {
* @property topK The maximum number of tokens to consider when sampling.
* @property thinkingConfig Controls whether the model should expose its chain-of-thought
* and how many tokens it may spend on it (see [GoogleThinkingConfig]).
* @property groundingSearchConfig Google Search grounding configuration.
* Set [GoogleSearchConfig.groundingEnabled] to true to enable grounding.
*/
@Suppress("LongParameterList")
public class GoogleParams(
Expand All @@ -49,6 +104,7 @@ public class GoogleParams(
public val topP: Double? = null,
public val topK: Int? = null,
public val thinkingConfig: GoogleThinkingConfig? = null,
public val groundingSearchConfig: GoogleSearchConfig? = null,
) : LLMParams(
temperature,
maxTokens,
Expand Down Expand Up @@ -92,6 +148,7 @@ public class GoogleParams(
topP = topP,
topK = topK,
thinkingConfig = thinkingConfig,
groundingSearchConfig = groundingSearchConfig,
)

/**
Expand All @@ -109,6 +166,7 @@ public class GoogleParams(
topP: Double? = this.topP,
topK: Int? = this.topK,
thinkingConfig: GoogleThinkingConfig? = this.thinkingConfig,
groundingSearchConfig: GoogleSearchConfig? = this.groundingSearchConfig,
): GoogleParams = GoogleParams(
temperature = temperature,
maxTokens = maxTokens,
Expand All @@ -121,6 +179,7 @@ public class GoogleParams(
topP = topP,
topK = topK,
thinkingConfig = thinkingConfig,
groundingSearchConfig = groundingSearchConfig,
)

override fun equals(other: Any?): Boolean = when {
Expand All @@ -137,13 +196,15 @@ public class GoogleParams(
additionalProperties == other.additionalProperties &&
topP == other.topP &&
topK == other.topK &&
thinkingConfig == other.thinkingConfig
thinkingConfig == other.thinkingConfig &&
groundingSearchConfig == other.groundingSearchConfig
}

override fun hashCode(): Int = listOf(
temperature, maxTokens, numberOfChoices,
speculation, schema, toolChoice, user,
additionalProperties, topP, topK, thinkingConfig
additionalProperties, topP, topK, thinkingConfig,
groundingSearchConfig
).fold(0) { acc, element ->
31 * acc + (element?.hashCode() ?: 0)
}
Expand All @@ -161,6 +222,7 @@ public class GoogleParams(
append(", topP=$topP")
append(", topK=$topK")
append(", thinkingConfig=$thinkingConfig")
append(", groundingSearchConfig=$groundingSearchConfig")
append(")")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,34 @@ internal sealed interface GoogleData {
* The next conversation turn may contain a `FunctionResponse` with the `Content.role` "function" generation context for
* the next model turn.
*/
@Serializable
internal class GoogleSearch(
val timeRangeFilter: Interval? = null,
val searchTypes: SearchTypes? = null,
)

@Serializable
internal class Interval(
val startTime: String,
val endTime: String,
)

@Serializable
internal class SearchTypes(
val webSearch: WebSearch? = null,
val imageSearch: ImageSearch? = null,
)

@Serializable
internal class WebSearch

@Serializable
internal class ImageSearch

@Serializable
internal class GoogleTool(
val functionDeclarations: List<GoogleFunctionDeclaration>? = null,
val googleSearch: GoogleSearch? = null,
)

/**
Expand Down
Loading