Implement socket streaming server and optimize connection handling; adjust test provider timing
This commit is contained in:
@@ -57,7 +57,7 @@ func main() {
|
|||||||
// Setup
|
// Setup
|
||||||
r := router.NewRouter(2048)
|
r := router.NewRouter(2048)
|
||||||
m := manager.NewManager(r)
|
m := manager.NewManager(r)
|
||||||
testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*50)
|
testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*13)
|
||||||
if err := m.AddProvider("test_provider", testProvider); err != nil {
|
if err := m.AddProvider("test_provider", testProvider); err != nil {
|
||||||
slog.Error("add provider failed", "err", err)
|
slog.Error("add provider failed", "err", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@@ -97,5 +97,20 @@ func main() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Socket Streaming Server
|
||||||
|
socketStreamingServer := server.NewSocketStreamingServer(m)
|
||||||
|
go func() {
|
||||||
|
lis, err := net.Listen("tcp", ":50060")
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("listen failed", "cmp", "socket-streaming", "addr", ":50060", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("listening", "cmp", "socket-streaming", "addr", ":50060")
|
||||||
|
if err := socketStreamingServer.Serve(lis); err != nil {
|
||||||
|
slog.Error("serve failed", "cmp", "socket-streaming", "err", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
select {}
|
select {}
|
||||||
}
|
}
|
||||||
|
|||||||
331
services/data_service/cmd/stream_tap_v2/main.go
Normal file
331
services/data_service/cmd/stream_tap_v2/main.go
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/connectivity"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type idsFlag []string
|
||||||
|
|
||||||
|
func (i *idsFlag) String() string { return strings.Join(*i, ",") }
|
||||||
|
func (i *idsFlag) Set(v string) error {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
*i = append(*i, v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseIDPair(s string) (provider, subject string, err error) {
|
||||||
|
parts := strings.SplitN(s, ":", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return "", "", fmt.Errorf("want provider:subject, got %q", s)
|
||||||
|
}
|
||||||
|
return parts[0], parts[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toIdentifierKey(input string) (string, error) {
|
||||||
|
if strings.Contains(input, "::") {
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
prov, subj, err := parseIDPair(input)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "raw::" + strings.ToLower(prov) + "." + subj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||||
|
for {
|
||||||
|
s := conn.GetState()
|
||||||
|
if s == connectivity.Ready {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if !conn.WaitForStateChange(ctx, s) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("WaitForStateChange returned without state change")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type streamStats struct {
|
||||||
|
TotalMsgs int64
|
||||||
|
TotalBytes int64
|
||||||
|
TickMsgs int64
|
||||||
|
TickBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type stats struct {
|
||||||
|
TotalMsgs int64
|
||||||
|
TotalBytes int64
|
||||||
|
ByStream map[string]*streamStats
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var ids idsFlag
|
||||||
|
var ctlAddr string
|
||||||
|
var strAddr string
|
||||||
|
var timeout time.Duration
|
||||||
|
var refresh time.Duration
|
||||||
|
|
||||||
|
flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable")
|
||||||
|
flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address")
|
||||||
|
flag.StringVar(&strAddr, "str", "127.0.0.1:50060", "socket streaming address host:port")
|
||||||
|
flag.DurationVar(&timeout, "timeout", 10*time.Second, "start/config/connect timeout")
|
||||||
|
flag.DurationVar(&refresh, "refresh", 1*time.Second, "dashboard refresh interval")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
if len(ids) == 0 {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)")
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Control channel
|
||||||
|
ccCtl, err := grpc.NewClient(
|
||||||
|
ctlAddr,
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "new control client: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer ccCtl.Close()
|
||||||
|
ccCtl.Connect()
|
||||||
|
|
||||||
|
ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout)
|
||||||
|
if err := waitReady(ctlConnCtx, ccCtl); err != nil {
|
||||||
|
cancelCtlConn()
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "connect control: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
cancelCtlConn()
|
||||||
|
|
||||||
|
ctl := pb.NewDataServiceControlClient(ccCtl)
|
||||||
|
|
||||||
|
// Start stream
|
||||||
|
ctxStart, cancelStart := context.WithTimeout(ctx, timeout)
|
||||||
|
startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{})
|
||||||
|
cancelStart()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "StartStream: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
streamUUID := startResp.GetStreamUuid()
|
||||||
|
fmt.Printf("stream: %s\n", streamUUID)
|
||||||
|
|
||||||
|
// Configure identifiers
|
||||||
|
var pbIDs []*pb.Identifier
|
||||||
|
orderedIDs := make([]string, 0, len(ids))
|
||||||
|
for _, s := range ids {
|
||||||
|
key, err := toIdentifierKey(s)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "bad --id: %v\n", err)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
pbIDs = append(pbIDs, &pb.Identifier{Key: key})
|
||||||
|
orderedIDs = append(orderedIDs, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout)
|
||||||
|
_, err = ctl.ConfigureStream(ctxCfg, &pb.ConfigureStreamRequest{
|
||||||
|
StreamUuid: streamUUID,
|
||||||
|
Identifiers: pbIDs,
|
||||||
|
})
|
||||||
|
cancelCfg()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "ConfigureStream: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Printf("configured %d identifiers\n", len(pbIDs))
|
||||||
|
|
||||||
|
// Socket streaming connection
|
||||||
|
d := net.Dialer{Timeout: timeout, KeepAlive: 30 * time.Second}
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", strAddr)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "dial socket: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if tc, ok := conn.(*net.TCPConn); ok {
|
||||||
|
_ = tc.SetNoDelay(true)
|
||||||
|
_ = tc.SetWriteBuffer(512 * 1024)
|
||||||
|
_ = tc.SetReadBuffer(512 * 1024)
|
||||||
|
_ = tc.SetKeepAlive(true)
|
||||||
|
_ = tc.SetKeepAlivePeriod(30 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the stream UUID followed by '\n' per socket server contract.
|
||||||
|
if _, err := io.WriteString(conn, streamUUID+"\n"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "send stream UUID: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
fmt.Println("connected; streaming… (Ctrl-C to quit)")
|
||||||
|
|
||||||
|
// Receiver goroutine → channel
|
||||||
|
type msgWrap struct {
|
||||||
|
idKey string
|
||||||
|
size int
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
msgCh := make(chan msgWrap, 1024)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(msgCh)
|
||||||
|
r := bufio.NewReaderSize(conn, 256*1024)
|
||||||
|
var hdr [4]byte
|
||||||
|
for {
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(120 * time.Second)); err != nil {
|
||||||
|
msgCh <- msgWrap{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.ReadFull(r, hdr[:]); err != nil {
|
||||||
|
msgCh <- msgWrap{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n := binary.BigEndian.Uint32(hdr[:])
|
||||||
|
if n == 0 || n > 64*1024*1024 {
|
||||||
|
msgCh <- msgWrap{err: fmt.Errorf("invalid frame length: %d", n)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
buf := make([]byte, n)
|
||||||
|
if _, err := io.ReadFull(r, buf); err != nil {
|
||||||
|
msgCh <- msgWrap{err: err}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var m pb.Message
|
||||||
|
if err := proto.Unmarshal(buf, &m); err != nil {
|
||||||
|
msgCh <- msgWrap{err: fmt.Errorf("unmarshal: %w", err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := m.GetIdentifier().GetKey()
|
||||||
|
msgCh <- msgWrap{idKey: id, size: len(m.GetPayload())}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Stats and dashboard
|
||||||
|
st := &stats{ByStream: make(map[string]*streamStats)}
|
||||||
|
seen := make(map[string]bool, len(orderedIDs))
|
||||||
|
for _, id := range orderedIDs {
|
||||||
|
seen[id] = true
|
||||||
|
}
|
||||||
|
tick := time.NewTicker(refresh)
|
||||||
|
defer tick.Stop()
|
||||||
|
|
||||||
|
clear := func() { fmt.Print("\033[H\033[2J") }
|
||||||
|
header := func() {
|
||||||
|
fmt.Printf("stream: %s now: %s refresh: %s\n",
|
||||||
|
streamUUID, time.Now().Format(time.RFC3339), refresh)
|
||||||
|
fmt.Println("--------------------------------------------------------------------------------------")
|
||||||
|
fmt.Printf("%-56s %10s %14s %12s %16s\n", "identifier", "msgs/s", "bytes/s", "total", "total_bytes")
|
||||||
|
fmt.Println("--------------------------------------------------------------------------------------")
|
||||||
|
}
|
||||||
|
|
||||||
|
printAndReset := func() {
|
||||||
|
clear()
|
||||||
|
header()
|
||||||
|
|
||||||
|
var totMsgsPS, totBytesPS float64
|
||||||
|
for _, id := range orderedIDs {
|
||||||
|
s, ok := st.ByStream[id]
|
||||||
|
var msgsPS, bytesPS float64
|
||||||
|
var totMsgs, totBytes int64
|
||||||
|
if ok {
|
||||||
|
msgsPS = float64(atomic.SwapInt64(&s.TickMsgs, 0)) / refresh.Seconds()
|
||||||
|
bytesPS = float64(atomic.SwapInt64(&s.TickBytes, 0)) / refresh.Seconds()
|
||||||
|
totMsgs = atomic.LoadInt64(&s.TotalMsgs)
|
||||||
|
totBytes = atomic.LoadInt64(&s.TotalBytes)
|
||||||
|
}
|
||||||
|
totMsgsPS += msgsPS
|
||||||
|
totBytesPS += bytesPS
|
||||||
|
fmt.Printf("%-56s %10d %14d %12d %16d\n",
|
||||||
|
id,
|
||||||
|
int64(math.Round(msgsPS)),
|
||||||
|
int64(math.Round(bytesPS)),
|
||||||
|
totMsgs,
|
||||||
|
totBytes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("--------------------------------------------------------------------------------------")
|
||||||
|
fmt.Printf("%-56s %10d %14d %12d %16d\n",
|
||||||
|
"TOTAL",
|
||||||
|
int64(math.Round(totMsgsPS)),
|
||||||
|
int64(math.Round(totBytesPS)),
|
||||||
|
atomic.LoadInt64(&st.TotalMsgs),
|
||||||
|
atomic.LoadInt64(&st.TotalBytes),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
fmt.Println("\nshutting down")
|
||||||
|
return
|
||||||
|
|
||||||
|
case <-tick.C:
|
||||||
|
printAndReset()
|
||||||
|
|
||||||
|
case mw, ok := <-msgCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if mw.err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ne, ok := mw.err.(net.Error); ok && ne.Timeout() {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "recv timeout")
|
||||||
|
} else if mw.err == io.EOF {
|
||||||
|
_, _ = fmt.Fprintln(os.Stderr, "server closed stream")
|
||||||
|
} else {
|
||||||
|
_, _ = fmt.Fprintf(os.Stderr, "recv: %v\n", mw.err)
|
||||||
|
}
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !seen[mw.idKey] {
|
||||||
|
seen[mw.idKey] = true
|
||||||
|
orderedIDs = append(orderedIDs, mw.idKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
atomic.AddInt64(&st.TotalMsgs, 1)
|
||||||
|
atomic.AddInt64(&st.TotalBytes, int64(mw.size))
|
||||||
|
|
||||||
|
ss := st.ByStream[mw.idKey]
|
||||||
|
if ss == nil {
|
||||||
|
ss = &streamStats{}
|
||||||
|
st.ByStream[mw.idKey] = ss
|
||||||
|
}
|
||||||
|
atomic.AddInt64(&ss.TotalMsgs, 1)
|
||||||
|
atomic.AddInt64(&ss.TotalBytes, int64(mw.size))
|
||||||
|
atomic.AddInt64(&ss.TickMsgs, 1)
|
||||||
|
atomic.AddInt64(&ss.TickBytes, int64(mw.size))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
242
services/data_service/internal/server/socket_streaming_server.go
Normal file
242
services/data_service/internal/server/socket_streaming_server.go
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
|
||||||
|
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SocketStreamingServer struct {
|
||||||
|
manager *manager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
|
||||||
|
return &SocketStreamingServer{manager: m}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SocketStreamingServer) Serve(lis net.Listener) error {
|
||||||
|
for {
|
||||||
|
conn, err := lis.Accept()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("accept error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go s.handleConnection(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
fmt.Printf("conn close error: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Println("connection closed")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if tc, ok := conn.(*net.TCPConn); ok {
|
||||||
|
_ = tc.SetNoDelay(true) // low latency
|
||||||
|
_ = tc.SetWriteBuffer(2 * 1024 * 1024) // bigger kernel sndbuf
|
||||||
|
_ = tc.SetReadBuffer(256 * 1024)
|
||||||
|
_ = tc.SetKeepAlive(true)
|
||||||
|
_ = tc.SetKeepAlivePeriod(30 * time.Second)
|
||||||
|
// Note: avoid SetLinger>0; default is fine.
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReaderSize(conn, 64*1024)
|
||||||
|
line, err := reader.ReadBytes('\n')
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("read stream UUID error: %v\n", err)
|
||||||
|
_, _ = fmt.Fprint(conn, "Failed to read stream UUID\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
streamUUID, err := uuid.Parse(string(trimLineEnding(line)))
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprint(conn, "Invalid stream UUID\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Give the socket server room before router drops. Make out chan larger.
|
||||||
|
// Tune per your pressure. (in=256, out=8192 as example)
|
||||||
|
_, out, err := s.manager.AttachClient(streamUUID, 256, 8192)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(conn, "Failed to attach to stream: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = s.manager.DetachClient(streamUUID) }()
|
||||||
|
|
||||||
|
// Large bufio writer to reduce syscalls.
|
||||||
|
writer := bufio.NewWriterSize(conn, 1*1024*1024)
|
||||||
|
defer func() {
|
||||||
|
if err := writer.Flush(); err != nil {
|
||||||
|
fmt.Printf("final flush error: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// ---- Throughput optimizations ----
|
||||||
|
const (
|
||||||
|
maxBatchMsgs = 128 // cap number of msgs per batch
|
||||||
|
maxBatchBytes = 1 * 1024 * 1024 // cap bytes per batch
|
||||||
|
idleFlush = 2 * time.Millisecond // small idle flush timer
|
||||||
|
)
|
||||||
|
var (
|
||||||
|
hdr [4]byte
|
||||||
|
batchBuf = &bytes.Buffer{}
|
||||||
|
bufPool = sync.Pool{New: func() any { return make([]byte, 64*1024) }}
|
||||||
|
timer = time.NewTimer(idleFlush)
|
||||||
|
timerAlive = true
|
||||||
|
)
|
||||||
|
|
||||||
|
stopTimer := func() {
|
||||||
|
if timerAlive && timer.Stop() {
|
||||||
|
// drain if fired
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
timerAlive = false
|
||||||
|
}
|
||||||
|
resetTimer := func() {
|
||||||
|
if !timerAlive {
|
||||||
|
timer.Reset(idleFlush)
|
||||||
|
timerAlive = true
|
||||||
|
} else {
|
||||||
|
// re-arm
|
||||||
|
stopTimer()
|
||||||
|
timer.Reset(idleFlush)
|
||||||
|
timerAlive = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Main loop: drain out channel into a single write.
|
||||||
|
for {
|
||||||
|
// Block for at least one message or close.
|
||||||
|
msg, ok := <-out
|
||||||
|
if !ok {
|
||||||
|
_ = writer.Flush()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
batchBuf.Reset()
|
||||||
|
bytesInBatch := 0
|
||||||
|
msgsInBatch := 0
|
||||||
|
|
||||||
|
// Start with the message we just popped.
|
||||||
|
{
|
||||||
|
m := pb.Message{
|
||||||
|
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use pooled scratch to avoid per-message allocs in Marshal.
|
||||||
|
scratch := bufPool.Get().([]byte)[:0]
|
||||||
|
b, err := proto.MarshalOptions{}.MarshalAppend(scratch, &m)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("proto marshal error: %v\n", err)
|
||||||
|
bufPool.Put(scratch[:0])
|
||||||
|
// skip message
|
||||||
|
} else {
|
||||||
|
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
|
||||||
|
_, _ = batchBuf.Write(hdr[:])
|
||||||
|
_, _ = batchBuf.Write(b)
|
||||||
|
bytesInBatch += 4 + len(b)
|
||||||
|
msgsInBatch++
|
||||||
|
bufPool.Put(b[:0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opportunistically drain without blocking.
|
||||||
|
drain := true
|
||||||
|
resetTimer()
|
||||||
|
for drain && msgsInBatch < maxBatchMsgs && bytesInBatch < maxBatchBytes {
|
||||||
|
select {
|
||||||
|
case msg, ok = <-out:
|
||||||
|
if !ok {
|
||||||
|
// peer closed while batching; flush what we have.
|
||||||
|
if batchBuf.Len() > 0 {
|
||||||
|
if _, err := writer.Write(batchBuf.Bytes()); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("write error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := writer.Flush(); err != nil {
|
||||||
|
fmt.Printf("flush error: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m := pb.Message{
|
||||||
|
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
|
||||||
|
Payload: msg.Payload,
|
||||||
|
}
|
||||||
|
scratch := bufPool.Get().([]byte)[:0]
|
||||||
|
b, err := proto.MarshalOptions{}.MarshalAppend(scratch, &m)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("proto marshal error: %v\n", err)
|
||||||
|
bufPool.Put(scratch[:0])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
|
||||||
|
_, _ = batchBuf.Write(hdr[:])
|
||||||
|
_, _ = batchBuf.Write(b)
|
||||||
|
bytesInBatch += 4 + len(b)
|
||||||
|
msgsInBatch++
|
||||||
|
bufPool.Put(b[:0])
|
||||||
|
case <-timer.C:
|
||||||
|
timerAlive = false
|
||||||
|
// idle window hit; stop draining further this round
|
||||||
|
drain = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single write for the whole batch.
|
||||||
|
// Avoid per-message SetWriteDeadline. Let TCP handle buffering.
|
||||||
|
if _, err := writer.Write(batchBuf.Bytes()); err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("write error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush when batch is sizable or we saw the idle timer.
|
||||||
|
// This keeps latency low without flushing every message.
|
||||||
|
if msgsInBatch >= maxBatchMsgs ||
|
||||||
|
bytesInBatch >= maxBatchBytes ||
|
||||||
|
!timerAlive {
|
||||||
|
if err := writer.Flush(); err != nil {
|
||||||
|
fmt.Printf("flush error: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// trimLineEnding trims a single trailing '\n' and optional '\r' before it.
|
||||||
|
func trimLineEnding(b []byte) []byte {
|
||||||
|
n := len(b)
|
||||||
|
if n == 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b[n-1] == '\n' {
|
||||||
|
n--
|
||||||
|
if n > 0 && b[n-1] == '\r' {
|
||||||
|
n--
|
||||||
|
}
|
||||||
|
return b[:n]
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
|
|
||||||
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SocketStreamingServer struct {
|
|
||||||
manager *manager.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
|
|
||||||
return &SocketStreamingServer{manager: m}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SocketStreamingServer) Serve(lis net.Listener) error {
|
|
||||||
for {
|
|
||||||
conn, err := lis.Accept()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("accept error: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go s.handleConnection(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
fmt.Printf("conn close error: %v\n", err)
|
|
||||||
} else {
|
|
||||||
fmt.Println("connection closed")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if tc, ok := conn.(*net.TCPConn); ok {
|
|
||||||
_ = tc.SetNoDelay(true)
|
|
||||||
_ = tc.SetWriteBuffer(512 * 1024)
|
|
||||||
_ = tc.SetReadBuffer(512 * 1024)
|
|
||||||
_ = tc.SetKeepAlive(true)
|
|
||||||
_ = tc.SetKeepAlivePeriod(30 * time.Second)
|
|
||||||
}
|
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
|
||||||
|
|
||||||
raw, err := reader.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("read stream UUID error: %v\n", err)
|
|
||||||
_, _ = fmt.Fprint(conn, "Failed to read stream UUID\n")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
streamUUIDStr := strings.TrimSpace(raw)
|
|
||||||
streamUUID, err := uuid.Parse(streamUUIDStr)
|
|
||||||
if err != nil {
|
|
||||||
_, _ = fmt.Fprint(conn, "Invalid stream UUID\n")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
outCh, err := s.manager.ConnectClientStream(streamUUID)
|
|
||||||
if err != nil {
|
|
||||||
_, _ = fmt.Fprintf(conn, "Failed to connect to stream: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer s.manager.DisconnectClientStream(streamUUID)
|
|
||||||
|
|
||||||
writer := bufio.NewWriterSize(conn, 256*1024)
|
|
||||||
defer func(w *bufio.Writer) {
|
|
||||||
if err := w.Flush(); err != nil {
|
|
||||||
fmt.Printf("final flush error: %v\n", err)
|
|
||||||
}
|
|
||||||
}(writer)
|
|
||||||
|
|
||||||
const flushEvery = 32
|
|
||||||
batch := 0
|
|
||||||
|
|
||||||
for msg := range outCh {
|
|
||||||
m := pb.Message{
|
|
||||||
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
|
|
||||||
Payload: msg.Payload,
|
|
||||||
Encoding: string(msg.Encoding),
|
|
||||||
}
|
|
||||||
|
|
||||||
size := proto.Size(&m)
|
|
||||||
buf := make([]byte, 0, size)
|
|
||||||
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("proto marshal error: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var hdr [4]byte
|
|
||||||
if len(b) > int(^uint32(0)) {
|
|
||||||
fmt.Printf("message too large: %d bytes\n", len(b))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
|
|
||||||
|
|
||||||
if _, err := writer.Write(hdr[:]); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("write len error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, err := writer.Write(b); err != nil {
|
|
||||||
if err == io.EOF {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("write body error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
batch++
|
|
||||||
if batch >= flushEvery {
|
|
||||||
if err := writer.Flush(); err != nil {
|
|
||||||
fmt.Printf("flush error: %v\n", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
batch = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := writer.Flush(); err != nil {
|
|
||||||
fmt.Printf("final flush error: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user