diff --git a/services/data_service/internal/manager/manager.go b/services/data_service/internal/manager/manager.go index ef1fc8f..e40e45f 100644 --- a/services/data_service/internal/manager/manager.go +++ b/services/data_service/internal/manager/manager.go @@ -1,9 +1,9 @@ package manager import ( + "errors" "fmt" "sync" - "time" "github.com/google/uuid" "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain" @@ -11,249 +11,335 @@ import ( "gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/router" ) -type Manager struct { - providers map[string]provider.Provider - providerStreams map[domain.Identifier]chan domain.Message +var ( + ErrSessionNotFound = errors.New("session not found") + ErrSessionClosed = errors.New("session closed") + ErrOutboundFull = errors.New("session publish buffer full") + ErrInvalidIdentifier = errors.New("invalid identifier") + ErrUnknownProvider = errors.New("unknown provider") +) - clientStreams map[uuid.UUID]*ClientStream +type Manager struct { + providers map[string]provider.Provider + providerStreams map[domain.Identifier]chan domain.Message + rawReferenceCount map[domain.Identifier]int + + sessions map[uuid.UUID]*session router *router.Router - - mu sync.Mutex + mu sync.Mutex } -type ClientStream struct { - UUID uuid.UUID - Identifiers []domain.Identifier - OutChannel chan domain.Message - Timer *time.Timer +type session struct { + id uuid.UUID + in chan domain.Message + mailbox chan domain.Message + out chan domain.Message + bound map[domain.Identifier]struct{} + dropOnFull bool + closed bool } -func NewManager(router *router.Router) *Manager { - go router.Run() +func NewManager(r *router.Router) *Manager { + go r.Run() return &Manager{ - providers: make(map[string]provider.Provider), - providerStreams: make(map[domain.Identifier]chan domain.Message), - clientStreams: make(map[uuid.UUID]*ClientStream), - router: router, + providers: make(map[string]provider.Provider), + providerStreams: make(map[domain.Identifier]chan domain.Message), + rawReferenceCount: make(map[domain.Identifier]int), + sessions: make(map[uuid.UUID]*session), + router: r, } } -func (m *Manager) StartClientStream() (uuid.UUID, error) { - m.mu.Lock() - defer m.mu.Unlock() - - streamID := uuid.New() - m.clientStreams[streamID] = &ClientStream{ - UUID: streamID, - Identifiers: nil, - OutChannel: nil, - Timer: time.AfterFunc(1*time.Minute, func() { - fmt.Printf("stream %s expired due to inactivity\n", streamID) - err := m.StopClientStream(streamID) - if err != nil { - fmt.Printf("failed to stop stream after timeout: %v\n", err) - } - }), +func (m *Manager) NewSession(bufIn, bufOut int, dropOnFull bool) (uuid.UUID, chan<- domain.Message, <-chan domain.Message, error) { + if bufIn <= 0 { + bufIn = 1024 + } + if bufOut <= 0 { + bufOut = 1024 } - return streamID, nil + s := &session{ + id: uuid.New(), + in: make(chan domain.Message, bufIn), + mailbox: make(chan domain.Message, bufOut), + out: make(chan domain.Message, bufOut), + bound: make(map[domain.Identifier]struct{}), + dropOnFull: dropOnFull, + } + + m.mu.Lock() + m.sessions[s.id] = s + incoming := m.router.IncomingChannel() + m.mu.Unlock() + + go func() { + for msg := range s.in { + incoming <- msg + } + }() + + go func() { + for msg := range s.mailbox { + s.out <- msg + } + close(s.out) + }() + + return s.id, s.in, s.out, nil } -func (m *Manager) ConfigureClientStream(streamID uuid.UUID, newIds []domain.Identifier) error { +func (m *Manager) CloseSession(id uuid.UUID) error { m.mu.Lock() - defer m.mu.Unlock() - - stream, ok := m.clientStreams[streamID] + s, ok := m.sessions[id] if !ok { - return fmt.Errorf("stream not found: %s", streamID) + m.mu.Unlock() + return ErrSessionNotFound } + if s.closed { + m.mu.Unlock() + return nil + } + s.closed = true - for _, id := range newIds { - if id.IsRaw() { - providerName, subject, ok := id.ProviderSubject() - if !ok || providerName == "" || subject == "" { - return fmt.Errorf("empty identifier: %v", id) - } - prov, exists := m.providers[providerName] - if !exists { - return fmt.Errorf("unknown provider: %s", providerName) - } - if !prov.IsValidSubject(subject, false) { - return fmt.Errorf("invalid subject %q for provider %s", subject, providerName) - } + var ids []domain.Identifier + for k := range s.bound { + ids = append(ids, k) + delete(s.bound, k) + } + delete(m.sessions, id) + m.mu.Unlock() + + for _, ident := range ids { + m.router.DeregisterRoute(ident, s.mailbox) + if ident.IsRaw() { + m.releaseRawStreamIfUnused(ident) } } - oldSet := make(map[domain.Identifier]struct{}, len(stream.Identifiers)) - for _, id := range stream.Identifiers { - oldSet[id] = struct{}{} - } - newSet := make(map[domain.Identifier]struct{}, len(newIds)) - for _, id := range newIds { - newSet[id] = struct{}{} + close(s.mailbox) + close(s.in) + return nil +} + +func (m *Manager) Subscribe(id uuid.UUID, ids ...domain.Identifier) error { + m.mu.Lock() + s, ok := m.sessions[id] + m.mu.Unlock() + if !ok { + return ErrSessionNotFound } - for _, id := range newIds { - if _, seen := oldSet[id]; !seen { - if id.IsRaw() { - if _, ok := m.providerStreams[id]; !ok { - ch := make(chan domain.Message, 64) - providerName, subject, _ := id.ProviderSubject() - if err := m.providers[providerName].RequestStream(subject, ch); err != nil { - return fmt.Errorf("provision %v: %w", id, err) - } - m.providerStreams[id] = ch + for _, ident := range ids { + m.mu.Lock() + if _, exists := s.bound[ident]; exists { + m.mu.Unlock() + continue + } + m.mu.Unlock() - incomingChannel := m.router.IncomingChannel() - go func(c chan domain.Message) { - for msg := range c { - incomingChannel <- msg - } - }(ch) - } - } - - if stream.OutChannel != nil { - m.router.RegisterRoute(id, stream.OutChannel) + if ident.IsRaw() { + if err := m.provisionRawStream(ident); err != nil { + return err } } + + m.mu.Lock() + s.bound[ident] = struct{}{} + m.mu.Unlock() + m.router.RegisterRoute(ident, s.mailbox) + } + return nil +} + +func (m *Manager) Unsubscribe(id uuid.UUID, ids ...domain.Identifier) error { + m.mu.Lock() + s, ok := m.sessions[id] + m.mu.Unlock() + if !ok { + return ErrSessionNotFound } - for _, oldId := range stream.Identifiers { - if _, keep := newSet[oldId]; !keep { - if stream.OutChannel != nil { - m.router.DeregisterRoute(oldId, stream.OutChannel) - } + for _, ident := range ids { + m.mu.Lock() + if _, exists := s.bound[ident]; !exists { + m.mu.Unlock() + continue } + delete(s.bound, ident) + m.mu.Unlock() + + m.router.DeregisterRoute(ident, s.mailbox) + if ident.IsRaw() { + m.releaseRawStreamIfUnused(ident) + } + } + return nil +} + +func (m *Manager) SetSubscriptions(id uuid.UUID, next []domain.Identifier) error { + m.mu.Lock() + s, ok := m.sessions[id] + if !ok { + m.mu.Unlock() + return ErrSessionNotFound + } + old := make(map[domain.Identifier]struct{}, len(s.bound)) + for k := range s.bound { + old[k] = struct{}{} + } + m.mu.Unlock() + + toAdd, toDel := m.identifierSetDifferences(old, next) + if len(toAdd) > 0 { + if err := m.Subscribe(id, toAdd...); err != nil { + return err + } + } + if len(toDel) > 0 { + if err := m.Unsubscribe(id, toDel...); err != nil { + return err + } + } + return nil +} + +func (m *Manager) Publish(id uuid.UUID, msg domain.Message) error { + m.mu.Lock() + s, ok := m.sessions[id] + if !ok { + m.mu.Unlock() + return ErrSessionNotFound + } + if s.closed { + m.mu.Unlock() + return ErrSessionClosed + } + ch := s.in + drop := s.dropOnFull + m.mu.Unlock() + + if drop { + select { + case ch <- msg: + return nil + default: + return ErrOutboundFull + } + } + ch <- msg + return nil +} + +func (m *Manager) AddProvider(name string, p provider.Provider) error { + m.mu.Lock() + if _, exists := m.providers[name]; exists { + m.mu.Unlock() + return fmt.Errorf("provider exists: %s", name) + } + m.mu.Unlock() + + if err := p.Start(); err != nil { + return fmt.Errorf("start provider %s: %w", name, err) } - stream.Identifiers = newIds + m.mu.Lock() + m.providers[name] = p + m.mu.Unlock() + return nil +} - used := make(map[domain.Identifier]bool) - for _, cs := range m.clientStreams { - for _, id := range cs.Identifiers { - if id.IsRaw() { - used[id] = true - } - } +func (m *Manager) RemoveProvider(name string) error { + m.mu.Lock() + _, ok := m.providers[name] + m.mu.Unlock() + if !ok { + return fmt.Errorf("provider not found: %s", name) } - for id, ch := range m.providerStreams { - if !used[id] { - providerName, subject, _ := id.ProviderSubject() - m.providers[providerName].CancelStream(subject) - close(ch) - delete(m.providerStreams, id) - } + return fmt.Errorf("RemoveProvider not implemented") +} + +// helpers + +func (m *Manager) provisionRawStream(id domain.Identifier) error { + providerName, subject, ok := id.ProviderSubject() + if !ok || providerName == "" || subject == "" { + return ErrInvalidIdentifier } + m.mu.Lock() + prov, exists := m.providers[providerName] + if !exists { + m.mu.Unlock() + return ErrUnknownProvider + } + if !prov.IsValidSubject(subject, false) { + m.mu.Unlock() + return fmt.Errorf("invalid subject %q for provider %s", subject, providerName) + } + + if ch, ok := m.providerStreams[id]; ok { + m.rawReferenceCount[id] = m.rawReferenceCount[id] + 1 + m.mu.Unlock() + _ = ch + return nil + } + + ch := make(chan domain.Message, 64) + if err := prov.RequestStream(subject, ch); err != nil { + m.mu.Unlock() + return fmt.Errorf("provision %v: %w", id, err) + } + m.providerStreams[id] = ch + m.rawReferenceCount[id] = 1 + incoming := m.router.IncomingChannel() + m.mu.Unlock() + + go func(c chan domain.Message) { + for msg := range c { + incoming <- msg + } + }(ch) + return nil } -func (m *Manager) StopClientStream(streamID uuid.UUID) error { - m.DisconnectClientStream(streamID) - - m.mu.Lock() - defer m.mu.Unlock() - - stream, ok := m.clientStreams[streamID] +func (m *Manager) releaseRawStreamIfUnused(id domain.Identifier) { + providerName, subject, ok := id.ProviderSubject() if !ok { - return fmt.Errorf("stream not found: %s", streamID) - } - - stream.Timer.Stop() - - delete(m.clientStreams, streamID) - - used := make(map[domain.Identifier]bool) - for _, s := range m.clientStreams { - for _, id := range s.Identifiers { - if id.IsRaw() { - used[id] = true - } - } - } - - for id, ch := range m.providerStreams { - if !used[id] { - providerName, subject, _ := id.ProviderSubject() - m.providers[providerName].CancelStream(subject) - close(ch) - delete(m.providerStreams, id) - } - } - - return nil -} - -func (m *Manager) ConnectClientStream(streamID uuid.UUID) (<-chan domain.Message, error) { - m.mu.Lock() - defer m.mu.Unlock() - - stream, ok := m.clientStreams[streamID] - if !ok { - return nil, fmt.Errorf("stream not found: %s", streamID) - } - - if stream.OutChannel != nil { - return nil, fmt.Errorf("stream already connected") - } - - ch := make(chan domain.Message, 128) - stream.OutChannel = ch - - for _, ident := range stream.Identifiers { - m.router.RegisterRoute(ident, ch) - } - - if stream.Timer != nil { - stream.Timer.Stop() - stream.Timer = nil - } - - return ch, nil -} - -func (m *Manager) DisconnectClientStream(streamID uuid.UUID) { - m.mu.Lock() - defer m.mu.Unlock() - - stream, ok := m.clientStreams[streamID] - if !ok || stream.OutChannel == nil { return } - for _, ident := range stream.Identifiers { - m.router.DeregisterRoute(ident, stream.OutChannel) - } - - close(stream.OutChannel) - stream.OutChannel = nil - - stream.Timer = time.AfterFunc(1*time.Minute, func() { - fmt.Printf("stream %s expired due to inactivity\n", streamID) - err := m.StopClientStream(streamID) - if err != nil { - fmt.Printf("failed to stop stream after disconnect: %v\n", err) - } - }) -} - -func (m *Manager) AddProvider(name string, p provider.Provider) { m.mu.Lock() - defer m.mu.Unlock() - - if _, exists := m.providers[name]; exists { - panic(fmt.Sprintf("provider %s already exists", name)) + rc := m.rawReferenceCount[id] - 1 + if rc <= 0 { + if ch, ok := m.providerStreams[id]; ok { + if prov, exists := m.providers[providerName]; exists { + prov.CancelStream(subject) + } + close(ch) + delete(m.providerStreams, id) + } + delete(m.rawReferenceCount, id) + m.mu.Unlock() + return } - - if err := p.Start(); err != nil { - panic(fmt.Errorf("failed to start provider %s: %w", name, err)) - } - - m.providers[name] = p + m.rawReferenceCount[id] = rc + m.mu.Unlock() } -func (m *Manager) RemoveProvider(_ string) { - panic("not implemented yet") +func (m *Manager) identifierSetDifferences(old map[domain.Identifier]struct{}, next []domain.Identifier) (toAdd, toDel []domain.Identifier) { + newSet := make(map[domain.Identifier]struct{}, len(next)) + for _, id := range next { + newSet[id] = struct{}{} + if _, ok := old[id]; !ok { + toAdd = append(toAdd, id) + } + } + for id := range old { + if _, ok := newSet[id]; !ok { + toDel = append(toDel, id) + } + } + return }