Refactor identifier handling: replace Provider and Subject fields with a single Key field in Identifier struct, update related message structures, and adjust parsing logic

This commit is contained in:
2025-08-17 06:16:49 +00:00
parent ef4a28fb29
commit 2484a33945
9 changed files with 290 additions and 165 deletions

View File

@@ -1,6 +1,181 @@
package domain
type Identifier struct {
Provider string
Subject string
import (
"errors"
"fmt"
"regexp"
"sort"
"strings"
)
const (
prefixRaw = "raw::"
prefixInternal = "internal::"
)
// Identifier is a canonical representation of a data stream identifier.
type Identifier struct{ key string }
func (id Identifier) IsRaw() bool { return strings.HasPrefix(id.key, prefixRaw) }
func (id Identifier) IsInternal() bool { return strings.HasPrefix(id.key, prefixInternal) }
func (id Identifier) Key() string { return id.key }
func (id Identifier) ProviderSubject() (provider, subject string, ok bool) {
if !id.IsRaw() {
return "", "", false
}
body := strings.TrimPrefix(id.key, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
return prov, subj, ok
}
func (id Identifier) InternalParts() (venue, stream, symbol string, params map[string]string, ok bool) {
if !id.IsInternal() {
return "", "", "", nil, false
}
body := strings.TrimPrefix(id.key, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return "", "", "", nil, false
}
return parts[0], parts[1], parts[2], decodeParams(strings.TrimSuffix(bracket, "]")), true
}
func RawID(provider, subject string) (Identifier, error) {
p := strings.ToLower(strings.TrimSpace(provider))
s := strings.TrimSpace(subject)
if err := validateComponent("provider", p, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("subject", s, true); err != nil {
return Identifier{}, err
}
return Identifier{key: prefixRaw + p + "." + s}, nil
}
func InternalID(venue, stream, symbol string, params map[string]string) (Identifier, error) {
v := strings.ToLower(strings.TrimSpace(venue))
t := strings.ToLower(strings.TrimSpace(stream))
sym := strings.ToUpper(strings.TrimSpace(symbol))
if err := validateComponent("venue", v, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("stream", t, false); err != nil {
return Identifier{}, err
}
if err := validateComponent("symbol", sym, false); err != nil {
return Identifier{}, err
}
paramStr, err := encodeParams(params) // "k=v;..." or ""
if err != nil {
return Identifier{}, err
}
if paramStr == "" {
paramStr = "[]"
} else {
paramStr = "[" + paramStr + "]"
}
return Identifier{key: prefixInternal + v + "." + t + "." + sym + paramStr}, nil
}
func ParseIdentifier(s string) (Identifier, error) {
s = strings.TrimSpace(s)
switch {
case strings.HasPrefix(s, prefixRaw):
// raw::provider.subject
body := strings.TrimPrefix(s, prefixRaw)
prov, subj, ok := strings.Cut(body, ".")
if !ok {
return Identifier{}, errors.New("invalid raw identifier: missing '.'")
}
return RawID(prov, subj)
case strings.HasPrefix(s, prefixInternal):
// internal::venue.stream.symbol[...]
body := strings.TrimPrefix(s, prefixInternal)
before, bracket, _ := strings.Cut(body, "[")
parts := strings.Split(before, ".")
if len(parts) != 3 {
return Identifier{}, errors.New("invalid internal identifier: need venue.stream.symbol")
}
params := decodeParams(strings.TrimSuffix(bracket, "]"))
return InternalID(parts[0], parts[1], parts[2], params)
}
return Identifier{}, errors.New("unknown identifier prefix")
}
var (
segDisallow = regexp.MustCompile(`[ \t\r\n\[\]]`) // forbid whitespace/brackets in fixed segments
dotDisallow = regexp.MustCompile(`[.]`) // fixed segments cannot contain '.'
)
// allowAny=true (for subject) skips dot checks but still forbids whitespace/brackets.
func validateComponent(name, v string, allowAny bool) error {
if v == "" {
return fmt.Errorf("%s cannot be empty", name)
}
if allowAny {
if segDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars [] or whitespace", name)
}
return nil
}
if segDisallow.MatchString(v) || dotDisallow.MatchString(v) {
return fmt.Errorf("%s contains illegal chars (dot/brackets/whitespace)", name)
}
return nil
}
// encodeParams renders sorted k=v pairs separated by ';'.
func encodeParams(params map[string]string) (string, error) {
if len(params) == 0 {
return "", nil
}
keys := make([]string, 0, len(params))
for k := range params {
k = strings.ToLower(strings.TrimSpace(k))
if k == "" {
continue
}
keys = append(keys, k)
}
sort.Strings(keys)
out := make([]string, 0, len(keys))
for _, k := range keys {
v := strings.TrimSpace(params[k])
// prevent breaking delimiters
if strings.ContainsAny(k, ";]") || strings.ContainsAny(v, ";]") {
return "", fmt.Errorf("param %q contains illegal ';' or ']'", k)
}
out = append(out, k+"="+v)
}
return strings.Join(out, ";"), nil
}
func decodeParams(s string) map[string]string {
s = strings.TrimSpace(s)
if s == "" {
return map[string]string{}
}
out := make(map[string]string, 4)
for _, p := range strings.Split(s, ";") {
if p == "" {
continue
}
kv := strings.SplitN(p, "=", 2)
if len(kv) != 2 {
continue
}
k := strings.ToLower(strings.TrimSpace(kv[0]))
v := strings.TrimSpace(kv[1])
if k != "" {
out[k] = v
}
}
return out
}

View File

@@ -30,7 +30,7 @@ type ClientStream struct {
}
func NewManager(router *router.Router) *Manager {
go router.Run() // Start the router in a separate goroutine
go router.Run()
return &Manager{
providers: make(map[string]provider.Provider),
providerStreams: make(map[domain.Identifier]chan domain.Message),
@@ -46,8 +46,8 @@ func (m *Manager) StartStream() (uuid.UUID, error) {
streamID := uuid.New()
m.clientStreams[streamID] = &ClientStream{
UUID: streamID,
Identifiers: nil, // start empty
OutChannel: nil, // not yet connected
Identifiers: nil,
OutChannel: nil,
Timer: time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID)
@@ -69,21 +69,22 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
return fmt.Errorf("stream not found: %s", streamID)
}
// Validate new identifiers.
for _, id := range newIds {
if id.Provider == "" || id.Subject == "" {
return fmt.Errorf("empty identifier: %v", id)
}
prov, exists := m.providers[id.Provider]
if !exists {
return fmt.Errorf("unknown provider: %s", id.Provider)
}
if !prov.IsValidSubject(id.Subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", id.Subject, id.Provider)
if id.IsRaw() {
providerName, subject, ok := id.ProviderSubject()
if !ok || providerName == "" || subject == "" {
return fmt.Errorf("empty identifier: %v", id)
}
prov, exists := m.providers[providerName]
if !exists {
return fmt.Errorf("unknown provider: %s", providerName)
}
if !prov.IsValidSubject(subject, false) {
return fmt.Errorf("invalid subject %q for provider %s", subject, providerName)
}
}
}
// Generate old and new sets of identifiers
oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers))
for _, id := range stream.Identifiers {
oldSet[id] = struct{}{}
@@ -93,55 +94,54 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
newSet[id] = struct{}{}
}
// Add identifiers that are in newIds but not in oldSet
for _, id := range newIds {
if _, seen := oldSet[id]; !seen {
// Provision the stream from the provider if needed
if _, ok := m.providerStreams[id]; !ok {
ch := make(chan domain.Message, 64)
if err := m.providers[id.Provider].RequestStream(id.Subject, ch); err != nil {
return fmt.Errorf("provision %v: %w", id, err)
}
m.providerStreams[id] = ch
incomingChannel := m.router.IncomingChannel()
go func(c chan domain.Message) {
for msg := range c {
incomingChannel <- msg
if id.IsRaw() {
if _, ok := m.providerStreams[id]; !ok {
ch := make(chan domain.Message, 64)
providerName, subject, _ := id.ProviderSubject()
if err := m.providers[providerName].RequestStream(subject, ch); err != nil {
return fmt.Errorf("provision %v: %w", id, err)
}
}(ch)
m.providerStreams[id] = ch
incomingChannel := m.router.IncomingChannel()
go func(c chan domain.Message) {
for msg := range c {
incomingChannel <- msg
}
}(ch)
}
}
// Register the new identifier with the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil {
m.router.RegisterRoute(id, stream.OutChannel)
}
}
}
// Remove identifiers that are in oldSet but not in newSet
for _, oldId := range stream.Identifiers {
if _, keep := newSet[oldId]; !keep {
// Deregister the identifier from the router, only if there's an active output channel (meaning the stream is connected)
if stream.OutChannel != nil {
m.router.DeregisterRoute(oldId, stream.OutChannel)
}
}
}
// Set the new identifiers for the stream
stream.Identifiers = newIds
// Clean up provider streams that are no longer used
used := make(map[domain.Identifier]bool)
for _, cs := range m.clientStreams {
for _, id := range cs.Identifiers {
used[id] = true
if id.IsRaw() {
used[id] = true
}
}
}
for id, ch := range m.providerStreams {
if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject)
providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch)
delete(m.providerStreams, id)
}
@@ -165,18 +165,19 @@ func (m *Manager) StopStream(streamID uuid.UUID) error {
delete(m.clientStreams, streamID)
// Find provider streams that are used by other client streams
used := make(map[domain.Identifier]bool)
for _, s := range m.clientStreams {
for _, id := range s.Identifiers {
used[id] = true
if id.IsRaw() {
used[id] = true
}
}
}
// Cancel provider streams that are not used by any client stream
for id, ch := range m.providerStreams {
if !used[id] {
m.providers[id.Provider].CancelStream(id.Subject)
providerName, subject, _ := id.ProviderSubject()
m.providers[providerName].CancelStream(subject)
close(ch)
delete(m.providerStreams, id)
}
@@ -219,19 +220,16 @@ func (m *Manager) DisconnectStream(streamID uuid.UUID) {
stream, ok := m.clientStreams[streamID]
if !ok || stream.OutChannel == nil {
return // already disconnected or does not exist
return
}
// Deregister all identifiers from the router
for _, ident := range stream.Identifiers {
m.router.DeregisterRoute(ident, stream.OutChannel)
}
// Close the output channel
close(stream.OutChannel)
stream.OutChannel = nil
// Set up the expiry timer
stream.Timer = time.AfterFunc(1*time.Minute, func() {
fmt.Printf("stream %s expired due to inactivity\n", streamID)
err := m.StopStream(streamID)
@@ -256,6 +254,6 @@ func (m *Manager) AddProvider(name string, p provider.Provider) {
m.providers[name] = p
}
func (m *Manager) RemoveProvider(name string) {
panic("not implemented yet") // TODO: Implement provider removal logic
func (m *Manager) RemoveProvider(_ string) {
panic("not implemented yet")
}

View File

@@ -109,7 +109,7 @@ func (b *FuturesWebsocket) IsValidSubject(subject string, isFetch bool) bool {
if isFetch {
return false
}
return len(subject) > 0 // Extend with regex or lookup if needed
return len(subject) > 0
}
func (b *FuturesWebsocket) readLoop() {
@@ -138,13 +138,15 @@ func (b *FuturesWebsocket) readLoop() {
continue
}
id, err := domain.RawID("binance_futures_websocket", container.Stream)
if err != nil {
continue
}
msg := domain.Message{
Identifier: domain.Identifier{
Provider: "binance_futures_websocket",
Subject: container.Stream,
},
Payload: []byte(container.Data),
Encoding: domain.EncodingJSON,
Identifier: id,
Payload: container.Data,
Encoding: domain.EncodingJSON,
}
select {

View File

@@ -18,9 +18,7 @@ type GRPCControlServer struct {
}
func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer {
return &GRPCControlServer{
manager: m,
}
return &GRPCControlServer{manager: m}
}
func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) {
@@ -28,7 +26,6 @@ func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequ
if err != nil {
return nil, fmt.Errorf("failed to start stream: %w", err)
}
return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil
}
@@ -38,19 +35,18 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
}
// Transform identifiers from protobuf to domain format
var ids []domain.Identifier
for _, i := range req.Identifiers {
ids = append(ids, domain.Identifier{
Provider: i.Provider,
Subject: i.Subject,
})
for _, in := range req.Identifiers {
id, e := domain.ParseIdentifier(in.Key)
if e != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identifier %q: %v", in.Key, e)
}
ids = append(ids, id)
}
if err := s.manager.ConfigureStream(streamID, ids); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err)
}
return &pb.ConfigureStreamResponse{}, nil
}
@@ -59,11 +55,8 @@ func (s *GRPCControlServer) StopStream(_ context.Context, req *pb.StopStreamRequ
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
}
err = s.manager.StopStream(streamID) // Should only error if the stream doesn't exist
if err != nil {
if err := s.manager.StopStream(streamID); err != nil {
return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err)
}
return &pb.StopStreamResponse{}, nil
}

View File

@@ -14,9 +14,7 @@ type GRPCStreamingServer struct {
}
func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer {
return &GRPCStreamingServer{
manager: m,
}
return &GRPCStreamingServer{manager: m}
}
func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error {
@@ -39,17 +37,11 @@ func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream
if !ok {
return nil
}
err := stream.Send(&pb.Message{
Identifier: &pb.Identifier{
Provider: msg.Identifier.Provider,
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
})
if err != nil {
if err := stream.Send(&pb.Message{
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
}); err != nil {
return err
}
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"net"
"strings"
"time"
"github.com/google/uuid"
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
@@ -22,7 +23,6 @@ func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
return &SocketStreamingServer{manager: m}
}
// Accepts connections and hands each off to handleConnection.
func (s *SocketStreamingServer) Serve(lis net.Listener) error {
for {
conn, err := lis.Accept()
@@ -43,16 +43,16 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
}()
// Low-latency socket hints (best-effort).
if tc, ok := conn.(*net.TCPConn); ok {
_ = tc.SetNoDelay(true)
_ = tc.SetWriteBuffer(512 * 1024)
_ = tc.SetReadBuffer(512 * 1024)
_ = tc.SetKeepAlive(true)
_ = tc.SetKeepAlivePeriod(30 * time.Second)
}
reader := bufio.NewReader(conn)
// Protocol header: first line is the stream UUID.
raw, err := reader.ReadString('\n')
if err != nil {
fmt.Printf("read stream UUID error: %v\n", err)
@@ -74,9 +74,8 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
defer s.manager.DisconnectStream(streamUUID)
writer := bufio.NewWriterSize(conn, 256*1024)
defer func(writer *bufio.Writer) {
err := writer.Flush()
if err != nil {
defer func(w *bufio.Writer) {
if err := w.Flush(); err != nil {
fmt.Printf("final flush error: %v\n", err)
}
}(writer)
@@ -85,27 +84,20 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
batch := 0
for msg := range outCh {
// Build protobuf payload.
message := pb.Message{
Identifier: &pb.Identifier{
Provider: msg.Identifier.Provider,
Subject: msg.Identifier.Subject,
},
Payload: msg.Payload, // []byte
Encoding: string(msg.Encoding), // e.g., "application/json"
m := pb.Message{
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
Payload: msg.Payload,
Encoding: string(msg.Encoding),
}
// Marshal protobuf.
// Use MarshalAppend to reuse capacity and avoid an extra alloc.
size := proto.Size(&message)
size := proto.Size(&m)
buf := make([]byte, 0, size)
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &message)
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m)
if err != nil {
fmt.Printf("proto marshal error: %v\n", err)
continue
}
// Fixed 4-byte big-endian length prefix.
var hdr [4]byte
if len(b) > int(^uint32(0)) {
fmt.Printf("message too large: %d bytes\n", len(b))
@@ -113,7 +105,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
// Write frame: [len][bytes].
if _, err := writer.Write(hdr[:]); err != nil {
if err == io.EOF {
return
@@ -139,7 +130,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
}
}
// Final flush when channel closes.
if err := writer.Flush(); err != nil {
fmt.Printf("final flush error: %v\n", err)
}