From 40e5ce9708197d3e70a842e2ce5ff70b1d604670 Mon Sep 17 00:00:00 2001 From: Phillip Michelsen Date: Sun, 24 Aug 2025 14:02:21 +0700 Subject: [PATCH] Refactor session management: enhance session struct with internal channels, implement client attachment handling, and improve idle session management --- .../data_service/internal/manager/manager.go | 335 +++++++++++++----- 1 file changed, 253 insertions(+), 82 deletions(-) diff --git a/services/data_service/internal/manager/manager.go b/services/data_service/internal/manager/manager.go index e40e45f..9d5ae94 100644 --- a/services/data_service/internal/manager/manager.go +++ b/services/data_service/internal/manager/manager.go @@ -1,9 +1,11 @@ package manager import ( + "context" "errors" "fmt" "sync" + "time" "github.com/google/uuid" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" @@ -12,13 +14,26 @@ import ( ) var ( - ErrSessionNotFound = errors.New("session not found") - ErrSessionClosed = errors.New("session closed") - ErrOutboundFull = errors.New("session publish buffer full") - ErrInvalidIdentifier = errors.New("invalid identifier") - ErrUnknownProvider = errors.New("unknown provider") + ErrSessionNotFound = errors.New("session not found") + ErrSessionClosed = errors.New("session closed") + ErrInvalidIdentifier = errors.New("invalid identifier") + ErrUnknownProvider = errors.New("unknown provider") + ErrClientAlreadyBound = errors.New("client channels already bound") ) +const ( + defaultInternalBuf = 1024 + defaultClientBuf = 256 +) + +type ChannelOpts struct { + InBufSize int + OutBufSize int + // If true, drop to clientOut when its buffer is full. If false, block. + DropOutbound bool +} + +// Manager owns providers, sessions, and the router fanout. type Manager struct { providers map[string]provider.Provider providerStreams map[domain.Identifier]chan domain.Message @@ -31,13 +46,25 @@ type Manager struct { } type session struct { - id uuid.UUID - in chan domain.Message - mailbox chan domain.Message - out chan domain.Message - bound map[domain.Identifier]struct{} - dropOnFull bool - closed bool + id uuid.UUID + + // Stable internal channels. Only the session writes internalOut and reads internalIn. + internalIn chan domain.Message // forwarded into router.IncomingChannel() + internalOut chan domain.Message // registered as router route target + + // Current client attachment (optional). Created by GetChannels. + clientIn chan domain.Message // caller writes + clientOut chan domain.Message // caller reads + + // Cancels the permanent internalIn forwarder. + cancelInternal context.CancelFunc + // Cancels current client forwarders. + cancelClient context.CancelFunc + + bound map[domain.Identifier]struct{} + closed bool + idleAfter time.Duration + idleTimer *time.Timer } func NewManager(r *router.Router) *Manager { @@ -51,45 +78,137 @@ func NewManager(r *router.Router) *Manager { } } -func (m *Manager) NewSession(bufIn, bufOut int, dropOnFull bool) (uuid.UUID, chan<- domain.Message, <-chan domain.Message, error) { - if bufIn <= 0 { - bufIn = 1024 - } - if bufOut <= 0 { - bufOut = 1024 - } - +// NewSession creates a session with stable internal channels and a permanent +// forwarder that pipes internalIn into router.IncomingChannel(). +func (m *Manager) NewSession(idleAfter time.Duration) (uuid.UUID, error) { s := &session{ - id: uuid.New(), - in: make(chan domain.Message, bufIn), - mailbox: make(chan domain.Message, bufOut), - out: make(chan domain.Message, bufOut), - bound: make(map[domain.Identifier]struct{}), - dropOnFull: dropOnFull, + id: uuid.New(), + internalIn: make(chan domain.Message, defaultInternalBuf), + internalOut: make(chan domain.Message, defaultInternalBuf), + bound: make(map[domain.Identifier]struct{}), + idleAfter: idleAfter, } + ctx, cancel := context.WithCancel(context.Background()) + s.cancelInternal = cancel m.mu.Lock() m.sessions[s.id] = s incoming := m.router.IncomingChannel() m.mu.Unlock() - go func() { - for msg := range s.in { - incoming <- msg + // Permanent forwarder: internalIn -> router.Incoming + go func(ctx context.Context, in <-chan domain.Message) { + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-in: + if !ok { + return + } + // Place to filter, validate, meter, or throttle. + incoming <- msg + } } - }() + }(ctx, s.internalIn) - go func() { - for msg := range s.mailbox { - s.out <- msg - } - close(s.out) - }() - - return s.id, s.in, s.out, nil + return s.id, nil } -func (m *Manager) CloseSession(id uuid.UUID) error { +// GetChannels creates a fresh client attachment and hooks both directions: +// clientIn -> internalIn and internalOut -> clientOut. Only one attachment at a time. +func (m *Manager) GetChannels(id uuid.UUID, opts ChannelOpts) (chan<- domain.Message, <-chan domain.Message, error) { + if opts.InBufSize <= 0 { + opts.InBufSize = defaultClientBuf + } + if opts.OutBufSize <= 0 { + opts.OutBufSize = defaultClientBuf + } + + m.mu.Lock() + s, ok := m.sessions[id] + if !ok { + m.mu.Unlock() + return nil, nil, ErrSessionNotFound + } + if s.closed { + m.mu.Unlock() + return nil, nil, ErrSessionClosed + } + if s.clientIn != nil || s.clientOut != nil { + m.mu.Unlock() + return nil, nil, ErrClientAlreadyBound + } + + // Create attachment channels. + cin := make(chan domain.Message, opts.InBufSize) + cout := make(chan domain.Message, opts.OutBufSize) + s.clientIn, s.clientOut = cin, cout + + // Stop idle timer while attached. + if s.idleTimer != nil { + s.idleTimer.Stop() + s.idleTimer = nil + } + + internalIn := s.internalIn + internalOut := s.internalOut + + // Prepare per-attachment cancel. + attachCtx, attachCancel := context.WithCancel(context.Background()) + s.cancelClient = attachCancel + + m.mu.Unlock() + + // Forward clientIn -> internalIn + go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message) { + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-src: + if !ok { + // Client closed input; stop forwarding. + return + } + // Per-client checks could go here. + dst <- msg + } + } + }(attachCtx, cin, internalIn) + + // Forward internalOut -> clientOut + go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message, drop bool) { + defer close(dst) + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-src: + if !ok { + // Session is closing; signal EOF to client. + return + } + if drop { + select { + case dst <- msg: + default: + // Drop on client backpressure. Add metrics if desired. + } + } else { + dst <- msg + } + } + } + }(attachCtx, internalOut, cout, opts.DropOutbound) + + // Return directional views. + return (chan<- domain.Message)(cin), (<-chan domain.Message)(cout), nil +} + +// DetachClient cancels current client forwarders and clears the attachment. +// It starts the idle close timer if configured. +func (m *Manager) DetachClient(id uuid.UUID) error { m.mu.Lock() s, ok := m.sessions[id] if !ok { @@ -98,37 +217,44 @@ func (m *Manager) CloseSession(id uuid.UUID) error { } if s.closed { m.mu.Unlock() - return nil + return ErrSessionClosed } - s.closed = true - - var ids []domain.Identifier - for k := range s.bound { - ids = append(ids, k) - delete(s.bound, k) - } - delete(m.sessions, id) + // Capture and clear client state. + cancel := s.cancelClient + cin := s.clientIn + s.cancelClient = nil + s.clientIn, s.clientOut = nil, nil + after := s.idleAfter m.mu.Unlock() - for _, ident := range ids { - m.router.DeregisterRoute(ident, s.mailbox) - if ident.IsRaw() { - m.releaseRawStreamIfUnused(ident) - } + if cancel != nil { + cancel() + } + // Close clientIn to terminate clientIn->internalIn forwarder if client forgot. + if cin != nil { + close(cin) } - close(s.mailbox) - close(s.in) + if after > 0 { + m.mu.Lock() + ss, ok := m.sessions[id] + if ok && !ss.closed && ss.clientOut == nil && ss.idleTimer == nil { + ss.idleTimer = time.AfterFunc(after, func() { _ = m.CloseSession(id) }) + } + m.mu.Unlock() + } return nil } func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error { m.mu.Lock() s, ok := m.sessions[id] - m.mu.Unlock() if !ok { + m.mu.Unlock() return ErrSessionNotFound } + out := s.internalOut + m.mu.Unlock() for _, ident := range ids { m.mu.Lock() @@ -136,6 +262,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error { m.mu.Unlock() continue } + s.bound[ident] = struct{}{} m.mu.Unlock() if ident.IsRaw() { @@ -143,11 +270,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error { return err } } - - m.mu.Lock() - s.bound[ident] = struct{}{} - m.mu.Unlock() - m.router.RegisterRoute(ident, s.mailbox) + m.router.RegisterRoute(ident, out) } return nil } @@ -155,10 +278,12 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error { func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error { m.mu.Lock() s, ok := m.sessions[id] - m.mu.Unlock() if !ok { + m.mu.Unlock() return ErrSessionNotFound } + out := s.internalOut + m.mu.Unlock() for _, ident := range ids { m.mu.Lock() @@ -169,7 +294,7 @@ func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error { delete(s.bound, ident) m.mu.Unlock() - m.router.DeregisterRoute(ident, s.mailbox) + m.router.DeregisterRoute(ident, out) if ident.IsRaw() { m.releaseRawStreamIfUnused(ident) } @@ -188,23 +313,41 @@ func (m *Manager) SetSubscriptions(id uuid.UUID, next []domain.Identifier) error for k := range s.bound { old[k] = struct{}{} } + out := s.internalOut m.mu.Unlock() toAdd, toDel := m.identifierSetDifferences(old, next) - if len(toAdd) > 0 { - if err := m.Subscribe(id, toAdd...); err != nil { - return err + + for _, ident := range toAdd { + m.mu.Lock() + s.bound[ident] = struct{}{} + m.mu.Unlock() + + if ident.IsRaw() { + if err := m.provisionRawStream(ident); err != nil { + return err + } } + m.router.RegisterRoute(ident, out) } - if len(toDel) > 0 { - if err := m.Unsubscribe(id, toDel...); err != nil { - return err + + for _, ident := range toDel { + m.mu.Lock() + _, exists := s.bound[ident] + delete(s.bound, ident) + m.mu.Unlock() + + if exists { + m.router.DeregisterRoute(ident, out) + if ident.IsRaw() { + m.releaseRawStreamIfUnused(ident) + } } } return nil } -func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error { +func (m *Manager) CloseSession(id uuid.UUID) error { m.mu.Lock() s, ok := m.sessions[id] if !ok { @@ -213,21 +356,49 @@ func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error { } if s.closed { m.mu.Unlock() - return ErrSessionClosed + return nil } - ch := s.in - drop := s.dropOnFull + s.closed = true + if s.idleTimer != nil { + s.idleTimer.Stop() + s.idleTimer = nil + } + out := s.internalOut + ids := make([]domain.Identifier, 0, len(s.bound)) + for k := range s.bound { + ids = append(ids, k) + } + cancelInternal := s.cancelInternal + cancelClient := s.cancelClient + // Clear attachments before unlock to avoid races. + s.cancelClient = nil + cin := s.clientIn + s.clientIn, s.clientOut = nil, nil + delete(m.sessions, id) m.mu.Unlock() - if drop { - select { - case ch <- msg: - return nil - default: - return ErrOutboundFull + // Deregister all routes and release raw streams. + for _, ident := range ids { + m.router.DeregisterRoute(ident, out) + if ident.IsRaw() { + m.releaseRawStreamIfUnused(ident) } } - ch <- msg + + // Stop forwarders and close internal channels. + if cancelClient != nil { + cancelClient() + } + if cancelInternal != nil { + cancelInternal() + } + close(s.internalIn) + close(s.internalOut) // will close clientOut via forwarder + + // Close clientIn to ensure its forwarder exits even if client forgot. + if cin != nil { + close(cin) + } return nil } @@ -256,11 +427,10 @@ func (m *Manager) RemoveProvider(name string) error { if !ok { return fmt.Errorf("provider not found: %s", name) } + // Optional: implement full drain and cancel of all streams for this provider. return fmt.Errorf("RemoveProvider not implemented") } -// helpers - func (m *Manager) provisionRawStream(id domain.Identifier) error { providerName, subject, ok := id.ProviderSubject() if !ok || providerName == "" || subject == "" { @@ -295,6 +465,7 @@ func (m *Manager) provisionRawStream(id domain.Identifier) error { incoming := m.router.IncomingChannel() m.mu.Unlock() + // Provider stream -> router.Incoming go func(c chan domain.Message) { for msg := range c { incoming <- msg