Skip to content

Commit c6a61ef

Browse files
committed
refactor(prompt): unify grounding settings under GoogleSearchConfig
1 parent 91c9b09 commit c6a61ef

4 files changed

Lines changed: 94 additions & 66 deletions

File tree

integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/GoogleGroundingLiveTest.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import ai.koog.prompt.dsl.prompt
44
import ai.koog.prompt.executor.clients.google.GoogleLLMClient
55
import ai.koog.prompt.executor.clients.google.GoogleModels
66
import ai.koog.prompt.executor.clients.google.GoogleParams
7+
import ai.koog.prompt.executor.clients.google.GoogleSearchConfig
78
import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor
89
import ai.koog.prompt.message.Message
910
import io.kotest.matchers.collections.shouldNotBeEmpty
@@ -31,7 +32,10 @@ class GoogleGroundingLiveTest {
3132

3233
@Test
3334
fun `grounding enabled returns correct answer for 2026 ICC Cricket World Cup winner`() = runTest(timeout = 60.seconds) {
34-
val p = prompt("grounding-on-test", params = GoogleParams(groundingEnabled = true)) {
35+
val p = prompt(
36+
"grounding-on-test",
37+
params = GoogleParams(groundingSearchConfig = GoogleSearchConfig(groundingEnabled = true))
38+
) {
3539
user("Who won the ICC Cricket World Cup 2026? Answer in one word.")
3640
}
3741
val response = executor.execute(p, GoogleModels.Gemini2_5Flash)
@@ -43,7 +47,7 @@ class GoogleGroundingLiveTest {
4347

4448
@Test
4549
fun `grounding disabled answers from training data`() = runTest(timeout = 60.seconds) {
46-
val p = prompt("grounding-off-test", params = GoogleParams(groundingEnabled = false)) {
50+
val p = prompt("grounding-off-test", params = GoogleParams()) {
4751
user("What is the capital of France? Answer in one word.")
4852
}
4953
val response = executor.execute(p, GoogleModels.Gemini2_5Flash)

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -511,17 +511,20 @@ public open class GoogleLLMClient @JvmOverloads constructor(
511511
null -> null
512512
}
513513

514-
val groundingTool: GoogleTool? = if (googleParams.groundingEnabled) {
515-
val interval = if (googleParams.groundingStartTime != null && googleParams.groundingEndTime != null) {
516-
Interval(startTime = googleParams.groundingStartTime, endTime = googleParams.groundingEndTime)
514+
val groundingConfig = googleParams.groundingSearchConfig
515+
val groundingTool: GoogleTool? = if (groundingConfig?.groundingEnabled == true) {
516+
val interval = if (groundingConfig.groundingStartTime != null && groundingConfig.groundingEndTime != null) {
517+
Interval(startTime = groundingConfig.groundingStartTime, endTime = groundingConfig.groundingEndTime)
517518
} else {
518519
null
519520
}
520-
val searchTypes = googleParams.groundingSearchConfig?.let { cfg ->
521+
val searchTypes = if (groundingConfig.webSearch || groundingConfig.imageSearch) {
521522
SearchTypes(
522-
webSearch = if (cfg.webSearch) WebSearch() else null,
523-
imageSearch = if (cfg.imageSearch) ImageSearch() else null,
523+
webSearch = if (groundingConfig.webSearch) WebSearch() else null,
524+
imageSearch = if (groundingConfig.imageSearch) ImageSearch() else null,
524525
)
526+
} else {
527+
null
525528
}
526529
GoogleTool(googleSearch = GoogleSearch(timeRangeFilter = interval, searchTypes = searchTypes))
527530
} else {

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleParams.kt

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,39 @@ import kotlinx.serialization.json.JsonElement
66
import kotlin.time.Instant
77

88
public data class GoogleSearchConfig(
9+
val groundingEnabled: Boolean = false,
10+
val groundingStartTime: String? = null,
11+
val groundingEndTime: String? = null,
912
val webSearch: Boolean = false,
1013
val imageSearch: Boolean = false,
11-
)
14+
) {
15+
init {
16+
require(
17+
groundingEnabled ||
18+
(groundingStartTime == null && groundingEndTime == null && !webSearch && !imageSearch)
19+
) {
20+
"groundingEnabled must be true when groundingStartTime/groundingEndTime/searchTypes are configured"
21+
}
22+
require((groundingStartTime == null) == (groundingEndTime == null)) {
23+
"Both groundingStartTime and groundingEndTime must be set together, or both must be null"
24+
}
25+
if (groundingStartTime != null && groundingEndTime != null) {
26+
val start = parseRfc3339("groundingStartTime", groundingStartTime)
27+
val end = parseRfc3339("groundingEndTime", groundingEndTime)
28+
require(start <= end) {
29+
"groundingStartTime must be <= groundingEndTime, but was $groundingStartTime > $groundingEndTime"
30+
}
31+
}
32+
}
33+
34+
private companion object {
35+
private fun parseRfc3339(fieldName: String, value: String): Instant = try {
36+
Instant.parse(value)
37+
} catch (_: IllegalArgumentException) {
38+
throw IllegalArgumentException("$fieldName must be a valid RFC3339 timestamp, but was $value")
39+
}
40+
}
41+
}
1242

1343
internal fun LLMParams.toGoogleParams(): GoogleParams {
1444
if (this is GoogleParams) return this
@@ -41,8 +71,8 @@ internal fun LLMParams.toGoogleParams(): GoogleParams {
4171
* @property topK The maximum number of tokens to consider when sampling.
4272
* @property thinkingConfig Controls whether the model should expose its chain-of-thought
4373
* and how many tokens it may spend on it (see [GoogleThinkingConfig]).
44-
* @property groundingEnabled Enables grounding with Google Search to augment responses with
45-
* real-time information. Supported by Gemini 2.0+ models.
74+
* @property groundingSearchConfig Google Search grounding configuration.
75+
* Set [GoogleSearchConfig.groundingEnabled] to true to enable grounding.
4676
*/
4777
@Suppress("LongParameterList")
4878
public class GoogleParams(
@@ -57,9 +87,6 @@ public class GoogleParams(
5787
public val topP: Double? = null,
5888
public val topK: Int? = null,
5989
public val thinkingConfig: GoogleThinkingConfig? = null,
60-
public val groundingEnabled: Boolean = false,
61-
public val groundingStartTime: String? = null,
62-
public val groundingEndTime: String? = null,
6390
public val groundingSearchConfig: GoogleSearchConfig? = null,
6491
) : LLMParams(
6592
temperature,
@@ -81,19 +108,6 @@ public class GoogleParams(
81108
require(topK == null || topK >= 0) {
82109
"topK must be >= 0, but was $topK"
83110
}
84-
require((groundingStartTime == null) == (groundingEndTime == null)) {
85-
"Both groundingStartTime and groundingEndTime must be set together, or both must be null"
86-
}
87-
if (groundingStartTime != null && groundingEndTime != null) {
88-
val start = parseRfc3339("groundingStartTime", groundingStartTime)
89-
val end = parseRfc3339("groundingEndTime", groundingEndTime)
90-
require(start <= end) {
91-
"groundingStartTime must be <= groundingEndTime, but was $groundingStartTime > $groundingEndTime"
92-
}
93-
}
94-
require(groundingSearchConfig == null || groundingEnabled) {
95-
"groundingSearchConfig requires groundingEnabled = true"
96-
}
97111
}
98112

99113
override fun copy(
@@ -117,9 +131,6 @@ public class GoogleParams(
117131
topP = topP,
118132
topK = topK,
119133
thinkingConfig = thinkingConfig,
120-
groundingEnabled = groundingEnabled,
121-
groundingStartTime = groundingStartTime,
122-
groundingEndTime = groundingEndTime,
123134
groundingSearchConfig = groundingSearchConfig,
124135
)
125136

@@ -138,9 +149,6 @@ public class GoogleParams(
138149
topP: Double? = this.topP,
139150
topK: Int? = this.topK,
140151
thinkingConfig: GoogleThinkingConfig? = this.thinkingConfig,
141-
groundingEnabled: Boolean = this.groundingEnabled,
142-
groundingStartTime: String? = this.groundingStartTime,
143-
groundingEndTime: String? = this.groundingEndTime,
144152
groundingSearchConfig: GoogleSearchConfig? = this.groundingSearchConfig,
145153
): GoogleParams = GoogleParams(
146154
temperature = temperature,
@@ -154,9 +162,6 @@ public class GoogleParams(
154162
topP = topP,
155163
topK = topK,
156164
thinkingConfig = thinkingConfig,
157-
groundingEnabled = groundingEnabled,
158-
groundingStartTime = groundingStartTime,
159-
groundingEndTime = groundingEndTime,
160165
groundingSearchConfig = groundingSearchConfig,
161166
)
162167

@@ -175,16 +180,13 @@ public class GoogleParams(
175180
topP == other.topP &&
176181
topK == other.topK &&
177182
thinkingConfig == other.thinkingConfig &&
178-
groundingEnabled == other.groundingEnabled &&
179-
groundingStartTime == other.groundingStartTime &&
180-
groundingEndTime == other.groundingEndTime &&
181183
groundingSearchConfig == other.groundingSearchConfig
182184
}
183185

184186
override fun hashCode(): Int = listOf(
185187
temperature, maxTokens, numberOfChoices,
186188
speculation, schema, toolChoice, user,
187-
additionalProperties, topP, topK, thinkingConfig, groundingEnabled, groundingStartTime, groundingEndTime,
189+
additionalProperties, topP, topK, thinkingConfig,
188190
groundingSearchConfig
189191
).fold(0) { acc, element ->
190192
31 * acc + (element?.hashCode() ?: 0)
@@ -203,18 +205,7 @@ public class GoogleParams(
203205
append(", topP=$topP")
204206
append(", topK=$topK")
205207
append(", thinkingConfig=$thinkingConfig")
206-
append(", groundingEnabled=$groundingEnabled")
207-
append(", groundingStartTime=$groundingStartTime")
208-
append(", groundingEndTime=$groundingEndTime")
209208
append(", groundingSearchConfig=$groundingSearchConfig")
210209
append(")")
211210
}
212-
213-
private companion object {
214-
private fun parseRfc3339(fieldName: String, value: String): Instant = try {
215-
Instant.parse(value)
216-
} catch (_: IllegalArgumentException) {
217-
throw IllegalArgumentException("$fieldName must be a valid RFC3339 timestamp, but was $value")
218-
}
219-
}
220211
}

prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -668,15 +668,15 @@ class GoogleLLMClientTest {
668668
}
669669

670670
@Test
671-
fun `createGoogleRequest injects googleSearch tool when groundingEnabled is true`() {
671+
fun `createGoogleRequest injects googleSearch tool when groundingSearchConfig is enabled`() {
672672
val client = GoogleLLMClient(apiKey = "apiKey")
673673
val model = GoogleModels.Gemini2_5Flash
674674

675675
val request = client.createGoogleRequest(
676676
prompt = Prompt(
677677
messages = emptyList(),
678678
id = "id",
679-
params = GoogleParams(groundingEnabled = true)
679+
params = GoogleParams(groundingSearchConfig = GoogleSearchConfig(groundingEnabled = true))
680680
),
681681
model = model,
682682
tools = emptyList()
@@ -688,13 +688,31 @@ class GoogleLLMClientTest {
688688
tools.first().functionDeclarations shouldBe null
689689
}
690690

691+
@Test
692+
fun `createGoogleRequest does not inject googleSearch tool when groundingSearchConfig is absent`() {
693+
val client = GoogleLLMClient(apiKey = "apiKey")
694+
val model = GoogleModels.Gemini2_5Flash
695+
696+
val request = client.createGoogleRequest(
697+
prompt = Prompt(messages = emptyList(), id = "id", params = GoogleParams()),
698+
model = model,
699+
tools = emptyList()
700+
)
701+
702+
request.tools shouldBe null
703+
}
704+
691705
@Test
692706
fun `createGoogleRequest googleSearch has no timeRangeFilter when no times are provided`() {
693707
val client = GoogleLLMClient(apiKey = "apiKey")
694708
val model = GoogleModels.Gemini2_5Flash
695709

696710
val request = client.createGoogleRequest(
697-
prompt = Prompt(messages = emptyList(), id = "id", params = GoogleParams(groundingEnabled = true)),
711+
prompt = Prompt(
712+
messages = emptyList(),
713+
id = "id",
714+
params = GoogleParams(groundingSearchConfig = GoogleSearchConfig(groundingEnabled = true))
715+
),
698716
model = model,
699717
tools = emptyList()
700718
)
@@ -713,9 +731,11 @@ class GoogleLLMClientTest {
713731
messages = emptyList(),
714732
id = "id",
715733
params = GoogleParams(
716-
groundingEnabled = true,
717-
groundingStartTime = "2025-01-01T00:00:00Z",
718-
groundingEndTime = "2025-07-01T00:00:00Z"
734+
groundingSearchConfig = GoogleSearchConfig(
735+
groundingEnabled = true,
736+
groundingStartTime = "2025-01-01T00:00:00Z",
737+
groundingEndTime = "2025-07-01T00:00:00Z"
738+
)
719739
)
720740
),
721741
model = model,
@@ -729,26 +749,26 @@ class GoogleLLMClientTest {
729749
}
730750

731751
@Test
732-
fun `GoogleParams throws when only one of groundingStartTime or groundingEndTime is set`() {
752+
fun `GoogleSearchConfig throws when only one of groundingStartTime or groundingEndTime is set`() {
733753
shouldThrow<IllegalArgumentException> {
734-
GoogleParams(groundingEnabled = true, groundingStartTime = "2025-01-01T00:00:00Z")
754+
GoogleSearchConfig(groundingEnabled = true, groundingStartTime = "2025-01-01T00:00:00Z")
735755
}
736756
shouldThrow<IllegalArgumentException> {
737-
GoogleParams(groundingEnabled = true, groundingEndTime = "2025-07-01T00:00:00Z")
757+
GoogleSearchConfig(groundingEnabled = true, groundingEndTime = "2025-07-01T00:00:00Z")
738758
}
739759
}
740760

741761
@Test
742-
fun `GoogleParams throws when grounding times are not valid RFC3339`() {
762+
fun `GoogleSearchConfig throws when grounding times are not valid RFC3339`() {
743763
shouldThrow<IllegalArgumentException> {
744-
GoogleParams(
764+
GoogleSearchConfig(
745765
groundingEnabled = true,
746766
groundingStartTime = "2025-01-01 00:00:00",
747767
groundingEndTime = "2025-07-01T00:00:00Z"
748768
)
749769
}
750770
shouldThrow<IllegalArgumentException> {
751-
GoogleParams(
771+
GoogleSearchConfig(
752772
groundingEnabled = true,
753773
groundingStartTime = "2025-01-01T00:00:00Z",
754774
groundingEndTime = "not-a-timestamp"
@@ -757,9 +777,19 @@ class GoogleLLMClientTest {
757777
}
758778

759779
@Test
760-
fun `GoogleParams throws when groundingStartTime is after groundingEndTime`() {
780+
fun `GoogleSearchConfig throws when grounding options are set while grounding is disabled`() {
781+
shouldThrow<IllegalArgumentException> {
782+
GoogleSearchConfig(webSearch = true)
783+
}
784+
shouldThrow<IllegalArgumentException> {
785+
GoogleSearchConfig(groundingStartTime = "2025-01-01T00:00:00Z", groundingEndTime = "2025-07-01T00:00:00Z")
786+
}
787+
}
788+
789+
@Test
790+
fun `GoogleSearchConfig throws when groundingStartTime is after groundingEndTime`() {
761791
shouldThrow<IllegalArgumentException> {
762-
GoogleParams(
792+
GoogleSearchConfig(
763793
groundingEnabled = true,
764794
groundingStartTime = "2025-07-01T00:00:00Z",
765795
groundingEndTime = "2025-01-01T00:00:00Z"
@@ -778,7 +808,7 @@ class GoogleLLMClientTest {
778808
prompt = Prompt(
779809
messages = emptyList(),
780810
id = "id",
781-
params = GoogleParams(groundingEnabled = true)
811+
params = GoogleParams(groundingSearchConfig = GoogleSearchConfig(groundingEnabled = true))
782812
),
783813
model = model,
784814
tools = listOf(tool)

0 commit comments

Comments
 (0)