Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions internal/base/queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ type Service[T any] interface {
// Queue is a generic message queue service that processes messages asynchronously.
// It is thread-safe and supports graceful shutdown.
type Queue[T any] struct {
name string
queue chan T
handler func(ctx context.Context, msg T) error
mu sync.RWMutex
closed bool
wg sync.WaitGroup
name string
queue chan T
handlers []func(ctx context.Context, msg T) error
mu sync.RWMutex
closed bool
wg sync.WaitGroup
}

// New creates a new queue with the given name and buffer size.
Expand Down Expand Up @@ -77,12 +77,13 @@ func (q *Queue[T]) Send(ctx context.Context, msg T) {
}
}

// RegisterHandler sets the handler function for processing messages.
// RegisterHandler adds a handler function for processing messages.
// Multiple handlers can be registered and all will be called for each message.
// This is thread-safe and can be called at any time.
func (q *Queue[T]) RegisterHandler(handler func(ctx context.Context, msg T) error) {
q.mu.Lock()
defer q.mu.Unlock()
q.handler = handler
q.handlers = append(q.handlers, handler)
}

// Close gracefully shuts down the queue, waiting for pending messages to be processed.
Expand Down Expand Up @@ -114,17 +115,20 @@ func (q *Queue[T]) startWorker() {
// processMessage handles a single message with proper synchronization.
func (q *Queue[T]) processMessage(msg T) {
q.mu.RLock()
handler := q.handler
handlers := make([]func(ctx context.Context, msg T) error, len(q.handlers))
copy(handlers, q.handlers)
q.mu.RUnlock()

if handler == nil {
if len(handlers) == 0 {
log.Warnf("[%s] no handler registered, dropping message: %+v", q.name, msg)
return
}

// Use background context for async processing
// TODO: Consider adding timeout or using a derived context
if err := handler(context.TODO(), msg); err != nil {
log.Errorf("[%s] handler error: %v", q.name, err)
for _, handler := range handlers {
if err := handler(context.TODO(), msg); err != nil {
log.Errorf("[%s] handler error: %v", q.name, err)
}
}
}
201 changes: 201 additions & 0 deletions internal/base/queue/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,207 @@ func TestQueue_ConcurrentRegisterHandler(t *testing.T) {
wg.Wait()
}

// TestQueue_MultipleHandlers verifies that all registered handlers are called
// for each message sent to the queue.
func TestQueue_MultipleHandlers(t *testing.T) {
q := New[*testMessage]("test-multi", 10)
defer q.Close()

var count1, count2, count3 atomic.Int32

q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
count1.Add(1)
return nil
})
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
count2.Add(1)
return nil
})
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
count3.Add(1)
return nil
})

numMessages := 10
for i := range numMessages {
q.Send(context.Background(), &testMessage{ID: i})
}

// Close to ensure all messages are processed
q.Close()

if int(count1.Load()) != numMessages {
t.Errorf("handler 1: expected %d calls, got %d", numMessages, count1.Load())
}
if int(count2.Load()) != numMessages {
t.Errorf("handler 2: expected %d calls, got %d", numMessages, count2.Load())
}
if int(count3.Load()) != numMessages {
t.Errorf("handler 3: expected %d calls, got %d", numMessages, count3.Load())
}
}

// TestQueue_MultipleHandlers_MessageContent verifies that each handler receives
// the exact same message reference.
func TestQueue_MultipleHandlers_MessageContent(t *testing.T) {
q := New[*testMessage]("test-multi-content", 10)
defer q.Close()

received1 := make(chan *testMessage, 1)
received2 := make(chan *testMessage, 1)

q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
received1 <- msg
return nil
})
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
received2 <- msg
return nil
})

msg := &testMessage{ID: 42, Data: "shared"}
q.Send(context.Background(), msg)

select {
case r := <-received1:
if r != msg {
t.Errorf("handler 1: got different message pointer")
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for handler 1")
}

select {
case r := <-received2:
if r != msg {
t.Errorf("handler 2: got different message pointer")
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for handler 2")
}
}

// TestQueue_HandlerErrorDoesNotBlockOthers verifies that when one handler returns
// an error, subsequent handlers are still called.
func TestQueue_HandlerErrorDoesNotBlockOthers(t *testing.T) {
q := New[*testMessage]("test-error", 10)
defer q.Close()

var called1, called2, called3 atomic.Bool

q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
called1.Store(true)
return nil
})
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
called2.Store(true)
return fmt.Errorf("simulated error")
})
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
called3.Store(true)
return nil
})

q.Send(context.Background(), &testMessage{ID: 1})

// Close to ensure processing is complete
q.Close()

if !called1.Load() {
t.Error("handler 1 was not called")
}
if !called2.Load() {
t.Error("handler 2 was not called")
}
if !called3.Load() {
t.Error("handler 3 (after error) was not called")
}
}

// TestQueue_RegisterHandlerDuringProcessing verifies that registering a new
// handler while messages are being processed does not cause races, and that
// the new handler is called for subsequent messages.
func TestQueue_RegisterHandlerDuringProcessing(t *testing.T) {
q := New[*testMessage]("test-dynamic", 100)
defer q.Close()

var firstCount atomic.Int32
var secondCount atomic.Int32
secondRegistered := make(chan struct{})

q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
firstCount.Add(1)
// Register a second handler after the first message is processed
if msg.ID == 0 {
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
secondCount.Add(1)
return nil
})
close(secondRegistered)
}
return nil
})

// Send first message to trigger second handler registration
q.Send(context.Background(), &testMessage{ID: 0})

// Wait for second handler to be registered
select {
case <-secondRegistered:
case <-time.After(time.Second):
t.Fatal("timeout waiting for second handler registration")
}

// Send more messages that should be processed by both handlers
numExtra := 5
for i := 1; i <= numExtra; i++ {
q.Send(context.Background(), &testMessage{ID: i})
}

q.Close()

// First handler should have processed all messages (1 + numExtra)
if int(firstCount.Load()) != 1+numExtra {
t.Errorf("first handler: expected %d calls, got %d", 1+numExtra, firstCount.Load())
}
// Second handler should have processed at least the extra messages
if int(secondCount.Load()) < numExtra {
t.Errorf("second handler: expected at least %d calls, got %d", numExtra, secondCount.Load())
}
}

// TestQueue_ConcurrentRegisterAndSend verifies that concurrently registering
// handlers and sending messages does not cause data races.
func TestQueue_ConcurrentRegisterAndSend(t *testing.T) {
q := New[*testMessage]("test-concurrent-reg-send", 1000)
defer q.Close()

var wg sync.WaitGroup

// Concurrently register handlers
for range 5 {
wg.Add(1)
go func() {
defer wg.Done()
q.RegisterHandler(func(ctx context.Context, msg *testMessage) error {
return nil
})
}()
}

// Concurrently send messages
for i := range 50 {
wg.Add(1)
go func(id int) {
defer wg.Done()
q.Send(context.Background(), &testMessage{ID: id})
}(i)
}

wg.Wait()
// No race or panic = pass
}

// TestQueue_SendCloseRace is a regression test for the race condition between
// Send and Close. Without proper synchronization, concurrent Send and Close
// calls could cause a "send on closed channel" panic.
Expand Down