Refactor Binance Futures Websocket: enhance configuration management, implement shard handling, and improve subscription logic

This commit is contained in:
2025-09-17 10:57:21 +00:00
parent 78c7632394
commit 7109acc207
8 changed files with 1255 additions and 104 deletions

View File

@@ -1,12 +1,465 @@
package ws
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"github.com/coder/websocket"
"github.com/google/uuid"
"gitlab.michelsen.id/phillmichelsen/tessera/services/data_service/internal/domain"
)
type shard struct {
ID uuid.UUID
conn websocket.Conn
activeStreams []string
type opType uint8
const (
opSubscribe opType = iota + 1
opUnsubscribe
)
type pendingBatch struct {
Op opType
Subjects []string
Waiters map[string][]chan error
}
type shard struct {
ID uuid.UUID
url string
cfg Config
ctx context.Context
cancel context.CancelFunc
conn *websocket.Conn
mu sync.RWMutex
active map[string]struct{}
subBatch map[string][]chan error
unsubBatch map[string][]chan error
sendQ chan []byte
rateTicker *time.Ticker
pingTicker *time.Ticker
pendingMu sync.Mutex
pendingByID map[uint64]*pendingBatch
nextReqID func() uint64
wg sync.WaitGroup
bus chan<- domain.Message
}
func newShard(pctx context.Context, cfg Config, bus chan<- domain.Message, next func() uint64) (*shard, error) {
id := uuid.New()
ctx, cancel := context.WithCancel(pctx)
sh := &shard{
ID: id,
url: cfg.Endpoint,
cfg: cfg,
ctx: ctx,
cancel: cancel,
active: make(map[string]struct{}),
subBatch: make(map[string][]chan error),
unsubBatch: make(map[string][]chan error),
sendQ: make(chan []byte, 256),
pendingByID: make(map[uint64]*pendingBatch),
nextReqID: next,
bus: bus,
}
// per-shard rate limiter; also drives batch flushing
rate := cfg.RateLimitPerSec
if rate <= 0 {
rate = 1
}
interval := time.Second / time.Duration(rate)
sh.rateTicker = time.NewTicker(interval)
sh.pingTicker = time.NewTicker(30 * time.Second)
slog.Default().Info("shard created", "cmp", providerName, "shard", sh.ID.String())
if err := sh.connect(); err != nil {
slog.Default().Error("shard connection failed", "cmp", providerName, "shard", sh.ID.String(), "error", err)
return nil, err
}
sh.startLoops()
return sh, nil
}
func (s *shard) connect() error {
dctx, cancel := context.WithTimeout(s.ctx, 10*time.Second)
defer cancel()
c, _, err := websocket.Dial(dctx, s.url, &websocket.DialOptions{})
if err != nil {
slog.Default().Error("shard connection error", "cmp", providerName, "shard", s.ID.String(), "error", err)
return err
}
s.conn = c
slog.Default().Info("shard connected", "cmp", providerName, "shard", s.ID.String())
return nil
}
func (s *shard) startLoops() {
s.wg.Add(3)
go s.writeLoop()
go s.readLoop()
go s.pingLoop()
}
func (s *shard) stop() {
s.cancel()
if s.conn != nil {
_ = s.conn.Close(websocket.StatusNormalClosure, "shutdown")
}
if s.rateTicker != nil {
s.rateTicker.Stop()
}
if s.pingTicker != nil {
s.pingTicker.Stop()
}
s.wg.Wait()
s.pendingMu.Lock()
for _, p := range s.pendingByID {
for _, arr := range p.Waiters {
for _, ch := range arr {
select {
case ch <- context.Canceled:
default:
}
}
}
}
s.pendingByID = map[uint64]*pendingBatch{}
s.pendingMu.Unlock()
s.mu.Lock()
for _, arr := range s.subBatch {
for _, ch := range arr {
select {
case ch <- context.Canceled:
default:
}
}
}
for _, arr := range s.unsubBatch {
for _, ch := range arr {
select {
case ch <- context.Canceled:
default:
}
}
}
s.subBatch = map[string][]chan error{}
s.unsubBatch = map[string][]chan error{}
s.mu.Unlock()
slog.Default().Info("shard stopped", "cmp", providerName, "shard", s.ID.String())
}
func (s *shard) enqueueSubscribe(subject string, ch chan error) {
s.mu.Lock()
s.subBatch[subject] = append(s.subBatch[subject], ch)
s.mu.Unlock()
slog.Default().Debug("shard enqueue subscribe", "cmp", providerName, "shard", s.ID, "subject", subject)
}
func (s *shard) enqueueUnsubscribe(subject string, ch chan error) {
s.mu.Lock()
s.unsubBatch[subject] = append(s.unsubBatch[subject], ch)
s.mu.Unlock()
slog.Default().Debug("shard enqueue unsubscribe", "cmp", providerName, "shard", s.ID, "subject", subject)
}
func (s *shard) isActive(subj string) bool {
s.mu.RLock()
_, ok := s.active[subj]
s.mu.RUnlock()
return ok
}
func (s *shard) activeCount() int {
s.mu.RLock()
n := len(s.active)
s.mu.RUnlock()
return n
}
func (s *shard) loadEstimate() int { // active + pending subscribes
s.mu.RLock()
n := len(s.active) + len(s.subBatch)
s.mu.RUnlock()
return n
}
func (s *shard) isIdle() bool {
s.mu.RLock()
idle := len(s.active) == 0 && len(s.subBatch) == 0 && len(s.unsubBatch) == 0
s.mu.RUnlock()
return idle
}
func (s *shard) activeList() []string {
s.mu.RLock()
defer s.mu.RUnlock()
out := make([]string, 0, len(s.active))
for k := range s.active {
out = append(out, k)
}
return out
}
func (s *shard) writeLoop() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
case <-s.rateTicker.C:
// snapshot and clear pending batch operations
var subs, unsubs map[string][]chan error
s.mu.Lock()
if len(s.subBatch) > 0 {
subs = s.subBatch
s.subBatch = make(map[string][]chan error)
}
if len(s.unsubBatch) > 0 {
unsubs = s.unsubBatch
s.unsubBatch = make(map[string][]chan error)
}
s.mu.Unlock()
// send SUBSCRIBE batch
if len(subs) > 0 {
params := make([]string, 0, len(subs))
waiters := make(map[string][]chan error, len(subs))
for k, v := range subs {
params = append(params, k)
waiters[k] = v
}
id := s.nextReqID()
frame := map[string]any{"method": "SUBSCRIBE", "params": params, "id": id}
payload, _ := json.Marshal(frame)
s.recordPending(id, opSubscribe, params, waiters)
if err := s.writeFrame(payload); err != nil {
s.reconnect()
return
}
}
// send UNSUBSCRIBE batch
if len(unsubs) > 0 {
params := make([]string, 0, len(unsubs))
waiters := make(map[string][]chan error, len(unsubs))
for k, v := range unsubs {
params = append(params, k)
waiters[k] = v
}
id := s.nextReqID()
frame := map[string]any{"method": "UNSUBSCRIBE", "params": params, "id": id}
payload, _ := json.Marshal(frame)
s.recordPending(id, opUnsubscribe, params, waiters)
if err := s.writeFrame(payload); err != nil {
s.reconnect()
return
}
}
// optional: one queued ad-hoc frame per tick
select {
case msg := <-s.sendQ:
if err := s.writeFrame(msg); err != nil {
s.reconnect()
return
}
default:
}
}
}
}
func (s *shard) writeFrame(msg []byte) error {
wctx, cancel := context.WithTimeout(s.ctx, 5*time.Second)
defer cancel()
err := s.conn.Write(wctx, websocket.MessageText, msg)
if err != nil {
slog.Default().Warn("shard write error", "cmp", providerName, "shard", s.ID, "error", err)
}
return err
}
func (s *shard) readLoop() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
default:
// longer idle timeout when no active subscriptions
timeout := 60 * time.Second
if s.activeCount() == 0 {
timeout = 5 * time.Minute
}
rctx, cancel := context.WithTimeout(s.ctx, timeout)
_, data, err := s.conn.Read(rctx)
cancel()
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
slog.Default().Debug("shard read idle timeout", "cmp", providerName, "shard", s.ID)
continue
}
slog.Default().Warn("shard read error", "cmp", providerName, "shard", s.ID, "error", err)
s.reconnect()
return
}
if bytes.Contains(data, []byte("\"id\"")) {
var ack struct {
ID uint64 `json:"id"`
Result *json.RawMessage `json:"result"`
Error *struct {
Code int `json:"code"`
Msg string `json:"msg"`
} `json:"error"`
}
if json.Unmarshal(data, &ack) == nil && ack.ID != 0 {
if ack.Error != nil {
slog.Default().Warn("shard ack error", "cmp", providerName, "shard", s.ID, "id", ack.ID, "code", ack.Error.Code, "msg", ack.Error.Msg)
s.resolvePending(ack.ID, fmt.Errorf("binance error %d: %s", ack.Error.Code, ack.Error.Msg))
} else {
slog.Default().Debug("shard ack ok", "cmp", providerName, "shard", s.ID, "id", ack.ID)
s.resolvePending(ack.ID, nil)
}
continue
}
}
var frame struct {
Stream string `json:"stream"`
Data json.RawMessage `json:"data"`
}
if json.Unmarshal(data, &frame) == nil && frame.Stream != "" {
id, err := domain.RawID(providerName, frame.Stream)
if err == nil {
select {
case s.bus <- domain.Message{Identifier: id, Payload: frame.Data}:
default:
}
}
continue
}
slog.Default().Debug("shard unknown message", "cmp", providerName, "shard", s.ID, "data", string(data))
}
}
}
func (s *shard) pingLoop() {
defer s.wg.Done()
for {
select {
case <-s.ctx.Done():
return
case <-s.pingTicker.C:
ctx, cancel := context.WithTimeout(s.ctx, 5*time.Second)
err := s.conn.Ping(ctx)
cancel()
if err != nil {
slog.Default().Warn("shard ping failed", "cmp", providerName, "shard", s.ID, "error", err)
s.reconnect()
return
}
}
}
}
func (s *shard) recordPending(id uint64, op opType, subjects []string, waiters map[string][]chan error) {
s.pendingMu.Lock()
s.pendingByID[id] = &pendingBatch{Op: op, Subjects: subjects, Waiters: waiters}
s.pendingMu.Unlock()
}
func (s *shard) resolvePending(id uint64, err error) {
s.pendingMu.Lock()
p := s.pendingByID[id]
delete(s.pendingByID, id)
s.pendingMu.Unlock()
if p == nil {
return
}
if err == nil {
s.mu.Lock()
if p.Op == opSubscribe {
for _, subj := range p.Subjects {
s.active[subj] = struct{}{}
}
slog.Default().Debug("shard subscribed", "cmp", providerName, "shard", s.ID, "subjects", p.Subjects)
} else {
for _, subj := range p.Subjects {
delete(s.active, subj)
}
slog.Default().Debug("shard unsubscribed", "cmp", providerName, "shard", s.ID, "subjects", p.Subjects)
}
s.mu.Unlock()
} else {
slog.Default().Warn("shard pending error", "cmp", providerName, "shard", s.ID, "error", err)
}
for _, arr := range p.Waiters {
for _, ch := range arr {
select {
case ch <- err:
default:
}
}
}
}
func (s *shard) queue(payload []byte) {
select {
case s.sendQ <- payload:
default:
slog.Default().Warn("shard sendQ full, dropping one message", "cmp", providerName, "shard", s.ID)
<-s.sendQ
s.sendQ <- payload
}
}
func (s *shard) reconnect() {
reconnectStartTime := time.Now()
if s.conn != nil {
_ = s.conn.Close(websocket.StatusGoingAway, "reconnect")
}
for {
select {
case <-s.ctx.Done():
return
default:
if err := s.connect(); err != nil {
time.Sleep(200 * time.Millisecond)
continue
}
// re-stage current actives for batch subscribe on next tick
s.mu.RLock()
for k := range s.active {
s.subBatch[k] = append(s.subBatch[k], nil)
}
s.mu.RUnlock()
// restart loops
s.startLoops()
slog.Default().Info("shard reconnected", "cmp", providerName, "shard", s.ID, "downtime", time.Since(reconnectStartTime).String())
return
}
}
}