From 716ade89687150daadbb41bdec4a00d6d6e34193 Mon Sep 17 00:00:00 2001
From: Tim Culverhouse <tim@timculverhouse.com>
Date: Sun, 25 Sep 2022 14:38:45 -0500
Subject: [PATCH] worker: lock access to callback maps

Worker callbacks are inherently set and called from different
goroutines. Protect access to all callback maps with a mutex.

Signed-off-by: Moritz Poldrack <moritz@poldrack.dev>
Signed-off-by: Tim Culverhouse <tim@timculverhouse.com>
Acked-by: Robin Jarry <robin@jarry.cc>
---
 worker/types/worker.go | 11 +++++++++++
 1 file changed, 11 insertions(+)

diff --git a/worker/types/worker.go b/worker/types/worker.go
index ad35949..61b96da 100644
--- a/worker/types/worker.go
+++ b/worker/types/worker.go
@@ -1,6 +1,7 @@
 package types
 
 import (
+	"sync"
 	"sync/atomic"
 
 	"git.sr.ht/~rjarry/aerc/logging"
@@ -20,6 +21,8 @@ type Worker struct {
 
 	actionCallbacks  map[int64]func(msg WorkerMessage)
 	messageCallbacks map[int64]func(msg WorkerMessage)
+
+	sync.Mutex
 }
 
 func NewWorker() *Worker {
@@ -49,7 +52,9 @@ func (worker *Worker) PostAction(msg WorkerMessage, cb func(msg WorkerMessage))
 	worker.Actions <- msg
 
 	if cb != nil {
+		worker.Lock()
 		worker.actionCallbacks[msg.getId()] = cb
+		worker.Unlock()
 	}
 }
 
@@ -68,7 +73,9 @@ func (worker *Worker) PostMessage(msg WorkerMessage,
 	worker.Messages <- msg
 
 	if cb != nil {
+		worker.Lock()
 		worker.messageCallbacks[msg.getId()] = cb
+		worker.Unlock()
 	}
 }
 
@@ -79,12 +86,14 @@ func (worker *Worker) ProcessMessage(msg WorkerMessage) WorkerMessage {
 		logging.Debugf("ProcessMessage %T(%d)", msg, msg.getId())
 	}
 	if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+		worker.Lock()
 		if f, ok := worker.actionCallbacks[inResponseTo.getId()]; ok {
 			f(msg)
 			if _, ok := msg.(*Done); ok {
 				delete(worker.actionCallbacks, inResponseTo.getId())
 			}
 		}
+		worker.Unlock()
 	}
 	return msg
 }
@@ -96,12 +105,14 @@ func (worker *Worker) ProcessAction(msg WorkerMessage) WorkerMessage {
 		logging.Debugf("ProcessAction %T(%d)", msg, msg.getId())
 	}
 	if inResponseTo := msg.InResponseTo(); inResponseTo != nil {
+		worker.Lock()
 		if f, ok := worker.messageCallbacks[inResponseTo.getId()]; ok {
 			f(msg)
 			if _, ok := msg.(*Done); ok {
 				delete(worker.messageCallbacks, inResponseTo.getId())
 			}
 		}
+		worker.Unlock()
 	}
 	return msg
 }