From 924187b2f593f718c90ae76e6d60feafc775f422 Mon Sep 17 00:00:00 2001 From: Phillip Michelsen Date: Thu, 11 Sep 2025 08:29:12 +0000 Subject: [PATCH] Begun redesign of binance futures websocket. Added test provider for testing purposes. --- .../data_service/cmd/data_service/main.go | 6 +- .../data_service/internal/manager/helper.go | 4 - .../data_service/internal/manager/manager.go | 60 ++- .../data_service/internal/manager/session.go | 12 - .../provider/binance/futures_websocket.go | 410 ------------------ .../providers/binance/ws/binance_futures.go | 52 +++ .../provider/providers/binance/ws/shard.go | 12 + .../provider/providers/binance/ws/types.go | 0 .../provider/providers/test/test_provider.go | 149 +++++++ .../internal/server/gprc_control_server.go | 7 +- 10 files changed, 261 insertions(+), 451 deletions(-) delete mode 100644 services/data_service/internal/provider/binance/futures_websocket.go create mode 100644 services/data_service/internal/provider/providers/binance/ws/binance_futures.go create mode 100644 services/data_service/internal/provider/providers/binance/ws/shard.go create mode 100644 services/data_service/internal/provider/providers/binance/ws/types.go create mode 100644 services/data_service/internal/provider/providers/test/test_provider.go diff --git a/services/data_service/cmd/data_service/main.go b/services/data_service/cmd/data_service/main.go index 5f3ddf3..1ebcc06 100644 --- a/services/data_service/cmd/data_service/main.go +++ b/services/data_service/cmd/data_service/main.go @@ -9,7 +9,7 @@ import ( "github.com/lmittmann/tint" pb "gitlab.michelsen.id/phillmichelsen/tessera/pkg/pb/data_service" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/manager" - "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/provider/binance" + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/provider/providers/test" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/router" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/server" "google.golang.org/grpc" @@ -57,8 +57,8 @@ func main() { // Setup r := router.NewRouter(2048) m := manager.NewManager(r) - binanceFutures := binance.NewFuturesWebsocket(r.IncomingChannel()) - if err := m.AddProvider("binance_futures_websocket", binanceFutures); err != nil { + testProvider := test.NewTestProvider(r.IncomingChannel(), time.Microsecond*50) + if err := m.AddProvider("test_provider", testProvider); err != nil { slog.Error("add provider failed", "err", err) os.Exit(1) } diff --git a/services/data_service/internal/manager/helper.go b/services/data_service/internal/manager/helper.go index 1f93814..0f95a5a 100644 --- a/services/data_service/internal/manager/helper.go +++ b/services/data_service/internal/manager/helper.go @@ -1,13 +1,9 @@ package manager import ( - "log/slog" - "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" ) -func lg() *slog.Logger { return slog.Default().With("cmp", "manager") } - func identifierSetDifferences(oldIDs, nextIDs []domain.Identifier) (toAdd, toDel []domain.Identifier) { oldSet := make(map[domain.Identifier]struct{}, len(oldIDs)) for _, id := range oldIDs { diff --git a/services/data_service/internal/manager/manager.go b/services/data_service/internal/manager/manager.go index 075d433..687ba93 100644 --- a/services/data_service/internal/manager/manager.go +++ b/services/data_service/internal/manager/manager.go @@ -42,7 +42,7 @@ func NewManager(r *router.Router) *Manager { go r.Run() go m.run() - lg().Info("manager started") + slog.Default().Info("manager started", slog.String("cmp", "manager")) return m } @@ -51,71 +51,85 @@ func NewManager(r *router.Router) *Manager { // AddProvider adds and starts a new provider. func (m *Manager) AddProvider(name string, p provider.Provider) error { - lg().Debug("add provider request", slog.String("name", name)) + slog.Default().Debug("add provider request", slog.String("cmp", "manager"), slog.String("name", name)) resp := make(chan addProviderResult, 1) m.cmdCh <- addProviderCmd{name: name, p: p, resp: resp} r := <-resp + + slog.Default().Info("provider added", slog.String("cmp", "manager"), slog.String("name", name)) return r.err } // RemoveProvider stops and removes a provider, cleaning up all sessions. func (m *Manager) RemoveProvider(name string) error { - lg().Debug("remove provider request", slog.String("name", name)) + slog.Default().Debug("remove provider request", slog.String("cmp", "manager"), slog.String("name", name)) resp := make(chan removeProviderResult, 1) m.cmdCh <- removeProviderCmd{name: name, resp: resp} r := <-resp + + slog.Default().Info("provider removed", slog.String("cmp", "manager"), slog.String("name", name)) return r.err } // NewSession creates a new session with the given idle timeout. func (m *Manager) NewSession(idleAfter time.Duration) uuid.UUID { - lg().Debug("new session request", slog.Duration("idle_after", idleAfter)) + slog.Default().Debug("new session request", slog.String("cmp", "manager"), slog.Duration("idle_after", idleAfter)) resp := make(chan newSessionResult, 1) m.cmdCh <- newSessionCmd{idleAfter: idleAfter, resp: resp} r := <-resp + + slog.Default().Info("new session created", slog.String("cmp", "manager"), slog.String("session", r.id.String())) return r.id } // AttachClient attaches a client to a session, creates and returns client channels for the session. func (m *Manager) AttachClient(id uuid.UUID, inBuf, outBuf int) (chan<- domain.Message, <-chan domain.Message, error) { - lg().Debug("attach client request", slog.String("session", id.String()), slog.Int("in_buf", inBuf), slog.Int("out_buf", outBuf)) + slog.Default().Debug("attach client request", slog.String("cmp", "manager"), slog.String("session", id.String()), slog.Int("in_buf", inBuf), slog.Int("out_buf", outBuf)) resp := make(chan attachResult, 1) m.cmdCh <- attachCmd{sid: id, inBuf: inBuf, outBuf: outBuf, resp: resp} r := <-resp + + slog.Default().Debug("client attached", slog.String("cmp", "manager"), slog.String("session", id.String())) return r.cin, r.cout, r.err } // DetachClient detaches the client from the session, closes client channels and arms timeout. func (m *Manager) DetachClient(id uuid.UUID) error { - lg().Debug("detach client request", slog.String("session", id.String())) + slog.Default().Debug("detach client request", slog.String("cmp", "manager"), slog.String("session", id.String())) resp := make(chan detachResult, 1) m.cmdCh <- detachCmd{sid: id, resp: resp} r := <-resp + + slog.Default().Debug("client detached", slog.String("cmp", "manager"), slog.String("session", id.String())) return r.err } // ConfigureSession sets the next set of identifiers for the session, starting and stopping streams as needed. func (m *Manager) ConfigureSession(id uuid.UUID, next []domain.Identifier) error { - lg().Debug("configure session request", slog.String("session", id.String()), slog.Int("idents", len(next))) + slog.Default().Debug("configure session request", slog.String("cmp", "manager"), slog.String("session", id.String()), slog.Int("idents", len(next))) resp := make(chan configureResult, 1) m.cmdCh <- configureCmd{sid: id, next: next, resp: resp} r := <-resp + + slog.Default().Debug("session configured", slog.String("cmp", "manager"), slog.String("session", id.String()), slog.String("err", fmt.Sprintf("%v", r.err))) return r.err } // CloseSession closes and removes the session, cleaning up all bindings. func (m *Manager) CloseSession(id uuid.UUID) error { - lg().Debug("close session request", slog.String("session", id.String())) + slog.Default().Debug("close session request", slog.String("cmp", "manager"), slog.String("session", id.String())) resp := make(chan closeSessionResult, 1) m.cmdCh <- closeSessionCmd{sid: id, resp: resp} r := <-resp + + slog.Default().Info("session closed", slog.String("cmp", "manager"), slog.String("session", id.String())) return r.err } @@ -146,12 +160,12 @@ func (m *Manager) run() { func (m *Manager) handleAddProvider(cmd addProviderCmd) { if _, ok := m.providers[cmd.name]; ok { - lg().Warn("provider already exists", slog.String("name", cmd.name)) + slog.Default().Warn("provider already exists", slog.String("cmp", "manager"), slog.String("name", cmd.name)) cmd.resp <- addProviderResult{err: fmt.Errorf("provider exists: %s", cmd.name)} return } if err := cmd.p.Start(); err != nil { - lg().Warn("failed to start provider", slog.String("name", cmd.name), slog.String("err", err.Error())) + slog.Default().Warn("failed to start provider", slog.String("cmp", "manager"), slog.String("name", cmd.name), slog.String("err", err.Error())) cmd.resp <- addProviderResult{err: fmt.Errorf("failed to start provider %s: %w", cmd.name, err)} return } @@ -191,6 +205,10 @@ func (m *Manager) handleAttach(cmd attachCmd) { s.attached = true s.disarmIdleTimer() + for id := range s.bound { + m.router.RegisterRoute(id, cout) + } + cmd.resp <- attachResult{cin: cin, cout: cout, err: nil} } @@ -205,6 +223,10 @@ func (m *Manager) handleDetach(cmd detachCmd) { return } + for id := range s.bound { + m.router.DeregisterRoute(id, s.outChannel) + } + s.clearChannels() s.armIdleTimer(func() { resp := make(chan closeSessionResult, 1) @@ -302,12 +324,18 @@ func (m *Manager) handleConfigure(cmd configureCmd) { removed = append(removed, id) } - // Update the router routes to reflect the new successful bindings - for _, id := range added { - m.router.RegisterRoute(id, s.outChannel) - } - for _, id := range removed { - m.router.DeregisterRoute(id, s.outChannel) + if s.attached { + if s.inChannel == nil || s.outChannel == nil { + errs = errors.Join(errs, fmt.Errorf("channels do not exist despite attached state")) // error should never be hit + slog.Default().Error("no channels despite attached state", slog.String("cmp", "manager"), slog.String("session", cmd.sid.String())) + } else { + for _, id := range added { + m.router.RegisterRoute(id, s.outChannel) + } + for _, id := range removed { + m.router.DeregisterRoute(id, s.outChannel) + } + } } cmd.resp <- configureResult{err: errs} diff --git a/services/data_service/internal/manager/session.go b/services/data_service/internal/manager/session.go index b4067bd..d63076b 100644 --- a/services/data_service/internal/manager/session.go +++ b/services/data_service/internal/manager/session.go @@ -72,15 +72,3 @@ func (s *session) clearChannels() { s.outChannel = nil } } - -func (m *Manager) getSessionChannels(sid uuid.UUID) (chan<- domain.Message, <-chan domain.Message, error) { - s, ok := m.sessions[sid] - if !ok { - return nil, nil, ErrSessionNotFound - } - if !s.attached { - return nil, nil, ErrClientNotAttached - } - - return s.inChannel, s.outChannel, nil -} diff --git a/services/data_service/internal/provider/binance/futures_websocket.go b/services/data_service/internal/provider/binance/futures_websocket.go deleted file mode 100644 index 7129f1e..0000000 --- a/services/data_service/internal/provider/binance/futures_websocket.go +++ /dev/null @@ -1,410 +0,0 @@ -package binance - -import ( - "context" - "encoding/json" - "errors" - "log/slog" - "sync" - "sync/atomic" - "time" - - "github.com/coder/websocket" - "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" -) - -const ( - endpoint = "wss://stream.binance.com:9443/stream" - cmpName = "binance_futures_websocket" - - // I/O limits - readLimitBytes = 8 << 20 - writeTimeout = 5 * time.Second - dialTimeout = 10 * time.Second - reconnectMaxBackoff = 30 * time.Second -) - -type wsReq struct { - Method string `json:"method"` - Params []string `json:"params,omitempty"` - ID uint64 `json:"id"` -} - -type wsAck struct { - Result any `json:"result"` - ID uint64 `json:"id"` -} - -type combinedEvent struct { - Stream string `json:"stream"` - Data json.RawMessage `json:"data"` -} - -type FuturesWebsocket struct { - out chan<- domain.Message - - mu sync.RWMutex - active map[string]bool - - connMu sync.Mutex - conn *websocket.Conn - cancel context.CancelFunc - - reqID atomic.Uint64 - pending map[uint64]chan error - pmu sync.Mutex - - // pumps - writer chan []byte - once sync.Once - stopCh chan struct{} -} - -func NewFuturesWebsocket(out chan<- domain.Message) *FuturesWebsocket { - return &FuturesWebsocket{ - out: out, - active: make(map[string]bool), - pending: make(map[uint64]chan error), - writer: make(chan []byte, 256), - stopCh: make(chan struct{}), - } -} - -func (p *FuturesWebsocket) Start() error { - var startErr error - p.once.Do(func() { - go p.run() - }) - return startErr -} - -func (p *FuturesWebsocket) Stop() { - close(p.stopCh) - p.connMu.Lock() - if p.cancel != nil { - p.cancel() - } - if p.conn != nil { - _ = p.conn.Close(websocket.StatusNormalClosure, "shutdown") - p.conn = nil - } - p.connMu.Unlock() - - // fail pending waiters - p.pmu.Lock() - for id, ch := range p.pending { - ch <- errors.New("provider stopped") - close(ch) - delete(p.pending, id) - } - p.pmu.Unlock() - - slog.Default().Info("stopped", "cmp", cmpName) -} - -func (p *FuturesWebsocket) StartStreams(keys []string) <-chan error { - ch := make(chan error, 1) - go func() { - defer close(ch) - if len(keys) == 0 { - ch <- nil - return - } - id, ack := p.sendReq("SUBSCRIBE", keys) - if ack == nil { - ch <- errors.New("not connected") - slog.Default().Error("subscribe failed; not connected", "cmp", cmpName, "keys", keys) - return - } - if err := <-ack; err != nil { - ch <- err - slog.Default().Error("subscribe NACK", "cmp", cmpName, "id", id, "keys", keys, "err", err) - return - } - p.mu.Lock() - for _, k := range keys { - p.active[k] = true - } - p.mu.Unlock() - slog.Default().Info("subscribed", "cmp", cmpName, "id", id, "keys", keys) - ch <- nil - }() - return ch -} - -func (p *FuturesWebsocket) StopStreams(keys []string) <-chan error { - ch := make(chan error, 1) - go func() { - defer close(ch) - if len(keys) == 0 { - ch <- nil - return - } - id, ack := p.sendReq("UNSUBSCRIBE", keys) - if ack == nil { - ch <- errors.New("not connected") - slog.Default().Error("unsubscribe failed; not connected", "cmp", cmpName, "keys", keys) - return - } - if err := <-ack; err != nil { - ch <- err - slog.Default().Error("unsubscribe NACK", "cmp", cmpName, "id", id, "keys", keys, "err", err) - return - } - p.mu.Lock() - for _, k := range keys { - delete(p.active, k) - } - p.mu.Unlock() - slog.Default().Info("unsubscribed", "cmp", cmpName, "id", id, "keys", keys) - ch <- nil - }() - return ch -} - -func (p *FuturesWebsocket) Fetch(key string) (domain.Message, error) { - return domain.Message{}, errors.New("not implemented") -} - -func (p *FuturesWebsocket) IsStreamActive(key string) bool { - p.mu.RLock() - defer p.mu.RUnlock() - return p.active[key] -} - -func (p *FuturesWebsocket) IsValidSubject(key string, _ bool) bool { - return len(key) > 0 -} - -// internal - -func (p *FuturesWebsocket) run() { - backoff := time.Second - - for { - // stop? - select { - case <-p.stopCh: - return - default: - } - - if err := p.connect(); err != nil { - slog.Default().Error("dial failed", "cmp", cmpName, "err", err) - time.Sleep(backoff) - if backoff < reconnectMaxBackoff { - backoff *= 2 - } - continue - } - backoff = time.Second - - // resubscribe existing keys - func() { - p.mu.RLock() - if len(p.active) > 0 { - keys := make([]string, 0, len(p.active)) - for k := range p.active { - keys = append(keys, k) - } - _, ack := p.sendReq("SUBSCRIBE", keys) - if ack != nil { - if err := <-ack; err != nil { - slog.Default().Warn("resubscribe error", "cmp", cmpName, "err", err) - } else { - slog.Default().Info("resubscribed", "cmp", cmpName, "count", len(keys)) - } - } - } - p.mu.RUnlock() - }() - - // run read and write pumps - ctx, cancel := context.WithCancel(context.Background()) - errc := make(chan error, 2) - go func() { errc <- p.readLoop(ctx) }() - go func() { errc <- p.writeLoop(ctx) }() - - // wait for failure or stop - var err error - select { - case <-p.stopCh: - cancel() - p.cleanupConn() - return - case err = <-errc: - cancel() - } - - // fail pendings on error - p.pmu.Lock() - for id, ch := range p.pending { - ch <- err - close(ch) - delete(p.pending, id) - } - p.pmu.Unlock() - - slog.Default().Error("ws loop error; reconnecting", "cmp", cmpName, "err", err) - p.cleanupConn() - } -} - -func (p *FuturesWebsocket) connect() error { - p.connMu.Lock() - defer p.connMu.Unlock() - if p.conn != nil { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - - c, _, err := websocket.Dial(ctx, endpoint, &websocket.DialOptions{ - CompressionMode: websocket.CompressionDisabled, - OnPingReceived: func(ctx context.Context, _ []byte) bool { - slog.Default().Info("ping received", "cmp", cmpName) - return true - }, - }) - if err != nil { - cancel() - return err - } - - c.SetReadLimit(8 << 20) - - p.conn = c - p.cancel = cancel - slog.Default().Info("connected", "cmp", cmpName, "endpoint", endpoint) - return nil -} - -func (p *FuturesWebsocket) cleanupConn() { - p.connMu.Lock() - defer p.connMu.Unlock() - if p.cancel != nil { - p.cancel() - p.cancel = nil - } - if p.conn != nil { - _ = p.conn.Close(websocket.StatusAbnormalClosure, "reconnect") - p.conn = nil - } -} - -func (p *FuturesWebsocket) writeLoop(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - - case b := <-p.writer: - p.connMu.Lock() - c := p.conn - p.connMu.Unlock() - if c == nil { - return errors.New("conn nil") - } - wctx, cancel := context.WithTimeout(ctx, writeTimeout) - err := c.Write(wctx, websocket.MessageText, b) - cancel() - if err != nil { - return err - } - } - } -} - -func (p *FuturesWebsocket) readLoop(ctx context.Context) error { - slog.Default().Info("read loop started", "cmp", cmpName) - defer slog.Default().Info("read loop exited", "cmp", cmpName) - - for { - p.connMu.Lock() - c := p.conn - p.connMu.Unlock() - if c == nil { - return errors.New("conn nil") - } - - _, data, err := c.Read(ctx) - if err != nil { - return err - } - - // ACK - var ack wsAck - if json.Unmarshal(data, &ack) == nil && ack.ID != 0 { - p.pmu.Lock() - if ch, ok := p.pending[ack.ID]; ok { - if ack.Result == nil { - ch <- nil - slog.Default().Debug("ack ok", "cmp", cmpName, "id", ack.ID) - } else { - resb, _ := json.Marshal(ack.Result) - ch <- errors.New(string(resb)) - slog.Default().Warn("ack error", "cmp", cmpName, "id", ack.ID, "result", string(resb)) - } - close(ch) - delete(p.pending, ack.ID) - } else { - slog.Default().Warn("ack with unknown id", "cmp", cmpName, "id", ack.ID) - } - p.pmu.Unlock() - continue - } - - // Combined stream payload - var evt combinedEvent - if json.Unmarshal(data, &evt) == nil && evt.Stream != "" { - ident, _ := domain.RawID(cmpName, evt.Stream) - msg := domain.Message{ - Identifier: ident, - Payload: evt.Data, - } - select { - case p.out <- msg: - default: - slog.Default().Warn("dropping message since router buffer full", "cmp", cmpName, "stream", evt.Stream) - } - continue - } - - // Unknown frame - const maxSample = 512 - if len(data) > maxSample { - slog.Default().Debug("unparsed frame", "cmp", cmpName, "size", len(data)) - } else { - slog.Default().Debug("unparsed frame", "cmp", cmpName, "size", len(data), "body", string(data)) - } - } -} - -func (p *FuturesWebsocket) sendReq(method string, params []string) (uint64, <-chan error) { - p.connMu.Lock() - c := p.conn - p.connMu.Unlock() - if c == nil { - return 0, nil - } - - id := p.reqID.Add(1) - req := wsReq{Method: method, Params: params, ID: id} - b, _ := json.Marshal(req) - - ack := make(chan error, 1) - p.pmu.Lock() - p.pending[id] = ack - p.pmu.Unlock() - - // enqueue to single writer to avoid concurrent writes - select { - case p.writer <- b: - default: - // avoid blocking the caller; offload - go func() { p.writer <- b }() - } - - slog.Default().Debug("request enqueued", "cmp", cmpName, "id", id, "method", method, "params", params) - return id, ack -} diff --git a/services/data_service/internal/provider/providers/binance/ws/binance_futures.go b/services/data_service/internal/provider/providers/binance/ws/binance_futures.go new file mode 100644 index 0000000..f44db10 --- /dev/null +++ b/services/data_service/internal/provider/providers/binance/ws/binance_futures.go @@ -0,0 +1,52 @@ +package ws + +import ( + "fmt" + "time" + + "github.com/google/uuid" + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" +) + +type BinanceFutures struct { + cfg config + shards map[uuid.UUID]*shard + streamAssignments map[string]*shard +} + +type config struct { + Endpoint string + MaxStreamsPerShard uint8 + BatchInterval time.Duration +} + +func NewBinanceFuturesWebsocket(cfg config) *BinanceFutures { + return &BinanceFutures{ + cfg: cfg, + shards: make(map[uuid.UUID]*shard), + } +} + +func (b *BinanceFutures) Start() error { + return nil +} + +func (b *BinanceFutures) Stop() { + return +} + +func (b *BinanceFutures) Subscribe(subject string) <-chan error { + return nil +} + +func (b *BinanceFutures) Unsubscribe(subject string) <-chan error { + return nil +} + +func (b *BinanceFutures) Fetch(subject string) (domain.Message, error) { + return domain.Message{}, fmt.Errorf("fetch not supported by provider") +} + +func (b *BinanceFutures) GetActiveStreams() []string { return nil } +func (b *BinanceFutures) IsStreamActive(key string) bool { return false } +func (b *BinanceFutures) IsValidSubject(key string, isFetch bool) bool { return false } diff --git a/services/data_service/internal/provider/providers/binance/ws/shard.go b/services/data_service/internal/provider/providers/binance/ws/shard.go new file mode 100644 index 0000000..e22942c --- /dev/null +++ b/services/data_service/internal/provider/providers/binance/ws/shard.go @@ -0,0 +1,12 @@ +package ws + +import ( + "github.com/coder/websocket" + "github.com/google/uuid" +) + +type shard struct { + ID uuid.UUID + conn websocket.Conn + activeStreams []string +} diff --git a/services/data_service/internal/provider/providers/binance/ws/types.go b/services/data_service/internal/provider/providers/binance/ws/types.go new file mode 100644 index 0000000..e69de29 diff --git a/services/data_service/internal/provider/providers/test/test_provider.go b/services/data_service/internal/provider/providers/test/test_provider.go new file mode 100644 index 0000000..75cd2d3 --- /dev/null +++ b/services/data_service/internal/provider/providers/test/test_provider.go @@ -0,0 +1,149 @@ +package test + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" +) + +type TestProvider struct { + mu sync.Mutex + streams map[string]*stream + outputChannel chan<- domain.Message + tickDuration time.Duration +} + +type stream struct { + cancel context.CancelFunc + done chan struct{} +} + +// NewTestProvider wires the outbound channel. +func NewTestProvider(out chan<- domain.Message, tickDuration time.Duration) *TestProvider { + return &TestProvider{ + streams: make(map[string]*stream), + outputChannel: out, + tickDuration: tickDuration, + } +} + +func (t *TestProvider) Start() error { return nil } + +func (t *TestProvider) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + for key, s := range t.streams { + s.cancel() + <-s.done + delete(t.streams, key) + } +} + +func (t *TestProvider) Subscribe(subject string) <-chan error { + errCh := make(chan error, 1) + + if !t.IsValidSubject(subject, false) { + errCh <- errors.New("invalid subject") + close(errCh) + return errCh + } + + t.mu.Lock() + // Already active: treat as success. + if _, ok := t.streams[subject]; ok { + t.mu.Unlock() + errCh <- nil + return errCh + } + + ctx, cancel := context.WithCancel(context.Background()) + s := &stream{cancel: cancel, done: make(chan struct{})} + t.streams[subject] = s + out := t.outputChannel + t.mu.Unlock() + + // Stream goroutine. + go func(subj string, s *stream) { + slog.Default().Debug("new stream routine started", slog.String("cmp", "test_provider"), slog.String("subject", subj)) + ticker := time.NewTicker(t.tickDuration) + ident, _ := domain.RawID("test_provider", subj) + defer func() { + ticker.Stop() + close(s.done) + }() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if out != nil { + msg := domain.Message{ + Identifier: ident, + Payload: []byte(time.Now().UTC().Format(time.RFC3339Nano)), + } + // Non-blocking send avoids deadlock if caller stops reading. + select { + case out <- msg: + default: + slog.Default().Warn("dropping message due to backpressure", "cmp", "test_provider", "subject", subj) + } + } + } + } + }(subject, s) + + // Signal successful subscription. + errCh <- nil + return errCh +} + +func (t *TestProvider) Unsubscribe(subject string) <-chan error { + errCh := make(chan error, 1) + + t.mu.Lock() + s, ok := t.streams[subject] + if !ok { + t.mu.Unlock() + errCh <- errors.New("not subscribed") + return errCh + } + delete(t.streams, subject) + t.mu.Unlock() + + go func() { + s.cancel() + <-s.done + errCh <- nil + }() + return errCh +} + +func (t *TestProvider) Fetch(subject string) (domain.Message, error) { + return domain.Message{}, fmt.Errorf("fetch not supported by provider") +} + +func (t *TestProvider) GetActiveStreams() []string { + t.mu.Lock() + defer t.mu.Unlock() + keys := make([]string, 0, len(t.streams)) + for k := range t.streams { + keys = append(keys, k) + } + return keys +} + +func (t *TestProvider) IsStreamActive(key string) bool { + t.mu.Lock() + _, ok := t.streams[key] + t.mu.Unlock() + return ok +} + +func (t *TestProvider) IsValidSubject(key string, _ bool) bool { + return key != "" +} diff --git a/services/data_service/internal/server/gprc_control_server.go b/services/data_service/internal/server/gprc_control_server.go index 8c206ef..0a5b1a8 100644 --- a/services/data_service/internal/server/gprc_control_server.go +++ b/services/data_service/internal/server/gprc_control_server.go @@ -24,10 +24,7 @@ func NewGRPCControlServer(m *manager.Manager) *GRPCControlServer { // StartStream creates a new session. It does NOT attach client channels. // Your streaming RPC should later call AttachClient(sessionID, opts). func (s *GRPCControlServer) StartStream(_ context.Context, req *pb.StartStreamRequest) (*pb.StartStreamResponse, error) { - sessionID, err := s.manager.NewSession(time.Duration(1) * time.Minute) // timeout set to 1 minute - if err != nil { - return nil, status.Errorf(codes.Internal, "new session: %v", err) - } + sessionID := s.manager.NewSession(time.Duration(1) * time.Minute) // timeout set to 1 minute return &pb.StartStreamResponse{StreamUuid: sessionID.String()}, nil } @@ -56,8 +53,6 @@ func (s *GRPCControlServer) ConfigureStream(_ context.Context, req *pb.Configure switch err { case manager.ErrSessionNotFound: return nil, status.Errorf(codes.NotFound, "session not found: %v", err) - case manager.ErrSessionClosed: - return nil, status.Errorf(codes.FailedPrecondition, "session closed: %v", err) default: return nil, status.Errorf(codes.Internal, "set subscriptions: %v", err) }