Refactor session management: enhance session struct with internal channels, implement client attachment handling, and improve idle session management
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain"
|
||||
@@ -12,13 +14,26 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionClosed = errors.New("session closed")
|
||||
ErrOutboundFull = errors.New("session publish buffer full")
|
||||
ErrInvalidIdentifier = errors.New("invalid identifier")
|
||||
ErrUnknownProvider = errors.New("unknown provider")
|
||||
ErrSessionNotFound = errors.New("session not found")
|
||||
ErrSessionClosed = errors.New("session closed")
|
||||
ErrInvalidIdentifier = errors.New("invalid identifier")
|
||||
ErrUnknownProvider = errors.New("unknown provider")
|
||||
ErrClientAlreadyBound = errors.New("client channels already bound")
|
||||
)
|
||||
|
||||
const (
|
||||
defaultInternalBuf = 1024
|
||||
defaultClientBuf = 256
|
||||
)
|
||||
|
||||
type ChannelOpts struct {
|
||||
InBufSize int
|
||||
OutBufSize int
|
||||
// If true, drop to clientOut when its buffer is full. If false, block.
|
||||
DropOutbound bool
|
||||
}
|
||||
|
||||
// Manager owns providers, sessions, and the router fanout.
|
||||
type Manager struct {
|
||||
providers map[string]provider.Provider
|
||||
providerStreams map[domain.Identifier]chan domain.Message
|
||||
@@ -31,13 +46,25 @@ type Manager struct {
|
||||
}
|
||||
|
||||
type session struct {
|
||||
id uuid.UUID
|
||||
in chan domain.Message
|
||||
mailbox chan domain.Message
|
||||
out chan domain.Message
|
||||
bound map[domain.Identifier]struct{}
|
||||
dropOnFull bool
|
||||
closed bool
|
||||
id uuid.UUID
|
||||
|
||||
// Stable internal channels. Only the session writes internalOut and reads internalIn.
|
||||
internalIn chan domain.Message // forwarded into router.IncomingChannel()
|
||||
internalOut chan domain.Message // registered as router route target
|
||||
|
||||
// Current client attachment (optional). Created by GetChannels.
|
||||
clientIn chan domain.Message // caller writes
|
||||
clientOut chan domain.Message // caller reads
|
||||
|
||||
// Cancels the permanent internalIn forwarder.
|
||||
cancelInternal context.CancelFunc
|
||||
// Cancels current client forwarders.
|
||||
cancelClient context.CancelFunc
|
||||
|
||||
bound map[domain.Identifier]struct{}
|
||||
closed bool
|
||||
idleAfter time.Duration
|
||||
idleTimer *time.Timer
|
||||
}
|
||||
|
||||
func NewManager(r *router.Router) *Manager {
|
||||
@@ -51,45 +78,137 @@ func NewManager(r *router.Router) *Manager {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) NewSession(bufIn, bufOut int, dropOnFull bool) (uuid.UUID, chan<- domain.Message, <-chan domain.Message, error) {
|
||||
if bufIn <= 0 {
|
||||
bufIn = 1024
|
||||
}
|
||||
if bufOut <= 0 {
|
||||
bufOut = 1024
|
||||
}
|
||||
|
||||
// NewSession creates a session with stable internal channels and a permanent
|
||||
// forwarder that pipes internalIn into router.IncomingChannel().
|
||||
func (m *Manager) NewSession(idleAfter time.Duration) (uuid.UUID, error) {
|
||||
s := &session{
|
||||
id: uuid.New(),
|
||||
in: make(chan domain.Message, bufIn),
|
||||
mailbox: make(chan domain.Message, bufOut),
|
||||
out: make(chan domain.Message, bufOut),
|
||||
bound: make(map[domain.Identifier]struct{}),
|
||||
dropOnFull: dropOnFull,
|
||||
id: uuid.New(),
|
||||
internalIn: make(chan domain.Message, defaultInternalBuf),
|
||||
internalOut: make(chan domain.Message, defaultInternalBuf),
|
||||
bound: make(map[domain.Identifier]struct{}),
|
||||
idleAfter: idleAfter,
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.cancelInternal = cancel
|
||||
|
||||
m.mu.Lock()
|
||||
m.sessions[s.id] = s
|
||||
incoming := m.router.IncomingChannel()
|
||||
m.mu.Unlock()
|
||||
|
||||
go func() {
|
||||
for msg := range s.in {
|
||||
incoming <- msg
|
||||
// Permanent forwarder: internalIn -> router.Incoming
|
||||
go func(ctx context.Context, in <-chan domain.Message) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-in:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// Place to filter, validate, meter, or throttle.
|
||||
incoming <- msg
|
||||
}
|
||||
}
|
||||
}()
|
||||
}(ctx, s.internalIn)
|
||||
|
||||
go func() {
|
||||
for msg := range s.mailbox {
|
||||
s.out <- msg
|
||||
}
|
||||
close(s.out)
|
||||
}()
|
||||
|
||||
return s.id, s.in, s.out, nil
|
||||
return s.id, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CloseSession(id uuid.UUID) error {
|
||||
// GetChannels creates a fresh client attachment and hooks both directions:
|
||||
// clientIn -> internalIn and internalOut -> clientOut. Only one attachment at a time.
|
||||
func (m *Manager) GetChannels(id uuid.UUID, opts ChannelOpts) (chan<- domain.Message, <-chan domain.Message, error) {
|
||||
if opts.InBufSize <= 0 {
|
||||
opts.InBufSize = defaultClientBuf
|
||||
}
|
||||
if opts.OutBufSize <= 0 {
|
||||
opts.OutBufSize = defaultClientBuf
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
s, ok := m.sessions[id]
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return nil, nil, ErrSessionNotFound
|
||||
}
|
||||
if s.closed {
|
||||
m.mu.Unlock()
|
||||
return nil, nil, ErrSessionClosed
|
||||
}
|
||||
if s.clientIn != nil || s.clientOut != nil {
|
||||
m.mu.Unlock()
|
||||
return nil, nil, ErrClientAlreadyBound
|
||||
}
|
||||
|
||||
// Create attachment channels.
|
||||
cin := make(chan domain.Message, opts.InBufSize)
|
||||
cout := make(chan domain.Message, opts.OutBufSize)
|
||||
s.clientIn, s.clientOut = cin, cout
|
||||
|
||||
// Stop idle timer while attached.
|
||||
if s.idleTimer != nil {
|
||||
s.idleTimer.Stop()
|
||||
s.idleTimer = nil
|
||||
}
|
||||
|
||||
internalIn := s.internalIn
|
||||
internalOut := s.internalOut
|
||||
|
||||
// Prepare per-attachment cancel.
|
||||
attachCtx, attachCancel := context.WithCancel(context.Background())
|
||||
s.cancelClient = attachCancel
|
||||
|
||||
m.mu.Unlock()
|
||||
|
||||
// Forward clientIn -> internalIn
|
||||
go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-src:
|
||||
if !ok {
|
||||
// Client closed input; stop forwarding.
|
||||
return
|
||||
}
|
||||
// Per-client checks could go here.
|
||||
dst <- msg
|
||||
}
|
||||
}
|
||||
}(attachCtx, cin, internalIn)
|
||||
|
||||
// Forward internalOut -> clientOut
|
||||
go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message, drop bool) {
|
||||
defer close(dst)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-src:
|
||||
if !ok {
|
||||
// Session is closing; signal EOF to client.
|
||||
return
|
||||
}
|
||||
if drop {
|
||||
select {
|
||||
case dst <- msg:
|
||||
default:
|
||||
// Drop on client backpressure. Add metrics if desired.
|
||||
}
|
||||
} else {
|
||||
dst <- msg
|
||||
}
|
||||
}
|
||||
}
|
||||
}(attachCtx, internalOut, cout, opts.DropOutbound)
|
||||
|
||||
// Return directional views.
|
||||
return (chan<- domain.Message)(cin), (<-chan domain.Message)(cout), nil
|
||||
}
|
||||
|
||||
// DetachClient cancels current client forwarders and clears the attachment.
|
||||
// It starts the idle close timer if configured.
|
||||
func (m *Manager) DetachClient(id uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
s, ok := m.sessions[id]
|
||||
if !ok {
|
||||
@@ -98,37 +217,44 @@ func (m *Manager) CloseSession(id uuid.UUID) error {
|
||||
}
|
||||
if s.closed {
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
return ErrSessionClosed
|
||||
}
|
||||
s.closed = true
|
||||
|
||||
var ids []domain.Identifier
|
||||
for k := range s.bound {
|
||||
ids = append(ids, k)
|
||||
delete(s.bound, k)
|
||||
}
|
||||
delete(m.sessions, id)
|
||||
// Capture and clear client state.
|
||||
cancel := s.cancelClient
|
||||
cin := s.clientIn
|
||||
s.cancelClient = nil
|
||||
s.clientIn, s.clientOut = nil, nil
|
||||
after := s.idleAfter
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, ident := range ids {
|
||||
m.router.DeregisterRoute(ident, s.mailbox)
|
||||
if ident.IsRaw() {
|
||||
m.releaseRawStreamIfUnused(ident)
|
||||
}
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
// Close clientIn to terminate clientIn->internalIn forwarder if client forgot.
|
||||
if cin != nil {
|
||||
close(cin)
|
||||
}
|
||||
|
||||
close(s.mailbox)
|
||||
close(s.in)
|
||||
if after > 0 {
|
||||
m.mu.Lock()
|
||||
ss, ok := m.sessions[id]
|
||||
if ok && !ss.closed && ss.clientOut == nil && ss.idleTimer == nil {
|
||||
ss.idleTimer = time.AfterFunc(after, func() { _ = m.CloseSession(id) })
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
m.mu.Lock()
|
||||
s, ok := m.sessions[id]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
out := s.internalOut
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, ident := range ids {
|
||||
m.mu.Lock()
|
||||
@@ -136,6 +262,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
m.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
s.bound[ident] = struct{}{}
|
||||
m.mu.Unlock()
|
||||
|
||||
if ident.IsRaw() {
|
||||
@@ -143,11 +270,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
s.bound[ident] = struct{}{}
|
||||
m.mu.Unlock()
|
||||
m.router.RegisterRoute(ident, s.mailbox)
|
||||
m.router.RegisterRoute(ident, out)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -155,10 +278,12 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
m.mu.Lock()
|
||||
s, ok := m.sessions[id]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
m.mu.Unlock()
|
||||
return ErrSessionNotFound
|
||||
}
|
||||
out := s.internalOut
|
||||
m.mu.Unlock()
|
||||
|
||||
for _, ident := range ids {
|
||||
m.mu.Lock()
|
||||
@@ -169,7 +294,7 @@ func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||
delete(s.bound, ident)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.router.DeregisterRoute(ident, s.mailbox)
|
||||
m.router.DeregisterRoute(ident, out)
|
||||
if ident.IsRaw() {
|
||||
m.releaseRawStreamIfUnused(ident)
|
||||
}
|
||||
@@ -188,23 +313,41 @@ func (m *Manager) SetSubscriptions(id uuid.UUID, next []domain.Identifier) error
|
||||
for k := range s.bound {
|
||||
old[k] = struct{}{}
|
||||
}
|
||||
out := s.internalOut
|
||||
m.mu.Unlock()
|
||||
|
||||
toAdd, toDel := m.identifierSetDifferences(old, next)
|
||||
if len(toAdd) > 0 {
|
||||
if err := m.Subscribe(id, toAdd...); err != nil {
|
||||
return err
|
||||
|
||||
for _, ident := range toAdd {
|
||||
m.mu.Lock()
|
||||
s.bound[ident] = struct{}{}
|
||||
m.mu.Unlock()
|
||||
|
||||
if ident.IsRaw() {
|
||||
if err := m.provisionRawStream(ident); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.router.RegisterRoute(ident, out)
|
||||
}
|
||||
if len(toDel) > 0 {
|
||||
if err := m.Unsubscribe(id, toDel...); err != nil {
|
||||
return err
|
||||
|
||||
for _, ident := range toDel {
|
||||
m.mu.Lock()
|
||||
_, exists := s.bound[ident]
|
||||
delete(s.bound, ident)
|
||||
m.mu.Unlock()
|
||||
|
||||
if exists {
|
||||
m.router.DeregisterRoute(ident, out)
|
||||
if ident.IsRaw() {
|
||||
m.releaseRawStreamIfUnused(ident)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error {
|
||||
func (m *Manager) CloseSession(id uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
s, ok := m.sessions[id]
|
||||
if !ok {
|
||||
@@ -213,21 +356,49 @@ func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error {
|
||||
}
|
||||
if s.closed {
|
||||
m.mu.Unlock()
|
||||
return ErrSessionClosed
|
||||
return nil
|
||||
}
|
||||
ch := s.in
|
||||
drop := s.dropOnFull
|
||||
s.closed = true
|
||||
if s.idleTimer != nil {
|
||||
s.idleTimer.Stop()
|
||||
s.idleTimer = nil
|
||||
}
|
||||
out := s.internalOut
|
||||
ids := make([]domain.Identifier, 0, len(s.bound))
|
||||
for k := range s.bound {
|
||||
ids = append(ids, k)
|
||||
}
|
||||
cancelInternal := s.cancelInternal
|
||||
cancelClient := s.cancelClient
|
||||
// Clear attachments before unlock to avoid races.
|
||||
s.cancelClient = nil
|
||||
cin := s.clientIn
|
||||
s.clientIn, s.clientOut = nil, nil
|
||||
delete(m.sessions, id)
|
||||
m.mu.Unlock()
|
||||
|
||||
if drop {
|
||||
select {
|
||||
case ch <- msg:
|
||||
return nil
|
||||
default:
|
||||
return ErrOutboundFull
|
||||
// Deregister all routes and release raw streams.
|
||||
for _, ident := range ids {
|
||||
m.router.DeregisterRoute(ident, out)
|
||||
if ident.IsRaw() {
|
||||
m.releaseRawStreamIfUnused(ident)
|
||||
}
|
||||
}
|
||||
ch <- msg
|
||||
|
||||
// Stop forwarders and close internal channels.
|
||||
if cancelClient != nil {
|
||||
cancelClient()
|
||||
}
|
||||
if cancelInternal != nil {
|
||||
cancelInternal()
|
||||
}
|
||||
close(s.internalIn)
|
||||
close(s.internalOut) // will close clientOut via forwarder
|
||||
|
||||
// Close clientIn to ensure its forwarder exits even if client forgot.
|
||||
if cin != nil {
|
||||
close(cin)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -256,11 +427,10 @@ func (m *Manager) RemoveProvider(name string) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("provider not found: %s", name)
|
||||
}
|
||||
// Optional: implement full drain and cancel of all streams for this provider.
|
||||
return fmt.Errorf("RemoveProvider not implemented")
|
||||
}
|
||||
|
||||
// helpers
|
||||
|
||||
func (m *Manager) provisionRawStream(id domain.Identifier) error {
|
||||
providerName, subject, ok := id.ProviderSubject()
|
||||
if !ok || providerName == "" || subject == "" {
|
||||
@@ -295,6 +465,7 @@ func (m *Manager) provisionRawStream(id domain.Identifier) error {
|
||||
incoming := m.router.IncomingChannel()
|
||||
m.mu.Unlock()
|
||||
|
||||
// Provider stream -> router.Incoming
|
||||
go func(c chan domain.Message) {
|
||||
for msg := range c {
|
||||
incoming <- msg
|
||||
|
||||
Reference in New Issue
Block a user