diff --git a/commands/msg/mark.go b/commands/msg/mark.go index 939d456..b5a228d 100644 --- a/commands/msg/mark.go +++ b/commands/msg/mark.go @@ -32,7 +32,7 @@ func (Mark) Execute(aerc *widgets.Aerc, args []string) error { return err } marker := store.Marker() - opts, _, err := getopt.Getopts(args, "atvV") + opts, _, err := getopt.Getopts(args, "atvVT") if err != nil { return err } @@ -40,6 +40,7 @@ func (Mark) Execute(aerc *widgets.Aerc, args []string) error { var toggle bool var visual bool var clearVisual bool + var thread bool for _, opt := range opts { switch opt.Option { case 'a': @@ -51,9 +52,23 @@ func (Mark) Execute(aerc *widgets.Aerc, args []string) error { visual = true case 't': toggle = true + case 'T': + thread = true } } + if thread && len(store.Threads()) == 0 { + return fmt.Errorf("No threads found") + } + + if thread && all { + return fmt.Errorf("-a and -T are mutually exclusive") + } + + if thread && visual { + return fmt.Errorf("-v and -T are mutually exclusive") + } + switch args[0] { case "mark": if all && visual { @@ -77,7 +92,13 @@ func (Mark) Execute(aerc *widgets.Aerc, args []string) error { marker.ToggleVisualMark(clearVisual) return nil default: - modFunc(selected.Uid) + if thread { + for _, uid := range store.SelectedThread().Root().Uids() { + modFunc(uid) + } + } else { + modFunc(selected.Uid) + } return nil } @@ -97,11 +118,17 @@ func (Mark) Execute(aerc *widgets.Aerc, args []string) error { marker.ClearVisualMark() return nil default: - marker.Unmark(selected.Uid) + if thread { + for _, uid := range store.SelectedThread().Root().Uids() { + marker.Unmark(uid) + } + } else { + marker.Unmark(selected.Uid) + } return nil } case "remark": - if all || visual || toggle { + if all || visual || toggle || thread { return fmt.Errorf("Usage: :remark") } marker.Remark() diff --git a/doc/aerc.1.scd b/doc/aerc.1.scd index b5edc32..84bc775 100644 --- a/doc/aerc.1.scd +++ b/doc/aerc.1.scd @@ -402,7 +402,7 @@ message list, the message in the message viewer, etc). *-a*: Save all attachments. Individual filenames cannot be specified. -*mark* [-atv] +*mark* [-atvT] Marks messages. Commands will execute on all marked messages instead of the highlighted one if applicable. The flags below can be combined as needed. @@ -414,6 +414,8 @@ message list, the message in the message viewer, etc). *-V*: Same as -v but does not clear existing selection + *-T*: Marks the displayed message thread of the selected message. + *unmark* [-at] Unmarks messages. The flags below can be combined as needed. diff --git a/lib/msgstore.go b/lib/msgstore.go index d126fee..d47c14f 100644 --- a/lib/msgstore.go +++ b/lib/msgstore.go @@ -411,6 +411,28 @@ func (store *MessageStore) runThreadBuilder() { }) } +// SelectedThread returns the thread with the UID from the selected message +func (store *MessageStore) SelectedThread() *types.Thread { + var thread *types.Thread + for _, root := range store.Threads() { + found := false + err := root.Walk(func(t *types.Thread, _ int, _ error) error { + if t.Uid == store.SelectedUid() { + thread = t + found = true + } + return nil + }) + if err != nil { + logging.Errorf("SelectedThread failed: %w", err) + } + if found { + break + } + } + return thread +} + func (store *MessageStore) Delete(uids []uint32, cb func(msg types.WorkerMessage), ) { diff --git a/worker/types/thread.go b/worker/types/thread.go index 9f59e9e..60ecc7c 100644 --- a/worker/types/thread.go +++ b/worker/types/thread.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "sort" + + "git.sr.ht/~rjarry/aerc/logging" ) type Thread struct { @@ -48,6 +50,33 @@ func (t *Thread) Walk(walkFn NewThreadWalkFn) error { return err } +// Root returns the root thread of the thread tree +func (t *Thread) Root() *Thread { + if t == nil { + return nil + } + var iter *Thread + for iter = t; iter.Parent != nil; iter = iter.Parent { + } + return iter +} + +// Uids returns all associated uids for the given thread and its children +func (t *Thread) Uids() []uint32 { + if t == nil { + return nil + } + uids := make([]uint32, 0) + err := t.Walk(func(node *Thread, _ int, _ error) error { + uids = append(uids, node.Uid) + return nil + }) + if err != nil { + logging.Errorf("walk to collect uids failed: %w", err) + } + return uids +} + func (t *Thread) String() string { if t == nil { return ""