Refactor identifier handling: replace Provider and Subject fields with a single Key field in Identifier struct, update related message structures, and adjust parsing logic

This commit is contained in:
2025-08-17 06:16:49 +00:00
parent ef4a28fb29
commit 2484a33945
9 changed files with 290 additions and 165 deletions

View File

@@ -21,11 +21,9 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
// Domain Models
type Identifier struct { type Identifier struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // e.g., "binance" Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` // e.g., "BTCUSDT"
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -60,16 +58,9 @@ func (*Identifier) Descriptor() ([]byte, []int) {
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{0} return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{0}
} }
func (x *Identifier) GetProvider() string { func (x *Identifier) GetKey() string {
if x != nil { if x != nil {
return x.Provider return x.Key
}
return ""
}
func (x *Identifier) GetSubject() string {
if x != nil {
return x.Subject
} }
return "" return ""
} }
@@ -77,8 +68,8 @@ func (x *Identifier) GetSubject() string {
type Message struct { type Message struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Identifier *Identifier `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"` Identifier *Identifier `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` // JSON-encoded data Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"` // e.g., "json", "protobuf" Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -134,7 +125,6 @@ func (x *Message) GetEncoding() string {
return "" return ""
} }
// Control Requests and Responses
type StartStreamRequest struct { type StartStreamRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
@@ -383,7 +373,6 @@ func (*StopStreamResponse) Descriptor() ([]byte, []int) {
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{7} return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{7}
} }
// Stream Requests and Responses
type ConnectStreamRequest struct { type ConnectStreamRequest struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
StreamUuid string `protobuf:"bytes,1,opt,name=stream_uuid,json=streamUuid,proto3" json:"stream_uuid,omitempty"` StreamUuid string `protobuf:"bytes,1,opt,name=stream_uuid,json=streamUuid,proto3" json:"stream_uuid,omitempty"`
@@ -432,11 +421,10 @@ var File_pkg_pb_data_service_data_service_proto protoreflect.FileDescriptor
const file_pkg_pb_data_service_data_service_proto_rawDesc = "" + const file_pkg_pb_data_service_data_service_proto_rawDesc = "" +
"\n" + "\n" +
"&pkg/pb/data_service/data_service.proto\x12\fdata_service\"B\n" + "&pkg/pb/data_service/data_service.proto\x12\fdata_service\"\x1e\n" +
"\n" + "\n" +
"Identifier\x12\x1a\n" + "Identifier\x12\x10\n" +
"\bprovider\x18\x01 \x01(\tR\bprovider\x12\x18\n" + "\x03key\x18\x01 \x01(\tR\x03key\"y\n" +
"\asubject\x18\x02 \x01(\tR\asubject\"y\n" +
"\aMessage\x128\n" + "\aMessage\x128\n" +
"\n" + "\n" +
"identifier\x18\x01 \x01(\v2\x18.data_service.IdentifierR\n" + "identifier\x18\x01 \x01(\v2\x18.data_service.IdentifierR\n" +

View File

@@ -14,39 +14,26 @@ service DataServiceStreaming {
rpc ConnectStream(ConnectStreamRequest) returns (stream Message); rpc ConnectStream(ConnectStreamRequest) returns (stream Message);
} }
// Domain Models
message Identifier { message Identifier {
string provider = 1; // e.g., "binance" string key = 1;
string subject = 2; // e.g., "BTCUSDT"
} }
message Message { message Message {
Identifier identifier = 1; Identifier identifier = 1;
bytes payload = 2; // JSON-encoded data bytes payload = 2;
string encoding = 3; // e.g., "json", "protobuf" string encoding = 3;
} }
// Control Requests and Responses
message StartStreamRequest {} message StartStreamRequest {}
message StartStreamResponse { string stream_uuid = 1; }
message StartStreamResponse {
string stream_uuid = 1;
}
message ConfigureStreamRequest { message ConfigureStreamRequest {
string stream_uuid = 1; string stream_uuid = 1;
repeated Identifier identifiers = 2; repeated Identifier identifiers = 2;
} }
message ConfigureStreamResponse {} message ConfigureStreamResponse {}
message StopStreamRequest { message StopStreamRequest { string stream_uuid = 1; }
string stream_uuid = 1;
}
message StopStreamResponse {} message StopStreamResponse {}
// Stream Requests and Responses message ConnectStreamRequest { string stream_uuid = 1; }
message ConnectStreamRequest {
string stream_uuid = 1;
}

