Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions dataproxy/logs/k8s_log_streamer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package logs
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @pingsutw because the TailLogs request and response types are different i just duplicated this here. I'll rip out all the equivalents in runs/ after the SDK cuts over to the new endpoint


import (
"bufio"
"context"
"fmt"
"io"
"strings"
"time"

"connectrpc.com/connect"
corev1 "k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"

"google.golang.org/protobuf/types/known/timestamppb"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/dataproxy"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/logs/dataplane"
)

const (
logBatchSize = 100
defaultInitialLines = int64(1000)
)

// K8sLogStreamer streams logs directly from Kubernetes pods.
type K8sLogStreamer struct {
clientset kubernetes.Interface
}

// NewK8sLogStreamer creates a K8sLogStreamer from a Kubernetes REST config.
// It clears the timeout so that long-lived log streams are not interrupted.
func NewK8sLogStreamer(k8sConfig *rest.Config) (*K8sLogStreamer, error) {
cfg := rest.CopyConfig(k8sConfig)
cfg.Timeout = 0
clientset, err := kubernetes.NewForConfig(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create kubernetes clientset: %w", err)
}
return &K8sLogStreamer{clientset: clientset}, nil
}

// TailLogs streams log lines for the given LogContext from a Kubernetes pod.
func (s *K8sLogStreamer) TailLogs(ctx context.Context, logContext *core.LogContext, stream *connect.ServerStream[dataproxy.TailLogsResponse]) error {
pod, container, err := GetPrimaryPodAndContainer(logContext)
if err != nil {
return connect.NewError(connect.CodeNotFound, err)
}

tailLines := defaultInitialLines
opts := &corev1.PodLogOptions{
Container: container.GetContainerName(),
Follow: true,
Timestamps: true,
TailLines: &tailLines,
}

// Set SinceTime from container start time if available.
// When SinceTime is set, it takes precedence and we clear TailLines
// to stream all logs from that point forward.
if startTime := container.GetProcess().GetContainerStartTime(); startTime != nil {
t := metav1.NewTime(startTime.AsTime())
opts.SinceTime = &t
opts.TailLines = nil
}

// Only follow logs when the pod is actively running. For pending or
// terminated pods, disable follow so existing logs are returned immediately.
podObj, err := s.clientset.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetPodName(), metav1.GetOptions{})
if err != nil {
if k8serrors.IsNotFound(err) {
return connect.NewError(connect.CodeNotFound, fmt.Errorf("pod %s not found in namespace %s", pod.GetPodName(), pod.GetNamespace()))
}
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to get pod: %w", err))
}
opts.Follow = podObj.Status.Phase == corev1.PodRunning

// Create a context without the incoming gRPC deadline so long-lived follow
// streams are not killed by a short client/proxy timeout. Cancellation is
// still propagated so the stream closes when the client disconnects.
streamCtx, streamCancel := context.WithCancel(context.Background())
defer streamCancel()
stop := context.AfterFunc(ctx, streamCancel)
defer stop()

logStream, err := s.clientset.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetPodName(), opts).Stream(streamCtx)
if err != nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to stream pod logs: %w", err))
}
defer logStream.Close()

reader := bufio.NewReader(logStream)

lines := make([]*dataplane.LogLine, 0, logBatchSize)
var readErr error

for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
// Trim trailing newline(s) including possible CRLF.
line = strings.TrimRight(line, "\r\n")
logLine := parseLogLine(line)
lines = append(lines, logLine)

if len(lines) >= logBatchSize {
if sendErr := stream.Send(&dataproxy.TailLogsResponse{
Logs: []*dataproxy.TailLogsResponse_Logs{
{Lines: lines},
},
}); sendErr != nil {
return sendErr
}
lines = make([]*dataplane.LogLine, 0, logBatchSize)
}
}
if err != nil {
if err != io.EOF {
readErr = err
}
break
}

// Flush buffered lines when no more data is immediately available.
// Without this, lines sit in the buffer while ReadString blocks
// waiting for the next newline (e.g. pod is sleeping).
if len(lines) > 0 && reader.Buffered() == 0 {
if sendErr := stream.Send(&dataproxy.TailLogsResponse{
Logs: []*dataproxy.TailLogsResponse_Logs{
{Lines: lines},
},
}); sendErr != nil {
return sendErr
}
lines = make([]*dataplane.LogLine, 0, logBatchSize)
}
}

// Send remaining lines.
if len(lines) > 0 {
if err := stream.Send(&dataproxy.TailLogsResponse{
Logs: []*dataproxy.TailLogsResponse_Logs{
{Lines: lines},
},
}); err != nil {
return err
}
}

// Return error for non-EOF read failures (unless context was canceled).
if readErr != nil && ctx.Err() == nil {
return connect.NewError(connect.CodeInternal, fmt.Errorf("error reading log stream: %w", readErr))
}

return nil
}

// parseLogLine splits a K8s log line into timestamp and message.
// K8s log lines with timestamps are formatted as: "2006-01-02T15:04:05.999999999Z message"
func parseLogLine(line string) *dataplane.LogLine {
if idx := strings.IndexByte(line, ' '); idx > 0 {
if t, err := time.Parse(time.RFC3339Nano, line[:idx]); err == nil {
return &dataplane.LogLine{
Originator: dataplane.LogLineOriginator_USER,
Timestamp: timestamppb.New(t),
Message: line[idx+1:],
}
}
}

return &dataplane.LogLine{
Originator: dataplane.LogLineOriginator_USER,
Message: line,
}
}
182 changes: 182 additions & 0 deletions dataproxy/logs/k8s_log_streamer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package logs

