From 2484a33945dc4d9dda566f491d53ece49fe620dc Mon Sep 17 00:00:00 2001 From: Phillip Michelsen Date: Sun, 17 Aug 2025 06:16:49 +0000 Subject: [PATCH] Refactor identifier handling: replace Provider and Subject fields with a single Key field in Identifier struct, update related message structures, and adjust parsing logic --- pkg/pb/data_service/data_service.pb.go | 28 +-- pkg/pb/data_service/data_service.proto | 25 +-- services/data_service/cmd/stream_tap/main.go | 42 ++-- .../internal/domain/identifier.go | 181 +++++++++++++++++- .../data_service/internal/manager/manager.go | 88 +++++---- .../provider/binance/futures_websocket.go | 16 +- .../internal/server/gprc_control_server.go | 23 +-- .../internal/server/grpc_streaming_server.go | 20 +- .../server/socket_streaming_server.go | 32 ++-- 9 files changed, 290 insertions(+), 165 deletions(-) diff --git a/pkg/pb/data_service/data_service.pb.go b/pkg/pb/data_service/data_service.pb.go index abe760f..a514b76 100644 --- a/pkg/pb/data_service/data_service.pb.go +++ b/pkg/pb/data_service/data_service.pb.go @@ -21,11 +21,9 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// Domain Models type Identifier struct { state protoimpl.MessageState `protogen:"open.v1"` - Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // e.g., "binance" - Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` // e.g., "BTCUSDT" + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -60,16 +58,9 @@ func (*Identifier) Descriptor() ([]byte, []int) { return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{0} } -func (x *Identifier) GetProvider() string { +func (x *Identifier) GetKey() string { if x != nil { - return x.Provider - } - return "" -} - -func (x *Identifier) GetSubject() string { - if x != nil { - return x.Subject + return x.Key } return "" } @@ -77,8 +68,8 @@ func (x *Identifier) GetSubject() string { type Message struct { state protoimpl.MessageState `protogen:"open.v1"` Identifier *Identifier `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"` - Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` // JSON-encoded data - Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"` // e.g., "json", "protobuf" + Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` + Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -134,7 +125,6 @@ func (x *Message) GetEncoding() string { return "" } -// Control Requests and Responses type StartStreamRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -383,7 +373,6 @@ func (*StopStreamResponse) Descriptor() ([]byte, []int) { return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{7} } -// Stream Requests and Responses type ConnectStreamRequest struct { state protoimpl.MessageState `protogen:"open.v1"` StreamUuid string `protobuf:"bytes,1,opt,name=stream_uuid,json=streamUuid,proto3" json:"stream_uuid,omitempty"` @@ -432,11 +421,10 @@ var File_pkg_pb_data_service_data_service_proto protoreflect.FileDescriptor const file_pkg_pb_data_service_data_service_proto_rawDesc = "" + "\n" + - "&pkg/pb/data_service/data_service.proto\x12\fdata_service\"B\n" + + "&pkg/pb/data_service/data_service.proto\x12\fdata_service\"\x1e\n" + "\n" + - "Identifier\x12\x1a\n" + - "\bprovider\x18\x01 \x01(\tR\bprovider\x12\x18\n" + - "\asubject\x18\x02 \x01(\tR\asubject\"y\n" + + "Identifier\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\"y\n" + "\aMessage\x128\n" + "\n" + "identifier\x18\x01 \x01(\v2\x18.data_service.IdentifierR\n" + diff --git a/pkg/pb/data_service/data_service.proto b/pkg/pb/data_service/data_service.proto index a6b2b0f..a99eb20 100644 --- a/pkg/pb/data_service/data_service.proto +++ b/pkg/pb/data_service/data_service.proto @@ -14,39 +14,26 @@ service DataServiceStreaming { rpc ConnectStream(ConnectStreamRequest) returns (stream Message); } -// Domain Models message Identifier { - string provider = 1; // e.g., "binance" - string subject = 2; // e.g., "BTCUSDT" + string key = 1; } message Message { Identifier identifier = 1; - bytes payload = 2; // JSON-encoded data - string encoding = 3; // e.g., "json", "protobuf" + bytes payload = 2; + string encoding = 3; } -// Control Requests and Responses message StartStreamRequest {} - -message StartStreamResponse { - string stream_uuid = 1; -} +message StartStreamResponse { string stream_uuid = 1; } message ConfigureStreamRequest { string stream_uuid = 1; repeated Identifier identifiers = 2; } - message ConfigureStreamResponse {} -message StopStreamRequest { - string stream_uuid = 1; -} - +message StopStreamRequest { string stream_uuid = 1; } message StopStreamResponse {} -// Stream Requests and Responses -message ConnectStreamRequest { - string stream_uuid = 1; -} +message ConnectStreamRequest { string stream_uuid = 1; } diff --git a/services/data_service/cmd/stream_tap/main.go b/services/data_service/cmd/stream_tap/main.go index 32d4c52..67e8cbe 100644 --- a/services/data_service/cmd/stream_tap/main.go +++ b/services/data_service/cmd/stream_tap/main.go @@ -28,7 +28,7 @@ func (i *idsFlag) Set(v string) error { return nil } -func parseID(s string) (provider, subject string, err error) { +func parseIDPair(s string) (provider, subject string, err error) { parts := strings.SplitN(s, ":", 2) if len(parts) != 2 || parts[0] == "" || parts[1] == "" { return "", "", fmt.Errorf("want provider:subject, got %q", s) @@ -36,13 +36,24 @@ func parseID(s string) (provider, subject string, err error) { return parts[0], parts[1], nil } +func toIdentifierKey(input string) (string, error) { + if strings.Contains(input, "::") { + return input, nil + } + prov, subj, err := parseIDPair(input) + if err != nil { + return "", err + } + return "raw::" + strings.ToLower(prov) + "." + subj, nil +} + func prettyOrRaw(b []byte, pretty bool) string { if !pretty || len(b) == 0 { return string(b) } var tmp any if err := json.Unmarshal(b, &tmp); err != nil { - return string(b) // not JSON + return string(b) } out, err := json.MarshalIndent(tmp, "", " ") if err != nil { @@ -51,7 +62,6 @@ func prettyOrRaw(b []byte, pretty bool) string { return string(out) } -// waitReady blocks until conn is READY or ctx times out/cancels. func waitReady(ctx context.Context, conn *grpc.ClientConn) error { for { s := conn.GetState() @@ -67,7 +77,6 @@ func waitReady(ctx context.Context, conn *grpc.ClientConn) error { } } -//goland:noinspection GoUnhandledErrorResult func main() { var ids idsFlag var ctlAddr string @@ -75,7 +84,7 @@ func main() { var pretty bool var timeout time.Duration - flag.Var(&ids, "id", "identifier in form provider:subject (repeatable)") + flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable") flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address") flag.StringVar(&strAddr, "str", "127.0.0.1:50052", "gRPC streaming address") flag.BoolVar(&pretty, "pretty", true, "pretty-print JSON payloads when possible") @@ -83,15 +92,13 @@ func main() { flag.Parse() if len(ids) == 0 { - fmt.Fprintln(os.Stderr, "provide at least one --id provider:subject") + fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)") os.Exit(2) } - // Ctrl-C handling ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() - // ----- Control client ----- ccCtl, err := grpc.NewClient( ctlAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -101,8 +108,7 @@ func main() { os.Exit(1) } defer ccCtl.Close() - - ccCtl.Connect() // start dialing in background + ccCtl.Connect() ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout) if err := waitReady(ctlConnCtx, ccCtl); err != nil { @@ -114,7 +120,6 @@ func main() { ctl := pb.NewDataServiceControlClient(ccCtl) - // Start stream ctxStart, cancelStart := context.WithTimeout(ctx, timeout) startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{}) cancelStart() @@ -125,15 +130,14 @@ func main() { streamUUID := startResp.GetStreamUuid() fmt.Printf("stream: %s\n", streamUUID) - // Configure var pbIDs []*pb.Identifier for _, s := range ids { - prov, subj, err := parseID(s) + key, err := toIdentifierKey(s) if err != nil { fmt.Fprintf(os.Stderr, "bad --id: %v\n", err) os.Exit(2) } - pbIDs = append(pbIDs, &pb.Identifier{Provider: prov, Subject: subj}) + pbIDs = append(pbIDs, &pb.Identifier{Key: key}) } ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout) @@ -148,7 +152,6 @@ func main() { } fmt.Printf("configured %d identifiers\n", len(pbIDs)) - // ----- Streaming client ----- ccStr, err := grpc.NewClient( strAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -158,7 +161,6 @@ func main() { os.Exit(1) } defer ccStr.Close() - ccStr.Connect() strConnCtx, cancelStrConn := context.WithTimeout(ctx, timeout) @@ -171,7 +173,6 @@ func main() { str := pb.NewDataServiceStreamingClient(ccStr) - // This context lives until Ctrl-C streamCtx, streamCancel := context.WithCancel(ctx) defer streamCancel() @@ -182,7 +183,6 @@ func main() { } fmt.Println("connected; streaming… (Ctrl-C to quit)") - // Receive loop until Ctrl-C for { select { case <-ctx.Done(): @@ -192,14 +192,14 @@ func main() { msg, err := stream.Recv() if err != nil { if ctx.Err() != nil { - return // normal shutdown + return } fmt.Fprintf(os.Stderr, "recv: %v\n", err) os.Exit(1) } id := msg.GetIdentifier() - fmt.Printf("[%s] %s bytes=%d enc=%s t=%s\n", - id.GetProvider(), id.GetSubject(), len(msg.GetPayload()), msg.GetEncoding(), + fmt.Printf("[%s] bytes=%d enc=%s t=%s\n", + id.GetKey(), len(msg.GetPayload()), msg.GetEncoding(), time.Now().Format(time.RFC3339Nano), ) fmt.Println(prettyOrRaw(msg.GetPayload(), pretty)) diff --git a/services/data_service/internal/domain/identifier.go b/services/data_service/internal/domain/identifier.go index 3a61744..867cc31 100644 --- a/services/data_service/internal/domain/identifier.go +++ b/services/data_service/internal/domain/identifier.go @@ -1,6 +1,181 @@ package domain -type Identifier struct { - Provider string - Subject string +import ( + "errors" + "fmt" + "regexp" + "sort" + "strings" +) + +const ( + prefixRaw = "raw::" + prefixInternal = "internal::" +) + +// Identifier is a canonical representation of a data stream identifier. +type Identifier struct{ key string } + +func (id Identifier) IsRaw() bool { return strings.HasPrefix(id.key, prefixRaw) } +func (id Identifier) IsInternal() bool { return strings.HasPrefix(id.key, prefixInternal) } +func (id Identifier) Key() string { return id.key } + +func (id Identifier) ProviderSubject() (provider, subject string, ok bool) { + if !id.IsRaw() { + return "", "", false + } + body := strings.TrimPrefix(id.key, prefixRaw) + prov, subj, ok := strings.Cut(body, ".") + return prov, subj, ok +} + +func (id Identifier) InternalParts() (venue, stream, symbol string, params map[string]string, ok bool) { + if !id.IsInternal() { + return "", "", "", nil, false + } + body := strings.TrimPrefix(id.key, prefixInternal) + before, bracket, _ := strings.Cut(body, "[") + parts := strings.Split(before, ".") + if len(parts) != 3 { + return "", "", "", nil, false + } + return parts[0], parts[1], parts[2], decodeParams(strings.TrimSuffix(bracket, "]")), true +} + +func RawID(provider, subject string) (Identifier, error) { + p := strings.ToLower(strings.TrimSpace(provider)) + s := strings.TrimSpace(subject) + + if err := validateComponent("provider", p, false); err != nil { + return Identifier{}, err + } + if err := validateComponent("subject", s, true); err != nil { + return Identifier{}, err + } + return Identifier{key: prefixRaw + p + "." + s}, nil +} + +func InternalID(venue, stream, symbol string, params map[string]string) (Identifier, error) { + v := strings.ToLower(strings.TrimSpace(venue)) + t := strings.ToLower(strings.TrimSpace(stream)) + sym := strings.ToUpper(strings.TrimSpace(symbol)) + + if err := validateComponent("venue", v, false); err != nil { + return Identifier{}, err + } + if err := validateComponent("stream", t, false); err != nil { + return Identifier{}, err + } + if err := validateComponent("symbol", sym, false); err != nil { + return Identifier{}, err + } + + paramStr, err := encodeParams(params) // "k=v;..." or "" + if err != nil { + return Identifier{}, err + } + if paramStr == "" { + paramStr = "[]" + } else { + paramStr = "[" + paramStr + "]" + } + return Identifier{key: prefixInternal + v + "." + t + "." + sym + paramStr}, nil +} + +func ParseIdentifier(s string) (Identifier, error) { + s = strings.TrimSpace(s) + switch { + case strings.HasPrefix(s, prefixRaw): + // raw::provider.subject + body := strings.TrimPrefix(s, prefixRaw) + prov, subj, ok := strings.Cut(body, ".") + if !ok { + return Identifier{}, errors.New("invalid raw identifier: missing '.'") + } + return RawID(prov, subj) + + case strings.HasPrefix(s, prefixInternal): + // internal::venue.stream.symbol[...] + body := strings.TrimPrefix(s, prefixInternal) + before, bracket, _ := strings.Cut(body, "[") + parts := strings.Split(before, ".") + if len(parts) != 3 { + return Identifier{}, errors.New("invalid internal identifier: need venue.stream.symbol") + } + params := decodeParams(strings.TrimSuffix(bracket, "]")) + return InternalID(parts[0], parts[1], parts[2], params) + } + return Identifier{}, errors.New("unknown identifier prefix") +} + +var ( + segDisallow = regexp.MustCompile(`[ \t\r\n\[\]]`) // forbid whitespace/brackets in fixed segments + dotDisallow = regexp.MustCompile(`[.]`) // fixed segments cannot contain '.' +) + +// allowAny=true (for subject) skips dot checks but still forbids whitespace/brackets. +func validateComponent(name, v string, allowAny bool) error { + if v == "" { + return fmt.Errorf("%s cannot be empty", name) + } + if allowAny { + if segDisallow.MatchString(v) { + return fmt.Errorf("%s contains illegal chars [] or whitespace", name) + } + return nil + } + if segDisallow.MatchString(v) || dotDisallow.MatchString(v) { + return fmt.Errorf("%s contains illegal chars (dot/brackets/whitespace)", name) + } + return nil +} + +// encodeParams renders sorted k=v pairs separated by ';'. +func encodeParams(params map[string]string) (string, error) { + if len(params) == 0 { + return "", nil + } + keys := make([]string, 0, len(params)) + for k := range params { + k = strings.ToLower(strings.TrimSpace(k)) + if k == "" { + continue + } + keys = append(keys, k) + } + sort.Strings(keys) + + out := make([]string, 0, len(keys)) + for _, k := range keys { + v := strings.TrimSpace(params[k]) + // prevent breaking delimiters + if strings.ContainsAny(k, ";]") || strings.ContainsAny(v, ";]") { + return "", fmt.Errorf("param %q contains illegal ';' or ']'", k) + } + out = append(out, k+"="+v) + } + return strings.Join(out, ";"), nil +} + +func decodeParams(s string) map[string]string { + s = strings.TrimSpace(s) + if s == "" { + return map[string]string{} + } + out := make(map[string]string, 4) + for _, p := range strings.Split(s, ";") { + if p == "" { + continue + } + kv := strings.SplitN(p, "=", 2) + if len(kv) != 2 { + continue + } + k := strings.ToLower(strings.TrimSpace(kv[0])) + v := strings.TrimSpace(kv[1]) + if k != "" { + out[k] = v + } + } + return out } diff --git a/services/data_service/internal/manager/manager.go b/services/data_service/internal/manager/manager.go index c350087..e0a6326 100644 --- a/services/data_service/internal/manager/manager.go +++ b/services/data_service/internal/manager/manager.go @@ -30,7 +30,7 @@ type ClientStream struct { } func NewManager(router *router.Router) *Manager { - go router.Run() // Start the router in a separate goroutine + go router.Run() return &Manager{ providers: make(map[string]provider.Provider), providerStreams: make(map[domain.Identifier]chan domain.Message), @@ -46,8 +46,8 @@ func (m *Manager) StartStream() (uuid.UUID, error) { streamID := uuid.New() m.clientStreams[streamID] = &ClientStream{ UUID: streamID, - Identifiers: nil, // start empty - OutChannel: nil, // not yet connected + Identifiers: nil, + OutChannel: nil, Timer: time.AfterFunc(1*time.Minute, func() { fmt.Printf("stream %s expired due to inactivity\n", streamID) err := m.StopStream(streamID) @@ -69,21 +69,22 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier return fmt.Errorf("stream not found: %s", streamID) } - // Validate new identifiers. for _, id := range newIds { - if id.Provider == "" || id.Subject == "" { - return fmt.Errorf("empty identifier: %v", id) - } - prov, exists := m.providers[id.Provider] - if !exists { - return fmt.Errorf("unknown provider: %s", id.Provider) - } - if !prov.IsValidSubject(id.Subject, false) { - return fmt.Errorf("invalid subject %q for provider %s", id.Subject, id.Provider) + if id.IsRaw() { + providerName, subject, ok := id.ProviderSubject() + if !ok || providerName == "" || subject == "" { + return fmt.Errorf("empty identifier: %v", id) + } + prov, exists := m.providers[providerName] + if !exists { + return fmt.Errorf("unknown provider: %s", providerName) + } + if !prov.IsValidSubject(subject, false) { + return fmt.Errorf("invalid subject %q for provider %s", subject, providerName) + } } } - // Generate old and new sets of identifiers oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers)) for _, id := range stream.Identifiers { oldSet[id] = struct{}{} @@ -93,55 +94,54 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier newSet[id] = struct{}{} } - // Add identifiers that are in newIds but not in oldSet for _, id := range newIds { if _, seen := oldSet[id]; !seen { - // Provision the stream from the provider if needed - if _, ok := m.providerStreams[id]; !ok { - ch := make(chan domain.Message, 64) - if err := m.providers[id.Provider].RequestStream(id.Subject, ch); err != nil { - return fmt.Errorf("provision %v: %w", id, err) - } - m.providerStreams[id] = ch - - incomingChannel := m.router.IncomingChannel() - go func(c chan domain.Message) { - for msg := range c { - incomingChannel <- msg + if id.IsRaw() { + if _, ok := m.providerStreams[id]; !ok { + ch := make(chan domain.Message, 64) + providerName, subject, _ := id.ProviderSubject() + if err := m.providers[providerName].RequestStream(subject, ch); err != nil { + return fmt.Errorf("provision %v: %w", id, err) } - }(ch) + m.providerStreams[id] = ch + + incomingChannel := m.router.IncomingChannel() + go func(c chan domain.Message) { + for msg := range c { + incomingChannel <- msg + } + }(ch) + } } - // Register the new identifier with the router, only if there's an active output channel (meaning the stream is connected) if stream.OutChannel != nil { m.router.RegisterRoute(id, stream.OutChannel) } } } - // Remove identifiers that are in oldSet but not in newSet for _, oldId := range stream.Identifiers { if _, keep := newSet[oldId]; !keep { - // Deregister the identifier from the router, only if there's an active output channel (meaning the stream is connected) if stream.OutChannel != nil { m.router.DeregisterRoute(oldId, stream.OutChannel) } } } - // Set the new identifiers for the stream stream.Identifiers = newIds - // Clean up provider streams that are no longer used used := make(map[domain.Identifier]bool) for _, cs := range m.clientStreams { for _, id := range cs.Identifiers { - used[id] = true + if id.IsRaw() { + used[id] = true + } } } for id, ch := range m.providerStreams { if !used[id] { - m.providers[id.Provider].CancelStream(id.Subject) + providerName, subject, _ := id.ProviderSubject() + m.providers[providerName].CancelStream(subject) close(ch) delete(m.providerStreams, id) } @@ -165,18 +165,19 @@ func (m *Manager) StopStream(streamID uuid.UUID) error { delete(m.clientStreams, streamID) - // Find provider streams that are used by other client streams used := make(map[domain.Identifier]bool) for _, s := range m.clientStreams { for _, id := range s.Identifiers { - used[id] = true + if id.IsRaw() { + used[id] = true + } } } - // Cancel provider streams that are not used by any client stream for id, ch := range m.providerStreams { if !used[id] { - m.providers[id.Provider].CancelStream(id.Subject) + providerName, subject, _ := id.ProviderSubject() + m.providers[providerName].CancelStream(subject) close(ch) delete(m.providerStreams, id) } @@ -219,19 +220,16 @@ func (m *Manager) DisconnectStream(streamID uuid.UUID) { stream, ok := m.clientStreams[streamID] if !ok || stream.OutChannel == nil { - return // already disconnected or does not exist + return } - // Deregister all identifiers from the router for _, ident := range stream.Identifiers { m.router.DeregisterRoute(ident, stream.OutChannel) } - // Close the output channel close(stream.OutChannel) stream.OutChannel = nil - // Set up the expiry timer stream.Timer = time.AfterFunc(1*time.Minute, func() { fmt.Printf("stream %s expired due to inactivity\n", streamID) err := m.StopStream(streamID) @@ -256,6 +254,6 @@ func (m *Manager) AddProvider(name string, p provider.Provider) { m.providers[name] = p } -func (m *Manager) RemoveProvider(name string) { - panic("not implemented yet") // TODO: Implement provider removal logic +func (m *Manager) RemoveProvider(_ string) { + panic("not implemented yet") } diff --git a/services/data_service/internal/provider/binance/futures_websocket.go b/services/data_service/internal/provider/binance/futures_websocket.go index ab82d6e..6d2e517 100644 --- a/services/data_service/internal/provider/binance/futures_websocket.go +++ b/services/data_service/internal/provider/binance/futures_websocket.go @@ -109,7 +109,7 @@ func (b *FuturesWebsocket) IsValidSubject(subject string, isFetch bool) bool { if isFetch { return false } - return len(subject) > 0 // Extend with regex or lookup if needed + return len(subject) > 0 } func (b *FuturesWebsocket) readLoop() { @@ -138,13 +138,15 @@ func (b *FuturesWebsocket) readLoop() { continue } + id, err := domain.RawID("binance_futures_websocket", container.Stream) + if err != nil { + continue + } + msg := domain.Message{ - Identifier: domain.Identifier{ - Provider: "binance_futures_websocket", - Subject: container.Stream, - }, - Payload: []byte(container.Data), - Encoding: domain.EncodingJSON, + Identifier: id, + Payload: container.Data, + Encoding: domain.EncodingJSON, } select { diff --git a/services/data_service/internal/server/gprc_control_server.go b/services/data_service/internal/server/gprc_control_server.go index 5dd8dd6..c0aee6b 100644 --- a/services/data_service/internal/server/gprc_control_server.go +++ b/services/data_service/internal/server/gprc_control_server.go @@ -18,9 +18,7 @@ type GRPCControlServer struct { } func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer { - return &GRPCControlServer{ - manager: m, - } + return &GRPCControlServer{manager: m} } func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) { @@ -28,7 +26,6 @@ func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequ if err != nil { return nil, fmt.Errorf("failed to start stream: %w", err) } - return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil } @@ -38,19 +35,18 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err) } - // Transform identifiers from protobuf to domain format var ids []domain.Identifier - for _, i := range req.Identifiers { - ids = append(ids, domain.Identifier{ - Provider: i.Provider, - Subject: i.Subject, - }) + for _, in := range req.Identifiers { + id, e := domain.ParseIdentifier(in.Key) + if e != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid identifier %q: %v", in.Key, e) + } + ids = append(ids, id) } if err := s.manager.ConfigureStream(streamID, ids); err != nil { return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err) } - return &pb.ConfigureStreamResponse{}, nil } @@ -59,11 +55,8 @@ func (s *GRPCControlServer) StopStream(_ context.Context, req *pb.StopStreamRequ if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err) } - - err = s.manager.StopStream(streamID) // Should only error if the stream doesn't exist - if err != nil { + if err := s.manager.StopStream(streamID); err != nil { return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err) } - return &pb.StopStreamResponse{}, nil } diff --git a/services/data_service/internal/server/grpc_streaming_server.go b/services/data_service/internal/server/grpc_streaming_server.go index 225e2a4..120caa3 100644 --- a/services/data_service/internal/server/grpc_streaming_server.go +++ b/services/data_service/internal/server/grpc_streaming_server.go @@ -14,9 +14,7 @@ type GRPCStreamingServer struct { } func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer { - return &GRPCStreamingServer{ - manager: m, - } + return &GRPCStreamingServer{manager: m} } func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error { @@ -39,17 +37,11 @@ func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream if !ok { return nil } - - err := stream.Send(&pb.Message{ - Identifier: &pb.Identifier{ - Provider: msg.Identifier.Provider, - Subject: msg.Identifier.Subject, - }, - Payload: msg.Payload, - Encoding: string(msg.Encoding), - }) - - if err != nil { + if err := stream.Send(&pb.Message{ + Identifier: &pb.Identifier{Key: msg.Identifier.Key()}, + Payload: msg.Payload, + Encoding: string(msg.Encoding), + }); err != nil { return err } } diff --git a/services/data_service/internal/server/socket_streaming_server.go b/services/data_service/internal/server/socket_streaming_server.go index a551391..423f508 100644 --- a/services/data_service/internal/server/socket_streaming_server.go +++ b/services/data_service/internal/server/socket_streaming_server.go @@ -7,6 +7,7 @@ import ( "io" "net" "strings" + "time" "github.com/google/uuid" pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" @@ -22,7 +23,6 @@ func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer { return &SocketStreamingServer{manager: m} } -// Accepts connections and hands each off to handleConnection. func (s *SocketStreamingServer) Serve(lis net.Listener) error { for { conn, err := lis.Accept() @@ -43,16 +43,16 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) { } }() - // Low-latency socket hints (best-effort). if tc, ok := conn.(*net.TCPConn); ok { _ = tc.SetNoDelay(true) _ = tc.SetWriteBuffer(512 * 1024) _ = tc.SetReadBuffer(512 * 1024) + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) } reader := bufio.NewReader(conn) - // Protocol header: first line is the stream UUID. raw, err := reader.ReadString('\n') if err != nil { fmt.Printf("read stream UUID error: %v\n", err) @@ -74,9 +74,8 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) { defer s.manager.DisconnectStream(streamUUID) writer := bufio.NewWriterSize(conn, 256*1024) - defer func(writer *bufio.Writer) { - err := writer.Flush() - if err != nil { + defer func(w *bufio.Writer) { + if err := w.Flush(); err != nil { fmt.Printf("final flush error: %v\n", err) } }(writer) @@ -85,27 +84,20 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) { batch := 0 for msg := range outCh { - // Build protobuf payload. - message := pb.Message{ - Identifier: &pb.Identifier{ - Provider: msg.Identifier.Provider, - Subject: msg.Identifier.Subject, - }, - Payload: msg.Payload, // []byte - Encoding: string(msg.Encoding), // e.g., "application/json" + m := pb.Message{ + Identifier: &pb.Identifier{Key: msg.Identifier.Key()}, + Payload: msg.Payload, + Encoding: string(msg.Encoding), } - // Marshal protobuf. - // Use MarshalAppend to reuse capacity and avoid an extra alloc. - size := proto.Size(&message) + size := proto.Size(&m) buf := make([]byte, 0, size) - b, err := proto.MarshalOptions{}.MarshalAppend(buf, &message) + b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m) if err != nil { fmt.Printf("proto marshal error: %v\n", err) continue } - // Fixed 4-byte big-endian length prefix. var hdr [4]byte if len(b) > int(^uint32(0)) { fmt.Printf("message too large: %d bytes\n", len(b)) @@ -113,7 +105,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) { } binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) - // Write frame: [len][bytes]. if _, err := writer.Write(hdr[:]); err != nil { if err == io.EOF { return @@ -139,7 +130,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) { } } - // Final flush when channel closes. if err := writer.Flush(); err != nil { fmt.Printf("final flush error: %v\n", err) }