worker: lock access to callback maps
Worker callbacks are inherently set and called from different goroutines. Protect access to all callback maps with a mutex. Signed-off-by: Moritz Poldrack <moritz@poldrack.dev> Signed-off-by: Tim Culverhouse <tim@timculverhouse.com> Acked-by: Robin Jarry <robin@jarry.cc>
This commit is contained in:
parent
c8c4b8c7cb
commit
716ade8968
1 changed files with 11 additions and 0 deletions
|
@ -1,6 +1,7 @@
|
||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"git.sr.ht/~rjarry/aerc/logging"
|
"git.sr.ht/~rjarry/aerc/logging"
|
||||||
|
@ -20,6 +21,8 @@ type Worker struct {
|
||||||
|
|
||||||
actionCallbacks map[int64]func(msg WorkerMessage)
|
actionCallbacks map[int64]func(msg WorkerMessage)
|
||||||
messageCallbacks map[int64]func(msg WorkerMessage)
|
messageCallbacks map[int64]func(msg WorkerMessage)
|
||||||
|
|
||||||
|
sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWorker() *Worker {
|
func NewWorker() *Worker {
|
||||||
|
@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage))
|
||||||
worker.Actions <- msg
|
worker.Actions <- msg
|
||||||
|
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
|
worker.Lock()
|
||||||
worker.actionCallbacks[msg.getId()] = cb
|
worker.actionCallbacks[msg.getId()] = cb
|
||||||
|
worker.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage,
|
||||||
worker.Messages <- msg
|
worker.Messages <- msg
|
||||||
|
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
|
worker.Lock()
|
||||||
worker.messageCallbacks[msg.getId()] = cb
|
worker.messageCallbacks[msg.getId()] = cb
|
||||||
|
worker.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
|
||||||
logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId())
|
logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId())
|
||||||
}
|
}
|
||||||
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
|
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
|
||||||
|
worker.Lock()
|
||||||
if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
|
if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
|
||||||
f(msg)
|
f(msg)
|
||||||
if _, ok := msg.(*Done); ok {
|
if _, ok := msg.(*Done); ok {
|
||||||
delete(worker.actionCallbacks, inResponseTo.getId())
|
delete(worker.actionCallbacks, inResponseTo.getId())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
worker.Unlock()
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage {
|
||||||
logging.Debugf("ProcessAction %T(%d)", msg, msg.getId())
|
logging.Debugf("ProcessAction %T(%d)", msg, msg.getId())
|
||||||
}
|
}
|
||||||
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
|
if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
|
||||||
|
worker.Lock()
|
||||||
if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
|
if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
|
||||||
f(msg)
|
f(msg)
|
||||||
if _, ok := msg.(*Done); ok {
|
if _, ok := msg.(*Done); ok {
|
||||||
delete(worker.messageCallbacks, inResponseTo.getId())
|
delete(worker.messageCallbacks, inResponseTo.getId())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
worker.Unlock()
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue