Skip to content

Commit 2e7423a

Browse files
itamargolanclaude
andauthored
[OPIK-5845] [BE] fix: fix dataset expansion for test suites and all LLM providers (#6277)
* [OPIK-5845] [BE] fix: fix dataset expansion for test suites and all LLM providers - Add maxCompletionTokens (4000) to expansion requests, fixing Anthropic models that require this field - Re-throw ClientErrorException/ServerErrorException with original messages instead of wrapping in generic BadRequestException - Skip _generated/_generation_model metadata for test suite expansions so synthetic fields don't pollute data passed to agents in local runner scenarios Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(expansion): use proper HTTP status codes for error handling - Catch BadRequestException separately to preserve validation errors - Rethrow ClientErrorException/ServerErrorException from LLM providers - Use InternalServerErrorException for unexpected failures instead of mapping everything to 400 BadRequest - Use generic message for 500s, log full details server-side Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(expansion): address PR review comments - Use lighter getById() instead of findById() to avoid unnecessary dataset enrichment when only the type is needed - Extract buildDatasetItem() to DRY the duplicated item construction - Make maxCompletionTokens provider-aware: only set for Anthropic by default (4000), skip for other providers to avoid impacting results - Allow users to override maxCompletionTokens via the API request - Inject LlmProviderFactory to resolve provider from model name Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test(expansion): add unit tests and @min validation for maxCompletionTokens Add DatasetExpansionServiceTest covering maxCompletionTokens resolution, dataset item building (metadata for regular vs test suite), and error handling. Add @min(100) validation on maxCompletionTokens field. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 134915b commit 2e7423a

3 files changed

Lines changed: 436 additions & 53 deletions

File tree

apps/opik-backend/src/main/java/com/comet/opik/api/DatasetExpansion.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ public record DatasetExpansion(
2525
@JsonView({
2626
DatasetExpansion.View.Write.class}) @Schema(description = "Additional instructions for data variation", example = "Create variations that test edge cases") String variationInstructions,
2727
@JsonView({
28-
DatasetExpansion.View.Write.class}) @Schema(description = "Custom prompt to use for generation instead of auto-generated one") String customPrompt){
28+
DatasetExpansion.View.Write.class}) @Schema(description = "Custom prompt to use for generation instead of auto-generated one") String customPrompt,
29+
@JsonView({
30+
DatasetExpansion.View.Write.class}) @Min(100) @Schema(description = "Maximum number of tokens for the LLM response. Required by Anthropic, used as maxOutputTokens for Gemini. If not provided, defaults to 4000 for Anthropic models only.") Integer maxCompletionTokens){
2931

3032
public static class View {
3133
public static class Write {

apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetExpansionService.java

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
import com.comet.opik.api.DatasetExpansion;
44
import com.comet.opik.api.DatasetExpansionResponse;
55
import com.comet.opik.api.DatasetItem;
6+
import com.comet.opik.api.DatasetType;
7+
import com.comet.opik.api.LlmProvider;
68
import com.comet.opik.domain.llm.ChatCompletionService;
9+
import com.comet.opik.domain.llm.LlmProviderFactory;
710
import com.comet.opik.infrastructure.auth.RequestContext;
811
import com.comet.opik.utils.AsyncUtils;
912
import com.comet.opik.utils.JsonUtils;
@@ -15,6 +18,9 @@
1518
import jakarta.inject.Provider;
1619
import jakarta.inject.Singleton;
1720
import jakarta.ws.rs.BadRequestException;
21+
import jakarta.ws.rs.ClientErrorException;
22+
import jakarta.ws.rs.InternalServerErrorException;
23+
import jakarta.ws.rs.ServerErrorException;
1824
import lombok.NonNull;
1925
import lombok.RequiredArgsConstructor;
2026
import lombok.extern.slf4j.Slf4j;
@@ -33,8 +39,12 @@
3339
@RequiredArgsConstructor(onConstructor_ = @Inject)
3440
public class DatasetExpansionService {
3541

42+
private static final int DEFAULT_MAX_COMPLETION_TOKENS = 4000;
43+
3644
private final @NonNull ChatCompletionService chatCompletionService;
45+
private final @NonNull LlmProviderFactory llmProviderFactory;
3746
private final @NonNull DatasetItemService datasetItemService;
47+
private final @NonNull DatasetService datasetService;
3848
private final @NonNull Provider<RequestContext> requestContext;
3949
private final @NonNull ObjectMapper objectMapper;
4050
private final @NonNull IdGenerator idGenerator;
@@ -60,12 +70,17 @@ public DatasetExpansionResponse expandDataset(@NonNull UUID datasetId, @NonNull
6070
throw new BadRequestException("Cannot expand empty dataset. Add at least one sample first");
6171
}
6272

73+
var datasetType = datasetService.getById(datasetId, workspaceId)
74+
.orElseThrow(() -> new BadRequestException("Dataset not found"))
75+
.type();
76+
6377
// Use custom prompt if provided, otherwise build default prompt
6478
var generationPrompt = StringUtils.isNotBlank(request.customPrompt())
6579
? request.customPrompt().trim()
6680
: buildGenerationPrompt(existingItems.content(), request);
6781
// Generate samples using LLM with batch processing for large requests
68-
var generatedSamples = generateSamplesInBatches(generationPrompt, request, datasetId, workspaceId);
82+
var generatedSamples = generateSamplesInBatches(generationPrompt, request, datasetId, workspaceId,
83+
datasetType);
6984
log.info("Finished dataset expansion for datasetId '{}', workspaceId '{}', total samples '{}'",
7085
datasetId, workspaceId, generatedSamples.size());
7186
return DatasetExpansionResponse.builder()
@@ -122,8 +137,10 @@ private String buildGenerationPrompt(List<DatasetItem> existingItems, DatasetExp
122137
}
123138

124139
private List<DatasetItem> generateSamplesInBatches(
125-
String basePrompt, DatasetExpansion request, UUID datasetId, String workspaceId) {
140+
String basePrompt, DatasetExpansion request, UUID datasetId, String workspaceId,
141+
DatasetType datasetType) {
126142
var allSamples = new ArrayList<DatasetItem>();
143+
var maxCompletionTokens = resolveMaxCompletionTokens(request);
127144
var totalSamples = request.sampleCount();
128145
var batchSize = Math.min(20, totalSamples); // Process in batches of up to 20
129146
var remainingSamples = totalSamples;
@@ -162,7 +179,7 @@ private List<DatasetItem> generateSamplesInBatches(
162179
.preserveFields(request.preserveFields())
163180
.variationInstructions(request.variationInstructions())
164181
.build(),
165-
datasetId, workspaceId);
182+
datasetId, workspaceId, datasetType, maxCompletionTokens);
166183

167184
allSamples.addAll(batchSamples);
168185
remainingSamples -= currentBatchSize;
@@ -173,15 +190,20 @@ private List<DatasetItem> generateSamplesInBatches(
173190
}
174191

175192
private List<DatasetItem> generateSamples(
176-
String prompt, DatasetExpansion request, UUID datasetId, String workspaceId) {
193+
String prompt, DatasetExpansion request, UUID datasetId, String workspaceId,
194+
DatasetType datasetType, Integer maxCompletionTokens) {
177195
try {
178-
// Create chat completion request, request should handle most models including reasoning models like GPT-5, Sonnet, etc.
179-
var chatRequest = ChatCompletionRequest.builder()
196+
var builder = ChatCompletionRequest.builder()
180197
.model(request.model())
181198
.addUserMessage(prompt)
182-
.temperature(1.0) // Set temperature to 1.0 for consistent output
183-
.stream(false) // Non-streaming request for dataset expansion
184-
.build();
199+
.temperature(1.0)
200+
.stream(false);
201+
202+
if (maxCompletionTokens != null) {
203+
builder.maxCompletionTokens(maxCompletionTokens);
204+
}
205+
206+
var chatRequest = builder.build();
185207

186208
// Call LLM
187209
var response = chatCompletionService.create(chatRequest, workspaceId);
@@ -191,22 +213,38 @@ private List<DatasetItem> generateSamples(
191213

192214
// Parse the JSON response
193215
var parsedSamples = parseGeneratedSamples(
194-
generatedContent, datasetId, request.model(), request.sampleCount());
216+
generatedContent, datasetId, request.model(), request.sampleCount(), datasetType);
195217
log.debug("Parsed '{}' samples from LLM response", parsedSamples.size());
196218
return parsedSamples;
197219

220+
} catch (BadRequestException exception) {
221+
log.error("Validation error during sample generation", exception);
222+
throw exception;
223+
} catch (ClientErrorException | ServerErrorException exception) {
224+
log.error("LLM service error during sample generation", exception);
225+
throw exception;
198226
} catch (Exception exception) {
199227
log.error("Failed to generate samples using LLM", exception);
200-
// If it's already a RuntimeException with a detailed message, preserve it
201-
if (exception instanceof BadRequestException && exception.getMessage().contains("AI model")) {
202-
throw exception;
203-
}
204-
throw new BadRequestException("Failed to generate synthetic samples", exception);
228+
throw new InternalServerErrorException("Failed to generate synthetic samples", exception);
205229
}
206230
}
207231

232+
private Integer resolveMaxCompletionTokens(DatasetExpansion request) {
233+
if (request.maxCompletionTokens() != null) {
234+
return request.maxCompletionTokens();
235+
}
236+
237+
var provider = llmProviderFactory.getLlmProvider(request.model());
238+
if (provider == LlmProvider.ANTHROPIC) {
239+
return DEFAULT_MAX_COMPLETION_TOKENS;
240+
}
241+
242+
return null;
243+
}
244+
208245
private List<DatasetItem> parseGeneratedSamples(
209-
String generatedContent, UUID datasetId, String model, int requestedSampleCount) {
246+
String generatedContent, UUID datasetId, String model, int requestedSampleCount,
247+
DatasetType datasetType) {
210248
try {
211249
// Clean the response - sometimes LLMs add markdown formatting
212250
String cleanedContent = generatedContent.trim();
@@ -241,46 +279,12 @@ private List<DatasetItem> parseGeneratedSamples(
241279
if (rootNode.isArray()) {
242280
for (var sampleNode : rootNode) {
243281
if (sampleNode.isObject()) {
244-
var dataNode = (ObjectNode) sampleNode;
245-
246-
// Add metadata to indicate this is synthetic
247-
dataNode.put("_generated", true);
248-
dataNode.put("_generation_model", model);
249-
250-
// Convert to Map for DatasetItem
251-
Map<String, JsonNode> dataMap = objectMapper.convertValue(dataNode,
252-
objectMapper.getTypeFactory().constructMapType(Map.class, String.class,
253-
JsonNode.class));
254-
255-
var sample = DatasetItem.builder()
256-
.id(idGenerator.generateId())
257-
.datasetId(datasetId)
258-
.data(dataMap)
259-
.source(com.comet.opik.api.DatasetItemSource.MANUAL)
260-
.build();
261-
262-
samples.add(sample);
282+
samples.add(buildDatasetItem((ObjectNode) sampleNode, datasetId, model, datasetType));
263283
}
264284
}
265285
} else if (rootNode.isObject()) {
266-
// Handle case where LLM returns a single object instead of array
267286
log.warn("LLM returned single object instead of array, wrapping in array");
268-
var dataNode = (ObjectNode) rootNode;
269-
dataNode.put("_generated", true);
270-
dataNode.put("_generation_model", model);
271-
272-
Map<String, JsonNode> dataMap = objectMapper.convertValue(dataNode,
273-
objectMapper.getTypeFactory().constructMapType(Map.class, String.class,
274-
JsonNode.class));
275-
276-
var sample = DatasetItem.builder()
277-
.id(idGenerator.generateId())
278-
.datasetId(datasetId)
279-
.data(dataMap)
280-
.source(com.comet.opik.api.DatasetItemSource.MANUAL)
281-
.build();
282-
283-
samples.add(sample);
287+
samples.add(buildDatasetItem((ObjectNode) rootNode, datasetId, model, datasetType));
284288
} else {
285289
throw new BadRequestException(
286290
"Expected JSON array or object, but got: '%s'".formatted(rootNode.getNodeType()));
@@ -314,6 +318,24 @@ private List<DatasetItem> parseGeneratedSamples(
314318
}
315319
}
316320

321+
private DatasetItem buildDatasetItem(ObjectNode dataNode, UUID datasetId, String model,
322+
DatasetType datasetType) {
323+
if (datasetType != DatasetType.TEST_SUITE) {
324+
dataNode.put("_generated", true);
325+
dataNode.put("_generation_model", model);
326+
}
327+
328+
Map<String, JsonNode> dataMap = objectMapper.convertValue(dataNode,
329+
objectMapper.getTypeFactory().constructMapType(Map.class, String.class, JsonNode.class));
330+
331+
return DatasetItem.builder()
332+
.id(idGenerator.generateId())
333+
.datasetId(datasetId)
334+
.data(dataMap)
335+
.source(com.comet.opik.api.DatasetItemSource.MANUAL)
336+
.build();
337+
}
338+
317339
private String buildUserFriendlyErrorMessage(Exception e, String generatedContent) {
318340
// Check the type of error and provide specific guidance
319341
if (e instanceof com.fasterxml.jackson.core.JsonParseException) {

0 commit comments

Comments
 (0)