Refactor identifier handling: replace Provider and Subject fields with a single Key field in Identifier struct, update related message structures, and adjust parsing logic
This commit is contained in:
@@ -21,11 +21,9 @@ const (
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
// Domain Models
|
||||
type Identifier struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Provider string `protobuf:"bytes,1,opt,name=provider,proto3" json:"provider,omitempty"` // e.g., "binance"
|
||||
Subject string `protobuf:"bytes,2,opt,name=subject,proto3" json:"subject,omitempty"` // e.g., "BTCUSDT"
|
||||
Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -60,16 +58,9 @@ func (*Identifier) Descriptor() ([]byte, []int) {
|
||||
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (x *Identifier) GetProvider() string {
|
||||
func (x *Identifier) GetKey() string {
|
||||
if x != nil {
|
||||
return x.Provider
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *Identifier) GetSubject() string {
|
||||
if x != nil {
|
||||
return x.Subject
|
||||
return x.Key
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -77,8 +68,8 @@ func (x *Identifier) GetSubject() string {
|
||||
type Message struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
Identifier *Identifier `protobuf:"bytes,1,opt,name=identifier,proto3" json:"identifier,omitempty"`
|
||||
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` // JSON-encoded data
|
||||
Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"` // e.g., "json", "protobuf"
|
||||
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
Encoding string `protobuf:"bytes,3,opt,name=encoding,proto3" json:"encoding,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
@@ -134,7 +125,6 @@ func (x *Message) GetEncoding() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Control Requests and Responses
|
||||
type StartStreamRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
@@ -383,7 +373,6 @@ func (*StopStreamResponse) Descriptor() ([]byte, []int) {
|
||||
return file_pkg_pb_data_service_data_service_proto_rawDescGZIP(), []int{7}
|
||||
}
|
||||
|
||||
// Stream Requests and Responses
|
||||
type ConnectStreamRequest struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
StreamUuid string `protobuf:"bytes,1,opt,name=stream_uuid,json=streamUuid,proto3" json:"stream_uuid,omitempty"`
|
||||
@@ -432,11 +421,10 @@ var File_pkg_pb_data_service_data_service_proto protoreflect.FileDescriptor
|
||||
|
||||
const file_pkg_pb_data_service_data_service_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"&pkg/pb/data_service/data_service.proto\x12\fdata_service\"B\n" +
|
||||
"&pkg/pb/data_service/data_service.proto\x12\fdata_service\"\x1e\n" +
|
||||
"\n" +
|
||||
"Identifier\x12\x1a\n" +
|
||||
"\bprovider\x18\x01 \x01(\tR\bprovider\x12\x18\n" +
|
||||
"\asubject\x18\x02 \x01(\tR\asubject\"y\n" +
|
||||
"Identifier\x12\x10\n" +
|
||||
"\x03key\x18\x01 \x01(\tR\x03key\"y\n" +
|
||||
"\aMessage\x128\n" +
|
||||
"\n" +
|
||||
"identifier\x18\x01 \x01(\v2\x18.data_service.IdentifierR\n" +
|
||||
|
||||
@@ -14,39 +14,26 @@ service DataServiceStreaming {
|
||||
rpc ConnectStream(ConnectStreamRequest) returns (stream Message);
|
||||
}
|
||||
|
||||
// Domain Models
|
||||
message Identifier {
|
||||
string provider = 1; // e.g., "binance"
|
||||
string subject = 2; // e.g., "BTCUSDT"
|
||||
string key = 1;
|
||||
}
|
||||
|
||||
message Message {
|
||||
Identifier identifier = 1;
|
||||
bytes payload = 2; // JSON-encoded data
|
||||
string encoding = 3; // e.g., "json", "protobuf"
|
||||
bytes payload = 2;
|
||||
string encoding = 3;
|
||||
}
|
||||
|
||||
// Control Requests and Responses
|
||||
message StartStreamRequest {}
|
||||
|
||||
message StartStreamResponse {
|
||||
string stream_uuid = 1;
|
||||
}
|
||||
message StartStreamResponse { string stream_uuid = 1; }
|
||||
|
||||
message ConfigureStreamRequest {
|
||||
string stream_uuid = 1;
|
||||
repeated Identifier identifiers = 2;
|
||||
}
|
||||
|
||||
message ConfigureStreamResponse {}
|
||||
|
||||
message StopStreamRequest {
|
||||
string stream_uuid = 1;
|
||||
}
|
||||
|
||||
message StopStreamRequest { string stream_uuid = 1; }
|
||||
message StopStreamResponse {}
|
||||
|
||||
// Stream Requests and Responses
|
||||
message ConnectStreamRequest {
|
||||
string stream_uuid = 1;
|
||||
}
|
||||
message ConnectStreamRequest { string stream_uuid = 1; }
|
||||
|
||||
@@ -28,7 +28,7 @@ func (i *idsFlag) Set(v string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseID(s string) (provider, subject string, err error) {
|
||||
func parseIDPair(s string) (provider, subject string, err error) {
|
||||
parts := strings.SplitN(s, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return "", "", fmt.Errorf("want provider:subject, got %q", s)
|
||||
@@ -36,13 +36,24 @@ func parseID(s string) (provider, subject string, err error) {
|
||||
return parts[0], parts[1], nil
|
||||
}
|
||||
|
||||
func toIdentifierKey(input string) (string, error) {
|
||||
if strings.Contains(input, "::") {
|
||||
return input, nil
|
||||
}
|
||||
prov, subj, err := parseIDPair(input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "raw::" + strings.ToLower(prov) + "." + subj, nil
|
||||
}
|
||||
|
||||
func prettyOrRaw(b []byte, pretty bool) string {
|
||||
if !pretty || len(b) == 0 {
|
||||
return string(b)
|
||||
}
|
||||
var tmp any
|
||||
if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
return string(b) // not JSON
|
||||
return string(b)
|
||||
}
|
||||
out, err := json.MarshalIndent(tmp, "", " ")
|
||||
if err != nil {
|
||||
@@ -51,7 +62,6 @@ func prettyOrRaw(b []byte, pretty bool) string {
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// waitReady blocks until conn is READY or ctx times out/cancels.
|
||||
func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
for {
|
||||
s := conn.GetState()
|
||||
@@ -67,7 +77,6 @@ func waitReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
}
|
||||
}
|
||||
|
||||
//goland:noinspection GoUnhandledErrorResult
|
||||
func main() {
|
||||
var ids idsFlag
|
||||
var ctlAddr string
|
||||
@@ -75,7 +84,7 @@ func main() {
|
||||
var pretty bool
|
||||
var timeout time.Duration
|
||||
|
||||
flag.Var(&ids, "id", "identifier in form provider:subject (repeatable)")
|
||||
flag.Var(&ids, "id", "identifier (provider:subject or canonical key); repeatable")
|
||||
flag.StringVar(&ctlAddr, "ctl", "127.0.0.1:50051", "gRPC control address")
|
||||
flag.StringVar(&strAddr, "str", "127.0.0.1:50052", "gRPC streaming address")
|
||||
flag.BoolVar(&pretty, "pretty", true, "pretty-print JSON payloads when possible")
|
||||
@@ -83,15 +92,13 @@ func main() {
|
||||
flag.Parse()
|
||||
|
||||
if len(ids) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "provide at least one --id provider:subject")
|
||||
fmt.Fprintln(os.Stderr, "provide at least one --id (provider:subject or canonical key)")
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
// Ctrl-C handling
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
||||
defer cancel()
|
||||
|
||||
// ----- Control client -----
|
||||
ccCtl, err := grpc.NewClient(
|
||||
ctlAddr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
@@ -101,8 +108,7 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer ccCtl.Close()
|
||||
|
||||
ccCtl.Connect() // start dialing in background
|
||||
ccCtl.Connect()
|
||||
|
||||
ctlConnCtx, cancelCtlConn := context.WithTimeout(ctx, timeout)
|
||||
if err := waitReady(ctlConnCtx, ccCtl); err != nil {
|
||||
@@ -114,7 +120,6 @@ func main() {
|
||||
|
||||
ctl := pb.NewDataServiceControlClient(ccCtl)
|
||||
|
||||
// Start stream
|
||||
ctxStart, cancelStart := context.WithTimeout(ctx, timeout)
|
||||
startResp, err := ctl.StartStream(ctxStart, &pb.StartStreamRequest{})
|
||||
cancelStart()
|
||||
@@ -125,15 +130,14 @@ func main() {
|
||||
streamUUID := startResp.GetStreamUuid()
|
||||
fmt.Printf("stream: %s\n", streamUUID)
|
||||
|
||||
// Configure
|
||||
var pbIDs []*pb.Identifier
|
||||
for _, s := range ids {
|
||||
prov, subj, err := parseID(s)
|
||||
key, err := toIdentifierKey(s)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "bad --id: %v\n", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
pbIDs = append(pbIDs, &pb.Identifier{Provider: prov, Subject: subj})
|
||||
pbIDs = append(pbIDs, &pb.Identifier{Key: key})
|
||||
}
|
||||
|
||||
ctxCfg, cancelCfg := context.WithTimeout(ctx, timeout)
|
||||
@@ -148,7 +152,6 @@ func main() {
|
||||
}
|
||||
fmt.Printf("configured %d identifiers\n", len(pbIDs))
|
||||
|
||||
// ----- Streaming client -----
|
||||
ccStr, err := grpc.NewClient(
|
||||
strAddr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
@@ -158,7 +161,6 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer ccStr.Close()
|
||||
|
||||
ccStr.Connect()
|
||||
|
||||
strConnCtx, cancelStrConn := context.WithTimeout(ctx, timeout)
|
||||
@@ -171,7 +173,6 @@ func main() {
|
||||
|
||||
str := pb.NewDataServiceStreamingClient(ccStr)
|
||||
|
||||
// This context lives until Ctrl-C
|
||||
streamCtx, streamCancel := context.WithCancel(ctx)
|
||||
defer streamCancel()
|
||||
|
||||
@@ -182,7 +183,6 @@ func main() {
|
||||
}
|
||||
fmt.Println("connected; streaming… (Ctrl-C to quit)")
|
||||
|
||||
// Receive loop until Ctrl-C
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -192,14 +192,14 @@ func main() {
|
||||
msg, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return // normal shutdown
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "recv: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
id := msg.GetIdentifier()
|
||||
fmt.Printf("[%s] %s bytes=%d enc=%s t=%s\n",
|
||||
id.GetProvider(), id.GetSubject(), len(msg.GetPayload()), msg.GetEncoding(),
|
||||
fmt.Printf("[%s] bytes=%d enc=%s t=%s\n",
|
||||
id.GetKey(), len(msg.GetPayload()), msg.GetEncoding(),
|
||||
time.Now().Format(time.RFC3339Nano),
|
||||
)
|
||||
fmt.Println(prettyOrRaw(msg.GetPayload(), pretty))
|
||||
|
||||
@@ -1,6 +1,181 @@
|
||||
package domain
|
||||
|
||||
type Identifier struct {
|
||||
Provider string
|
||||
Subject string
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
prefixRaw = "raw::"
|
||||
prefixInternal = "internal::"
|
||||
)
|
||||
|
||||
// Identifier is a canonical representation of a data stream identifier.
|
||||
type Identifier struct{ key string }
|
||||
|
||||
func (id Identifier) IsRaw() bool { return strings.HasPrefix(id.key, prefixRaw) }
|
||||
func (id Identifier) IsInternal() bool { return strings.HasPrefix(id.key, prefixInternal) }
|
||||
func (id Identifier) Key() string { return id.key }
|
||||
|
||||
func (id Identifier) ProviderSubject() (provider, subject string, ok bool) {
|
||||
if !id.IsRaw() {
|
||||
return "", "", false
|
||||
}
|
||||
body := strings.TrimPrefix(id.key, prefixRaw)
|
||||
prov, subj, ok := strings.Cut(body, ".")
|
||||
return prov, subj, ok
|
||||
}
|
||||
|
||||
func (id Identifier) InternalParts() (venue, stream, symbol string, params map[string]string, ok bool) {
|
||||
if !id.IsInternal() {
|
||||
return "", "", "", nil, false
|
||||
}
|
||||
body := strings.TrimPrefix(id.key, prefixInternal)
|
||||
before, bracket, _ := strings.Cut(body, "[")
|
||||
parts := strings.Split(before, ".")
|
||||
if len(parts) != 3 {
|
||||
return "", "", "", nil, false
|
||||
}
|
||||
return parts[0], parts[1], parts[2], decodeParams(strings.TrimSuffix(bracket, "]")), true
|
||||
}
|
||||
|
||||
func RawID(provider, subject string) (Identifier, error) {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
s := strings.TrimSpace(subject)
|
||||
|
||||
if err := validateComponent("provider", p, false); err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
if err := validateComponent("subject", s, true); err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
return Identifier{key: prefixRaw + p + "." + s}, nil
|
||||
}
|
||||
|
||||
func InternalID(venue, stream, symbol string, params map[string]string) (Identifier, error) {
|
||||
v := strings.ToLower(strings.TrimSpace(venue))
|
||||
t := strings.ToLower(strings.TrimSpace(stream))
|
||||
sym := strings.ToUpper(strings.TrimSpace(symbol))
|
||||
|
||||
if err := validateComponent("venue", v, false); err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
if err := validateComponent("stream", t, false); err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
if err := validateComponent("symbol", sym, false); err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
|
||||
paramStr, err := encodeParams(params) // "k=v;..." or ""
|
||||
if err != nil {
|
||||
return Identifier{}, err
|
||||
}
|
||||
if paramStr == "" {
|
||||
paramStr = "[]"
|
||||
} else {
|
||||
paramStr = "[" + paramStr + "]"
|
||||
}
|
||||
return Identifier{key: prefixInternal + v + "." + t + "." + sym + paramStr}, nil
|
||||
}
|
||||
|
||||
func ParseIdentifier(s string) (Identifier, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
switch {
|
||||
case strings.HasPrefix(s, prefixRaw):
|
||||
// raw::provider.subject
|
||||
body := strings.TrimPrefix(s, prefixRaw)
|
||||
prov, subj, ok := strings.Cut(body, ".")
|
||||
if !ok {
|
||||
return Identifier{}, errors.New("invalid raw identifier: missing '.'")
|
||||
}
|
||||
return RawID(prov, subj)
|
||||
|
||||
case strings.HasPrefix(s, prefixInternal):
|
||||
// internal::venue.stream.symbol[...]
|
||||
body := strings.TrimPrefix(s, prefixInternal)
|
||||
before, bracket, _ := strings.Cut(body, "[")
|
||||
parts := strings.Split(before, ".")
|
||||
if len(parts) != 3 {
|
||||
return Identifier{}, errors.New("invalid internal identifier: need venue.stream.symbol")
|
||||
}
|
||||
params := decodeParams(strings.TrimSuffix(bracket, "]"))
|
||||
return InternalID(parts[0], parts[1], parts[2], params)
|
||||
}
|
||||
return Identifier{}, errors.New("unknown identifier prefix")
|
||||
}
|
||||
|
||||
var (
|
||||
segDisallow = regexp.MustCompile(`[ \t\r\n\[\]]`) // forbid whitespace/brackets in fixed segments
|
||||
dotDisallow = regexp.MustCompile(`[.]`) // fixed segments cannot contain '.'
|
||||
)
|
||||
|
||||
// allowAny=true (for subject) skips dot checks but still forbids whitespace/brackets.
|
||||
func validateComponent(name, v string, allowAny bool) error {
|
||||
if v == "" {
|
||||
return fmt.Errorf("%s cannot be empty", name)
|
||||
}
|
||||
if allowAny {
|
||||
if segDisallow.MatchString(v) {
|
||||
return fmt.Errorf("%s contains illegal chars [] or whitespace", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if segDisallow.MatchString(v) || dotDisallow.MatchString(v) {
|
||||
return fmt.Errorf("%s contains illegal chars (dot/brackets/whitespace)", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeParams renders sorted k=v pairs separated by ';'.
|
||||
func encodeParams(params map[string]string) (string, error) {
|
||||
if len(params) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
k = strings.ToLower(strings.TrimSpace(k))
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
out := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
v := strings.TrimSpace(params[k])
|
||||
// prevent breaking delimiters
|
||||
if strings.ContainsAny(k, ";]") || strings.ContainsAny(v, ";]") {
|
||||
return "", fmt.Errorf("param %q contains illegal ';' or ']'", k)
|
||||
}
|
||||
out = append(out, k+"="+v)
|
||||
}
|
||||
return strings.Join(out, ";"), nil
|
||||
}
|
||||
|
||||
func decodeParams(s string) map[string]string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return map[string]string{}
|
||||
}
|
||||
out := make(map[string]string, 4)
|
||||
for _, p := range strings.Split(s, ";") {
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
if len(kv) != 2 {
|
||||
continue
|
||||
}
|
||||
k := strings.ToLower(strings.TrimSpace(kv[0]))
|
||||
v := strings.TrimSpace(kv[1])
|
||||
if k != "" {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ type ClientStream struct {
|
||||
}
|
||||
|
||||
func NewManager(router *router.Router) *Manager {
|
||||
go router.Run() // Start the router in a separate goroutine
|
||||
go router.Run()
|
||||
return &Manager{
|
||||
providers: make(map[string]provider.Provider),
|
||||
providerStreams: make(map[domain.Identifier]chan domain.Message),
|
||||
@@ -46,8 +46,8 @@ func (m *Manager) StartStream() (uuid.UUID, error) {
|
||||
streamID := uuid.New()
|
||||
m.clientStreams[streamID] = &ClientStream{
|
||||
UUID: streamID,
|
||||
Identifiers: nil, // start empty
|
||||
OutChannel: nil, // not yet connected
|
||||
Identifiers: nil,
|
||||
OutChannel: nil,
|
||||
Timer: time.AfterFunc(1*time.Minute, func() {
|
||||
fmt.Printf("stream %s expired due to inactivity\n", streamID)
|
||||
err := m.StopStream(streamID)
|
||||
@@ -69,21 +69,22 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
|
||||
return fmt.Errorf("stream not found: %s", streamID)
|
||||
}
|
||||
|
||||
// Validate new identifiers.
|
||||
for _, id := range newIds {
|
||||
if id.Provider == "" || id.Subject == "" {
|
||||
return fmt.Errorf("empty identifier: %v", id)
|
||||
}
|
||||
prov, exists := m.providers[id.Provider]
|
||||
if !exists {
|
||||
return fmt.Errorf("unknown provider: %s", id.Provider)
|
||||
}
|
||||
if !prov.IsValidSubject(id.Subject, false) {
|
||||
return fmt.Errorf("invalid subject %q for provider %s", id.Subject, id.Provider)
|
||||
if id.IsRaw() {
|
||||
providerName, subject, ok := id.ProviderSubject()
|
||||
if !ok || providerName == "" || subject == "" {
|
||||
return fmt.Errorf("empty identifier: %v", id)
|
||||
}
|
||||
prov, exists := m.providers[providerName]
|
||||
if !exists {
|
||||
return fmt.Errorf("unknown provider: %s", providerName)
|
||||
}
|
||||
if !prov.IsValidSubject(subject, false) {
|
||||
return fmt.Errorf("invalid subject %q for provider %s", subject, providerName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate old and new sets of identifiers
|
||||
oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers))
|
||||
for _, id := range stream.Identifiers {
|
||||
oldSet[id] = struct{}{}
|
||||
@@ -93,55 +94,54 @@ func (m *Manager) ConfigureStream(streamID uuid.UUID, newIds []domain.Identifier
|
||||
newSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
// Add identifiers that are in newIds but not in oldSet
|
||||
for _, id := range newIds {
|
||||
if _, seen := oldSet[id]; !seen {
|
||||
// Provision the stream from the provider if needed
|
||||
if _, ok := m.providerStreams[id]; !ok {
|
||||
ch := make(chan domain.Message, 64)
|
||||
if err := m.providers[id.Provider].RequestStream(id.Subject, ch); err != nil {
|
||||
return fmt.Errorf("provision %v: %w", id, err)
|
||||
}
|
||||
m.providerStreams[id] = ch
|
||||
|
||||
incomingChannel := m.router.IncomingChannel()
|
||||
go func(c chan domain.Message) {
|
||||
for msg := range c {
|
||||
incomingChannel <- msg
|
||||
if id.IsRaw() {
|
||||
if _, ok := m.providerStreams[id]; !ok {
|
||||
ch := make(chan domain.Message, 64)
|
||||
providerName, subject, _ := id.ProviderSubject()
|
||||
if err := m.providers[providerName].RequestStream(subject, ch); err != nil {
|
||||
return fmt.Errorf("provision %v: %w", id, err)
|
||||
}
|
||||
}(ch)
|
||||
m.providerStreams[id] = ch
|
||||
|
||||
incomingChannel := m.router.IncomingChannel()
|
||||
go func(c chan domain.Message) {
|
||||
for msg := range c {
|
||||
incomingChannel <- msg
|
||||
}
|
||||
}(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Register the new identifier with the router, only if there's an active output channel (meaning the stream is connected)
|
||||
if stream.OutChannel != nil {
|
||||
m.router.RegisterRoute(id, stream.OutChannel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove identifiers that are in oldSet but not in newSet
|
||||
for _, oldId := range stream.Identifiers {
|
||||
if _, keep := newSet[oldId]; !keep {
|
||||
// Deregister the identifier from the router, only if there's an active output channel (meaning the stream is connected)
|
||||
if stream.OutChannel != nil {
|
||||
m.router.DeregisterRoute(oldId, stream.OutChannel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the new identifiers for the stream
|
||||
stream.Identifiers = newIds
|
||||
|
||||
// Clean up provider streams that are no longer used
|
||||
used := make(map[domain.Identifier]bool)
|
||||
for _, cs := range m.clientStreams {
|
||||
for _, id := range cs.Identifiers {
|
||||
used[id] = true
|
||||
if id.IsRaw() {
|
||||
used[id] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
for id, ch := range m.providerStreams {
|
||||
if !used[id] {
|
||||
m.providers[id.Provider].CancelStream(id.Subject)
|
||||
providerName, subject, _ := id.ProviderSubject()
|
||||
m.providers[providerName].CancelStream(subject)
|
||||
close(ch)
|
||||
delete(m.providerStreams, id)
|
||||
}
|
||||
@@ -165,18 +165,19 @@ func (m *Manager) StopStream(streamID uuid.UUID) error {
|
||||
|
||||
delete(m.clientStreams, streamID)
|
||||
|
||||
// Find provider streams that are used by other client streams
|
||||
used := make(map[domain.Identifier]bool)
|
||||
for _, s := range m.clientStreams {
|
||||
for _, id := range s.Identifiers {
|
||||
used[id] = true
|
||||
if id.IsRaw() {
|
||||
used[id] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel provider streams that are not used by any client stream
|
||||
for id, ch := range m.providerStreams {
|
||||
if !used[id] {
|
||||
m.providers[id.Provider].CancelStream(id.Subject)
|
||||
providerName, subject, _ := id.ProviderSubject()
|
||||
m.providers[providerName].CancelStream(subject)
|
||||
close(ch)
|
||||
delete(m.providerStreams, id)
|
||||
}
|
||||
@@ -219,19 +220,16 @@ func (m *Manager) DisconnectStream(streamID uuid.UUID) {
|
||||
|
||||
stream, ok := m.clientStreams[streamID]
|
||||
if !ok || stream.OutChannel == nil {
|
||||
return // already disconnected or does not exist
|
||||
return
|
||||
}
|
||||
|
||||
// Deregister all identifiers from the router
|
||||
for _, ident := range stream.Identifiers {
|
||||
m.router.DeregisterRoute(ident, stream.OutChannel)
|
||||
}
|
||||
|
||||
// Close the output channel
|
||||
close(stream.OutChannel)
|
||||
stream.OutChannel = nil
|
||||
|
||||
// Set up the expiry timer
|
||||
stream.Timer = time.AfterFunc(1*time.Minute, func() {
|
||||
fmt.Printf("stream %s expired due to inactivity\n", streamID)
|
||||
err := m.StopStream(streamID)
|
||||
@@ -256,6 +254,6 @@ func (m *Manager) AddProvider(name string, p provider.Provider) {
|
||||
m.providers[name] = p
|
||||
}
|
||||
|
||||
func (m *Manager) RemoveProvider(name string) {
|
||||
panic("not implemented yet") // TODO: Implement provider removal logic
|
||||
func (m *Manager) RemoveProvider(_ string) {
|
||||
panic("not implemented yet")
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func (b *FuturesWebsocket) IsValidSubject(subject string, isFetch bool) bool {
|
||||
if isFetch {
|
||||
return false
|
||||
}
|
||||
return len(subject) > 0 // Extend with regex or lookup if needed
|
||||
return len(subject) > 0
|
||||
}
|
||||
|
||||
func (b *FuturesWebsocket) readLoop() {
|
||||
@@ -138,13 +138,15 @@ func (b *FuturesWebsocket) readLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
id, err := domain.RawID("binance_futures_websocket", container.Stream)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
msg := domain.Message{
|
||||
Identifier: domain.Identifier{
|
||||
Provider: "binance_futures_websocket",
|
||||
Subject: container.Stream,
|
||||
},
|
||||
Payload: []byte(container.Data),
|
||||
Encoding: domain.EncodingJSON,
|
||||
Identifier: id,
|
||||
Payload: container.Data,
|
||||
Encoding: domain.EncodingJSON,
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
@@ -18,9 +18,7 @@ type GRPCControlServer struct {
|
||||
}
|
||||
|
||||
func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer {
|
||||
return &GRPCControlServer{
|
||||
manager: m,
|
||||
}
|
||||
return &GRPCControlServer{manager: m}
|
||||
}
|
||||
|
||||
func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequest) (*pb.StartStreamResponse, error) {
|
||||
@@ -28,7 +26,6 @@ func (s *GRPCControlServer) StartStream(_ context.Context, _ *pb.StartStreamRequ
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start stream: %w", err)
|
||||
}
|
||||
|
||||
return &pb.StartStreamResponse{StreamUuid: streamID.String()}, nil
|
||||
}
|
||||
|
||||
@@ -38,19 +35,18 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
|
||||
}
|
||||
|
||||
// Transform identifiers from protobuf to domain format
|
||||
var ids []domain.Identifier
|
||||
for _, i := range req.Identifiers {
|
||||
ids = append(ids, domain.Identifier{
|
||||
Provider: i.Provider,
|
||||
Subject: i.Subject,
|
||||
})
|
||||
for _, in := range req.Identifiers {
|
||||
id, e := domain.ParseIdentifier(in.Key)
|
||||
if e != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid identifier %q: %v", in.Key, e)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
if err := s.manager.ConfigureStream(streamID, ids); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "configure failed: %v", err)
|
||||
}
|
||||
|
||||
return &pb.ConfigureStreamResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -59,11 +55,8 @@ func (s *GRPCControlServer) StopStream(_ context.Context, req *pb.StopStreamRequ
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid stream_uuid %q: %v", req.StreamUuid, err)
|
||||
}
|
||||
|
||||
err = s.manager.StopStream(streamID) // Should only error if the stream doesn't exist
|
||||
if err != nil {
|
||||
if err := s.manager.StopStream(streamID); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to stop stream: %v", err)
|
||||
}
|
||||
|
||||
return &pb.StopStreamResponse{}, nil
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ type GRPCStreamingServer struct {
|
||||
}
|
||||
|
||||
func NewGRPCStreamingServer(m *manager.Manager) *GRPCStreamingServer {
|
||||
return &GRPCStreamingServer{
|
||||
manager: m,
|
||||
}
|
||||
return &GRPCStreamingServer{manager: m}
|
||||
}
|
||||
|
||||
func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream pb.DataServiceStreaming_ConnectStreamServer) error {
|
||||
@@ -39,17 +37,11 @@ func (s *GRPCStreamingServer) ConnectStream(req *pb.ConnectStreamRequest, stream
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := stream.Send(&pb.Message{
|
||||
Identifier: &pb.Identifier{
|
||||
Provider: msg.Identifier.Provider,
|
||||
Subject: msg.Identifier.Subject,
|
||||
},
|
||||
Payload: msg.Payload,
|
||||
Encoding: string(msg.Encoding),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
if err := stream.Send(&pb.Message{
|
||||
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
|
||||
Payload: msg.Payload,
|
||||
Encoding: string(msg.Encoding),
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service"
|
||||
@@ -22,7 +23,6 @@ func NewSocketStreamingServer(m *manager.Manager) *SocketStreamingServer {
|
||||
return &SocketStreamingServer{manager: m}
|
||||
}
|
||||
|
||||
// Accepts connections and hands each off to handleConnection.
|
||||
func (s *SocketStreamingServer) Serve(lis net.Listener) error {
|
||||
for {
|
||||
conn, err := lis.Accept()
|
||||
@@ -43,16 +43,16 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Low-latency socket hints (best-effort).
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
_ = tc.SetNoDelay(true)
|
||||
_ = tc.SetWriteBuffer(512 * 1024)
|
||||
_ = tc.SetReadBuffer(512 * 1024)
|
||||
_ = tc.SetKeepAlive(true)
|
||||
_ = tc.SetKeepAlivePeriod(30 * time.Second)
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// Protocol header: first line is the stream UUID.
|
||||
raw, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
fmt.Printf("read stream UUID error: %v\n", err)
|
||||
@@ -74,9 +74,8 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||
defer s.manager.DisconnectStream(streamUUID)
|
||||
|
||||
writer := bufio.NewWriterSize(conn, 256*1024)
|
||||
defer func(writer *bufio.Writer) {
|
||||
err := writer.Flush()
|
||||
if err != nil {
|
||||
defer func(w *bufio.Writer) {
|
||||
if err := w.Flush(); err != nil {
|
||||
fmt.Printf("final flush error: %v\n", err)
|
||||
}
|
||||
}(writer)
|
||||
@@ -85,27 +84,20 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||
batch := 0
|
||||
|
||||
for msg := range outCh {
|
||||
// Build protobuf payload.
|
||||
message := pb.Message{
|
||||
Identifier: &pb.Identifier{
|
||||
Provider: msg.Identifier.Provider,
|
||||
Subject: msg.Identifier.Subject,
|
||||
},
|
||||
Payload: msg.Payload, // []byte
|
||||
Encoding: string(msg.Encoding), // e.g., "application/json"
|
||||
m := pb.Message{
|
||||
Identifier: &pb.Identifier{Key: msg.Identifier.Key()},
|
||||
Payload: msg.Payload,
|
||||
Encoding: string(msg.Encoding),
|
||||
}
|
||||
|
||||
// Marshal protobuf.
|
||||
// Use MarshalAppend to reuse capacity and avoid an extra alloc.
|
||||
size := proto.Size(&message)
|
||||
size := proto.Size(&m)
|
||||
buf := make([]byte, 0, size)
|
||||
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &message)
|
||||
b, err := proto.MarshalOptions{}.MarshalAppend(buf, &m)
|
||||
if err != nil {
|
||||
fmt.Printf("proto marshal error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Fixed 4-byte big-endian length prefix.
|
||||
var hdr [4]byte
|
||||
if len(b) > int(^uint32(0)) {
|
||||
fmt.Printf("message too large: %d bytes\n", len(b))
|
||||
@@ -113,7 +105,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||
}
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(b)))
|
||||
|
||||
// Write frame: [len][bytes].
|
||||
if _, err := writer.Write(hdr[:]); err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
@@ -139,7 +130,6 @@ func (s *SocketStreamingServer) handleConnection(conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// Final flush when channel closes.
|
||||
if err := writer.Flush(); err != nil {
|
||||
fmt.Printf("final flush error: %v\n", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user