View File

@@ -28,7 +28,7 @@ func (i *idsFlag) Set(v string) error {
return nil return nil
} }
func parseID(s string) (provider, subject string, err error) { func parseIDPair(s string) (provider, subject string, err error) {
parts := strings.SplitN(s, ":", 2) parts := strings.SplitN(s, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" { if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return "", "", fmt.Errorf("want provider:subject, got %q", s) return "", "", fmt.Errorf("want provider:subject, got %q", s)
@@ -36,13 +36,24 @@ func parseID(s string) (provider, subject string, err error) {
return parts[0], parts[1], nil return parts[0], parts[1], nil
} }
func toIdentifierKey(input string) (string, error) {
if strings.Contains(input, "::") {
return input, nil
}
prov, subj, err := parseIDPair(input)
if err != nil {
return "", err
}
return "raw::" + strings.ToLower(prov) + "." + subj, nil
}
func prettyOrRaw(b []byte, pretty bool) string { func prettyOrRaw(b []byte, pretty bool) string {
if !pretty || len(b) == 0 { if !pretty || len(b) == 0 {
return string(b) return string(b)
} }
var tmp any var tmp any
if err := json.Unmarshal(b, &tmp); err != nil { if err := json.Unmarshal(b, &tmp); err != nil {
return string(b) // not JSON return string(b)
} }
out, err := json.MarshalIndent(tmp, "", " ") out, err := json.MarshalIndent(tmp, "", " ")
if err != nil { if err != nil {
@@ -51,7 +62,6 @@ func prettyOrRaw(b []byte, pretty bool) string {
return string(out) return string(out)
} }
// waitReady blocks until conn is READY or ctx times out/cancels.
func waitReady(ctx context.Context, conn *grpc.ClientConn) error { func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
for { for {
s := conn.GetState() s := conn.GetState()
@@ -67,7 +77,6 @@ func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
} }
} }
//goland:noinspection GoUnhandledErrorResult
func main() { func main() {
var ids idsFlag var ids idsFlag
var ctlAddr string var ctlAddr string
@@ -75,7 +84,7 @@ func main() {
var pretty bool var pretty bool
var timeout time.Duration var timeout time.Duration
flag.Var(&ids, "id", "identifier in form provider:subject (repeatable)") flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable")
flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address") flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address")
flag.StringVar(&strAddr, "str", "127.0.0.1:50052", "gRPC streaming address") flag.StringVar(&strAddr, "str", "127.0.0.1:50052", "gRPC streaming address")
flag.BoolVar(&pretty, "pretty", true, "pretty-print JSON payloads when possible") flag.BoolVar(&pretty, "pretty", true, "pretty-print JSON payloads when possible")
@@ -83,15 +92,13 @@ func main() {
flag.Parse() flag.Parse()
if len(ids) == 0 { if len(ids) == 0 {
fmt.Fprintln(os.Stderr, "provide at least one --id provider:subject") fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)")
os.Exit(2) os.Exit(2)
} }
// Ctrl-C handling
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel() defer cancel()
// ----- Control client -----
ccCtl, err := grpc.NewClient( ccCtl, err := grpc.NewClient(
ctlAddr, ctlAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
@@ -101,8 +108,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
defer ccCtl.Close() defer ccCtl.Close()
ccCtl.Connect()
ccCtl.Connect() // start dialing in background
ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout) ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout)
if err := waitReady(ctlConnCtx, ccCtl); err != nil { if err := waitReady(ctlConnCtx, ccCtl); err != nil {
@@ -114,7 +120,6 @@ func main() {
ctl := pb.NewDataServiceControlClient(ccCtl) ctl := pb.NewDataServiceControlClient(ccCtl)
// Start stream
ctxStart, cancelStart := context.WithTimeout(ctx, timeout) ctxStart, cancelStart := context.WithTimeout(ctx, timeout)
startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{}) startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{})
cancelStart() cancelStart()
@@ -125,15 +130,14 @@ func main() {
streamUUID := startResp.GetStreamUuid() streamUUID := startResp.GetStreamUuid()
fmt.Printf("stream: %s\n", streamUUID) fmt.Printf("stream: %s\n", streamUUID)
// Configure
var pbIDs []*pb.Identifier var pbIDs []*pb.Identifier
for _, s := range ids { for _, s := range ids {
prov, subj, err := parseID(s) key, err := toIdentifierKey(s)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "bad --id: %v\n", err) fmt.Fprintf(os.Stderr, "bad --id: %v\n", err)
os.Exit(2) os.Exit(2)
} }
pbIDs = append(pbIDs, &pb.Identifier{Provider: prov, Subject: subj}) pbIDs = append(pbIDs, &pb.Identifier{Key: key})
} }
ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout) ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout)
@@ -148,7 +152,6 @@ func main() {
} }
fmt.Printf("configured %d identifiers\n", len(pbIDs)) fmt.Printf("configured %d identifiers\n", len(pbIDs))
// ----- Streaming client -----
ccStr, err := grpc.NewClient( ccStr, err := grpc.NewClient(
strAddr, strAddr,
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
@@ -158,7 +161,6 @@ func main() {
os.Exit(1) os.Exit(1)
} }
defer ccStr.Close() defer ccStr.Close()
ccStr.Connect() ccStr.Connect()
strConnCtx, cancelStrConn := context.WithTimeout(ctx, timeout) strConnCtx, cancelStrConn := context.WithTimeout(ctx, timeout)
@@ -171,7 +173,6 @@ func main() {
str := pb.NewDataServiceStreamingClient(ccStr) str := pb.NewDataServiceStreamingClient(ccStr)
// This context lives until Ctrl-C
streamCtx, streamCancel := context.WithCancel(ctx) streamCtx, streamCancel := context.WithCancel(ctx)
defer streamCancel() defer streamCancel()
@@ -182,7 +183,6 @@ func main() {
} }
fmt.Println("connected; streaming… (Ctrl-C to quit)") fmt.Println("connected; streaming… (Ctrl-C to quit)")
// Receive loop until Ctrl-C
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -192,14 +192,14 @@ func main() {
msg, err := stream.Recv() msg, err := stream.Recv()
if err != nil { if err != nil {
if ctx.Err() != nil { if ctx.Err() != nil {
return // normal shutdown return
} }
fmt.Fprintf(os.Stderr, "recv: %v\n", err) fmt.Fprintf(os.Stderr, "recv: %v\n", err)
os.Exit(1) os.Exit(1)
} }
id := msg.GetIdentifier() id := msg.GetIdentifier()
fmt.Printf("[%s] %s bytes=%d enc=%s t=%s\n", fmt.Printf("[%s] bytes=%d enc=%s t=%s\n",
id.GetProvider(), id.GetSubject(), len(msg.GetPayload()), msg.GetEncoding(), id.GetKey(), len(msg.GetPayload()), msg.GetEncoding(),
time.Now().Format(time.RFC3339Nano), time.Now().Format(time.RFC3339Nano),
) )
fmt.Println(prettyOrRaw(msg.GetPayload(), pretty)) fmt.Println(prettyOrRaw(msg.GetPayload(), pretty))

View File

@@ -1,6 +1,181 @@
package domain package domain
type Identifier struct { import (
Provider string "errors"
Subject string "fmt"
"regexp"
"sort"
"strings"
)
const (
prefixRaw = "raw::"
prefixInternal = "internal::"
)
// Identifier is a canonical representation of a data stream identifier.
type Identifier struct{ key string }
func (id Identifier) IsRaw() bool { return strings.HasPrefix(id.key, prefixRaw) }
func (id Identifier) IsInternal() bool { return strings.HasPrefix(id.key, prefixInternal) }
func (id Identifier) Key() string { return id.key }
func (id Identifier) ProviderSubject() (provider, subject string, ok bool) {
if !id.IsRaw() {
return "", "", false
}
body := strings.TrimPrefix(id.key, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
return prov, subj, ok
}
func (id Identifier) InternalParts() (venue, stream, symbol string, params map[string]string, ok bool) {
if !id.IsInternal() {
return "", "", "", nil, false
}
body := strings.TrimPrefix(id.key, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return "", "", "", nil, false
}
return parts[0], parts[1], parts[2], decodeParams(strings.TrimSuffix(bracket, "]")), true
}
func RawID(provider, subject string) (Identifier, error) {
p := strings.ToLower(strings.TrimSpace(provider))
s := strings.TrimSpace(subject)
if err := validateComponent("provider", p, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("subject", s, true); err != nil {
return Identifier{}, err
}
return Identifier{key: prefixRaw + p + "." + s}, nil
}
func InternalID(venue, stream, symbol string, params map[string]string) (Identifier, error) {
v := strings.ToLower(strings.TrimSpace(venue))
t := strings.ToLower(strings.TrimSpace(stream))
sym := strings.ToUpper(strings.TrimSpace(symbol))
if err := validateComponent("venue", v, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("stream", t, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("symbol", sym, false); err != nil {
return Identifier{}, err
}
paramStr, err := encodeParams(params) // "k=v;..." or ""
if err != nil {
return Identifier{}, err
}
if paramStr == "" {
paramStr = "[]"
} else {
paramStr = "[" + paramStr + "]"
}
return Identifier{key: prefixInternal + v + "." + t + "." + sym + paramStr}, nil
}
func ParseIdentifier(s string) (Identifier, error) {
s = strings.TrimSpace(s)
switch {
case strings.HasPrefix(s, prefixRaw):
// raw::provider.subject
body := strings.TrimPrefix(s, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
if !ok {
return Identifier{}, errors.New("invalid raw identifier: missing '.'")
}
return RawID(prov, subj)
case strings.HasPrefix(s, prefixInternal):
// internal::venue.stream.symbol[...]
body := strings.TrimPrefix(s, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return Identifier{}, errors.New("invalid internal identifier: need venue.stream.symbol")
}
params := decodeParams(strings.TrimSuffix(bracket, "]"))
return InternalID(parts[0], parts[1], parts[2], params)
}
return Identifier{}, errors.New("unknown identifier prefix")
}
var (
segDisallow = regexp.MustCompile(`[ \t\r\n\[\]]`) // forbid whitespace/brackets in fixed segments
dotDisallow = regexp.MustCompile(`[.]`) // fixed segments cannot contain '.'
)
// allowAny=true (for subject) skips dot checks but still forbids whitespace/brackets.
func validateComponent(name, v string, allowAny bool) error {
if v == "" {
return fmt.Errorf("%s cannot be empty", name)
}
if allowAny {
if segDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars [] or whitespace", name)
}
return nil
}
if segDisallow.MatchString(v) || dotDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars (dot/brackets/whitespace)", name)
}
return nil
}
// encodeParams renders sorted k=v pairs separated by ';'.
func encodeParams(params map[string]string) (string, error) {
if len(params) == 0 {
return "", nil
}
keys := make([]string, 0, len(params))
for k := range params {
k = strings.ToLower(strings.TrimSpace(k))
if k == "" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
v := strings.TrimSpace(params[k])
// prevent breaking delimiters
if strings.ContainsAny(k, ";]") || strings.ContainsAny(v, ";]") {
return "", fmt.Errorf("param %q contains illegal ';' or ']'", k)
}
out = append(out, k+"="+v)
}
return strings.Join(out, ";"), nil
}
func decodeParams(s string) map[string]string {
s = strings.TrimSpace(s)
if s == "" {
return map[string]string{}
}
out := make(map[string]string, 4)
for _, p := range strings.Split(s, ";") {
if p == "" {
continue
}
kv := strings.SplitN(p, "=", 2)
if len(kv) != 2 {
continue
}
k := strings.ToLower(strings.TrimSpace(kv[0]))
v := strings.TrimSpace(kv[1])
if k != "" {
out[k] = v
}
}
return out
} }

View File

@@ -30,7 +30,7 @@ type ClientStream struct {
} }
func NewManager(router *router.Router) *Manager { func NewManager(router *router.Router) *Manager {
go router.Run() // Start the router in a separate goroutine go router.Run()
return &Manager{ return &Manager{
providers: make(map[string]provider.Provider), providers: make(map[string]provider.Provider),
providerStreams: make(map[domain.Identifier]chan domain.Message), providerStreams: make(map[domain.Identifier]chan domain.Message),
@@ -46,8 +46,8 @@ func (m *Manager) StartStream() (uuid.UUID, error) {
streamID := uuid.New() streamID := uuid.New()
m.clientStreams[streamID] = &ClientStream{ m.clientStreams[streamID] = &ClientStream{
UUID: streamID, UUID: streamID,
Identifiers: nil, // start empty Identifiers: nil,
OutChannel: nil, // not yet connected OutChannel: nil,
Timer: time.AfterFunc(1*time.Minute, func() { Timer: time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID) fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID) err := m.StopStream(streamID)
@@ -69,21 +69,22 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
return fmt.Errorf("stream not found: %s", streamID) return fmt.Errorf("stream not found: %s", streamID)
} }
// Validate new identifiers.
for _, id := range newIds { for _, id := range newIds {
if id.Provider == "" || id.Subject == "" { if id.IsRaw() {
providerName, subject, ok := id.ProviderSubject()
if !ok || providerName == "" || subject == "" {
return fmt.Errorf("empty identifier: %v", id) return fmt.Errorf("empty identifier: %v", id)
} }
prov, exists := m.providers[id.Provider] prov, exists := m.providers[providerName]
if !exists { if !exists {
return fmt.Errorf("unknown provider: %s", id.Provider) return fmt.Errorf("unknown provider: %s", providerName)
}
if !prov.IsValidSubject(subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", subject, providerName)
} }
if !prov.IsValidSubject(id.Subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", id.Subject, id.Provider)
} }
} }
// Generate old and new sets of identifiers
oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers)) oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers))
for _, id := range stream.Identifiers { for _, id := range stream.Identifiers {
oldSet[id] = struct{}{} oldSet[id] = struct{}{}
@@ -93,13 +94,13 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
newSet[id] = struct{}{} newSet[id] = struct{}{}
} }
// Add identifiers that are in newIds but not in oldSet
for _, id := range newIds { for _, id := range newIds {
if _, seen := oldSet[id]; !seen { if _, seen := oldSet[id]; !seen {
// Provision the stream from the provider if needed if id.IsRaw() {
if _, ok := m.providerStreams[id]; !ok { if _, ok := m.providerStreams[id]; !ok {
ch := make(chan domain.Message, 64) ch := make(chan domain.Message, 64)
if err := m.providers[id.Provider].RequestStream(id.Subject, ch); err != nil { providerName, subject, _ := id.ProviderSubject()
if err := m.providers[providerName].RequestStream(subject, ch); err != nil {
return fmt.Errorf("provision %v: %w", id, err) return fmt.Errorf("provision %v: %w", id, err)
} }
m.providerStreams[id] = ch m.providerStreams[id] = ch
@@ -111,37 +112,36 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
} }
}(ch) }(ch)
} }
}
// Register the new identifier with the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil { if stream.OutChannel != nil {
m.router.RegisterRoute(id, stream.OutChannel) m.router.RegisterRoute(id, stream.OutChannel)
} }
} }
} }
// Remove identifiers that are in oldSet but not in newSet
for _, oldId := range stream.Identifiers { for _, oldId := range stream.Identifiers {
if _, keep := newSet[oldId]; !keep { if _, keep := newSet[oldId]; !keep {
// Deregister the identifier from the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil { if stream.OutChannel != nil {
m.router.DeregisterRoute(oldId, stream.OutChannel) m.router.DeregisterRoute(oldId, stream.OutChannel)
} }
} }
} }
// Set the new identifiers for the stream
stream.Identifiers = newIds stream.Identifiers = newIds
// Clean up provider streams that are no longer used
used := make(map[domain.Identifier]bool) used := make(map[domain.Identifier]bool)
for _, cs := range m.clientStreams { for _, cs := range m.clientStreams {
for _, id := range cs.Identifiers { for _, id := range cs.Identifiers {
if id.IsRaw() {
used[id] = true used[id] = true
} }
} }
}
for id, ch := range m.providerStreams { for id, ch := range m.providerStreams {
if !used[id] { if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject) providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch) close(ch)
delete(m.providerStreams, id) delete(m.providerStreams, id)
} }
@@ -165,18 +165,19 @@ func (m *Manager) StopStream(streamID uuid.UUID) error {
delete(m.clientStreams, streamID) delete(m.clientStreams, streamID)
// Find provider streams that are used by other client streams
used := make(map[domain.Identifier]bool) used := make(map[domain.Identifier]bool)
for _, s := range m.clientStreams { for _, s := range m.clientStreams {
for _, id := range s.Identifiers { for _, id := range s.Identifiers {
if id.IsRaw() {
used[id] = true used[id] = true
} }
} }
}
// Cancel provider streams that are not used by any client stream
for id, ch := range m.providerStreams { for id, ch := range m.providerStreams {
if !used[id] { if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject) providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch) close(ch)
delete(m.providerStreams, id) delete(m.providerStreams, id)
} }
@@ -219,19 +220,16 @@ func (m *Manager) DisconnectStream(streamID uuid.UUID) {
stream, ok := m.clientStreams[streamID] stream, ok := m.clientStreams[streamID]
if !ok || stream.OutChannel == nil { if !ok || stream.OutChannel == nil {
return // already disconnected or does not exist return
} }
// Deregister all identifiers from the router
for _, ident := range stream.Identifiers { for _, ident := range stream.Identifiers {
m.router.DeregisterRoute(ident, stream.OutChannel) m.router.DeregisterRoute(ident, stream.OutChannel)
} }
// Close the output channel
close(stream.OutChannel) close(stream.OutChannel)
stream.OutChannel = nil stream.OutChannel = nil
// Set up the expiry timer
stream.Timer = time.AfterFunc(1*time.Minute, func() { stream.Timer = time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID) fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID) err := m.StopStream(streamID)
@@ -256,6 +254,6 @@ func (m *Manager) AddProvider(name string, p provider.Provider) {
m.providers[name] = p m.providers[name] = p
} }
func (m *Manager) RemoveProvider(name string) { func (m *Manager) RemoveProvider(_ string) {
panic("not implemented yet") // TODO: Implement provider removal logic panic("not implemented yet")
} }

