Refactor identifier handling: replace Provider and Subject fields with a single Key field in Identifier struct, update related message structures, and adjust parsing logic

This commit is contained in:
2025-08-17 06:16:49 +00:00
parent ef4a28fb29
commit 2484a33945
9 changed files with 290 additions and 165 deletions

View File

@@ -21,11 +21,9 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Domain Models
type Identifier struct {
state protoimpl.MessageState `protogen:"open.v1"`
Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // e.g., "binance"
Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` // e.g., "BTCUSDT"
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -60,16 +58,9 @@ func (*Identifier) Descriptor() ([]byte, []int) {
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{0}
}
func (x *Identifier) GetProvider() string {
func (x *Identifier) GetKey() string {
if x != nil {
return x.Provider
}
return ""
}
func (x *Identifier) GetSubject() string {
if x != nil {
return x.Subject
return x.Key
}
return ""
}
@@ -77,8 +68,8 @@ func (x *Identifier) GetSubject() string {
type Message struct {
state protoimpl.MessageState `protogen:"open.v1"`
Identifier *Identifier `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` // JSON-encoded data
Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"` // e.g., "json", "protobuf"
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -134,7 +125,6 @@ func (x *Message) GetEncoding() string {
return ""
}
// Control Requests and Responses
type StartStreamRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -383,7 +373,6 @@ func (*StopStreamResponse) Descriptor() ([]byte, []int) {
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{7}
}
// Stream Requests and Responses
type ConnectStreamRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
StreamUuid string `protobuf:"bytes,1,opt,name=stream_uuid,json=streamUuid,proto3" json:"stream_uuid,omitempty"`
@@ -432,11 +421,10 @@ var File_pkg_pb_data_service_data_service_proto protoreflect.FileDescriptor
const file_pkg_pb_data_service_data_service_proto_rawDesc = "" +
"\n" +
"&pkg/pb/data_service/data_service.proto\x12\fdata_service\"B\n" +
"&pkg/pb/data_service/data_service.proto\x12\fdata_service\"\x1e\n" +
"\n" +
"Identifier\x12\x1a\n" +
"\bprovider\x18\x01 \x01(\tR\bprovider\x12\x18\n" +
"\asubject\x18\x02 \x01(\tR\asubject\"y\n" +
"Identifier\x12\x10\n" +
"\x03key\x18\x01 \x01(\tR\x03key\"y\n" +
"\aMessage\x128\n" +
"\n" +
"identifier\x18\x01 \x01(\v2\x18.data_service.IdentifierR\n" +

View File

@@ -14,39 +14,26 @@ service DataServiceStreaming {
rpc ConnectStream(ConnectStreamRequest) returns (stream Message);
}
// Domain Models
message Identifier {
string provider = 1; // e.g., "binance"
string subject = 2; // e.g., "BTCUSDT"
string key = 1;
}
message Message {
Identifier identifier = 1;
bytes payload = 2; // JSON-encoded data
string encoding = 3; // e.g., "json", "protobuf"
bytes payload = 2;
string encoding = 3;
}
// Control Requests and Responses
message StartStreamRequest {}
message StartStreamResponse {
string stream_uuid = 1;
}
message StartStreamResponse { string stream_uuid = 1; }
message ConfigureStreamRequest {
string stream_uuid = 1;
repeated Identifier identifiers = 2;
}
message ConfigureStreamResponse {}
message StopStreamRequest {
string stream_uuid = 1;
}
message StopStreamRequest { string stream_uuid = 1; }
message StopStreamResponse {}
// Stream Requests and Responses
message ConnectStreamRequest {
string stream_uuid = 1;
}
message ConnectStreamRequest { string stream_uuid = 1; }

View File

@@ -28,7 +28,7 @@ func (i *idsFlag) Set(v string) error {
return nil
}
func parseID(s string) (provider, subject string, err error) {
func parseIDPair(s string) (provider, subject string, err error) {
parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("want provider:subject, got %q", s)
@@ -36,13 +36,24 @@ func parseID(s string) (provider, subject string, err error) {
return parts[0], parts[1], nil
}
func toIdentifierKey(input string) (string, error) {
if strings.Contains(input, "::") {
return input, nil
}
prov, subj, err := parseIDPair(input)
if err != nil {
return "", err
}
return "raw::" + strings.ToLower(prov) + "." + subj, nil
}
func prettyOrRaw(b []byte, pretty bool) string {
if !pretty || len(b) == 0 {
return string(b)
}
var tmp any
if err := json.Unmarshal(b, &tmp); err != nil {
return string(b) // not JSON
return string(b)
}
out, err := json.MarshalIndent(tmp, "", " ")
if err != nil {
@@ -51,7 +62,6 @@ func prettyOrRaw(b []byte, pretty bool) string {
return string(out)
}
// waitReady blocks until conn is READY or ctx times out/cancels.
func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
for {
s := conn.GetState()
@@ -67,7 +77,6 @@ func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
}
}
//goland:noinspection GoUnhandledErrorResult
func main() {
var ids idsFlag
var ctlAddr string
@@ -75,7 +84,7 @@ func main() {
var pretty bool
var timeout time.Duration
flag.Var(&ids, "id", "identifier in form provider:subject (repeatable)")
flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable")
flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address")
flag.StringVar(&strAddr, "str", "127.0.0.1:50052", "gRPC streaming address")
flag.BoolVar(&pretty, "pretty", true, "pretty-print JSON payloads when possible")
@@ -83,15 +92,13 @@ func main() {
flag.Parse()
if len(ids) == 0 {
fmt.Fprintln(os.Stderr, "provide at least one --id provider:subject")
fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)")
os.Exit(2)
}
// Ctrl-C handling
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
// ----- Control client -----
ccCtl, err := grpc.NewClient(
ctlAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
@@ -101,8 +108,7 @@ func main() {
os.Exit(1)
}
defer ccCtl.Close()
ccCtl.Connect() // start dialing in background
ccCtl.Connect()
ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout)
if err := waitReady(ctlConnCtx, ccCtl); err != nil {
@@ -114,7 +120,6 @@ func main() {
ctl := pb.NewDataServiceControlClient(ccCtl)
// Start stream
ctxStart, cancelStart := context.WithTimeout(ctx, timeout)
startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{})
cancelStart()
@@ -125,15 +130,14 @@ func main() {
streamUUID := startResp.GetStreamUuid()
fmt.Printf("stream: %s\n", streamUUID)
// Configure
var pbIDs []*pb.Identifier
for _, s := range ids {
prov, subj, err := parseID(s)
key, err := toIdentifierKey(s)
if err != nil {
fmt.Fprintf(os.Stderr, "bad --id: %v\n", err)
os.Exit(2)
}
pbIDs = append(pbIDs, &pb.Identifier{Provider: prov, Subject: subj})
pbIDs = append(pbIDs, &pb.Identifier{Key: key})
}
ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout)
@@ -148,7 +152,6 @@ func main() {
}
fmt.Printf("configured %d identifiers\n", len(pbIDs))
// ----- Streaming client -----
ccStr, err := grpc.NewClient(
strAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
@@ -158,7 +161,6 @@ func main() {
os.Exit(1)
}
defer ccStr.Close()
ccStr.Connect()
strConnCtx, cancelStrConn := context.WithTimeout(ctx, timeout)
@@ -171,7 +173,6 @@ func main() {
str := pb.NewDataServiceStreamingClient(ccStr)
// This context lives until Ctrl-C
streamCtx, streamCancel := context.WithCancel(ctx)
defer streamCancel()
@@ -182,7 +183,6 @@ func main() {
}
fmt.Println("connected; streaming… (Ctrl-C to quit)")
// Receive loop until Ctrl-C
for {
select {
case <-ctx.Done():
@@ -192,14 +192,14 @@ func main() {
msg, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
return // normal shutdown
return
}
fmt.Fprintf(os.Stderr, "recv: %v\n", err)
os.Exit(1)
}
id := msg.GetIdentifier()
fmt.Printf("[%s] %s bytes=%d enc=%s t=%s\n",
id.GetProvider(), id.GetSubject(), len(msg.GetPayload()), msg.GetEncoding(),
fmt.Printf("[%s] bytes=%d enc=%s t=%s\n",
id.GetKey(), len(msg.GetPayload()), msg.GetEncoding(),
time.Now().Format(time.RFC3339Nano),
)
fmt.Println(prettyOrRaw(msg.GetPayload(), pretty))

View File

@@ -1,6 +1,181 @@
package domain
type Identifier struct {
Provider string
Subject string
import (
"errors"
"fmt"
"regexp"
"sort"
"strings"
)
const (
prefixRaw = "raw::"
prefixInternal = "internal::"
)
// Identifier is a canonical representation of a data stream identifier.
type Identifier struct{ key string }
func (id Identifier) IsRaw() bool { return strings.HasPrefix(id.key, prefixRaw) }
func (id Identifier) IsInternal() bool { return strings.HasPrefix(id.key, prefixInternal) }
func (id Identifier) Key() string { return id.key }
func (id Identifier) ProviderSubject() (provider, subject string, ok bool) {
if !id.IsRaw() {
return "", "", false
}
body := strings.TrimPrefix(id.key, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
return prov, subj, ok
}
func (id Identifier) InternalParts() (venue, stream, symbol string, params map[string]string, ok bool) {
if !id.IsInternal() {
return "", "", "", nil, false
}
body := strings.TrimPrefix(id.key, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return "", "", "", nil, false
}
return parts[0], parts[1], parts[2], decodeParams(strings.TrimSuffix(bracket, "]")), true
}
func RawID(provider, subject string) (Identifier, error) {
p := strings.ToLower(strings.TrimSpace(provider))
s := strings.TrimSpace(subject)
if err := validateComponent("provider", p, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("subject", s, true); err != nil {
return Identifier{}, err
}
return Identifier{key: prefixRaw + p + "." + s}, nil
}
func InternalID(venue, stream, symbol string, params map[string]string) (Identifier, error) {
v := strings.ToLower(strings.TrimSpace(venue))
t := strings.ToLower(strings.TrimSpace(stream))
sym := strings.ToUpper(strings.TrimSpace(symbol))
if err := validateComponent("venue", v, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("stream", t, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("symbol", sym, false); err != nil {
return Identifier{}, err
}
paramStr, err := encodeParams(params) // "k=v;..." or ""
if err != nil {
return Identifier{}, err
}
if paramStr == "" {
paramStr = "[]"
} else {
paramStr = "[" + paramStr + "]"
}
return Identifier{key: prefixInternal + v + "." + t + "." + sym + paramStr}, nil
}
func ParseIdentifier(s string) (Identifier, error) {
s = strings.TrimSpace(s)
switch {
case strings.HasPrefix(s, prefixRaw):
// raw::provider.subject
body := strings.TrimPrefix(s, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
if !ok {
return Identifier{}, errors.New("invalid raw identifier: missing '.'")
}
return RawID(prov, subj)
case strings.HasPrefix(s, prefixInternal):
// internal::venue.stream.symbol[...]
body := strings.TrimPrefix(s, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return Identifier{}, errors.New("invalid internal identifier: need venue.stream.symbol")
}
params := decodeParams(strings.TrimSuffix(bracket, "]"))
return InternalID(parts[0], parts[1], parts[2], params)
}
return Identifier{}, errors.New("unknown identifier prefix")
}
var (
segDisallow = regexp.MustCompile(`[ \t\r\n\[\]]`) // forbid whitespace/brackets in fixed segments
dotDisallow = regexp.MustCompile(`[.]`) // fixed segments cannot contain '.'
)
// allowAny=true (for subject) skips dot checks but still forbids whitespace/brackets.
func validateComponent(name, v string, allowAny bool) error {
if v == "" {
return fmt.Errorf("%s cannot be empty", name)
}
if allowAny {
if segDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars [] or whitespace", name)
}
return nil
}
if segDisallow.MatchString(v) || dotDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars (dot/brackets/whitespace)", name)
}
return nil
}
// encodeParams renders sorted k=v pairs separated by ';'.
func encodeParams(params map[string]string) (string, error) {
if len(params) == 0 {
return "", nil
}
keys := make([]string, 0, len(params))
for k := range params {
k = strings.ToLower(strings.TrimSpace(k))
if k == "" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
v := strings.TrimSpace(params[k])
// prevent breaking delimiters
if strings.ContainsAny(k, ";]") || strings.ContainsAny(v, ";]") {
return "", fmt.Errorf("param %q contains illegal ';' or ']'", k)
}
out = append(out, k+"="+v)
}
return strings.Join(out, ";"), nil
}
func decodeParams(s string) map[string]string {
s = strings.TrimSpace(s)
if s == "" {
return map[string]string{}
}
out := make(map[string]string, 4)
for _, p := range strings.Split(s, ";") {
if p == "" {
continue
}
kv := strings.SplitN(p, "=", 2)
if len(kv) != 2 {
continue
}
k := strings.ToLower(strings.TrimSpace(kv[0]))
v := strings.TrimSpace(kv[1])
if k != "" {
out[k] = v
}
}
return out
}

View File

@@ -30,7 +30,7 @@ type ClientStream struct {
}
func NewManager(router *router.Router) *Manager {
go router.Run() // Start the router in a separate goroutine
go router.Run()
return &Manager{
providers: make(map[string]provider.Provider),
providerStreams: make(map[domain.Identifier]chan domain.Message),
@@ -46,8 +46,8 @@ func (m *Manager) StartStream() (uuid.UUID, error) {
streamID := uuid.New()
m.clientStreams[streamID] = &ClientStream{
UUID: streamID,
Identifiers: nil, // start empty
OutChannel: nil, // not yet connected
Identifiers: nil,
OutChannel: nil,
Timer: time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID)
@@ -69,21 +69,22 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
return fmt.Errorf("stream not found: %s", streamID)
}
// Validate new identifiers.
for _, id := range newIds {
if id.Provider == "" || id.Subject == "" {
return fmt.Errorf("empty identifier: %v", id)
}
prov, exists := m.providers[id.Provider]
if !exists {
return fmt.Errorf("unknown provider: %s", id.Provider)
}
if !prov.IsValidSubject(id.Subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", id.Subject, id.Provider)
if id.IsRaw() {
providerName, subject, ok := id.ProviderSubject()
if !ok || providerName == "" || subject == "" {
return fmt.Errorf("empty identifier: %v", id)
}
prov, exists := m.providers[providerName]
if !exists {
return fmt.Errorf("unknown provider: %s", providerName)
}
if !prov.IsValidSubject(subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", subject, providerName)
}
}
}
// Generate old and new sets of identifiers
oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers))
for _, id := range stream.Identifiers {
oldSet[id] = struct{}{}
@@ -93,55 +94,54 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
newSet[id] = struct{}{}
}
// Add identifiers that are in newIds but not in oldSet
for _, id := range newIds {
if _, seen := oldSet[id]; !seen {
// Provision the stream from the provider if needed
if _, ok := m.providerStreams[id]; !ok {
ch := make(chan domain.Message, 64)
if err := m.providers[id.Provider].RequestStream(id.Subject, ch); err != nil {
return fmt.Errorf("provision %v: %w", id, err)
}
m.providerStreams[id] = ch
incomingChannel := m.router.IncomingChannel()
go func(c chan domain.Message) {
for msg := range c {
incomingChannel <- msg
if id.IsRaw() {
if _, ok := m.providerStreams[id]; !ok {
ch := make(chan domain.Message, 64)
providerName, subject, _ := id.ProviderSubject()
if err := m.providers[providerName].RequestStream(subject, ch); err != nil {
return fmt.Errorf("provision %v: %w", id, err)
}
}(ch)
m.providerStreams[id] = ch
incomingChannel := m.router.IncomingChannel()
go func(c chan domain.Message) {
for msg := range c {
incomingChannel <- msg
}
}(ch)
}
}
// Register the new identifier with the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil {
m.router.RegisterRoute(id, stream.OutChannel)
}
}
}
// Remove identifiers that are in oldSet but not in newSet
for _, oldId := range stream.Identifiers {
if _, keep := newSet[oldId]; !keep {
// Deregister the identifier from the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil {
m.router.DeregisterRoute(oldId, stream.OutChannel)
}
}
}
// Set the new identifiers for the stream
stream.Identifiers = newIds
// Clean up provider streams that are no longer used
used := make(map[domain.Identifier]bool)
for _, cs := range m.clientStreams {
for _, id := range cs.Identifiers {
used[id] = true
if id.IsRaw() {
used[id] = true
}
}
}
for id, ch := range m.providerStreams {
if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject)
providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch)
delete(m.providerStreams, id)
}
@@ -165,18 +165,19 @@ func (m *Manager) StopStream(streamID uuid.UUID) error {
delete(m.clientStreams, streamID)
// Find provider streams that are used by other client streams
used := make(map[domain.Identifier]bool)
for _, s := range m.clientStreams {
for _, id := range s.Identifiers {
used[id] = true
if id.IsRaw() {
used[id] = true
}
}
}
// Cancel provider streams that are not used by any client stream
for id, ch := range m.providerStreams {
if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject)
providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch)
delete(m.providerStreams, id)
}
@@ -219,19 +220,16 @@ func (m *Manager) DisconnectStream(streamID uuid.UUID) {
stream, ok := m.clientStreams[streamID]
if !ok || stream.OutChannel == nil {
return // already disconnected or does not exist
return
}
// Deregister all identifiers from the router
for _, ident := range stream.Identifiers {
m.router.DeregisterRoute(ident, stream.OutChannel)
}
// Close the output channel
close(stream.OutChannel)
stream.OutChannel = nil
// Set up the expiry timer
stream.Timer = time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID)
@@ -256,6 +254,6 @@ func (m *Manager) AddProvider(name string, p provider.Provider) {
m.providers[name] = p
}
func (m *Manager) RemoveProvider(name string) {
panic("not implemented yet") // TODO: Implement provider removal logic
func (m *Manager) RemoveProvider(_ string) {
panic("not implemented yet")
}

View File

@@ -109,7 +109,7 @@ func (b *FuturesWebsocket) IsValidSubject(subject string, isFetch bool) bool {
if isFetch {
return false
}
return len(subject) > 0 // Extend with regex or lookup if needed
return len(subject) > 0
}
func (b *FuturesWebsocket) readLoop() {
@@ -138,13 +138,15 @@ func (b *FuturesWebsocket) readLoop() {
continue
}
id, err := domain.RawID("binance_futures_websocket", container.Stream)
if err != nil {
continue
}
msg := domain.Message{
Identifier: domain.Identifier{
Provider: "binance_futures_websocket",
Subject: container.Stream,
},
Payload: []byte(container.Data),
Encoding: domain.EncodingJSON,
Identifier: id,
Payload: container.Data,
Encoding: domain.EncodingJSON,
}
select {

View File

@@ -18,9 +18,7 @@ type GRPCControlServer struct {
}
func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer {
return &GRPCControlServer{
manager: m,
}
return &GRPCControlServer{manager: m}
}
func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) {
@@ -28,7 +26,6 @@ func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequ
if err != nil {
return nil, fmt.Errorf("failed to start stream: %w", err)
}
return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil
}
@@ -38,19 +35,18 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
}
// Transform identifiers from protobuf to domain format
var ids []domain.Identifier
for _, i := range req.Identifiers {
ids = append(ids, domain.Identifier{
Provider: i.Provider,
Subject: i.Subject,
})
for _, in := range req.Identifiers {
id, e := domain.ParseIdentifier(in.Key)
if e != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identifier %q: %v", in.Key, e)
}
ids = append(ids, id)
}
if err := s.manager.ConfigureStream(streamID, ids); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err)
}
return &pb.ConfigureStreamResponse{}, nil
}
@@ -59,11 +55,8 @@ func (s *GRPCControlServer) StopStream(_ context.Context, req *pb.StopStreamRequ
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
}
err = s.manager.StopStream(streamID) // Should only error if the stream doesn't exist
if err != nil {
if err := s.manager.StopStream(streamID); err != nil {
return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err)
}
return &pb.StopStreamResponse{}, nil
}

View File

@@ -14,9 +14,7 @@ type GRPCStreamingServer struct {
}
func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer {
return &GRPCStreamingServer{
manager: m,
}
return &GRPCStreamingServer{manager: m}
}
func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error {
@@ -39,17 +37,11 @@ func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream
if !ok {
return nil
}
err := stream.Send(&pb.Message{
Identifier: &pb.Identifier{
Provider: msg.Identifier.Provider,
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
})
if err != nil {
if err := stream.Send(&pb.Message{
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
}); err != nil {
return err
}
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"strings"
"time"
"github.com/google/uuid"
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
@@ -22,7 +23,6 @@ func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
return &SocketStreamingServer{manager: m}
}
// Accepts connections and hands each off to handleConnection.
func (s *SocketStreamingServer) Serve(lis net.Listener) error {
for {
conn, err := lis.Accept()
@@ -43,16 +43,16 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
}()
// Low-latency socket hints (best-effort).
if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.SetNoDelay(true)
_ = tc.SetWriteBuffer(512 * 1024)
_ = tc.SetReadBuffer(512 * 1024)
_ = tc.SetKeepAlive(true)
_ = tc.SetKeepAlivePeriod(30 * time.Second)
}
reader := bufio.NewReader(conn)
// Protocol header: first line is the stream UUID.
raw, err := reader.ReadString('\n')
if err != nil {
fmt.Printf("read stream UUID error: %v\n", err)
@@ -74,9 +74,8 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
defer s.manager.DisconnectStream(streamUUID)
writer := bufio.NewWriterSize(conn, 256*1024)
defer func(writer *bufio.Writer) {
err := writer.Flush()
if err != nil {
defer func(w *bufio.Writer) {
if err := w.Flush(); err != nil {
fmt.Printf("final flush error: %v\n", err)
}
}(writer)
@@ -85,27 +84,20 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
batch := 0
for msg := range outCh {
// Build protobuf payload.
message := pb.Message{
Identifier: &pb.Identifier{
Provider: msg.Identifier.Provider,
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload, // []byte
Encoding: string(msg.Encoding), // e.g., "application/json"
m := pb.Message{
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
}
// Marshal protobuf.
// Use MarshalAppend to reuse capacity and avoid an extra alloc.
size := proto.Size(&message)
size := proto.Size(&m)
buf := make([]byte, 0, size)
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &message)
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m)
if err != nil {
fmt.Printf("proto marshal error: %v\n", err)
continue
}
// Fixed 4-byte big-endian length prefix.
var hdr [4]byte
if len(b) > int(^uint32(0)) {
fmt.Printf("message too large: %d bytes\n", len(b))
@@ -113,7 +105,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
// Write frame: [len][bytes].
if _, err := writer.Write(hdr[:]); err != nil {
if err == io.EOF {
return
@@ -139,7 +130,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
}
// Final flush when channel closes.
if err := writer.Flush(); err != nil {
fmt.Printf("final flush error: %v\n", err)
}