import (
"context"
"testing"
"time"

"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/logs/dataplane"
)

func TestParseLogLine_WithTimestamp(t *testing.T) {
line := "2024-01-15T10:30:00.123456789Z Hello, world!"
logLine := parseLogLine(line)

assert.Equal(t, "Hello, world!", logLine.Message)
assert.NotNil(t, logLine.Timestamp)
expected := time.Date(2024, 1, 15, 10, 30, 0, 123456789, time.UTC)
assert.Equal(t, expected, logLine.Timestamp.AsTime())
assert.Equal(t, dataplane.LogLineOriginator_USER, logLine.Originator)
}

func TestParseLogLine_WithoutTimestamp(t *testing.T) {
line := "just a plain log message"
logLine := parseLogLine(line)

assert.Equal(t, "just a plain log message", logLine.Message)
assert.Nil(t, logLine.Timestamp)
assert.Equal(t, dataplane.LogLineOriginator_USER, logLine.Originator)
}

func TestParseLogLine_MalformedTimestamp(t *testing.T) {
line := "not-a-timestamp some message"
logLine := parseLogLine(line)

assert.Equal(t, "not-a-timestamp some message", logLine.Message)
assert.Nil(t, logLine.Timestamp)
}

func TestParseLogLine_EmptyMessage(t *testing.T) {
line := "2024-01-15T10:30:00Z "
logLine := parseLogLine(line)

assert.Equal(t, "", logLine.Message)
assert.NotNil(t, logLine.Timestamp)
}

func TestGetPrimaryPodAndContainer_HappyPath(t *testing.T) {
logCtx := &core.LogContext{
PrimaryPodName: "my-pod",
Pods: []*core.PodLogContext{
{
PodName: "my-pod",
Namespace: "default",
PrimaryContainerName: "main",
Containers: []*core.ContainerContext{
{ContainerName: "main"},
{ContainerName: "sidecar"},
},
},
},
}

pod, container, err := GetPrimaryPodAndContainer(logCtx)
assert.NoError(t, err)
assert.Equal(t, "my-pod", pod.GetPodName())
assert.Equal(t, "default", pod.GetNamespace())
assert.Equal(t, "main", container.GetContainerName())
}

func TestGetPrimaryPodAndContainer_EmptyPodName(t *testing.T) {
logCtx := &core.LogContext{
PrimaryPodName: "",
}

_, _, err := GetPrimaryPodAndContainer(logCtx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "primary pod name is empty")
}

func TestGetPrimaryPodAndContainer_PodNotFound(t *testing.T) {
logCtx := &core.LogContext{
PrimaryPodName: "missing-pod",
Pods: []*core.PodLogContext{
{PodName: "other-pod"},
},
}

_, _, err := GetPrimaryPodAndContainer(logCtx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found in log context")
}

func TestGetPrimaryPodAndContainer_ContainerNotFound(t *testing.T) {
logCtx := &core.LogContext{
PrimaryPodName: "my-pod",
Pods: []*core.PodLogContext{
{
PodName: "my-pod",
PrimaryContainerName: "missing-container",
Containers: []*core.ContainerContext{
{ContainerName: "other"},
},
},
},
}

_, _, err := GetPrimaryPodAndContainer(logCtx)
assert.Error(t, err)
assert.Contains(t, err.Error(), "primary container")
}

func newTestLogContext(podName, namespace, containerName string) *core.LogContext {
return &core.LogContext{
PrimaryPodName: podName,
Pods: []*core.PodLogContext{
{
PodName: podName,
Namespace: namespace,
PrimaryContainerName: containerName,
Containers: []*core.ContainerContext{
{ContainerName: containerName},
},
},
},
}
}

func TestTailLogs_PodNotFound(t *testing.T) {
clientset := fake.NewSimpleClientset() // no pods

streamer := &K8sLogStreamer{clientset: clientset}
logCtx := newTestLogContext("missing-pod", "default", "main")

err := streamer.TailLogs(context.Background(), logCtx, nil)
require.Error(t, err)
assert.Equal(t, connect.CodeNotFound, connect.CodeOf(err))
assert.Contains(t, err.Error(), "not found")
}

func TestTailLogs_FollowSetBasedOnPodPhase(t *testing.T) {
tests := []struct {
name string
phase corev1.PodPhase
wantFollow bool
}{
{"running pod should follow", corev1.PodRunning, true},
{"succeeded pod should not follow", corev1.PodSucceeded, false},
{"failed pod should not follow", corev1.PodFailed, false},
{"pending pod should not follow", corev1.PodPending, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientset := fake.NewSimpleClientset(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "my-pod",
Namespace: "default",
},
Status: corev1.PodStatus{
Phase: tt.phase,
},
})

// Verify we can fetch the pod and the phase is correct.
podObj, err := clientset.CoreV1().Pods("default").Get(context.Background(), "my-pod", metav1.GetOptions{})
require.NoError(t, err)
assert.Equal(t, tt.phase, podObj.Status.Phase)

// Verify the follow logic: Follow should only be true when phase is Running.
gotFollow := podObj.Status.Phase == corev1.PodRunning
assert.Equal(t, tt.wantFollow, gotFollow)
})
}
}
Loading
Loading