View File

@@ -109,7 +109,7 @@ func (b *FuturesWebsocket) IsValidSubject(subject string, isFetch bool) bool {
if isFetch { if isFetch {
return false return false
} }
return len(subject) > 0 // Extend with regex or lookup if needed return len(subject) > 0
} }
func (b *FuturesWebsocket) readLoop() { func (b *FuturesWebsocket) readLoop() {
@@ -138,12 +138,14 @@ func (b *FuturesWebsocket) readLoop() {
continue continue
} }
id, err := domain.RawID("binance_futures_websocket", container.Stream)
if err != nil {
continue
}
msg := domain.Message{ msg := domain.Message{
Identifier: domain.Identifier{ Identifier: id,
Provider: "binance_futures_websocket", Payload: container.Data,
Subject: container.Stream,
},
Payload: []byte(container.Data),
Encoding: domain.EncodingJSON, Encoding: domain.EncodingJSON,
} }

View File

@@ -18,9 +18,7 @@ type GRPCControlServer struct {
} }
func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer { func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer {
return &GRPCControlServer{ return &GRPCControlServer{manager: m}
manager: m,
}
} }
func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) { func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) {
@@ -28,7 +26,6 @@ func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequ
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to start stream: %w", err) return nil, fmt.Errorf("failed to start stream: %w", err)
} }
return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil
} }
@@ -38,19 +35,18 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err) return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
} }
// Transform identifiers from protobuf to domain format
var ids []domain.Identifier var ids []domain.Identifier
for _, i := range req.Identifiers { for _, in := range req.Identifiers {
ids = append(ids, domain.Identifier{ id, e := domain.ParseIdentifier(in.Key)
Provider: i.Provider, if e != nil {
Subject: i.Subject, return nil, status.Errorf(codes.InvalidArgument, "invalid identifier %q: %v", in.Key, e)
}) }
ids = append(ids, id)
} }
if err := s.manager.ConfigureStream(streamID, ids); err != nil { if err := s.manager.ConfigureStream(streamID, ids); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err) return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err)
} }
return &pb.ConfigureStreamResponse{}, nil return &pb.ConfigureStreamResponse{}, nil
} }
@@ -59,11 +55,8 @@ func (s *GRPCControlServer) StopStream(_ context.Context, req *pb.StopStreamRequ
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err) return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
} }
if err := s.manager.StopStream(streamID); err != nil {
err = s.manager.StopStream(streamID) // Should only error if the stream doesn't exist
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err) return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err)
} }
return &pb.StopStreamResponse{}, nil return &pb.StopStreamResponse{}, nil
} }

