package chroncore import ( "bytes" "context" "errors" "slices" "sort" "time" ) var ( ErrHeadNotSet = errors.New("HEAD not set") ErrEntryIDMismatch = errors.New("entry id mismatch") ErrBatchLoadSize = errors.New("LoadBatch returned unexpected number of entries") ) type Ledger struct { entryStore EntryStore referenceStore ReferenceStore } func NewLedger(entryStore EntryStore, referenceStore ReferenceStore) (*Ledger, error) { return &Ledger{ entryStore: entryStore, referenceStore: referenceStore, }, nil } func (l *Ledger) Add(ctx context.Context, parents []EntryID, payload []byte) (EntryID, error) { ps := normalizeParents(parents) ts := time.Now() id := ComputeEntryID(ps, ts, payload) e := Entry{ EntryID: id, Parents: ps, Timestamp: ts, Payload: payload, } if err := l.VerifyEntry(e); err != nil { return EntryID{}, err } if err := l.entryStore.Store(ctx, e); err != nil { return EntryID{}, err } return id, nil } func (l *Ledger) Append(ctx context.Context, payload []byte) (EntryID, error) { head, ok, err := l.referenceStore.Get(ctx, "HEAD") if err != nil { return EntryID{}, err } var parents []EntryID if ok { parents = []EntryID{head} } else { parents = nil } id, err := l.Add(ctx, parents, payload) if err != nil { return EntryID{}, err } if err := l.referenceStore.Set(ctx, "HEAD", id); err != nil { return EntryID{}, err } return id, nil } func (l *Ledger) AppendTo(ctx context.Context, parent EntryID, payload []byte) (EntryID, error) { return l.Add(ctx, []EntryID{parent}, payload) } func (l *Ledger) Get(ctx context.Context, id EntryID) (Entry, error) { return l.entryStore.Load(ctx, id) } func (l *Ledger) Exists(ctx context.Context, id EntryID) (bool, error) { return l.entryStore.Exists(ctx, id) } func (l *Ledger) Verify(ctx context.Context, id EntryID) error { e, err := l.entryStore.Load(ctx, id) if err != nil { return err } return l.VerifyEntry(e) } func (l *Ledger) VerifyEntry(e Entry) error { want := ComputeEntryID(e.Parents, e.Timestamp, e.Payload) if want != e.EntryID { return ErrEntryIDMismatch } return nil } func (l *Ledger) GetRef(ctx context.Context, name string) (EntryID, bool, error) { return l.referenceStore.Get(ctx, name) } func (l *Ledger) SetRef(ctx context.Context, name string, id EntryID) error { return l.referenceStore.Set(ctx, name, id) } func (l *Ledger) DeleteRef(ctx context.Context, name string) error { return l.referenceStore.Delete(ctx, name) } func (l *Ledger) ListRefs(ctx context.Context, prefix string) (map[string]EntryID, error) { return l.referenceStore.List(ctx, prefix) } func (l *Ledger) SetRefs(ctx context.Context, refs map[string]EntryID) error { return l.referenceStore.SetBatch(ctx, refs) } func (l *Ledger) GetRefs(ctx context.Context, names []string) (map[string]EntryID, error) { return l.referenceStore.GetBatch(ctx, names) } func (l *Ledger) GetHead(ctx context.Context) (EntryID, bool, error) { return l.referenceStore.Get(ctx, "HEAD") } func (l *Ledger) SetHead(ctx context.Context, id EntryID) error { return l.referenceStore.Set(ctx, "HEAD", id) } func (l *Ledger) GetHeads(ctx context.Context, prefix string) ([]EntryID, error) { m, err := l.referenceStore.List(ctx, prefix) if err != nil { return nil, err } seen := make(map[EntryID]struct{}, len(m)) for _, id := range m { seen[id] = struct{}{} } out := make([]EntryID, 0, len(seen)) for id := range seen { out = append(out, id) } sort.Slice(out, func(i, j int) bool { return bytes.Compare(out[i][:], out[j][:]) < 0 }) return out, nil } func (l *Ledger) WalkAncestors(ctx context.Context, start []EntryID, fn func(Entry) bool) error { frontier := dedupeIDs(start) visited := make(map[EntryID]struct{}, 1024) for len(frontier) > 0 { batchIDs := make([]EntryID, 0, len(frontier)) for _, id := range frontier { if _, ok := visited[id]; ok { continue } visited[id] = struct{}{} batchIDs = append(batchIDs, id) } if len(batchIDs) == 0 { return nil } entries, err := l.entryStore.LoadBatch(ctx, batchIDs) if err != nil { return err } if len(entries) != len(batchIDs) { return ErrBatchLoadSize } next := make([]EntryID, 0, len(batchIDs)*2) for _, e := range entries { if err := l.VerifyEntry(e); err != nil { return err } if !fn(e) { return nil } for _, p := range e.Parents { next = append(next, p) } } frontier = dedupeIDs(next) } return nil } func (l *Ledger) IsAncestor(ctx context.Context, ancestor, descendant EntryID) (bool, error) { if ancestor == descendant { return true, nil } found := false err := l.WalkAncestors(ctx, []EntryID{descendant}, func(e Entry) bool { if slices.Contains(e.Parents, ancestor) { found = true return false } return true }) if err != nil { return false, err } return found, nil } func (l *Ledger) CommonAncestors(ctx context.Context, a, b []EntryID, limit int) ([]EntryID, error) { aSet := make(map[EntryID]struct{}, 1024) if err := l.WalkAncestors(ctx, a, func(e Entry) bool { aSet[e.EntryID] = struct{}{} return true }); err != nil { return nil, err } var out []EntryID seen := make(map[EntryID]struct{}, 64) err := l.WalkAncestors(ctx, b, func(e Entry) bool { if _, ok := aSet[e.EntryID]; ok { if _, dup := seen[e.EntryID]; !dup { seen[e.EntryID] = struct{}{} out = append(out, e.EntryID) if limit > 0 && len(out) >= limit { return false } } } return true }) if err != nil { return nil, err } return out, nil } func normalizeParents(parents []EntryID) []EntryID { if len(parents) == 0 { return nil } ps := make([]EntryID, len(parents)) copy(ps, parents) sort.Slice(ps, func(i, j int) bool { return bytes.Compare(ps[i][:], ps[j][:]) < 0 }) out := ps[:0] var last EntryID for i, p := range ps { if i == 0 || p != last { out = append(out, p) last = p } } return out } func dedupeIDs(ids []EntryID) []EntryID { if len(ids) == 0 { return nil } tmp := make([]EntryID, len(ids)) copy(tmp, ids) sort.Slice(tmp, func(i, j int) bool { return bytes.Compare(tmp[i][:], tmp[j][:]) < 0 }) out := tmp[:0] var last EntryID for i, id := range tmp { if i == 0 || id != last { out = append(out, id) last = id } } return out }