diff --git a/README.md b/README.md index 4d26ea9..d011622 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,9 @@ -# priorityq - generic prioritized message queue in Go +# priorityq - generic prioritized queues in Go This module was inspired by [a reddit post][reddit] wherein /u/zandery23 asked -how to implement a priority queue in Go. A fantastic solution was [provided by -/u/Ploobers][sol]. That's probably right for 99 out of 100 use cases, but it's -not completely precise. +how to implement a prioritized message queue in Go. A fantastic solution was +[provided by /u/Ploobers][sol]. That's probably right for 99 out of 100 use +cases, but it's not completely precise. Particularly, the second select block does not guarantee that an item from the prioritized queue will be taken if there is also an item in the regular queue. @@ -26,10 +26,15 @@ From the [Go Language Specification][go_select]: Thus, it is possible for the second case to be chosen even if the first case is also ready. -The `precise` package in this module implements a concurrent, prioritized -message queue that guarantees receipt of a high-priority items before -low-priority ones. This is primarily a fun exercise, I cannot recommend that -anyone actually use this in a real project. +The `mq` package in this module implements a concurrent, prioritized message +queue that guarantees receipt of a high-priority items before low-priority +ones. This is primarily a fun exercise, I cannot recommend that anyone +actually use this in a real project. + +Additionally, the root `priorityq` package implements a concurrent priority +queue, using a binary max-heap. This is more general than `mq`, because it +allows multiple levels of priority, instead of just "high" and "low". This, of +course, also makes operations slower. [reddit]: https://www.reddit.com/r/golang/comments/11drc17/worker_pool_reading_from_two_channels_one_chan/ [sol]: https://www.reddit.com/r/golang/comments/11drc17/worker_pool_reading_from_two_channels_one_chan/jabfvkh/ diff --git a/binheap/lib.go b/binheap/lib.go new file mode 100644 index 0000000..0ae83d8 --- /dev/null +++ b/binheap/lib.go @@ -0,0 +1,102 @@ +package binheap + +import "golang.org/x/exp/constraints" + +// H is a generic, non-concurrent binary max-heap. +// +// `I` is the type of the priority IDs, and `E` the type of the elements. +type H[I constraints.Ordered, E any] struct { + heap []I + elems []E + len int +} + +// Make creates a new heap. +func Make[I constraints.Ordered, E any](cap int) H[I, E] { + heap := make([]I, cap) + elems := make([]E, cap) + h := H[I, E]{heap: heap, elems: elems} + return h +} + +// Capacity returns the total capacity of the heap. +func (h *H[I, E]) Capacity() int { + return cap(h.heap) +} + +// Len returns the number of items in the heap. +func (h *H[I, E]) Len() int { + return h.len +} + +// CanExtract returns true if the heap has any item, otherwise false. +func (h *H[I, E]) CanExtract() bool { + return h.len != 0 +} + +// CanInsert returns true if the heap has unused capacity, otherwise false. +func (h *H[I, E]) CanInsert() bool { + return cap(h.heap)-h.len != 0 +} + +// Extract returns the current heap root, then performs a heap-down pass. +// +// If the heap is empty, it panics. +func (h *H[I, E]) Extract() (I, E) { + if !h.CanExtract() { + panic("heap is empty") + } + + id := h.heap[0] + elem := h.elems[0] + var emptyId I + var emptyElem E + h.heap[0] = h.heap[h.len-1] + h.elems[0] = h.elems[h.len-1] + h.heap[h.len-1] = emptyId + h.elems[h.len-1] = emptyElem + h.len-- + idx := 0 + for { + left := idx*2 + 1 + right := idx*2 + 2 + largest := idx + if left < h.len && h.heap[left] > h.heap[largest] { + largest = left + } + if right < h.len && h.heap[right] > h.heap[largest] { + largest = right + } + if largest == idx { + break + } + h.heap[idx], h.heap[largest] = h.heap[largest], h.heap[idx] + h.elems[idx], h.elems[largest] = h.elems[largest], h.elems[idx] + idx = largest + } + + return id, elem +} + +// Insert adds an item to the heap, then performs a heap-up pass. +// +// If the heap is full, it panics. +func (h *H[I, E]) Insert(id I, elem E) { + if !h.CanInsert() { + panic("heap is full") + } + + idx := h.len + h.heap[idx] = id + h.elems[idx] = elem + h.len++ + for { + parent := (idx - 1) / 2 + if parent == idx || h.heap[parent] >= h.heap[idx] { + break + } + h.heap[parent], h.heap[idx] = h.heap[idx], h.heap[parent] + h.elems[parent], h.elems[idx] = h.elems[idx], h.elems[parent] + idx = parent + } +} diff --git a/binheap/lib_test.go b/binheap/lib_test.go new file mode 100644 index 0000000..f4d1115 --- /dev/null +++ b/binheap/lib_test.go @@ -0,0 +1,84 @@ +package binheap_test + +import ( + "math/rand" + "testing" + + "gogs.humancabbage.net/sam/priorityq/binheap" +) + +func TestSmoke(t *testing.T) { + h := binheap.Make[int, int](10) + if h.Capacity() != 10 { + t.Errorf("expected heap capacity to be 10") + } + h.Insert(1, 1) + h.Insert(2, 2) + h.Insert(3, 3) + h.Insert(4, 4) + if h.Len() != 4 { + t.Errorf("expected heap length to be 4") + } + checkExtract := func(n int) { + _, extracted := h.Extract() + if extracted != n { + t.Errorf("expected to extract %d, got %d", n, extracted) + } + } + checkExtract(4) + checkExtract(3) + checkExtract(2) + checkExtract(1) +} + +func TestInsertFullPanic(t *testing.T) { + h := binheap.Make[int, int](4) + h.Insert(1, 1) + h.Insert(2, 2) + h.Insert(3, 3) + h.Insert(4, 4) + defer func() { + if r := recover(); r == nil { + t.Errorf("expected final insert to panic") + } + }() + h.Insert(5, 5) +} + +func TestExtractEmptyPanic(t *testing.T) { + h := binheap.Make[int, int](4) + defer func() { + if r := recover(); r == nil { + t.Errorf("expected extract to panic") + } + }() + h.Extract() +} + +func TestRandomized(t *testing.T) { + h := binheap.Make[int, int](8192) + rs := rand.NewSource(0) + r := rand.New(rs) + // insert a bunch of random integers + for i := 0; i < h.Capacity(); i++ { + n := r.Int() + h.Insert(n, n) + } + // ensure that each extracted integer is <= the last extracted integer + var extracted []int + for h.CanExtract() { + id, item := h.Extract() + if id != item { + t.Errorf("id / item mismatch: %d %d", id, item) + } + lastIdx := len(extracted) - 1 + extracted = append(extracted, item) + if lastIdx < 0 { + continue + } + if item > extracted[lastIdx] { + t.Errorf("newly extracted %d is greater than %d", + item, extracted[lastIdx]) + } + } +} diff --git a/go.mod b/go.mod index 3b1a761..83f382d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module gogs.humancabbage.net/sam/priorityq go 1.20 + +require golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 diff --git a/go.sum b/go.sum index e69de29..0661dd6 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI= +golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= diff --git a/lib.go b/lib.go new file mode 100644 index 0000000..5d7340b --- /dev/null +++ b/lib.go @@ -0,0 +1,117 @@ +package priorityq + +import ( + "sync" + + "gogs.humancabbage.net/sam/priorityq/binheap" + "golang.org/x/exp/constraints" +) + +// Q is a generic, concurrent priority queue. +type Q[P constraints.Ordered, T any] struct { + *state[P, T] +} + +// Make a new queue. +func Make[P constraints.Ordered, T any](cap int) Q[P, T] { + heap := binheap.Make[P, T](cap) + s := &state[P, T]{ + heap: heap, + } + s.canRecv = sync.NewCond(&s.mu) + s.canSend = sync.NewCond(&s.mu) + return Q[P, T]{s} +} + +type state[P constraints.Ordered, T any] struct { + mu sync.Mutex + heap binheap.H[P, T] + canSend *sync.Cond + canRecv *sync.Cond + closed bool +} + +// Close marks the queue as closed. +// +// Subsequent attempts to send will panic. Subsequent calls to Recv will +// continue to return the remaining items in the queue. +func (s *state[P, T]) Close() { + s.mu.Lock() + s.closed = true + s.mu.Unlock() + s.canRecv.Broadcast() +} + +// Recv returns an item from the prioritized buffers, blocking if empty. +// +// The returned bool will be true if the queue still has items or is open. +// It will be false if the queue is empty and closed. +func (s *state[P, T]) Recv() (P, T, bool) { + s.mu.Lock() + defer s.mu.Unlock() + for { + for !s.closed && !s.heap.CanExtract() { + s.canRecv.Wait() + } + if s.closed && !s.heap.CanExtract() { + var emptyP P + var emptyT T + return emptyP, emptyT, false + } + if s.heap.CanExtract() { + priority, value := s.heap.Extract() + s.canSend.Broadcast() + return priority, value, true + } + } +} + +// Send adds an item to the queue, blocking if full. +func (s *state[P, T]) Send(priority P, value T) { + s.mu.Lock() + defer s.mu.Unlock() + for { + for !s.closed && !s.heap.CanInsert() { + s.canSend.Wait() + } + if s.closed { + panic("send on closed queue") + } + if s.heap.CanInsert() { + s.heap.Insert(priority, value) + s.canRecv.Broadcast() + return + } + } +} + +// TryRecv attempts to return an item from the queue. +// +// This method does not block. If there is an item in the queue, it returns +// true. If the queue is empty, it returns false. +func (s *state[P, T]) TryRecv() (priority P, value T, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.heap.CanExtract() { + priority, value = s.heap.Extract() + ok = true + s.canSend.Broadcast() + return + } + return +} + +// TrySend attempts to add an item to the high priority buffer. +// +// This method does not block. If there is space in the buffer, it returns +// true. If the buffer is full, it returns false. +func (s *state[P, T]) TrySend(priority P, value T) bool { + s.mu.Lock() + defer s.mu.Unlock() + if !s.heap.CanInsert() { + return false + } + s.heap.Insert(priority, value) + s.canRecv.Broadcast() + return true +} diff --git a/lib_test.go b/lib_test.go new file mode 100644 index 0000000..e636b20 --- /dev/null +++ b/lib_test.go @@ -0,0 +1,179 @@ +package priorityq_test + +import ( + "math/rand" + "sync" + "testing" + + "gogs.humancabbage.net/sam/priorityq" +) + +func TestRecvHighestFirst(t *testing.T) { + t.Parallel() + q := priorityq.Make[int, int](8) + q.Send(4, 4) + q.Send(2, 2) + q.Send(1, 1) + q.Send(5, 5) + q.Send(7, 7) + q.Send(8, 8) + q.Send(3, 3) + q.Send(6, 6) + checkRecv := func(n int) { + if _, v, _ := q.Recv(); v != n { + t.Errorf("popped %d, expected %d", v, n) + } + } + checkRecv(8) + checkRecv(7) + checkRecv(6) + checkRecv(5) + checkRecv(4) + checkRecv(3) + checkRecv(2) + checkRecv(1) +} + +func TestSendClosedPanic(t *testing.T) { + t.Parallel() + defer func() { + if r := recover(); r == nil { + t.Errorf("sending to closed queue did not panic") + } + }() + q := priorityq.Make[int, int](4) + q.Close() + q.Send(1, 1) +} + +func TestRecvClosed(t *testing.T) { + t.Parallel() + q := priorityq.Make[int, int](4) + q.Send(1, 1) + q.Close() + _, _, ok := q.Recv() + if !ok { + t.Errorf("queue should have item to receive") + } + _, _, ok = q.Recv() + if ok { + t.Errorf("queue should be closed") + } +} + +func TestTrySendRecv(t *testing.T) { + t.Parallel() + q := priorityq.Make[int, int](4) + assumeSendOk := func(n int) { + ok := q.TrySend(n, n) + if !ok { + t.Errorf("expected to be able to send") + } + } + assumeRecvOk := func(expected int) { + _, actual, ok := q.TryRecv() + if !ok { + t.Errorf("expected to be able to receive") + } + if actual != expected { + t.Errorf("expected %d, got %d", expected, actual) + } + } + assumeSendOk(1) + assumeSendOk(2) + assumeSendOk(3) + assumeSendOk(4) + ok := q.TrySend(5, 5) + if ok { + t.Errorf("expected queue to be full") + } + assumeRecvOk(4) + assumeRecvOk(3) + assumeRecvOk(2) + assumeRecvOk(1) + + _, _, ok = q.TryRecv() + if ok { + t.Errorf("expected queue to be empty") + } +} + +func TestConcProducerConsumer(t *testing.T) { + t.Parallel() + q := priorityq.Make[int, int](4) + var wg sync.WaitGroup + produceDone := make(chan struct{}) + wg.Add(2) + go func() { + for i := 0; i < 10000; i++ { + q.Send(rand.Int(), i) + } + close(produceDone) + wg.Done() + }() + go func() { + ok := true + for ok { + _, _, ok = q.Recv() + } + wg.Done() + }() + <-produceDone + t.Logf("producer done, closing channel") + q.Close() + wg.Wait() +} + +func BenchmarkSend(b *testing.B) { + q := priorityq.Make[int, int](b.N) + // randomize priorities to get amortized cost per op + ps := make([]int, b.N) + for i := 0; i < b.N; i++ { + ps[i] = rand.Int() + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.Send(ps[i], i) + } +} + +func BenchmarkRecv(b *testing.B) { + q := priorityq.Make[int, int](b.N) + // randomize priorities to get amortized cost per op + for i := 0; i < b.N; i++ { + q.Send(rand.Int(), i) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + q.Recv() + } +} + +func BenchmarkConcSendRecv(b *testing.B) { + q := priorityq.Make[int, int](b.N) + // randomize priorities to get amortized cost per op + ps := make([]int, b.N) + for i := 0; i < b.N; i++ { + ps[i] = rand.Int() + } + var wg sync.WaitGroup + wg.Add(2) + start := make(chan struct{}) + go func() { + <-start + for i := 0; i < b.N; i++ { + q.Send(ps[i], i) + } + wg.Done() + }() + go func() { + <-start + for i := 0; i < b.N; i++ { + q.Recv() + } + wg.Done() + }() + b.ResetTimer() + close(start) + wg.Wait() +} diff --git a/precise/lib.go b/mq/lib.go similarity index 99% rename from precise/lib.go rename to mq/lib.go index 9065024..9bcfead 100644 --- a/precise/lib.go +++ b/mq/lib.go @@ -1,4 +1,4 @@ -package precise +package mq import ( "sync" diff --git a/precise/lib_test.go b/mq/lib_test.go similarity index 92% rename from precise/lib_test.go rename to mq/lib_test.go index 53c449b..a887a73 100644 --- a/precise/lib_test.go +++ b/mq/lib_test.go @@ -1,16 +1,16 @@ -package precise_test +package mq_test import ( "math/rand" "sync" "testing" - "gogs.humancabbage.net/sam/priorityq/precise" + "gogs.humancabbage.net/sam/priorityq/mq" ) func TestRecvHighFirst(t *testing.T) { t.Parallel() - q := precise.Make[int](4) + q := mq.Make[int](4) q.Send(1) q.Send(2) q.Send(3) @@ -41,14 +41,14 @@ func TestSendClosedPanic(t *testing.T) { t.Errorf("sending to closed queue did not panic") } }() - q := precise.Make[int](4) + q := mq.Make[int](4) q.Close() q.Send(1) } func TestRecvClosed(t *testing.T) { t.Parallel() - q := precise.Make[int](4) + q := mq.Make[int](4) q.Send(1) q.Close() _, ok := q.Recv() @@ -63,7 +63,7 @@ func TestRecvClosed(t *testing.T) { func TestTrySendRecv(t *testing.T) { t.Parallel() - q := precise.Make[int](4) + q := mq.Make[int](4) assumeSendOk := func(n int, f func(int) bool) { ok := f(n) if !ok { @@ -113,7 +113,7 @@ func TestTrySendRecv(t *testing.T) { func TestConcProducerConsumer(t *testing.T) { t.Parallel() - q := precise.Make[int](4) + q := mq.Make[int](4) var wg sync.WaitGroup produceDone := make(chan struct{}) wg.Add(2) @@ -142,7 +142,7 @@ func TestConcProducerConsumer(t *testing.T) { } func BenchmarkSend(b *testing.B) { - q := precise.Make[int](b.N) + q := mq.Make[int](b.N) b.ResetTimer() for i := 0; i < b.N; i++ { q.Send(i) @@ -158,7 +158,7 @@ func BenchmarkSendChan(b *testing.B) { } func BenchmarkRecv(b *testing.B) { - q := precise.Make[int](b.N) + q := mq.Make[int](b.N) for i := 0; i < b.N; i++ { q.Send(i) } @@ -180,7 +180,7 @@ func BenchmarkRecvChan(b *testing.B) { } func BenchmarkConcSendRecv(b *testing.B) { - q := precise.Make[int](b.N) + q := mq.Make[int](b.N) var wg sync.WaitGroup wg.Add(2) start := make(chan struct{})