Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions internal/integration/unified/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ type operation struct {

// execute runs the operation and verifies the returned result and/or error. If the result needs to be saved as
// an entity, it also updates the entityMap associated with ctx to do so.
func (op *operation) execute(ctx context.Context, loopDone <-chan struct{}) error {
func (op *operation) execute(ctx context.Context, loopDone <-chan struct{}) (*operationResult, error) {
res, err := op.run(ctx, loopDone)
if err != nil {
return fmt.Errorf("execution failed: %v", err)
return nil, fmt.Errorf("execution failed: %v", err)
}

if op.IgnoreResultAndError {
return nil
return nil, nil
}

if err := verifyOperationError(ctx, op.ExpectedError, res); err != nil {
return fmt.Errorf("error verification failed: %v", err)
return nil, fmt.Errorf("error verification failed: %v", err)
}

if op.ExpectedResult != nil {
if err := verifyOperationResult(ctx, *op.ExpectedResult, res); err != nil {
return fmt.Errorf("result verification failed: %v", err)
return nil, fmt.Errorf("result verification failed: %v", err)
}
}
return nil
return res, nil
}

// isCreateView will return true if the operation is to create a collection with a view.
Expand Down Expand Up @@ -125,8 +125,9 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat
case "startTransaction":
return executeStartTransaction(ctx, op)
case "withTransaction":
// executeWithTransaction internally verifies results/errors for each operation, so it doesn't return a result.
return newEmptyResult(), executeWithTransaction(ctx, op, loopDone)
// executeWithTransaction internally verifies results/errors for each operation.
// The error from WithTransaction() is wrapped in the result.
return executeWithTransaction(ctx, op, loopDone)
case "getSnapshotTime":
// executeGetSnapshotTime stores the snapshot time of the session as on
// the entity map for subsequent use.
Expand Down
30 changes: 20 additions & 10 deletions internal/integration/unified/session_operation_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,38 +81,48 @@ func executeStartTransaction(ctx context.Context, operation *operation) (*operat
return newErrorResult(sess.StartTransaction(opts)), nil
}

func executeWithTransaction(ctx context.Context, op *operation, loopDone <-chan struct{}) error {
func executeWithTransaction(ctx context.Context, op *operation, loopDone <-chan struct{}) (*operationResult, error) {
sess, err := entities(ctx).session(op.Object)
if err != nil {
return err
return nil, err
}

// Process the "callback" argument. This is an array of operation objects, each of which should be executed inside
// the transaction.
callback, err := op.Arguments.LookupErr("callback")
if err != nil {
return newMissingArgumentError("callback")
return nil, newMissingArgumentError("callback")
}
var operations []*operation
if err := callback.Unmarshal(&operations); err != nil {
return fmt.Errorf("error transforming callback option to slice of operations: %v", err)
return nil, fmt.Errorf("error transforming callback option to slice of operations: %v", err)
}

// Remove the "callback" field and process the other options.
var temp transactionOptions
if err := bson.Unmarshal(removeFieldsFromDocument(op.Arguments, "callback"), &temp); err != nil {
return fmt.Errorf("error unmarshalling arguments to transactionOptions: %v", err)
return nil, fmt.Errorf("error unmarshalling arguments to transactionOptions: %v", err)
}

_, err = sess.WithTransaction(ctx, func(ctx context.Context) (any, error) {
_, withTransErr := sess.WithTransaction(ctx, func(ctx context.Context) (any, error) {
var cbErr error
for idx, oper := range operations {
if err := oper.execute(ctx, loopDone); err != nil {
return nil, fmt.Errorf("error executing operation %q at index %d: %v", oper.Name, idx, err)
res, execErr := oper.execute(ctx, loopDone)
if execErr != nil {
// Capture the error but continue executing the remaining operations in the callback.
err = fmt.Errorf("error executing operation %q at index %d: %v", oper.Name, idx, execErr)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: we capture err here because we want to get "3rd type of result" from this function, errors that are not related to WithTransaction logic, right? So we do want to get error, but we don't want driver to handle this as if it's a transaction error.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, we need to capture the error while still completing all operations in the transaction.

return nil, nil
Comment on lines +113 to +114
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

executeWithTransaction callback: on execErr, the callback currently returns (nil, nil), which causes WithTransaction to commit even though the callback did not successfully execute all operations, and relies on mutating the outer err to fail later. It would be safer to return the wrapped execution error from the callback so WithTransaction aborts/cleans up consistently and avoids committing partial work.

Suggested change
err = fmt.Errorf("error executing operation %q at index %d: %v", oper.Name, idx, execErr)
return nil, nil
return nil, fmt.Errorf("error executing operation %q at index %d: %w", oper.Name, idx, execErr)

Copilot uses AI. Check for mistakes.
}
if cbErr == nil && res != nil {
cbErr = res.Err
}
}
Comment on lines +108 to 119
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

executeWithTransaction callback: once an operation inside the callback produces a non-nil operation error (res.Err), the callback should return that error immediately so subsequent operations are not executed. The current logic records the first res.Err in cbErr but continues executing later operations, which can change transactional behavior compared to a real application callback and can cause extra side effects or different errors.

Copilot uses AI. Check for mistakes.
return nil, nil
return nil, cbErr
}, temp.TransactionOptionsBuilder)
return err
if err != nil {
return nil, err
}
return &operationResult{Err: withTransErr}, nil
}

func executeGetSnapshotTime(ctx context.Context, op *operation) (*operationResult, error) {
Expand Down
5 changes: 3 additions & 2 deletions internal/integration/unified/testrunner_operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-c
return fmt.Errorf("run on unknown thread: %s", thread)
}
routine.(*backgroundRoutine).addTask(threadOp.Name, func() error {
return threadOp.execute(ctx, loopDone)
_, execErr := threadOp.execute(ctx, loopDone)
return execErr
})
return nil
case "waitForThread":
Expand Down Expand Up @@ -323,7 +324,7 @@ func executeLoop(ctx context.Context, args *loopArgs, loopDone <-chan struct{})
if operation.Name == "loop" {
return fmt.Errorf("loop sub-operations should not include loop")
}
loopErr = operation.execute(ctx, loopDone)
_, loopErr = operation.execute(ctx, loopDone)

// if the operation errors, stop this loop
if loopErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/integration/unified/unified_spec_runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (tc *TestCase) Run(ls LoggerSkipper) error {
}

for idx, operation := range tc.Operations {
if err := operation.execute(testCtx, tc.loopDone); err != nil {
if _, err := operation.execute(testCtx, tc.loopDone); err != nil {
if isSkipTestError(err) {
ls.Skip(err)
}
Expand Down
1 change: 0 additions & 1 deletion internal/spectest/skip.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,6 @@ var skipTests = map[string][]string{
"TestUnifiedSpec/client-side-operations-timeout/tests/close-cursors.json/timeoutMS_is_refreshed_for_close",
"TestUnifiedSpec/client-side-operations-timeout/tests/convenient-transactions.json/withTransaction_raises_a_client-side_error_if_timeoutMS_is_overridden_inside_the_callback",
"TestUnifiedSpec/client-side-operations-timeout/tests/convenient-transactions.json/timeoutMS_is_not_refreshed_for_each_operation_in_the_callback",
"TestUnifiedSpec/client-side-operations-timeout/tests/convenient-transactions.json/withTransaction_surfaces_a_timeout_after_exhausting_transient_transaction_retries,_retaining_the_last_transient_error_as_the_timeout_cause.",
"TestUnifiedSpec/client-side-operations-timeout/tests/cursors.json/find_errors_if_timeoutMode_is_set_and_timeoutMS_is_not",
"TestUnifiedSpec/client-side-operations-timeout/tests/cursors.json/collection_aggregate_errors_if_timeoutMode_is_set_and_timeoutMS_is_not",
"TestUnifiedSpec/client-side-operations-timeout/tests/cursors.json/database_aggregate_errors_if_timeoutMode_is_set_and_timeoutMS_is_not",
Expand Down
29 changes: 29 additions & 0 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,35 @@ func (bwe ClientBulkWriteException) Error() string {
return "bulk write exception: " + strings.Join(causes, ", ")
}

// TimeoutError represents an error that occurred due to a timeout.
type TimeoutError struct {
Wrapped error
Comment on lines +831 to +832
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does TimeoutError need to be exported? If not, we shouldn't export it.

Additionally, we need to update IsTimeout to return true if the error is a timeoutError, including a test to confirm IsTimeout works with the new error type.

}

// Error implements the error interface.
func (e TimeoutError) Error() string {
Comment thread
matthewdale marked this conversation as resolved.
const timeoutMsg = "operation timed out"
if e.Wrapped == nil {
return timeoutMsg
}
return fmt.Sprintf("%s: %v", timeoutMsg, e.Wrapped.Error())
}

// Unwrap returns the underlying error.
func (e TimeoutError) Unwrap() error {
return e.Wrapped
}

// HasErrorLabel returns true if the error contains the specified label.
func (e TimeoutError) HasErrorLabel(label string) bool {
if label == "ExceededTimeLimitError" {
return true
} else if le := LabeledError(nil); errors.As(e.Wrapped, &le) {
return le.HasErrorLabel(label)
}
return false
}
Comment on lines +830 to +857
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a unit test for TimeoutError to ensure it integrates with existing timeout detection and labeling (e.g., IsTimeout should return true via the ExceededTimeLimitError label, and HasErrorLabel should delegate to the wrapped error for other labels). There are already table-driven tests for IsTimeout in mongo/errors_test.go that can be extended.

Copilot uses AI. Check for mistakes.

// returnResult is used to determine if a function calling processWriteError should return
// the result or return nil. Since the processWriteError function is used by many different
// methods, both *One and *Many, we need a way to differentiate if the method should return
Expand Down
55 changes: 55 additions & 0 deletions mongo/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,11 @@ func TestIsTimeout(t *testing.T) {
}),
result: true,
},
{
name: "timeout error",
err: TimeoutError{},
result: true,
},
{
name: "other error",
err: errors.New("foo"),
Expand All @@ -679,6 +684,56 @@ func TestIsTimeout(t *testing.T) {
}
}

func TestTimeoutError(t *testing.T) {
tests := []struct {
name string
err TimeoutError
errMsg string
labels []string
}{
{
name: "TimeoutError without wrapped error",
err: TimeoutError{
Wrapped: nil,
},
errMsg: "operation timed out",
labels: []string{"ExceededTimeLimitError"},
},
{
name: "TimeoutError with wrapped LabeledError",
err: TimeoutError{
Wrapped: CommandError{
Code: 100,
Message: "",
Labels: []string{"other"},
Name: "blah",
Wrapped: context.DeadlineExceeded,
Raw: nil,
},
},
errMsg: "operation timed out: (blah): context deadline exceeded",
labels: []string{"ExceededTimeLimitError", "other"},
},
{
name: "TimeoutError with wrapped non-LabeledError",
err: TimeoutError{
Wrapped: context.DeadlineExceeded,
},
errMsg: "operation timed out: context deadline exceeded",
labels: []string{"ExceededTimeLimitError"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert.True(t, IsTimeout(tc.err), "expected a timeout error")
assert.Equal(t, tc.err.Error(), tc.errMsg, "expected error message %q, got %q", tc.errMsg, tc.err.Error())
for _, label := range tc.labels {
assert.True(t, tc.err.HasErrorLabel(label), "expected label %q", label)
}
})
}
}

func TestServerError_ErrorCodes(t *testing.T) {
tests := []struct {
name string
Expand Down
17 changes: 8 additions & 9 deletions mongo/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ func (s *Session) WithTransaction(
}
backoff := expDur * time.Duration(jitter.Int63n(512)) / 512
if time.Since(startTime)+backoff > transTimeout {
return nil, err
return nil, TimeoutError{Wrapped: err}
}
sleep := time.NewTimer(backoff)
select {
case <-timeout.C:
sleep.Stop()
return nil, err
return nil, TimeoutError{Wrapped: err}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spec says we have to distinguish between CSOT and non-CSOT errors:
Note 1: When the TIMEOUT_MS (calculated in step [1.3] is reached we MUST report a timeout error wrapping the last error that was encountered which triggered the retry behavior. If timeoutMS is set, then timeout error is a special type which is defined in CSOT , If timeoutMS is not set, then propagate it as timeout error if the language allows to expose the underlying error as a cause of a timeout error. If timeout error is thrown then it SHOULD expose error label(s) from the transient error.
https://github.com/mongodb/specifications/blob/master/source/transactions-convenient-api/transactions-convenient-api.md#sequence-of-actions

In the current approach we always return TimeoutErrors which are indistinguishable.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tadjik1 The non-CSOT path in WithTransaction uses a manual wall-clock check against the 120-second limit, not a context deadline. AFIAK context.DeadlineExceeded is always CSOT in Go Driver.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, since the original error is wrapped in the new type, the user can always trace back and distinguish the errors.

Copy link
Copy Markdown
Contributor

@matthewdale matthewdale Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably beyond the scope of this PR, but moving forward we should avoid picking which error to return and instead return all errors joined using errutil.Join

E.g.

var errs []error
// ...

select {
case <-timeout.C:
	errs = append(errs, errors.New("default WithTransaction timeout reached"))
	return nil, errutil.Join(errs...)
default:
}
// ...

case <-sleep.C:
}
if expDur < backoffMax {
Expand All @@ -178,7 +178,7 @@ func (s *Session) WithTransaction(

select {
case <-timeout.C:
return nil, err
return nil, TimeoutError{Wrapped: err}
default:
}
Comment on lines 179 to 183
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WithTransaction now wraps some timeout paths with TimeoutError, but other uses of the same overall timeout timer still return the raw err (e.g., the commit retry loop). This can lead to inconsistent timeout signaling/labeling for callers. If the goal is to surface a timeout error for WithTransaction timeouts, consider wrapping all returns caused by the overall timeout timer consistently.

Copilot uses AI. Check for mistakes.

Expand Down Expand Up @@ -217,15 +217,14 @@ func (s *Session) WithTransaction(
return res, nil
}

select {
case <-timeout.C:
return res, err
default:
}

var cerr CommandError
if errors.As(err, &cerr) {
if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
select {
case <-timeout.C:
return res, TimeoutError{Wrapped: err}
default:
Comment on lines +223 to +226
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the intent of this change to only respect the default WithTransaction timeout if the error is a CommandError with label "UnknownTransactionCommitResult"?

}
continue
}
if cerr.HasErrorLabel(driver.TransientTransactionError) {
Expand Down
Loading