Refactor session management: enhance session struct with internal channels, implement client attachment handling, and improve idle session management
This commit is contained in:
@@ -1,9 +1,11 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain"
|
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain"
|
||||||
@@ -14,11 +16,24 @@ import (
|
|||||||
var (
|
var (
|
||||||
ErrSessionNotFound = errors.New("session not found")
|
ErrSessionNotFound = errors.New("session not found")
|
||||||
ErrSessionClosed = errors.New("session closed")
|
ErrSessionClosed = errors.New("session closed")
|
||||||
ErrOutboundFull = errors.New("session publish buffer full")
|
|
||||||
ErrInvalidIdentifier = errors.New("invalid identifier")
|
ErrInvalidIdentifier = errors.New("invalid identifier")
|
||||||
ErrUnknownProvider = errors.New("unknown provider")
|
ErrUnknownProvider = errors.New("unknown provider")
|
||||||
|
ErrClientAlreadyBound = errors.New("client channels already bound")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultInternalBuf = 1024
|
||||||
|
defaultClientBuf = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChannelOpts struct {
|
||||||
|
InBufSize int
|
||||||
|
OutBufSize int
|
||||||
|
// If true, drop to clientOut when its buffer is full. If false, block.
|
||||||
|
DropOutbound bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager owns providers, sessions, and the router fanout.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
providers map[string]provider.Provider
|
providers map[string]provider.Provider
|
||||||
providerStreams map[domain.Identifier]chan domain.Message
|
providerStreams map[domain.Identifier]chan domain.Message
|
||||||
@@ -32,12 +47,24 @@ type Manager struct {
|
|||||||
|
|
||||||
type session struct {
|
type session struct {
|
||||||
id uuid.UUID
|
id uuid.UUID
|
||||||
in chan domain.Message
|
|
||||||
mailbox chan domain.Message
|
// Stable internal channels. Only the session writes internalOut and reads internalIn.
|
||||||
out chan domain.Message
|
internalIn chan domain.Message // forwarded into router.IncomingChannel()
|
||||||
|
internalOut chan domain.Message // registered as router route target
|
||||||
|
|
||||||
|
// Current client attachment (optional). Created by GetChannels.
|
||||||
|
clientIn chan domain.Message // caller writes
|
||||||
|
clientOut chan domain.Message // caller reads
|
||||||
|
|
||||||
|
// Cancels the permanent internalIn forwarder.
|
||||||
|
cancelInternal context.CancelFunc
|
||||||
|
// Cancels current client forwarders.
|
||||||
|
cancelClient context.CancelFunc
|
||||||
|
|
||||||
bound map[domain.Identifier]struct{}
|
bound map[domain.Identifier]struct{}
|
||||||
dropOnFull bool
|
|
||||||
closed bool
|
closed bool
|
||||||
|
idleAfter time.Duration
|
||||||
|
idleTimer *time.Timer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(r *router.Router) *Manager {
|
func NewManager(r *router.Router) *Manager {
|
||||||
@@ -51,45 +78,137 @@ func NewManager(r *router.Router) *Manager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) NewSession(bufIn, bufOut int, dropOnFull bool) (uuid.UUID, chan<- domain.Message, <-chan domain.Message, error) {
|
// NewSession creates a session with stable internal channels and a permanent
|
||||||
if bufIn <= 0 {
|
// forwarder that pipes internalIn into router.IncomingChannel().
|
||||||
bufIn = 1024
|
func (m *Manager) NewSession(idleAfter time.Duration) (uuid.UUID, error) {
|
||||||
}
|
|
||||||
if bufOut <= 0 {
|
|
||||||
bufOut = 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &session{
|
s := &session{
|
||||||
id: uuid.New(),
|
id: uuid.New(),
|
||||||
in: make(chan domain.Message, bufIn),
|
internalIn: make(chan domain.Message, defaultInternalBuf),
|
||||||
mailbox: make(chan domain.Message, bufOut),
|
internalOut: make(chan domain.Message, defaultInternalBuf),
|
||||||
out: make(chan domain.Message, bufOut),
|
|
||||||
bound: make(map[domain.Identifier]struct{}),
|
bound: make(map[domain.Identifier]struct{}),
|
||||||
dropOnFull: dropOnFull,
|
idleAfter: idleAfter,
|
||||||
}
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s.cancelInternal = cancel
|
||||||
|
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
m.sessions[s.id] = s
|
m.sessions[s.id] = s
|
||||||
incoming := m.router.IncomingChannel()
|
incoming := m.router.IncomingChannel()
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
go func() {
|
// Permanent forwarder: internalIn -> router.Incoming
|
||||||
for msg := range s.in {
|
go func(ctx context.Context, in <-chan domain.Message) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-in:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Place to filter, validate, meter, or throttle.
|
||||||
incoming <- msg
|
incoming <- msg
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for msg := range s.mailbox {
|
|
||||||
s.out <- msg
|
|
||||||
}
|
}
|
||||||
close(s.out)
|
}(ctx, s.internalIn)
|
||||||
}()
|
|
||||||
|
|
||||||
return s.id, s.in, s.out, nil
|
return s.id, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) CloseSession(id uuid.UUID) error {
|
// GetChannels creates a fresh client attachment and hooks both directions:
|
||||||
|
// clientIn -> internalIn and internalOut -> clientOut. Only one attachment at a time.
|
||||||
|
func (m *Manager) GetChannels(id uuid.UUID, opts ChannelOpts) (chan<- domain.Message, <-chan domain.Message, error) {
|
||||||
|
if opts.InBufSize <= 0 {
|
||||||
|
opts.InBufSize = defaultClientBuf
|
||||||
|
}
|
||||||
|
if opts.OutBufSize <= 0 {
|
||||||
|
opts.OutBufSize = defaultClientBuf
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
s, ok := m.sessions[id]
|
||||||
|
if !ok {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return nil, nil, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
if s.closed {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return nil, nil, ErrSessionClosed
|
||||||
|
}
|
||||||
|
if s.clientIn != nil || s.clientOut != nil {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return nil, nil, ErrClientAlreadyBound
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create attachment channels.
|
||||||
|
cin := make(chan domain.Message, opts.InBufSize)
|
||||||
|
cout := make(chan domain.Message, opts.OutBufSize)
|
||||||
|
s.clientIn, s.clientOut = cin, cout
|
||||||
|
|
||||||
|
// Stop idle timer while attached.
|
||||||
|
if s.idleTimer != nil {
|
||||||
|
s.idleTimer.Stop()
|
||||||
|
s.idleTimer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
internalIn := s.internalIn
|
||||||
|
internalOut := s.internalOut
|
||||||
|
|
||||||
|
// Prepare per-attachment cancel.
|
||||||
|
attachCtx, attachCancel := context.WithCancel(context.Background())
|
||||||
|
s.cancelClient = attachCancel
|
||||||
|
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Forward clientIn -> internalIn
|
||||||
|
go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-src:
|
||||||
|
if !ok {
|
||||||
|
// Client closed input; stop forwarding.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Per-client checks could go here.
|
||||||
|
dst <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(attachCtx, cin, internalIn)
|
||||||
|
|
||||||
|
// Forward internalOut -> clientOut
|
||||||
|
go func(ctx context.Context, src <-chan domain.Message, dst chan<- domain.Message, drop bool) {
|
||||||
|
defer close(dst)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case msg, ok := <-src:
|
||||||
|
if !ok {
|
||||||
|
// Session is closing; signal EOF to client.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if drop {
|
||||||
|
select {
|
||||||
|
case dst <- msg:
|
||||||
|
default:
|
||||||
|
// Drop on client backpressure. Add metrics if desired.
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dst <- msg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(attachCtx, internalOut, cout, opts.DropOutbound)
|
||||||
|
|
||||||
|
// Return directional views.
|
||||||
|
return (chan<- domain.Message)(cin), (<-chan domain.Message)(cout), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetachClient cancels current client forwarders and clears the attachment.
|
||||||
|
// It starts the idle close timer if configured.
|
||||||
|
func (m *Manager) DetachClient(id uuid.UUID) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
s, ok := m.sessions[id]
|
s, ok := m.sessions[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -98,37 +217,44 @@ func (m *Manager) CloseSession(id uuid.UUID) error {
|
|||||||
}
|
}
|
||||||
if s.closed {
|
if s.closed {
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
return nil
|
return ErrSessionClosed
|
||||||
}
|
}
|
||||||
s.closed = true
|
// Capture and clear client state.
|
||||||
|
cancel := s.cancelClient
|
||||||
var ids []domain.Identifier
|
cin := s.clientIn
|
||||||
for k := range s.bound {
|
s.cancelClient = nil
|
||||||
ids = append(ids, k)
|
s.clientIn, s.clientOut = nil, nil
|
||||||
delete(s.bound, k)
|
after := s.idleAfter
|
||||||
}
|
|
||||||
delete(m.sessions, id)
|
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
for _, ident := range ids {
|
if cancel != nil {
|
||||||
m.router.DeregisterRoute(ident, s.mailbox)
|
cancel()
|
||||||
if ident.IsRaw() {
|
|
||||||
m.releaseRawStreamIfUnused(ident)
|
|
||||||
}
|
}
|
||||||
|
// Close clientIn to terminate clientIn->internalIn forwarder if client forgot.
|
||||||
|
if cin != nil {
|
||||||
|
close(cin)
|
||||||
}
|
}
|
||||||
|
|
||||||
close(s.mailbox)
|
if after > 0 {
|
||||||
close(s.in)
|
m.mu.Lock()
|
||||||
|
ss, ok := m.sessions[id]
|
||||||
|
if ok && !ss.closed && ss.clientOut == nil && ss.idleTimer == nil {
|
||||||
|
ss.idleTimer = time.AfterFunc(after, func() { _ = m.CloseSession(id) })
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
s, ok := m.sessions[id]
|
s, ok := m.sessions[id]
|
||||||
m.mu.Unlock()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
m.mu.Unlock()
|
||||||
return ErrSessionNotFound
|
return ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
out := s.internalOut
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
for _, ident := range ids {
|
for _, ident := range ids {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -136,6 +262,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
|||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
s.bound[ident] = struct{}{}
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
if ident.IsRaw() {
|
if ident.IsRaw() {
|
||||||
@@ -143,11 +270,7 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
m.router.RegisterRoute(ident, out)
|
||||||
m.mu.Lock()
|
|
||||||
s.bound[ident] = struct{}{}
|
|
||||||
m.mu.Unlock()
|
|
||||||
m.router.RegisterRoute(ident, s.mailbox)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -155,10 +278,12 @@ func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
|||||||
func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
s, ok := m.sessions[id]
|
s, ok := m.sessions[id]
|
||||||
m.mu.Unlock()
|
|
||||||
if !ok {
|
if !ok {
|
||||||
|
m.mu.Unlock()
|
||||||
return ErrSessionNotFound
|
return ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
out := s.internalOut
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
for _, ident := range ids {
|
for _, ident := range ids {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -169,7 +294,7 @@ func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error {
|
|||||||
delete(s.bound, ident)
|
delete(s.bound, ident)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
m.router.DeregisterRoute(ident, s.mailbox)
|
m.router.DeregisterRoute(ident, out)
|
||||||
if ident.IsRaw() {
|
if ident.IsRaw() {
|
||||||
m.releaseRawStreamIfUnused(ident)
|
m.releaseRawStreamIfUnused(ident)
|
||||||
}
|
}
|
||||||
@@ -188,23 +313,41 @@ func (m *Manager) SetSubscriptions(id uuid.UUID, next []domain.Identifier) error
|
|||||||
for k := range s.bound {
|
for k := range s.bound {
|
||||||
old[k] = struct{}{}
|
old[k] = struct{}{}
|
||||||
}
|
}
|
||||||
|
out := s.internalOut
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
toAdd, toDel := m.identifierSetDifferences(old, next)
|
toAdd, toDel := m.identifierSetDifferences(old, next)
|
||||||
if len(toAdd) > 0 {
|
|
||||||
if err := m.Subscribe(id, toAdd...); err != nil {
|
for _, ident := range toAdd {
|
||||||
|
m.mu.Lock()
|
||||||
|
s.bound[ident] = struct{}{}
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if ident.IsRaw() {
|
||||||
|
if err := m.provisionRawStream(ident); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(toDel) > 0 {
|
m.router.RegisterRoute(ident, out)
|
||||||
if err := m.Unsubscribe(id, toDel...); err != nil {
|
}
|
||||||
return err
|
|
||||||
|
for _, ident := range toDel {
|
||||||
|
m.mu.Lock()
|
||||||
|
_, exists := s.bound[ident]
|
||||||
|
delete(s.bound, ident)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
m.router.DeregisterRoute(ident, out)
|
||||||
|
if ident.IsRaw() {
|
||||||
|
m.releaseRawStreamIfUnused(ident)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error {
|
func (m *Manager) CloseSession(id uuid.UUID) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
s, ok := m.sessions[id]
|
s, ok := m.sessions[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -213,21 +356,49 @@ func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error {
|
|||||||
}
|
}
|
||||||
if s.closed {
|
if s.closed {
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
return ErrSessionClosed
|
return nil
|
||||||
}
|
}
|
||||||
ch := s.in
|
s.closed = true
|
||||||
drop := s.dropOnFull
|
if s.idleTimer != nil {
|
||||||
|
s.idleTimer.Stop()
|
||||||
|
s.idleTimer = nil
|
||||||
|
}
|
||||||
|
out := s.internalOut
|
||||||
|
ids := make([]domain.Identifier, 0, len(s.bound))
|
||||||
|
for k := range s.bound {
|
||||||
|
ids = append(ids, k)
|
||||||
|
}
|
||||||
|
cancelInternal := s.cancelInternal
|
||||||
|
cancelClient := s.cancelClient
|
||||||
|
// Clear attachments before unlock to avoid races.
|
||||||
|
s.cancelClient = nil
|
||||||
|
cin := s.clientIn
|
||||||
|
s.clientIn, s.clientOut = nil, nil
|
||||||
|
delete(m.sessions, id)
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
if drop {
|
// Deregister all routes and release raw streams.
|
||||||
select {
|
for _, ident := range ids {
|
||||||
case ch <- msg:
|
m.router.DeregisterRoute(ident, out)
|
||||||
return nil
|
if ident.IsRaw() {
|
||||||
default:
|
m.releaseRawStreamIfUnused(ident)
|
||||||
return ErrOutboundFull
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ch <- msg
|
|
||||||
|
// Stop forwarders and close internal channels.
|
||||||
|
if cancelClient != nil {
|
||||||
|
cancelClient()
|
||||||
|
}
|
||||||
|
if cancelInternal != nil {
|
||||||
|
cancelInternal()
|
||||||
|
}
|
||||||
|
close(s.internalIn)
|
||||||
|
close(s.internalOut) // will close clientOut via forwarder
|
||||||
|
|
||||||
|
// Close clientIn to ensure its forwarder exits even if client forgot.
|
||||||
|
if cin != nil {
|
||||||
|
close(cin)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,11 +427,10 @@ func (m *Manager) RemoveProvider(name string) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("provider not found: %s", name)
|
return fmt.Errorf("provider not found: %s", name)
|
||||||
}
|
}
|
||||||
|
// Optional: implement full drain and cancel of all streams for this provider.
|
||||||
return fmt.Errorf("RemoveProvider not implemented")
|
return fmt.Errorf("RemoveProvider not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// helpers
|
|
||||||
|
|
||||||
func (m *Manager) provisionRawStream(id domain.Identifier) error {
|
func (m *Manager) provisionRawStream(id domain.Identifier) error {
|
||||||
providerName, subject, ok := id.ProviderSubject()
|
providerName, subject, ok := id.ProviderSubject()
|
||||||
if !ok || providerName == "" || subject == "" {
|
if !ok || providerName == "" || subject == "" {
|
||||||
@@ -295,6 +465,7 @@ func (m *Manager) provisionRawStream(id domain.Identifier) error {
|
|||||||
incoming := m.router.IncomingChannel()
|
incoming := m.router.IncomingChannel()
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Provider stream -> router.Incoming
|
||||||
go func(c chan domain.Message) {
|
go func(c chan domain.Message) {
|
||||||
for msg := range c {
|
for msg := range c {
|
||||||
incoming <- msg
|
incoming <- msg
|
||||||
|
|||||||
Reference in New Issue
Block a user