worker: lock access to callback maps

Worker callbacks are inherently set and called from different
goroutines. Protect access to all callback maps with a mutex.

Signed-off-by: Moritz Poldrack <moritz@poldrack.dev>
Signed-off-by: Tim Culverhouse <tim@timculverhouse.com>
Acked-by: Robin Jarry <robin@jarry.cc>
This commit is contained in:
Tim Culverhouse 2022-09-25 14:38:45 -05:00 committed by Robin Jarry
parent c8c4b8c7cb
commit 716ade8968

View file

@ -1,6 +1,7 @@
package types package types
import ( import (
"sync"
"sync/atomic" "sync/atomic"
"git.sr.ht/~rjarry/aerc/logging" "git.sr.ht/~rjarry/aerc/logging"
@ -20,6 +21,8 @@ type Worker struct {
actionCallbacks map[int64]func(msg WorkerMessage) actionCallbacks map[int64]func(msg WorkerMessage)
messageCallbacks map[int64]func(msg WorkerMessage) messageCallbacks map[int64]func(msg WorkerMessage)
sync.Mutex
} }
func NewWorker() *Worker { func NewWorker() *Worker {
@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage))
worker.Actions <- msg worker.Actions <- msg
if cb != nil { if cb != nil {
worker.Lock()
worker.actionCallbacks[msg.getId()] = cb worker.actionCallbacks[msg.getId()] = cb
worker.Unlock()
} }
} }
@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage,
worker.Messages <- msg worker.Messages <- msg
if cb != nil { if cb != nil {
worker.Lock()
worker.messageCallbacks[msg.getId()] = cb worker.messageCallbacks[msg.getId()] = cb
worker.Unlock()
} }
} }
@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId()) logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId())
} }
if inResponseTo := msg.InResponseTo(); inResponseTo != nil { if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
worker.Lock()
if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok { if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
f(msg) f(msg)
if _, ok := msg.(*Done); ok { if _, ok := msg.(*Done); ok {
delete(worker.actionCallbacks, inResponseTo.getId()) delete(worker.actionCallbacks, inResponseTo.getId())
} }
} }
worker.Unlock()
} }
return msg return msg
} }
@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage {
logging.Debugf("ProcessAction %T(%d)", msg, msg.getId()) logging.Debugf("ProcessAction %T(%d)", msg, msg.getId())
} }
if inResponseTo := msg.InResponseTo(); inResponseTo != nil { if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
worker.Lock()
if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok { if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
f(msg) f(msg)
if _, ok := msg.(*Done); ok { if _, ok := msg.(*Done); ok {
delete(worker.messageCallbacks, inResponseTo.getId()) delete(worker.messageCallbacks, inResponseTo.getId())
} }
} }
worker.Unlock()
} }
return msg return msg
} }