From 78c76323949ab72b6fb61b30f39034e25ad9215a Mon Sep 17 00:00:00 2001 From: Phillip Michelsen Date: Fri, 12 Sep 2025 03:47:30 +0000 Subject: [PATCH] Implement socket streaming server and optimize connection handling; adjust test provider timing --- .../data_service/cmd/data_service/main.go | 17 +- .../data_service/cmd/stream_tap_v2/main.go | 331 ++++++++++++++++++ .../server/socket_streaming_server.go | 242 +++++++++++++ .../server/socket_streaming_server.go.bak | 136 ------- 4 files changed, 589 insertions(+), 137 deletions(-) create mode 100644 services/data_service/cmd/stream_tap_v2/main.go create mode 100644 services/data_service/internal/server/socket_streaming_server.go delete mode 100644 services/data_service/internal/server/socket_streaming_server.go.bak diff --git a/services/data_service/cmd/data_service/main.go b/services/data_service/cmd/data_service/main.go index 1ebcc06..dee5271 100644 --- a/services/data_service/cmd/data_service/main.go +++ b/services/data_service/cmd/data_service/main.go @@ -57,7 +57,7 @@ func main() { // Setup r := router.NewRouter(2048) m := manager.NewManager(r) - testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*50) + testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*13) if err := m.AddProvider("test_provider", testProvider); err != nil { slog.Error("add provider failed", "err", err) os.Exit(1) @@ -97,5 +97,20 @@ func main() { } }() + // Socket Streaming Server + socketStreamingServer := server.NewSocketStreamingServer(m) + go func() { + lis, err := net.Listen("tcp", ":50060") + if err != nil { + slog.Error("listen failed", "cmp", "socket-streaming", "addr", ":50060", "err", err) + os.Exit(1) + } + slog.Info("listening", "cmp", "socket-streaming", "addr", ":50060") + if err := socketStreamingServer.Serve(lis); err != nil { + slog.Error("serve failed", "cmp", "socket-streaming", "err", err) + os.Exit(1) + } + }() + select {} } diff --git a/services/data_service/cmd/stream_tap_v2/main.go b/services/data_service/cmd/stream_tap_v2/main.go new file mode 100644 index 0000000..e2e9941 --- /dev/null +++ b/services/data_service/cmd/stream_tap_v2/main.go @@ -0,0 +1,331 @@ +package main + +import ( + "bufio" + "context" + "encoding/binary" + "flag" + "fmt" + "io" + "math" + "net" + "os" + "os/signal" + "strings" + "sync/atomic" + "syscall" + "time" + + pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/proto" +) + +type idsFlag []string + +func (i *idsFlag) String() string { return strings.Join(*i, ",") } +func (i *idsFlag) Set(v string) error { + if v == "" { + return nil + } + *i = append(*i, v) + return nil +} + +func parseIDPair(s string) (provider, subject string, err error) { + parts := strings.SplitN(s, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("want provider:subject, got %q", s) + } + return parts[0], parts[1], nil +} + +func toIdentifierKey(input string) (string, error) { + if strings.Contains(input, "::") { + return input, nil + } + prov, subj, err := parseIDPair(input) + if err != nil { + return "", err + } + return "raw::" + strings.ToLower(prov) + "." + subj, nil +} + +func waitReady(ctx context.Context, conn *grpc.ClientConn) error { + for { + s := conn.GetState() + if s == connectivity.Ready { + return nil + } + if !conn.WaitForStateChange(ctx, s) { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("WaitForStateChange returned without state change") + } + } +} + +type streamStats struct { + TotalMsgs int64 + TotalBytes int64 + TickMsgs int64 + TickBytes int64 +} + +type stats struct { + TotalMsgs int64 + TotalBytes int64 + ByStream map[string]*streamStats +} + +func main() { + var ids idsFlag + var ctlAddr string + var strAddr string + var timeout time.Duration + var refresh time.Duration + + flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable") + flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address") + flag.StringVar(&strAddr, "str", "127.0.0.1:50060", "socket streaming address host:port") + flag.DurationVar(&timeout, "timeout", 10*time.Second, "start/config/connect timeout") + flag.DurationVar(&refresh, "refresh", 1*time.Second, "dashboard refresh interval") + flag.Parse() + + if len(ids) == 0 { + _, _ = fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)") + os.Exit(2) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + // Control channel + ccCtl, err := grpc.NewClient( + ctlAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "new control client: %v\n", err) + os.Exit(1) + } + defer ccCtl.Close() + ccCtl.Connect() + + ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout) + if err := waitReady(ctlConnCtx, ccCtl); err != nil { + cancelCtlConn() + _, _ = fmt.Fprintf(os.Stderr, "connect control: %v\n", err) + os.Exit(1) + } + cancelCtlConn() + + ctl := pb.NewDataServiceControlClient(ccCtl) + + // Start stream + ctxStart, cancelStart := context.WithTimeout(ctx, timeout) + startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{}) + cancelStart() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "StartStream: %v\n", err) + os.Exit(1) + } + streamUUID := startResp.GetStreamUuid() + fmt.Printf("stream: %s\n", streamUUID) + + // Configure identifiers + var pbIDs []*pb.Identifier + orderedIDs := make([]string, 0, len(ids)) + for _, s := range ids { + key, err := toIdentifierKey(s) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "bad --id: %v\n", err) + os.Exit(2) + } + pbIDs = append(pbIDs, &pb.Identifier{Key: key}) + orderedIDs = append(orderedIDs, key) + } + + ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout) + _, err = ctl.ConfigureStream(ctxCfg, &pb.ConfigureStreamRequest{ + StreamUuid: streamUUID, + Identifiers: pbIDs, + }) + cancelCfg() + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "ConfigureStream: %v\n", err) + os.Exit(1) + } + fmt.Printf("configured %d identifiers\n", len(pbIDs)) + + // Socket streaming connection + d := net.Dialer{Timeout: timeout, KeepAlive: 30 * time.Second} + conn, err := d.DialContext(ctx, "tcp", strAddr) + if err != nil { + _, _ = fmt.Fprintf(os.Stderr, "dial socket: %v\n", err) + os.Exit(1) + } + defer conn.Close() + + if tc, ok := conn.(*net.TCPConn); ok { + _ = tc.SetNoDelay(true) + _ = tc.SetWriteBuffer(512 * 1024) + _ = tc.SetReadBuffer(512 * 1024) + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) + } + + // Send the stream UUID followed by '\n' per socket server contract. + if _, err := io.WriteString(conn, streamUUID+"\n"); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "send stream UUID: %v\n", err) + os.Exit(1) + } + fmt.Println("connected; streaming… (Ctrl-C to quit)") + + // Receiver goroutine → channel + type msgWrap struct { + idKey string + size int + err error + } + msgCh := make(chan msgWrap, 1024) + + go func() { + defer close(msgCh) + r := bufio.NewReaderSize(conn, 256*1024) + var hdr [4]byte + for { + if err := conn.SetReadDeadline(time.Now().Add(120 * time.Second)); err != nil { + msgCh <- msgWrap{err: err} + return + } + + if _, err := io.ReadFull(r, hdr[:]); err != nil { + msgCh <- msgWrap{err: err} + return + } + n := binary.BigEndian.Uint32(hdr[:]) + if n == 0 || n > 64*1024*1024 { + msgCh <- msgWrap{err: fmt.Errorf("invalid frame length: %d", n)} + return + } + buf := make([]byte, n) + if _, err := io.ReadFull(r, buf); err != nil { + msgCh <- msgWrap{err: err} + return + } + + var m pb.Message + if err := proto.Unmarshal(buf, &m); err != nil { + msgCh <- msgWrap{err: fmt.Errorf("unmarshal: %w", err)} + return + } + id := m.GetIdentifier().GetKey() + msgCh <- msgWrap{idKey: id, size: len(m.GetPayload())} + } + }() + + // Stats and dashboard + st := &stats{ByStream: make(map[string]*streamStats)} + seen := make(map[string]bool, len(orderedIDs)) + for _, id := range orderedIDs { + seen[id] = true + } + tick := time.NewTicker(refresh) + defer tick.Stop() + + clear := func() { fmt.Print("\033[H\033[2J") } + header := func() { + fmt.Printf("stream: %s now: %s refresh: %s\n", + streamUUID, time.Now().Format(time.RFC3339), refresh) + fmt.Println("--------------------------------------------------------------------------------------") + fmt.Printf("%-56s %10s %14s %12s %16s\n", "identifier", "msgs/s", "bytes/s", "total", "total_bytes") + fmt.Println("--------------------------------------------------------------------------------------") + } + + printAndReset := func() { + clear() + header() + + var totMsgsPS, totBytesPS float64 + for _, id := range orderedIDs { + s, ok := st.ByStream[id] + var msgsPS, bytesPS float64 + var totMsgs, totBytes int64 + if ok { + msgsPS = float64(atomic.SwapInt64(&s.TickMsgs, 0)) / refresh.Seconds() + bytesPS = float64(atomic.SwapInt64(&s.TickBytes, 0)) / refresh.Seconds() + totMsgs = atomic.LoadInt64(&s.TotalMsgs) + totBytes = atomic.LoadInt64(&s.TotalBytes) + } + totMsgsPS += msgsPS + totBytesPS += bytesPS + fmt.Printf("%-56s %10d %14d %12d %16d\n", + id, + int64(math.Round(msgsPS)), + int64(math.Round(bytesPS)), + totMsgs, + totBytes, + ) + } + + fmt.Println("--------------------------------------------------------------------------------------") + fmt.Printf("%-56s %10d %14d %12d %16d\n", + "TOTAL", + int64(math.Round(totMsgsPS)), + int64(math.Round(totBytesPS)), + atomic.LoadInt64(&st.TotalMsgs), + atomic.LoadInt64(&st.TotalBytes), + ) + } + + for { + select { + case <-ctx.Done(): + fmt.Println("\nshutting down") + return + + case <-tick.C: + printAndReset() + + case mw, ok := <-msgCh: + if !ok { + return + } + if mw.err != nil { + if ctx.Err() != nil { + return + } + if ne, ok := mw.err.(net.Error); ok && ne.Timeout() { + _, _ = fmt.Fprintln(os.Stderr, "recv timeout") + } else if mw.err == io.EOF { + _, _ = fmt.Fprintln(os.Stderr, "server closed stream") + } else { + _, _ = fmt.Fprintf(os.Stderr, "recv: %v\n", mw.err) + } + os.Exit(1) + } + + if !seen[mw.idKey] { + seen[mw.idKey] = true + orderedIDs = append(orderedIDs, mw.idKey) + } + + atomic.AddInt64(&st.TotalMsgs, 1) + atomic.AddInt64(&st.TotalBytes, int64(mw.size)) + + ss := st.ByStream[mw.idKey] + if ss == nil { + ss = &streamStats{} + st.ByStream[mw.idKey] = ss + } + atomic.AddInt64(&ss.TotalMsgs, 1) + atomic.AddInt64(&ss.TotalBytes, int64(mw.size)) + atomic.AddInt64(&ss.TickMsgs, 1) + atomic.AddInt64(&ss.TickBytes, int64(mw.size)) + } + } +} diff --git a/services/data_service/internal/server/socket_streaming_server.go b/services/data_service/internal/server/socket_streaming_server.go new file mode 100644 index 0000000..f34b9a3 --- /dev/null +++ b/services/data_service/internal/server/socket_streaming_server.go @@ -0,0 +1,242 @@ +package server + +import ( + "bufio" + "bytes" + "encoding/binary" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/google/uuid" + pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager" + "google.golang.org/protobuf/proto" +) + +type SocketStreamingServer struct { + manager *manager.Manager +} + +func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer { + return &SocketStreamingServer{manager: m} +} + +func (s *SocketStreamingServer) Serve(lis net.Listener) error { + for { + conn, err := lis.Accept() + if err != nil { + fmt.Printf("accept error: %v\n", err) + continue + } + go s.handleConnection(conn) + } +} + +func (s *SocketStreamingServer) handleConnection(conn net.Conn) { + defer func() { + if err := conn.Close(); err != nil { + fmt.Printf("conn close error: %v\n", err) + } else { + fmt.Println("connection closed") + } + }() + + if tc, ok := conn.(*net.TCPConn); ok { + _ = tc.SetNoDelay(true) // low latency + _ = tc.SetWriteBuffer(2 * 1024 * 1024) // bigger kernel sndbuf + _ = tc.SetReadBuffer(256 * 1024) + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) + // Note: avoid SetLinger>0; default is fine. + } + + reader := bufio.NewReaderSize(conn, 64*1024) + line, err := reader.ReadBytes('\n') + if err != nil { + fmt.Printf("read stream UUID error: %v\n", err) + _, _ = fmt.Fprint(conn, "Failed to read stream UUID\n") + return + } + streamUUID, err := uuid.Parse(string(trimLineEnding(line))) + if err != nil { + _, _ = fmt.Fprint(conn, "Invalid stream UUID\n") + return + } + + // Give the socket server room before router drops. Make out chan larger. + // Tune per your pressure. (in=256, out=8192 as example) + _, out, err := s.manager.AttachClient(streamUUID, 256, 8192) + if err != nil { + _, _ = fmt.Fprintf(conn, "Failed to attach to stream: %v\n", err) + return + } + defer func() { _ = s.manager.DetachClient(streamUUID) }() + + // Large bufio writer to reduce syscalls. + writer := bufio.NewWriterSize(conn, 1*1024*1024) + defer func() { + if err := writer.Flush(); err != nil { + fmt.Printf("final flush error: %v\n", err) + } + }() + + // ---- Throughput optimizations ---- + const ( + maxBatchMsgs = 128 // cap number of msgs per batch + maxBatchBytes = 1 * 1024 * 1024 // cap bytes per batch + idleFlush = 2 * time.Millisecond // small idle flush timer + ) + var ( + hdr [4]byte + batchBuf = &bytes.Buffer{} + bufPool = sync.Pool{New: func() any { return make([]byte, 64*1024) }} + timer = time.NewTimer(idleFlush) + timerAlive = true + ) + + stopTimer := func() { + if timerAlive && timer.Stop() { + // drain if fired + select { + case <-timer.C: + default: + } + } + timerAlive = false + } + resetTimer := func() { + if !timerAlive { + timer.Reset(idleFlush) + timerAlive = true + } else { + // re-arm + stopTimer() + timer.Reset(idleFlush) + timerAlive = true + } + } + + // Main loop: drain out channel into a single write. + for { + // Block for at least one message or close. + msg, ok := <-out + if !ok { + _ = writer.Flush() + return + } + + batchBuf.Reset() + bytesInBatch := 0 + msgsInBatch := 0 + + // Start with the message we just popped. + { + m := pb.Message{ + Identifier: &pb.Identifier{Key: msg.Identifier.Key()}, + Payload: msg.Payload, + } + + // Use pooled scratch to avoid per-message allocs in Marshal. + scratch := bufPool.Get().([]byte)[:0] + b, err := proto.MarshalOptions{}.MarshalAppend(scratch, &m) + if err != nil { + fmt.Printf("proto marshal error: %v\n", err) + bufPool.Put(scratch[:0]) + // skip message + } else { + binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) + _, _ = batchBuf.Write(hdr[:]) + _, _ = batchBuf.Write(b) + bytesInBatch += 4 + len(b) + msgsInBatch++ + bufPool.Put(b[:0]) + } + } + + // Opportunistically drain without blocking. + drain := true + resetTimer() + for drain && msgsInBatch < maxBatchMsgs && bytesInBatch < maxBatchBytes { + select { + case msg, ok = <-out: + if !ok { + // peer closed while batching; flush what we have. + if batchBuf.Len() > 0 { + if _, err := writer.Write(batchBuf.Bytes()); err != nil { + if err == io.EOF { + return + } + fmt.Printf("write error: %v\n", err) + return + } + if err := writer.Flush(); err != nil { + fmt.Printf("flush error: %v\n", err) + } + } + return + } + m := pb.Message{ + Identifier: &pb.Identifier{Key: msg.Identifier.Key()}, + Payload: msg.Payload, + } + scratch := bufPool.Get().([]byte)[:0] + b, err := proto.MarshalOptions{}.MarshalAppend(scratch, &m) + if err != nil { + fmt.Printf("proto marshal error: %v\n", err) + bufPool.Put(scratch[:0]) + continue + } + binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) + _, _ = batchBuf.Write(hdr[:]) + _, _ = batchBuf.Write(b) + bytesInBatch += 4 + len(b) + msgsInBatch++ + bufPool.Put(b[:0]) + case <-timer.C: + timerAlive = false + // idle window hit; stop draining further this round + drain = false + } + } + + // Single write for the whole batch. + // Avoid per-message SetWriteDeadline. Let TCP handle buffering. + if _, err := writer.Write(batchBuf.Bytes()); err != nil { + if err == io.EOF { + return + } + fmt.Printf("write error: %v\n", err) + return + } + + // Flush when batch is sizable or we saw the idle timer. + // This keeps latency low without flushing every message. + if msgsInBatch >= maxBatchMsgs || + bytesInBatch >= maxBatchBytes || + !timerAlive { + if err := writer.Flush(); err != nil { + fmt.Printf("flush error: %v\n", err) + return + } + } + } +} + +// trimLineEnding trims a single trailing '\n' and optional '\r' before it. +func trimLineEnding(b []byte) []byte { + n := len(b) + if n == 0 { + return b + } + if b[n-1] == '\n' { + n-- + if n > 0 && b[n-1] == '\r' { + n-- + } + return b[:n] + } + return b +} diff --git a/services/data_service/internal/server/socket_streaming_server.go.bak b/services/data_service/internal/server/socket_streaming_server.go.bak deleted file mode 100644 index 5401027..0000000 --- a/services/data_service/internal/server/socket_streaming_server.go.bak +++ /dev/null @@ -1,136 +0,0 @@ -package server - -import ( - "bufio" - "encoding/binary" - "fmt" - "io" - "net" - "strings" - "time" - - "github.com/google/uuid" - pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" - "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager" - "google.golang.org/protobuf/proto" -) - -type SocketStreamingServer struct { - manager *manager.Manager -} - -func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer { - return &SocketStreamingServer{manager: m} -} - -func (s *SocketStreamingServer) Serve(lis net.Listener) error { - for { - conn, err := lis.Accept() - if err != nil { - fmt.Printf("accept error: %v\n", err) - continue - } - go s.handleConnection(conn) - } -} - -func (s *SocketStreamingServer) handleConnection(conn net.Conn) { - defer func() { - if err := conn.Close(); err != nil { - fmt.Printf("conn close error: %v\n", err) - } else { - fmt.Println("connection closed") - } - }() - - if tc, ok := conn.(*net.TCPConn); ok { - _ = tc.SetNoDelay(true) - _ = tc.SetWriteBuffer(512 * 1024) - _ = tc.SetReadBuffer(512 * 1024) - _ = tc.SetKeepAlive(true) - _ = tc.SetKeepAlivePeriod(30 * time.Second) - } - - reader := bufio.NewReader(conn) - - raw, err := reader.ReadString('\n') - if err != nil { - fmt.Printf("read stream UUID error: %v\n", err) - _, _ = fmt.Fprint(conn, "Failed to read stream UUID\n") - return - } - streamUUIDStr := strings.TrimSpace(raw) - streamUUID, err := uuid.Parse(streamUUIDStr) - if err != nil { - _, _ = fmt.Fprint(conn, "Invalid stream UUID\n") - return - } - - outCh, err := s.manager.ConnectClientStream(streamUUID) - if err != nil { - _, _ = fmt.Fprintf(conn, "Failed to connect to stream: %v\n", err) - return - } - defer s.manager.DisconnectClientStream(streamUUID) - - writer := bufio.NewWriterSize(conn, 256*1024) - defer func(w *bufio.Writer) { - if err := w.Flush(); err != nil { - fmt.Printf("final flush error: %v\n", err) - } - }(writer) - - const flushEvery = 32 - batch := 0 - - for msg := range outCh { - m := pb.Message{ - Identifier: &pb.Identifier{Key: msg.Identifier.Key()}, - Payload: msg.Payload, - Encoding: string(msg.Encoding), - } - - size := proto.Size(&m) - buf := make([]byte, 0, size) - b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m) - if err != nil { - fmt.Printf("proto marshal error: %v\n", err) - continue - } - - var hdr [4]byte - if len(b) > int(^uint32(0)) { - fmt.Printf("message too large: %d bytes\n", len(b)) - continue - } - binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) - - if _, err := writer.Write(hdr[:]); err != nil { - if err == io.EOF { - return - } - fmt.Printf("write len error: %v\n", err) - return - } - if _, err := writer.Write(b); err != nil { - if err == io.EOF { - return - } - fmt.Printf("write body error: %v\n", err) - return - } - - batch++ - if batch >= flushEvery { - if err := writer.Flush(); err != nil { - fmt.Printf("flush error: %v\n", err) - return - } - batch = 0 - } - } - - if err := writer.Flush(); err != nil { - fmt.Printf("final flush error: %v\n", err) - } -}