View File

@@ -14,9 +14,7 @@ type GRPCStreamingServer struct {
} }
func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer { func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer {
return &GRPCStreamingServer{ return &GRPCStreamingServer{manager: m}
manager: m,
}
} }
func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error { func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error {
@@ -39,17 +37,11 @@ func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream
if !ok { if !ok {
return nil return nil
} }
if err := stream.Send(&pb.Message{
err := stream.Send(&pb.Message{ Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Identifier: &pb.Identifier{
Provider: msg.Identifier.Provider,
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload, Payload: msg.Payload,
Encoding: string(msg.Encoding), Encoding: string(msg.Encoding),
}) }); err != nil {
if err != nil {
return err return err
} }
} }

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"net" "net"
"strings" "strings"
"time"
"github.com/google/uuid" "github.com/google/uuid"
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
@@ -22,7 +23,6 @@ func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
return &SocketStreamingServer{manager: m} return &SocketStreamingServer{manager: m}
} }
// Accepts connections and hands each off to handleConnection.
func (s *SocketStreamingServer) Serve(lis net.Listener) error { func (s *SocketStreamingServer) Serve(lis net.Listener) error {
for { for {
conn, err := lis.Accept() conn, err := lis.Accept()
@@ -43,16 +43,16 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
} }
}() }()
// Low-latency socket hints (best-effort).
if tc, ok := conn.(*net.TCPConn); ok { if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.SetNoDelay(true) _ = tc.SetNoDelay(true)
_ = tc.SetWriteBuffer(512 * 1024) _ = tc.SetWriteBuffer(512 * 1024)
_ = tc.SetReadBuffer(512 * 1024) _ = tc.SetReadBuffer(512 * 1024)
_ = tc.SetKeepAlive(true)
_ = tc.SetKeepAlivePeriod(30 * time.Second)
} }
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// Protocol header: first line is the stream UUID.
raw, err := reader.ReadString('\n') raw, err := reader.ReadString('\n')
if err != nil { if err != nil {
fmt.Printf("read stream UUID error: %v\n", err) fmt.Printf("read stream UUID error: %v\n", err)
@@ -74,9 +74,8 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
defer s.manager.DisconnectStream(streamUUID) defer s.manager.DisconnectStream(streamUUID)
writer := bufio.NewWriterSize(conn, 256*1024) writer := bufio.NewWriterSize(conn, 256*1024)
defer func(writer *bufio.Writer) { defer func(w *bufio.Writer) {
err := writer.Flush() if err := w.Flush(); err != nil {
if err != nil {
fmt.Printf("final flush error: %v\n", err) fmt.Printf("final flush error: %v\n", err)
} }
}(writer) }(writer)
@@ -85,27 +84,20 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
batch := 0 batch := 0
for msg := range outCh { for msg := range outCh {
// Build protobuf payload. m := pb.Message{
message := pb.Message{ Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Identifier: &pb.Identifier{ Payload: msg.Payload,
Provider: msg.Identifier.Provider, Encoding: string(msg.Encoding),
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload, // []byte
Encoding: string(msg.Encoding), // e.g., "application/json"
} }
// Marshal protobuf. size := proto.Size(&m)
// Use MarshalAppend to reuse capacity and avoid an extra alloc.
size := proto.Size(&message)
buf := make([]byte, 0, size) buf := make([]byte, 0, size)
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &message) b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m)
if err != nil { if err != nil {
fmt.Printf("proto marshal error: %v\n", err) fmt.Printf("proto marshal error: %v\n", err)
continue continue
} }
// Fixed 4-byte big-endian length prefix.
var hdr [4]byte var hdr [4]byte
if len(b) > int(^uint32(0)) { if len(b) > int(^uint32(0)) {
fmt.Printf("message too large: %d bytes\n", len(b)) fmt.Printf("message too large: %d bytes\n", len(b))
@@ -113,7 +105,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
} }
binary.BigEndian.PutUint32(hdr[:], uint32(len(b))) binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
// Write frame: [len][bytes].
if _, err := writer.Write(hdr[:]); err != nil { if _, err := writer.Write(hdr[:]); err != nil {
if err == io.EOF { if err == io.EOF {
return return
@@ -139,7 +130,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
} }
} }
// Final flush when channel closes.
if err := writer.Flush(); err != nil { if err := writer.Flush(); err != nil {
fmt.Printf("final flush error: %v\n", err) fmt.Printf("final flush error: %v\n", err)
} }