diff --git a/worker/types/worker.go b/worker/types/worker.go index ad35949..61b96da 100644 --- a/worker/types/worker.go +++ b/worker/types/worker.go @@ -1,6 +1,7 @@ package types import ( + "sync" "sync/atomic" "git.sr.ht/~rjarry/aerc/logging" @@ -20,6 +21,8 @@ type Worker struct { actionCallbacks map[int64]func(msg WorkerMessage) messageCallbacks map[int64]func(msg WorkerMessage) + + sync.Mutex } func NewWorker() *Worker { @@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage)) worker.Actions <- msg if cb != nil { + worker.Lock() worker.actionCallbacks[msg.getId()] = cb + worker.Unlock() } } @@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage, worker.Messages <- msg if cb != nil { + worker.Lock() worker.messageCallbacks[msg.getId()] = cb + worker.Unlock() } } @@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage { logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId()) } if inResponseTo := msg.InResponseTo(); inResponseTo != nil { + worker.Lock() if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok { f(msg) if _, ok := msg.(*Done); ok { delete(worker.actionCallbacks, inResponseTo.getId()) } } + worker.Unlock() } return msg } @@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage { logging.Debugf("ProcessAction %T(%d)", msg, msg.getId()) } if inResponseTo := msg.InResponseTo(); inResponseTo != nil { + worker.Lock() if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok { f(msg) if _, ok := msg.(*Done); ok { delete(worker.messageCallbacks, inResponseTo.getId()) } } + worker.Unlock() } return msg }