widgets/msglist: fix MessageList.store race

This field could be written to in the middle of a Draw call, which reads it
multiple times. Use an atomic variable instead.
This commit is contained in:
Simon Ser 2019-04-28 13:26:22 +00:00 committed by Drew DeVault
parent 089740758c
commit f1698a337e
1 changed files with 32 additions and 23 deletions

View File

@ -2,6 +2,7 @@ package widgets
import ( import (
"log" "log"
"sync/atomic"
"github.com/gdamore/tcell" "github.com/gdamore/tcell"
@ -19,7 +20,7 @@ type MessageList struct {
scroll int scroll int
selected int selected int
spinner *Spinner spinner *Spinner
store *lib.MessageStore store atomic.Value // *lib.MessageStore
} }
// TODO: fish in config // TODO: fish in config
@ -29,6 +30,7 @@ func NewMessageList(logger *log.Logger) *MessageList {
selected: 0, selected: 0,
spinner: NewSpinner(), spinner: NewSpinner(),
} }
ml.store.Store((*lib.MessageStore)(nil))
ml.spinner.OnInvalidate(func(_ ui.Drawable) { ml.spinner.OnInvalidate(func(_ ui.Drawable) {
ml.Invalidate() ml.Invalidate()
}) })
@ -45,7 +47,8 @@ func (ml *MessageList) Draw(ctx *ui.Context) {
ml.height = ctx.Height() ml.height = ctx.Height()
ctx.Fill(0, 0, ctx.Width(), ctx.Height(), ' ', tcell.StyleDefault) ctx.Fill(0, 0, ctx.Width(), ctx.Height(), ' ', tcell.StyleDefault)
if ml.store == nil { store := ml.Store()
if store == nil {
ml.spinner.Draw(ctx) ml.spinner.Draw(ctx)
return return
} }
@ -55,9 +58,9 @@ func (ml *MessageList) Draw(ctx *ui.Context) {
row int = 0 row int = 0
) )
for i := len(ml.store.Uids) - 1 - ml.scroll; i >= 0; i-- { for i := len(store.Uids) - 1 - ml.scroll; i >= 0; i-- {
uid := ml.store.Uids[i] uid := store.Uids[i]
msg := ml.store.Messages[uid] msg := store.Messages[uid]
if row >= ctx.Height() { if row >= ctx.Height() {
break break
@ -74,7 +77,7 @@ func (ml *MessageList) Draw(ctx *ui.Context) {
if row == ml.selected-ml.scroll { if row == ml.selected-ml.scroll {
style = style.Reverse(true) style = style.Reverse(true)
} }
if _, ok := ml.store.Deleted[msg.Uid]; ok { if _, ok := store.Deleted[msg.Uid]; ok {
style = style.Foreground(tcell.ColorGray) style = style.Foreground(tcell.ColorGray)
} }
ctx.Fill(0, row, ctx.Width(), 1, ' ', style) ctx.Fill(0, row, ctx.Width(), 1, ' ', style)
@ -83,14 +86,14 @@ func (ml *MessageList) Draw(ctx *ui.Context) {
row += 1 row += 1
} }
if len(ml.store.Uids) == 0 { if len(store.Uids) == 0 {
msg := "(no messages)" msg := "(no messages)"
ctx.Printf((ctx.Width()/2)-(len(msg)/2), 0, ctx.Printf((ctx.Width()/2)-(len(msg)/2), 0,
tcell.StyleDefault, "%s", msg) tcell.StyleDefault, "%s", msg)
} }
if len(needsHeaders) != 0 { if len(needsHeaders) != 0 {
ml.store.FetchHeaders(needsHeaders, nil) store.FetchHeaders(needsHeaders, nil)
ml.spinner.Start() ml.spinner.Start()
} else { } else {
ml.spinner.Stop() ml.spinner.Stop()
@ -102,26 +105,28 @@ func (ml *MessageList) Height() int {
} }
func (ml *MessageList) storeUpdate(store *lib.MessageStore) { func (ml *MessageList) storeUpdate(store *lib.MessageStore) {
if ml.store != store { if ml.Store() != store {
return return
} }
if len(ml.store.Uids) > 0 {
for ml.selected >= len(ml.store.Uids) { if len(store.Uids) > 0 {
for ml.selected >= len(store.Uids) {
ml.Prev() ml.Prev()
} }
} }
ml.Invalidate() ml.Invalidate()
} }
func (ml *MessageList) SetStore(store *lib.MessageStore) { func (ml *MessageList) SetStore(store *lib.MessageStore) {
if ml.store == store { if ml.Store() == store {
ml.scroll = 0 ml.scroll = 0
ml.selected = 0 ml.selected = 0
} }
ml.store = store ml.store.Store(store)
if store != nil { if store != nil {
ml.spinner.Stop() ml.spinner.Stop()
ml.store.OnUpdate(ml.storeUpdate) store.OnUpdate(ml.storeUpdate)
} else { } else {
ml.spinner.Start() ml.spinner.Start()
} }
@ -129,23 +134,26 @@ func (ml *MessageList) SetStore(store *lib.MessageStore) {
} }
func (ml *MessageList) Store() *lib.MessageStore { func (ml *MessageList) Store() *lib.MessageStore {
return ml.store return ml.store.Load().(*lib.MessageStore)
} }
func (ml *MessageList) Empty() bool { func (ml *MessageList) Empty() bool {
return ml.store == nil || len(ml.store.Uids) == 0 store := ml.Store()
return store == nil || len(store.Uids) == 0
} }
func (ml *MessageList) Selected() *types.MessageInfo { func (ml *MessageList) Selected() *types.MessageInfo {
return ml.store.Messages[ml.store.Uids[len(ml.store.Uids)-ml.selected-1]] store := ml.Store()
return store.Messages[store.Uids[len(store.Uids)-ml.selected-1]]
} }
func (ml *MessageList) Select(index int) { func (ml *MessageList) Select(index int) {
store := ml.Store()
ml.selected = index ml.selected = index
for ; ml.selected < 0; ml.selected = len(ml.store.Uids) + ml.selected { for ; ml.selected < 0; ml.selected = len(store.Uids) + ml.selected {
} }
if ml.selected > len(ml.store.Uids) { if ml.selected > len(store.Uids) {
ml.selected = len(ml.store.Uids) ml.selected = len(store.Uids)
} }
// I'm too lazy to do the math right now // I'm too lazy to do the math right now
for ml.selected-ml.scroll >= ml.Height() { for ml.selected-ml.scroll >= ml.Height() {
@ -157,15 +165,16 @@ func (ml *MessageList) Select(index int) {
} }
func (ml *MessageList) nextPrev(delta int) { func (ml *MessageList) nextPrev(delta int) {
if ml.store == nil || len(ml.store.Uids) == 0 { store := ml.Store()
if store == nil || len(store.Uids) == 0 {
return return
} }
ml.selected += delta ml.selected += delta
if ml.selected < 0 { if ml.selected < 0 {
ml.selected = 0 ml.selected = 0
} }
if ml.selected >= len(ml.store.Uids) { if ml.selected >= len(store.Uids) {
ml.selected = len(ml.store.Uids) - 1 ml.selected = len(store.Uids) - 1
} }
if ml.Height() != 0 { if ml.Height() != 0 {
if ml.selected-ml.scroll >= ml.Height() { if ml.selected-ml.scroll >= ml.Height() {