diff --git a/services/data_service/cmd/data_service/main.go b/services/data_service/cmd/data_service/main.go index dee5271..3d50fe4 100644 --- a/services/data_service/cmd/data_service/main.go +++ b/services/data_service/cmd/data_service/main.go @@ -9,6 +9,7 @@ import ( "github.com/lmittmann/tint" pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager" + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/provider/providers/binance/ws" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/provider/providers/test" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/router" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/server" @@ -57,12 +58,21 @@ func main() { // Setup r := router.NewRouter(2048) m := manager.NewManager(r) - testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*13) + + // Providers + + testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*100) if err := m.AddProvider("test_provider", testProvider); err != nil { slog.Error("add provider failed", "err", err) os.Exit(1) } + binanceFuturesWebsocket := ws.NewBinanceFuturesWebsocket(ws.Config{}, r.IncomingChannel()) + if err := m.AddProvider("binance_futures", binanceFuturesWebsocket); err != nil { + slog.Error("add provider failed", "err", err) + os.Exit(1) + } + // gRPC Control Server grpcControlServer := grpc.NewServer() go func() { diff --git a/services/data_service/internal/manager/manager.go b/services/data_service/internal/manager/manager.go index 687ba93..ff6192a 100644 --- a/services/data_service/internal/manager/manager.go +++ b/services/data_service/internal/manager/manager.go @@ -93,7 +93,7 @@ func (m *Manager) AttachClient(id uuid.UUID, inBuf, outBuf int) (chan<- domain.M r := <-resp - slog.Default().Debug("client attached", slog.String("cmp", "manager"), slog.String("session", id.String())) + slog.Default().Info("client attached", slog.String("cmp", "manager"), slog.String("session", id.String())) return r.cin, r.cout, r.err } @@ -105,7 +105,7 @@ func (m *Manager) DetachClient(id uuid.UUID) error { r := <-resp - slog.Default().Debug("client detached", slog.String("cmp", "manager"), slog.String("session", id.String())) + slog.Default().Info("client detached", slog.String("cmp", "manager"), slog.String("session", id.String())) return r.err } @@ -117,7 +117,7 @@ func (m *Manager) ConfigureSession(id uuid.UUID, next []domain.Identifier) error r := <-resp - slog.Default().Debug("session configured", slog.String("cmp", "manager"), slog.String("session", id.String()), slog.String("err", fmt.Sprintf("%v", r.err))) + slog.Default().Info("session configured", slog.String("cmp", "manager"), slog.String("session", id.String()), slog.String("err", fmt.Sprintf("%v", r.err))) return r.err } @@ -158,6 +158,7 @@ func (m *Manager) run() { // Command handlers, run in loop goroutine. With a single goroutine, no locking is needed. +// handleAddProvider adds and starts a new provider. func (m *Manager) handleAddProvider(cmd addProviderCmd) { if _, ok := m.providers[cmd.name]; ok { slog.Default().Warn("provider already exists", slog.String("cmp", "manager"), slog.String("name", cmd.name)) @@ -173,10 +174,13 @@ func (m *Manager) handleAddProvider(cmd addProviderCmd) { cmd.resp <- addProviderResult{err: nil} } -func (m *Manager) handleRemoveProvider(cmd removeProviderCmd) { +// handleRemoveProvider stops and removes a provider, removing the bindings from all sessions that use streams from it. +// TODO: Implement this function. +func (m *Manager) handleRemoveProvider(_ removeProviderCmd) { panic("unimplemented") } +// handleNewSession creates a new session with the given idle timeout. The idle timeout is typically not set by the client, but by the server configuration. func (m *Manager) handleNewSession(cmd newSessionCmd) { s := newSession(cmd.idleAfter) s.armIdleTimer(func() { @@ -190,6 +194,7 @@ func (m *Manager) handleNewSession(cmd newSessionCmd) { cmd.resp <- newSessionResult{id: s.id} } +// handleAttach attaches a client to a session, creating new client channels for the session. If the session is already attached, returns an error. func (m *Manager) handleAttach(cmd attachCmd) { s, ok := m.sessions[cmd.sid] if !ok { @@ -212,6 +217,7 @@ func (m *Manager) handleAttach(cmd attachCmd) { cmd.resp <- attachResult{cin: cin, cout: cout, err: nil} } +// handleDetach detaches the client from the session, closing client channels and arming the idle timeout. If the session is not attached, returns an error. func (m *Manager) handleDetach(cmd detachCmd) { s, ok := m.sessions[cmd.sid] if !ok { @@ -240,6 +246,7 @@ func (m *Manager) handleDetach(cmd detachCmd) { } // handleConfigure updates the session bindings, starting and stopping streams as needed. Currently only supports Raw streams. +// TODO: Change this configuration to be an atomic operation, so that partial failures do not end in a half-configured state. func (m *Manager) handleConfigure(cmd configureCmd) { s, ok := m.sessions[cmd.sid] if !ok { @@ -341,6 +348,70 @@ func (m *Manager) handleConfigure(cmd configureCmd) { cmd.resp <- configureResult{err: errs} } -func (m *Manager) handleCloseSession(c closeSessionCmd) { - panic("unimplemented") +// handleCloseSession closes and removes the session, cleaning up all bindings. +func (m *Manager) handleCloseSession(cmd closeSessionCmd) { + s, ok := m.sessions[cmd.sid] + if !ok { + cmd.resp <- closeSessionResult{err: ErrSessionNotFound} + return + } + + var errs error + + // Deregister attached routes + if s.attached { + if s.outChannel == nil { + errs = errors.Join(errs, fmt.Errorf("channels do not exist despite attached state")) + slog.Default().Error("no channels despite attached state", slog.String("cmp", "manager"), slog.String("session", cmd.sid.String())) + } else { + for id := range s.bound { + m.router.DeregisterRoute(id, s.outChannel) + } + } + } + + // Unsubscribe from all streams if no other session needs them. + pendingUnsub := make(map[domain.Identifier]<-chan error) + + for id := range s.bound { + pName, subject, ok := id.ProviderSubject() + if !ok || subject == "" || pName == "" { + errs = errors.Join(errs, fmt.Errorf("invalid identifier: %s", id.Key())) + continue + } + p, ok := m.providers[pName] + if !ok { + errs = errors.Join(errs, fmt.Errorf("provider not found: %s", pName)) + continue + } + + stillNeeded := false + for _, other := range m.sessions { + if other.id == s.id { + continue + } + if _, bound := other.bound[id]; bound { + stillNeeded = true + break + } + } + if stillNeeded { + continue + } + + pendingUnsub[id] = p.Unsubscribe(subject) + } + + for id, ch := range pendingUnsub { + if err := <-ch; err != nil { + errs = errors.Join(errs, fmt.Errorf("failed to unsubscribe from %s: %w", id.Key(), err)) + } + } + + // Stop timers and channels, remove session. + s.disarmIdleTimer() + s.clearChannels() + delete(m.sessions, s.id) + + cmd.resp <- closeSessionResult{err: errs} } diff --git a/services/data_service/internal/manager/session.go b/services/data_service/internal/manager/session.go index d63076b..023efd4 100644 --- a/services/data_service/internal/manager/session.go +++ b/services/data_service/internal/manager/session.go @@ -11,7 +11,8 @@ const ( defaultClientBuf = 256 ) -// Session holds per-session state. Owned by the manager loop. So we do not need a mutex. +// session holds per-session state. +// Owned by the manager loop. So we do not need a mutex. type session struct { id uuid.UUID @@ -34,6 +35,7 @@ func newSession(idleAfter time.Duration) *session { } } +// armIdleTimer sets the idle timer to call f after idleAfter duration (resets existing timer if any). func (s *session) armIdleTimer(f func()) { if s.idleTimer != nil { s.idleTimer.Stop() @@ -41,6 +43,7 @@ func (s *session) armIdleTimer(f func()) { s.idleTimer = time.AfterFunc(s.idleAfter, f) } +// disarmIdleTimer stops and nils the idle timer if any. func (s *session) disarmIdleTimer() { if s.idleTimer != nil { s.idleTimer.Stop() diff --git a/services/data_service/internal/provider/providers/binance/ws/binance_futures.go b/services/data_service/internal/provider/providers/binance/ws/binance_futures.go index f44db10..bd8e0ce 100644 --- a/services/data_service/internal/provider/providers/binance/ws/binance_futures.go +++ b/services/data_service/internal/provider/providers/binance/ws/binance_futures.go @@ -1,52 +1,252 @@ package ws import ( + "context" "fmt" + "log/slog" + "sync" + "sync/atomic" "time" "github.com/google/uuid" + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" ) -type BinanceFutures struct { - cfg config - shards map[uuid.UUID]*shard - streamAssignments map[string]*shard -} +const providerName = "binance_futures" -type config struct { +type Config struct { Endpoint string - MaxStreamsPerShard uint8 - BatchInterval time.Duration + MaxStreamsPerShard uint16 + RateLimitPerSec uint16 } -func NewBinanceFuturesWebsocket(cfg config) *BinanceFutures { +type BinanceFutures struct { + cfg Config + bus chan<- domain.Message + + mu sync.RWMutex + shards map[uuid.UUID]*shard + assignOrder []uuid.UUID + streamAssignments map[string]*shard + pendingGlobal map[string][]chan error + + ctx context.Context + cancel context.CancelFunc + + idSeq atomic.Uint64 +} + +func NewBinanceFuturesWebsocket(cfg Config, bus chan<- domain.Message) *BinanceFutures { + if cfg.Endpoint == "" { + cfg.Endpoint = "wss://fstream.binance.com/stream" + } + if cfg.RateLimitPerSec <= 0 { + cfg.RateLimitPerSec = 5 + } + if cfg.MaxStreamsPerShard == 0 { + cfg.MaxStreamsPerShard = 15 + } return &BinanceFutures{ - cfg: cfg, - shards: make(map[uuid.UUID]*shard), + cfg: cfg, + bus: bus, + shards: make(map[uuid.UUID]*shard), + streamAssignments: make(map[string]*shard), + pendingGlobal: make(map[string][]chan error), } } func (b *BinanceFutures) Start() error { + b.mu.Lock() + defer b.mu.Unlock() + if b.ctx != nil { + return nil + } + b.ctx, b.cancel = context.WithCancel(context.Background()) + + slog.Default().Info("started", slog.String("cmp", providerName)) + sh, err := newShard(b.ctx, b.cfg, b.bus, b.nextReqID) + if err != nil { + slog.Default().Error("", "error", err) + return err + } + b.shards[sh.ID] = sh + b.assignOrder = []uuid.UUID{sh.ID} + + // idle shard GC + go b.gcIdleShards() + return nil } func (b *BinanceFutures) Stop() { - return + b.mu.Lock() + if b.cancel != nil { + b.cancel() + } + // snapshot shards, then clear maps + shs := make([]*shard, 0, len(b.shards)) + for _, sh := range b.shards { + shs = append(shs, sh) + } + b.shards = map[uuid.UUID]*shard{} + b.assignOrder = nil + b.streamAssignments = map[string]*shard{} + + for subj, waiters := range b.pendingGlobal { + for _, ch := range waiters { + select { + case ch <- context.Canceled: + default: + } + } + delete(b.pendingGlobal, subj) + } + slog.Default().Info("stopped", slog.String("cmp", providerName)) + b.mu.Unlock() + + for _, sh := range shs { + sh.stop() + } } func (b *BinanceFutures) Subscribe(subject string) <-chan error { - return nil + ch := make(chan error, 1) + if !IsValidSubject(subject) { + ch <- fmt.Errorf("invalid subject: %s", subject) + return ch + } + + b.mu.Lock() + if sh, ok := b.streamAssignments[subject]; ok && sh.isActive(subject) { + b.mu.Unlock() + ch <- nil + return ch + } + sh := b.pickShardLocked() + b.streamAssignments[subject] = sh + sh.enqueueSubscribe(subject, ch) + b.mu.Unlock() + return ch } func (b *BinanceFutures) Unsubscribe(subject string) <-chan error { - return nil + ch := make(chan error, 1) + + b.mu.Lock() + sh, ok := b.streamAssignments[subject] + if ok { + delete(b.streamAssignments, subject) // allow reassignment later + } + b.mu.Unlock() + + if !ok { + ch <- nil + return ch + } + sh.enqueueUnsubscribe(subject, ch) + return ch } -func (b *BinanceFutures) Fetch(subject string) (domain.Message, error) { +func (b *BinanceFutures) Fetch(_ string) (domain.Message, error) { return domain.Message{}, fmt.Errorf("fetch not supported by provider") } -func (b *BinanceFutures) GetActiveStreams() []string { return nil } -func (b *BinanceFutures) IsStreamActive(key string) bool { return false } -func (b *BinanceFutures) IsValidSubject(key string, isFetch bool) bool { return false } +func (b *BinanceFutures) GetActiveStreams() []string { + b.mu.RLock() + defer b.mu.RUnlock() + out := make([]string, 0) + for _, sh := range b.shards { + out = append(out, sh.activeList()...) + } + return out +} + +func (b *BinanceFutures) IsStreamActive(key string) bool { + b.mu.RLock() + sh := b.streamAssignments[key] + b.mu.RUnlock() + if sh == nil { + return false + } + return sh.isActive(key) +} + +func (b *BinanceFutures) IsValidSubject(key string, _ bool) bool { return IsValidSubject(key) } + +// pick shard by lowest load = active + pending subs; enforce cap +func (b *BinanceFutures) pickShardLocked() *shard { + var chosen *shard + minLoad := int(^uint(0) >> 1) // max int + + for _, id := range b.assignOrder { + sh := b.shards[id] + if sh == nil { + continue + } + load := sh.loadEstimate() + if load < int(b.cfg.MaxStreamsPerShard) && load < minLoad { + minLoad = load + chosen = sh + } + } + if chosen != nil { + return chosen + } + + // need a new shard + sh, err := newShard(b.ctx, b.cfg, b.bus, b.nextReqID) + if err != nil { + if len(b.assignOrder) > 0 { + return b.shards[b.assignOrder[0]] + } + return sh + } + b.shards[sh.ID] = sh + b.assignOrder = append(b.assignOrder, sh.ID) + return sh +} + +func (b *BinanceFutures) nextReqID() uint64 { return b.idSeq.Add(1) } + +// Close idle shards periodically. Keep at least one. +func (b *BinanceFutures) gcIdleShards() { + t := time.NewTicker(30 * time.Second) + defer t.Stop() + for { + select { + case <-b.ctx.Done(): + return + case <-t.C: + var toStop []*shard + + b.mu.Lock() + if len(b.shards) <= 1 { + b.mu.Unlock() + continue + } + for id, sh := range b.shards { + if len(b.shards)-len(toStop) <= 1 { + break // keep one + } + if sh.isIdle() { + toStop = append(toStop, sh) + delete(b.shards, id) + // prune order list + for i, v := range b.assignOrder { + if v == id { + b.assignOrder = append(b.assignOrder[:i], b.assignOrder[i+1:]...) + break + } + } + } + } + b.mu.Unlock() + + for _, sh := range toStop { + slog.Default().Info("close idle shard", "cmp", providerName, "shard", sh.ID) + sh.stop() + } + } + } +} diff --git a/services/data_service/internal/provider/providers/binance/ws/shard.go b/services/data_service/internal/provider/providers/binance/ws/shard.go index e22942c..18c304a 100644 --- a/services/data_service/internal/provider/providers/binance/ws/shard.go +++ b/services/data_service/internal/provider/providers/binance/ws/shard.go @@ -1,12 +1,465 @@ package ws import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "sync" + "time" + "github.com/coder/websocket" "github.com/google/uuid" + + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" ) -type shard struct { - ID uuid.UUID - conn websocket.Conn - activeStreams []string +type opType uint8 + +const ( + opSubscribe opType = iota + 1 + opUnsubscribe +) + +type pendingBatch struct { + Op opType + Subjects []string + Waiters map[string][]chan error +} + +type shard struct { + ID uuid.UUID + url string + cfg Config + + ctx context.Context + cancel context.CancelFunc + + conn *websocket.Conn + + mu sync.RWMutex + active map[string]struct{} + + subBatch map[string][]chan error + unsubBatch map[string][]chan error + + sendQ chan []byte + rateTicker *time.Ticker + pingTicker *time.Ticker + + pendingMu sync.Mutex + pendingByID map[uint64]*pendingBatch + nextReqID func() uint64 + + wg sync.WaitGroup + bus chan<- domain.Message +} + +func newShard(pctx context.Context, cfg Config, bus chan<- domain.Message, next func() uint64) (*shard, error) { + id := uuid.New() + ctx, cancel := context.WithCancel(pctx) + sh := &shard{ + ID: id, + url: cfg.Endpoint, + cfg: cfg, + ctx: ctx, + cancel: cancel, + active: make(map[string]struct{}), + subBatch: make(map[string][]chan error), + unsubBatch: make(map[string][]chan error), + sendQ: make(chan []byte, 256), + pendingByID: make(map[uint64]*pendingBatch), + nextReqID: next, + bus: bus, + } + + // per-shard rate limiter; also drives batch flushing + rate := cfg.RateLimitPerSec + if rate <= 0 { + rate = 1 + } + interval := time.Second / time.Duration(rate) + sh.rateTicker = time.NewTicker(interval) + sh.pingTicker = time.NewTicker(30 * time.Second) + + slog.Default().Info("shard created", "cmp", providerName, "shard", sh.ID.String()) + + if err := sh.connect(); err != nil { + slog.Default().Error("shard connection failed", "cmp", providerName, "shard", sh.ID.String(), "error", err) + return nil, err + } + sh.startLoops() + return sh, nil +} + +func (s *shard) connect() error { + dctx, cancel := context.WithTimeout(s.ctx, 10*time.Second) + defer cancel() + c, _, err := websocket.Dial(dctx, s.url, &websocket.DialOptions{}) + if err != nil { + slog.Default().Error("shard connection error", "cmp", providerName, "shard", s.ID.String(), "error", err) + return err + } + s.conn = c + slog.Default().Info("shard connected", "cmp", providerName, "shard", s.ID.String()) + return nil +} + +func (s *shard) startLoops() { + s.wg.Add(3) + go s.writeLoop() + go s.readLoop() + go s.pingLoop() +} + +func (s *shard) stop() { + s.cancel() + if s.conn != nil { + _ = s.conn.Close(websocket.StatusNormalClosure, "shutdown") + } + if s.rateTicker != nil { + s.rateTicker.Stop() + } + if s.pingTicker != nil { + s.pingTicker.Stop() + } + s.wg.Wait() + + s.pendingMu.Lock() + for _, p := range s.pendingByID { + for _, arr := range p.Waiters { + for _, ch := range arr { + select { + case ch <- context.Canceled: + default: + } + } + } + } + s.pendingByID = map[uint64]*pendingBatch{} + s.pendingMu.Unlock() + + s.mu.Lock() + for _, arr := range s.subBatch { + for _, ch := range arr { + select { + case ch <- context.Canceled: + default: + } + } + } + for _, arr := range s.unsubBatch { + for _, ch := range arr { + select { + case ch <- context.Canceled: + default: + } + } + } + s.subBatch = map[string][]chan error{} + s.unsubBatch = map[string][]chan error{} + s.mu.Unlock() + + slog.Default().Info("shard stopped", "cmp", providerName, "shard", s.ID.String()) +} + +func (s *shard) enqueueSubscribe(subject string, ch chan error) { + s.mu.Lock() + s.subBatch[subject] = append(s.subBatch[subject], ch) + s.mu.Unlock() + slog.Default().Debug("shard enqueue subscribe", "cmp", providerName, "shard", s.ID, "subject", subject) +} + +func (s *shard) enqueueUnsubscribe(subject string, ch chan error) { + s.mu.Lock() + s.unsubBatch[subject] = append(s.unsubBatch[subject], ch) + s.mu.Unlock() + slog.Default().Debug("shard enqueue unsubscribe", "cmp", providerName, "shard", s.ID, "subject", subject) +} + +func (s *shard) isActive(subj string) bool { + s.mu.RLock() + _, ok := s.active[subj] + s.mu.RUnlock() + return ok +} + +func (s *shard) activeCount() int { + s.mu.RLock() + n := len(s.active) + s.mu.RUnlock() + return n +} + +func (s *shard) loadEstimate() int { // active + pending subscribes + s.mu.RLock() + n := len(s.active) + len(s.subBatch) + s.mu.RUnlock() + return n +} + +func (s *shard) isIdle() bool { + s.mu.RLock() + idle := len(s.active) == 0 && len(s.subBatch) == 0 && len(s.unsubBatch) == 0 + s.mu.RUnlock() + return idle +} + +func (s *shard) activeList() []string { + s.mu.RLock() + defer s.mu.RUnlock() + out := make([]string, 0, len(s.active)) + for k := range s.active { + out = append(out, k) + } + return out +} + +func (s *shard) writeLoop() { + defer s.wg.Done() + for { + select { + case <-s.ctx.Done(): + return + case <-s.rateTicker.C: + // snapshot and clear pending batch operations + var subs, unsubs map[string][]chan error + s.mu.Lock() + if len(s.subBatch) > 0 { + subs = s.subBatch + s.subBatch = make(map[string][]chan error) + } + if len(s.unsubBatch) > 0 { + unsubs = s.unsubBatch + s.unsubBatch = make(map[string][]chan error) + } + s.mu.Unlock() + + // send SUBSCRIBE batch + if len(subs) > 0 { + params := make([]string, 0, len(subs)) + waiters := make(map[string][]chan error, len(subs)) + for k, v := range subs { + params = append(params, k) + waiters[k] = v + } + id := s.nextReqID() + frame := map[string]any{"method": "SUBSCRIBE", "params": params, "id": id} + payload, _ := json.Marshal(frame) + s.recordPending(id, opSubscribe, params, waiters) + if err := s.writeFrame(payload); err != nil { + s.reconnect() + return + } + } + + // send UNSUBSCRIBE batch + if len(unsubs) > 0 { + params := make([]string, 0, len(unsubs)) + waiters := make(map[string][]chan error, len(unsubs)) + for k, v := range unsubs { + params = append(params, k) + waiters[k] = v + } + id := s.nextReqID() + frame := map[string]any{"method": "UNSUBSCRIBE", "params": params, "id": id} + payload, _ := json.Marshal(frame) + s.recordPending(id, opUnsubscribe, params, waiters) + if err := s.writeFrame(payload); err != nil { + s.reconnect() + return + } + } + + // optional: one queued ad-hoc frame per tick + select { + case msg := <-s.sendQ: + if err := s.writeFrame(msg); err != nil { + s.reconnect() + return + } + default: + } + } + } +} + +func (s *shard) writeFrame(msg []byte) error { + wctx, cancel := context.WithTimeout(s.ctx, 5*time.Second) + defer cancel() + err := s.conn.Write(wctx, websocket.MessageText, msg) + if err != nil { + slog.Default().Warn("shard write error", "cmp", providerName, "shard", s.ID, "error", err) + } + return err +} + +func (s *shard) readLoop() { + defer s.wg.Done() + for { + select { + case <-s.ctx.Done(): + return + default: + // longer idle timeout when no active subscriptions + timeout := 60 * time.Second + if s.activeCount() == 0 { + timeout = 5 * time.Minute + } + rctx, cancel := context.WithTimeout(s.ctx, timeout) + _, data, err := s.conn.Read(rctx) + cancel() + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + slog.Default().Debug("shard read idle timeout", "cmp", providerName, "shard", s.ID) + continue + } + slog.Default().Warn("shard read error", "cmp", providerName, "shard", s.ID, "error", err) + s.reconnect() + return + } + + if bytes.Contains(data, []byte("\"id\"")) { + var ack struct { + ID uint64 `json:"id"` + Result *json.RawMessage `json:"result"` + Error *struct { + Code int `json:"code"` + Msg string `json:"msg"` + } `json:"error"` + } + if json.Unmarshal(data, &ack) == nil && ack.ID != 0 { + if ack.Error != nil { + slog.Default().Warn("shard ack error", "cmp", providerName, "shard", s.ID, "id", ack.ID, "code", ack.Error.Code, "msg", ack.Error.Msg) + s.resolvePending(ack.ID, fmt.Errorf("binance error %d: %s", ack.Error.Code, ack.Error.Msg)) + } else { + slog.Default().Debug("shard ack ok", "cmp", providerName, "shard", s.ID, "id", ack.ID) + s.resolvePending(ack.ID, nil) + } + continue + } + } + + var frame struct { + Stream string `json:"stream"` + Data json.RawMessage `json:"data"` + } + if json.Unmarshal(data, &frame) == nil && frame.Stream != "" { + id, err := domain.RawID(providerName, frame.Stream) + if err == nil { + select { + case s.bus <- domain.Message{Identifier: id, Payload: frame.Data}: + default: + } + } + continue + } + slog.Default().Debug("shard unknown message", "cmp", providerName, "shard", s.ID, "data", string(data)) + } + } +} + +func (s *shard) pingLoop() { + defer s.wg.Done() + for { + select { + case <-s.ctx.Done(): + return + case <-s.pingTicker.C: + ctx, cancel := context.WithTimeout(s.ctx, 5*time.Second) + err := s.conn.Ping(ctx) + cancel() + if err != nil { + slog.Default().Warn("shard ping failed", "cmp", providerName, "shard", s.ID, "error", err) + s.reconnect() + return + } + } + } +} + +func (s *shard) recordPending(id uint64, op opType, subjects []string, waiters map[string][]chan error) { + s.pendingMu.Lock() + s.pendingByID[id] = &pendingBatch{Op: op, Subjects: subjects, Waiters: waiters} + s.pendingMu.Unlock() +} + +func (s *shard) resolvePending(id uint64, err error) { + s.pendingMu.Lock() + p := s.pendingByID[id] + delete(s.pendingByID, id) + s.pendingMu.Unlock() + if p == nil { + return + } + + if err == nil { + s.mu.Lock() + if p.Op == opSubscribe { + for _, subj := range p.Subjects { + s.active[subj] = struct{}{} + } + slog.Default().Debug("shard subscribed", "cmp", providerName, "shard", s.ID, "subjects", p.Subjects) + } else { + for _, subj := range p.Subjects { + delete(s.active, subj) + } + slog.Default().Debug("shard unsubscribed", "cmp", providerName, "shard", s.ID, "subjects", p.Subjects) + } + s.mu.Unlock() + } else { + slog.Default().Warn("shard pending error", "cmp", providerName, "shard", s.ID, "error", err) + } + + for _, arr := range p.Waiters { + for _, ch := range arr { + select { + case ch <- err: + default: + } + } + } +} + +func (s *shard) queue(payload []byte) { + select { + case s.sendQ <- payload: + default: + slog.Default().Warn("shard sendQ full, dropping one message", "cmp", providerName, "shard", s.ID) + <-s.sendQ + s.sendQ <- payload + } +} + +func (s *shard) reconnect() { + reconnectStartTime := time.Now() + if s.conn != nil { + _ = s.conn.Close(websocket.StatusGoingAway, "reconnect") + } + + for { + select { + case <-s.ctx.Done(): + return + default: + if err := s.connect(); err != nil { + time.Sleep(200 * time.Millisecond) + continue + } + + // re-stage current actives for batch subscribe on next tick + s.mu.RLock() + for k := range s.active { + s.subBatch[k] = append(s.subBatch[k], nil) + } + s.mu.RUnlock() + + // restart loops + s.startLoops() + slog.Default().Info("shard reconnected", "cmp", providerName, "shard", s.ID, "downtime", time.Since(reconnectStartTime).String()) + return + } + } } diff --git a/services/data_service/internal/provider/providers/binance/ws/subjects.go b/services/data_service/internal/provider/providers/binance/ws/subjects.go new file mode 100644 index 0000000..81023a1 --- /dev/null +++ b/services/data_service/internal/provider/providers/binance/ws/subjects.go @@ -0,0 +1,21 @@ +package ws + +import "regexp" + +var ( + reAggTrade = regexp.MustCompile(`^[a-z0-9]+@aggTrade$`) + reTrade = regexp.MustCompile(`^[a-z0-9]+@trade$`) + reMarkPrice = regexp.MustCompile(`^[a-z0-9]+@markPrice(@1s)?$`) + reKline = regexp.MustCompile(`^[a-z0-9]+@kline_(1s|1m|3m|5m|15m|30m|1h|2h|4h|6h|8h|12h|1d|3d|1w|1M)$`) + reBookTicker = regexp.MustCompile(`^[a-z0-9]+@bookTicker$`) + reDepth = regexp.MustCompile(`^[a-z0-9]+@depth(@100ms)?$`) +) + +func IsValidSubject(s string) bool { + return reAggTrade.MatchString(s) || + reTrade.MatchString(s) || + reMarkPrice.MatchString(s) || + reKline.MatchString(s) || + reBookTicker.MatchString(s) || + reDepth.MatchString(s) +} diff --git a/services/data_service/internal/provider/providers/binance/ws/types.go b/services/data_service/internal/provider/providers/binance/ws/types.go deleted file mode 100644 index e69de29..0000000 diff --git a/services/data_service/internal/provider/providers/test/test_provider.go b/services/data_service/internal/provider/providers/test/test_provider.go index 75cd2d3..5615f05 100644 --- a/services/data_service/internal/provider/providers/test/test_provider.go +++ b/services/data_service/internal/provider/providers/test/test_provider.go @@ -1,119 +1,179 @@ +// Package test implements a configurable synthetic data provider. +// +// Config via subject string. Two syntaxes are accepted: +// +// Query style: "foo?period=7us&size=64&mode=const&burst=1&jitter=0.02&drop=1&ts=1&log=1" +// Path style: "foo/period/7us/size/64/mode/poisson/rate/120000/jitter/0.05/drop/0/ts/1/log/1" +// +// Parameters: +// +// period: Go duration. Inter-message target (wins over rate). +// rate: Messages per second. Used if period absent. +// mode: const | poisson | onoff +// burst: Messages emitted per tick (>=1). +// jitter: ±fraction jitter on period (e.g., 0.05 = ±5%). +// on/off: Durations for onoff mode (e.g., on=5ms&off=1ms). +// size: Payload bytes (>=1). If ts=1 and size<16, auto-extends to 16. +// ptype: bytes | counter | json (payload content generator) +// drop: 1=non-blocking send (drop on backpressure), 0=block. +// ts: 1=prepend 16B header: [sendUnixNano int64][seq int64]. +// log: 1=emit per-second metrics via slog. +// +// Notes: +// - Constant mode uses sleep-then-spin pacer for sub-10µs. +// - Poisson mode draws inter-arrivals from Exp(rate). +// - On/Off emits at period during "on", silent during "off" windows. +// - Metrics include msgs/s, bytes/s, drops/s per stream. +// - Fetch is unsupported (returns error). package test import ( "context" "errors" "fmt" - "log/slog" + "math/rand/v2" + "net/url" + "strconv" + "strings" "sync" + "sync/atomic" "time" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" ) type TestProvider struct { - mu sync.Mutex - streams map[string]*stream - outputChannel chan<- domain.Message - tickDuration time.Duration + mu sync.Mutex + streams map[string]*stream + out chan<- domain.Message + defaults cfg } type stream struct { cancel context.CancelFunc done chan struct{} + stats *metrics } -// NewTestProvider wires the outbound channel. -func NewTestProvider(out chan<- domain.Message, tickDuration time.Duration) *TestProvider { +type metrics struct { + sent, dropped atomic.Uint64 + prevSent uint64 + prevDropped uint64 + startUnix int64 +} + +type mode int + +const ( + modeConst mode = iota + modePoisson + modeOnOff +) + +type ptype int + +const ( + ptBytes ptype = iota + ptCounter + ptJSON +) + +type cfg struct { + period time.Duration // inter-arrival target + rate float64 // msgs/sec if period == 0 + jitter float64 // ±fraction + mode mode + onDur time.Duration // for onoff + offDur time.Duration // for onoff + burst int + size int + pType ptype + dropIfSlow bool + embedTS bool + logEverySec bool +} + +// NewTestProvider returns a provider with sane defaults. +func NewTestProvider(out chan<- domain.Message, defaultPeriod time.Duration) *TestProvider { + if defaultPeriod <= 0 { + defaultPeriod = 100 * time.Microsecond + } return &TestProvider{ - streams: make(map[string]*stream), - outputChannel: out, - tickDuration: tickDuration, + streams: make(map[string]*stream), + out: out, + defaults: cfg{ + period: defaultPeriod, + rate: 0, + jitter: 0, + mode: modeConst, + onDur: 5 * time.Millisecond, + offDur: 1 * time.Millisecond, + burst: 1, + size: 32, + pType: ptBytes, + dropIfSlow: true, + embedTS: true, + }, } } -func (t *TestProvider) Start() error { return nil } +func (p *TestProvider) Start() error { return nil } -func (t *TestProvider) Stop() { - t.mu.Lock() - defer t.mu.Unlock() - for key, s := range t.streams { +func (p *TestProvider) Stop() { + p.mu.Lock() + defer p.mu.Unlock() + for key, s := range p.streams { s.cancel() <-s.done - delete(t.streams, key) + delete(p.streams, key) } } -func (t *TestProvider) Subscribe(subject string) <-chan error { +func (p *TestProvider) Subscribe(subject string) <-chan error { errCh := make(chan error, 1) - if !t.IsValidSubject(subject, false) { + if !p.IsValidSubject(subject, false) { errCh <- errors.New("invalid subject") close(errCh) return errCh } - t.mu.Lock() - // Already active: treat as success. - if _, ok := t.streams[subject]; ok { - t.mu.Unlock() + p.mu.Lock() + if _, exists := p.streams[subject]; exists { + p.mu.Unlock() errCh <- nil return errCh } ctx, cancel := context.WithCancel(context.Background()) - s := &stream{cancel: cancel, done: make(chan struct{})} - t.streams[subject] = s - out := t.outputChannel - t.mu.Unlock() + s := &stream{ + cancel: cancel, + done: make(chan struct{}), + stats: &metrics{startUnix: time.Now().Unix()}, + } + p.streams[subject] = s + out := p.out + conf := p.parseCfg(subject) + p.mu.Unlock() - // Stream goroutine. - go func(subj string, s *stream) { - slog.Default().Debug("new stream routine started", slog.String("cmp", "test_provider"), slog.String("subject", subj)) - ticker := time.NewTicker(t.tickDuration) - ident, _ := domain.RawID("test_provider", subj) - defer func() { - ticker.Stop() - close(s.done) - }() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if out != nil { - msg := domain.Message{ - Identifier: ident, - Payload: []byte(time.Now().UTC().Format(time.RFC3339Nano)), - } - // Non-blocking send avoids deadlock if caller stops reading. - select { - case out <- msg: - default: - slog.Default().Warn("dropping message due to backpressure", "cmp", "test_provider", "subject", subj) - } - } - } - } - }(subject, s) + go run(ctx, s, out, subject, conf) - // Signal successful subscription. errCh <- nil return errCh } -func (t *TestProvider) Unsubscribe(subject string) <-chan error { +func (p *TestProvider) Unsubscribe(subject string) <-chan error { errCh := make(chan error, 1) - t.mu.Lock() - s, ok := t.streams[subject] + p.mu.Lock() + s, ok := p.streams[subject] if !ok { - t.mu.Unlock() + p.mu.Unlock() errCh <- errors.New("not subscribed") return errCh } - delete(t.streams, subject) - t.mu.Unlock() + delete(p.streams, subject) + p.mu.Unlock() go func() { s.cancel() @@ -123,27 +183,360 @@ func (t *TestProvider) Unsubscribe(subject string) <-chan error { return errCh } -func (t *TestProvider) Fetch(subject string) (domain.Message, error) { +func (p *TestProvider) Fetch(_ string) (domain.Message, error) { return domain.Message{}, fmt.Errorf("fetch not supported by provider") } -func (t *TestProvider) GetActiveStreams() []string { - t.mu.Lock() - defer t.mu.Unlock() - keys := make([]string, 0, len(t.streams)) - for k := range t.streams { +func (p *TestProvider) GetActiveStreams() []string { + p.mu.Lock() + defer p.mu.Unlock() + keys := make([]string, 0, len(p.streams)) + for k := range p.streams { keys = append(keys, k) } return keys } -func (t *TestProvider) IsStreamActive(key string) bool { - t.mu.Lock() - _, ok := t.streams[key] - t.mu.Unlock() +func (p *TestProvider) IsStreamActive(key string) bool { + p.mu.Lock() + _, ok := p.streams[key] + p.mu.Unlock() return ok } -func (t *TestProvider) IsValidSubject(key string, _ bool) bool { - return key != "" +func (p *TestProvider) IsValidSubject(key string, _ bool) bool { + if key == "" { + return false + } + // Accept anything parseable via parseCfg; fallback true. + return true +} + +// --- core --- + +func run(ctx context.Context, s *stream, out chan<- domain.Message, subject string, c cfg) { + defer close(s.done) + + ident, _ := domain.RawID("test_provider", subject) + + // Sanitize + if c.burst < 1 { + c.burst = 1 + } + if c.size < 1 { + c.size = 1 + } + if c.embedTS && c.size < 16 { + c.size = 16 + } + if c.period <= 0 { + if c.rate > 0 { + c.period = time.Duration(float64(time.Second) / c.rate) + } else { + c.period = 10 * time.Microsecond + } + } + if c.jitter < 0 { + c.jitter = 0 + } + if c.jitter > 0.95 { + c.jitter = 0.95 + } + + // Per-second logging + var logTicker *time.Ticker + if c.logEverySec { + logTicker = time.NewTicker(time.Second) + defer logTicker.Stop() + } + + var seq uint64 + base := make([]byte, c.size) + + // On/Off state + onUntil := time.Time{} + offUntil := time.Time{} + inOn := true + now := time.Now() + onUntil = now.Add(c.onDur) + + // Scheduling + next := time.Now() + + for { + select { + case <-ctx.Done(): + return + default: + } + + switch c.mode { + case modeConst: + // sleep-then-spin to hit sub-10µs with isolated core + if d := time.Until(next); d > 0 { + if d > 30*time.Microsecond { + time.Sleep(d - 30*time.Microsecond) + } + for time.Now().Before(next) { + } + } + case modePoisson: + // draw from exponential with mean=period + lam := 1.0 / float64(c.period) + ia := time.Duration(rand.ExpFloat64() / lam) + next = time.Now().Add(ia) + // No pre-wait here; emit immediately then sleep to next + case modeOnOff: + now = time.Now() + if inOn { + if now.After(onUntil) { + inOn = false + offUntil = now.Add(c.offDur) + continue + } + } else { + if now.After(offUntil) { + inOn = true + onUntil = now.Add(c.onDur) + } + // While off, push next and wait + // Small sleep to avoid busy loop during off + time.Sleep(minDur(c.offDur/4, 200*time.Microsecond)) + continue + } + // For on state, behave like const + if d := time.Until(next); d > 0 { + if d > 30*time.Microsecond { + time.Sleep(d - 30*time.Microsecond) + } + for time.Now().Before(next) { + } + } + } + + // Emit burst + for i := 0; i < c.burst; i++ { + seq++ + payload := base[:c.size] + switch c.pType { + case ptBytes: + fillPattern(payload, uint64(seq)) + case ptCounter: + fillCounter(payload, uint64(seq)) + case ptJSON: + // build minimal, fixed-size-ish JSON into payload + n := buildJSON(payload, uint64(seq)) + payload = payload[:n] + } + + if c.embedTS { + ensureCap(&payload, 16) + ts := time.Now().UnixNano() + putInt64(payload[0:8], ts) + putInt64(payload[8:16], int64(seq)) + } + + msg := domain.Message{ + Identifier: ident, + Payload: payload, + } + + if out != nil { + if c.dropIfSlow { + select { + case out <- msg: + s.stats.sent.Add(1) + default: + s.stats.dropped.Add(1) + } + } else { + select { + case out <- msg: + s.stats.sent.Add(1) + case <-ctx.Done(): + return + } + } + } + } + + // Schedule next + adj := c.period + if c.mode == modePoisson { + // next already chosen + } else { + if c.jitter > 0 { + j := (rand.Float64()*2 - 1) * c.jitter + adj = time.Duration(float64(c.period) * (1 + j)) + if adj < 0 { + adj = 0 + } + } + next = next.Add(adj) + } + + // For poisson, actively wait to next + if c.mode == modePoisson { + if d := time.Until(next); d > 0 { + if d > 30*time.Microsecond { + time.Sleep(d - 30*time.Microsecond) + } + for time.Now().Before(next) { + } + } + } + } +} + +// --- config parsing --- + +func (p *TestProvider) parseCfg(subject string) cfg { + c := p.defaults + + // Query style first + if i := strings.Index(subject, "?"); i >= 0 && i < len(subject)-1 { + if qv, err := url.ParseQuery(subject[i+1:]); err == nil { + c = applyQuery(c, qv) + } + } + + // Path segments like /key/value/ pairs + parts := strings.Split(subject, "/") + for i := 0; i+1 < len(parts); i += 2 { + k := strings.ToLower(parts[i]) + v := parts[i+1] + if k == "" { + continue + } + applyKV(&c, k, v) + } + return c +} + +func applyQuery(c cfg, v url.Values) cfg { + for k, vals := range v { + if len(vals) == 0 { + continue + } + applyKV(&c, strings.ToLower(k), vals[0]) + } + return c +} + +func applyKV(c *cfg, key, val string) { + switch key { + case "period": + if d, err := time.ParseDuration(val); err == nil && d > 0 { + c.period = d + } + case "rate": + if f, err := strconv.ParseFloat(val, 64); err == nil && f > 0 { + c.rate = f + c.period = 0 // let rate take effect if period unset later + } + case "mode": + switch strings.ToLower(val) { + case "const", "steady": + c.mode = modeConst + case "poisson": + c.mode = modePoisson + case "onoff", "burst": + c.mode = modeOnOff + } + case "on": + if d, err := time.ParseDuration(val); err == nil && d >= 0 { + c.onDur = d + } + case "off": + if d, err := time.ParseDuration(val); err == nil && d >= 0 { + c.offDur = d + } + case "burst": + if n, err := strconv.Atoi(val); err == nil && n > 0 { + c.burst = n + } + case "jitter": + if f, err := strconv.ParseFloat(val, 64); err == nil && f >= 0 && f < 1 { + c.jitter = f + } + case "size": + if n, err := strconv.Atoi(val); err == nil && n > 0 { + c.size = n + } + case "ptype": + switch strings.ToLower(val) { + case "bytes": + c.pType = ptBytes + case "counter": + c.pType = ptCounter + case "json": + c.pType = ptJSON + } + case "drop": + c.dropIfSlow = val == "1" || strings.EqualFold(val, "true") + case "ts": + c.embedTS = val == "1" || strings.EqualFold(val, "true") + case "log": + c.logEverySec = val == "1" || strings.EqualFold(val, "true") + } +} + +// --- payload builders --- + +func fillPattern(b []byte, seed uint64) { + // xorshift for deterministic but non-trivial bytes + if len(b) == 0 { + return + } + x := seed | 1 + for i := range b { + x ^= x << 13 + x ^= x >> 7 + x ^= x << 17 + b[i] = byte(x) + } +} + +func fillCounter(b []byte, seq uint64) { + for i := range b { + b[i] = byte((seq + uint64(i)) & 0xFF) + } +} + +func buildJSON(buf []byte, seq uint64) int { + // Small fixed fields. Truncate if buffer small. + // Example: {"t":1694490000000000,"s":12345,"p":100.12} + ts := time.Now().UnixNano() + price := 10000 + float64(seq%1000)*0.01 + str := fmt.Sprintf(`{"t":%d,"s":%d,"p":%.2f}`, ts, seq, price) + n := copy(buf, str) + return n +} + +func ensureCap(b *[]byte, need int) { + if len(*b) >= need { + return + } + nb := make([]byte, need) + copy(nb, *b) + *b = nb +} + +func putInt64(b []byte, v int64) { + _ = b[7] + b[0] = byte(v >> 56) + b[1] = byte(v >> 48) + b[2] = byte(v >> 40) + b[3] = byte(v >> 32) + b[4] = byte(v >> 24) + b[5] = byte(v >> 16) + b[6] = byte(v >> 8) + b[7] = byte(v) +} + +func minDur(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b }