diff --git a/lib/iterator/impl.go b/lib/iterator/impl.go new file mode 100644 index 0000000..5e68516 --- /dev/null +++ b/lib/iterator/impl.go @@ -0,0 +1,125 @@ +package iterator + +import ( + "errors" + + "git.sr.ht/~rjarry/aerc/worker/types" +) + +// defaultFactory +type defaultFactory struct{} + +func (df *defaultFactory) NewIterator(a interface{}) Iterator { + switch data := a.(type) { + case []uint32: + return &defaultUid{data: data, index: len(data)} + case []*types.Thread: + return &defaultThread{data: data, index: len(data)} + } + panic(errors.New("a iterator for this type is not implemented yet")) +} + +// defaultUid +type defaultUid struct { + data []uint32 + index int +} + +func (du *defaultUid) Next() bool { + du.index-- + return du.index >= 0 +} + +func (du *defaultUid) Value() interface{} { + return du.data[du.index] +} + +func (du *defaultUid) StartIndex() int { + return len(du.data) - 1 +} + +func (du *defaultUid) EndIndex() int { + return 0 +} + +// defaultThread +type defaultThread struct { + data []*types.Thread + index int +} + +func (dt *defaultThread) Next() bool { + dt.index-- + return dt.index >= 0 +} + +func (dt *defaultThread) Value() interface{} { + return dt.data[dt.index] +} + +func (dt *defaultThread) StartIndex() int { + return len(dt.data) - 1 +} + +func (dt *defaultThread) EndIndex() int { + return 0 +} + +// reverseFactory +type reverseFactory struct{} + +func (rf *reverseFactory) NewIterator(a interface{}) Iterator { + switch data := a.(type) { + case []uint32: + return &reverseUid{data: data, index: -1} + case []*types.Thread: + return &reverseThread{data: data, index: -1} + } + panic(errors.New("an iterator for this type is not implemented yet")) +} + +// reverseUid +type reverseUid struct { + data []uint32 + index int +} + +func (ru *reverseUid) Next() bool { + ru.index++ + return ru.index < len(ru.data) +} + +func (ru *reverseUid) Value() interface{} { + return ru.data[ru.index] +} + +func (ru *reverseUid) StartIndex() int { + return 0 +} + +func (ru *reverseUid) EndIndex() int { + return len(ru.data) - 1 +} + +// reverseThread +type reverseThread struct { + data []*types.Thread + index int +} + +func (rt *reverseThread) Next() bool { + rt.index++ + return rt.index < len(rt.data) +} + +func (rt *reverseThread) Value() interface{} { + return rt.data[rt.index] +} + +func (rt *reverseThread) StartIndex() int { + return 0 +} + +func (rt *reverseThread) EndIndex() int { + return len(rt.data) - 1 +} diff --git a/lib/iterator/iterator.go b/lib/iterator/iterator.go new file mode 100644 index 0000000..28a9b8b --- /dev/null +++ b/lib/iterator/iterator.go @@ -0,0 +1,35 @@ +package iterator + +// Factory is the interface that wraps the NewIterator method. The +// NewIterator() creates either UID or thread iterators and ensures that both +// types of iterators implement the same iteration direction. +type Factory interface { + NewIterator(a interface{}) Iterator +} + +// Iterator implements an interface for iterating over UID or thread data. If +// Next() returns true, the current value of the iterator can be read with +// Value(). The return value of Value() is an interface{} type which needs to +// be cast to the correct type. +// +// The iterators are implemented such that the first returned value always +// represents the top message in the message list. Hence, StartIndex() would +// return the index of the top message whereas EndIndex() returns the index of +// message at the bottom of the list. +type Iterator interface { + Next() bool + Value() interface{} + StartIndex() int + EndIndex() int +} + +// NewFactory creates an iterator factory. When reverse is true, the iterators +// are reversed in the sense that the lowest UID messages are displayed at the +// top of the message list. Otherwise, the default order is with the highest +// UID message on top. +func NewFactory(reverse bool) Factory { + if reverse { + return &reverseFactory{} + } + return &defaultFactory{} +} diff --git a/lib/iterator/iterator_test.go b/lib/iterator/iterator_test.go new file mode 100644 index 0000000..6a8d3f6 --- /dev/null +++ b/lib/iterator/iterator_test.go @@ -0,0 +1,95 @@ +package iterator_test + +import ( + "testing" + + "git.sr.ht/~rjarry/aerc/lib/iterator" + "git.sr.ht/~rjarry/aerc/worker/types" +) + +func toThreads(uids []uint32) []*types.Thread { + threads := make([]*types.Thread, len(uids)) + for i, u := range uids { + threads[i] = &types.Thread{Uid: u} + } + return threads +} + +func TestIterator_DefaultFactory(t *testing.T) { + input := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9} + want := []uint32{9, 8, 7, 6, 5, 4, 3, 2, 1} + + factory := iterator.NewFactory(false) + if factory == nil { + t.Errorf("could not create factory") + } + start, end := len(input)-1, 0 + checkUids(t, factory, input, want, start, end) + checkThreads(t, factory, toThreads(input), + toThreads(want), start, end) +} + +func TestIterator_ReverseFactory(t *testing.T) { + input := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9} + want := []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9} + + factory := iterator.NewFactory(true) + if factory == nil { + t.Errorf("could not create factory") + } + + start, end := 0, len(input)-1 + checkUids(t, factory, input, want, start, end) + checkThreads(t, factory, toThreads(input), + toThreads(want), start, end) +} + +func checkUids(t *testing.T, factory iterator.Factory, + input []uint32, want []uint32, start, end int, +) { + label := "uids" + got := make([]uint32, 0) + iter := factory.NewIterator(input) + for iter.Next() { + got = append(got, iter.Value().(uint32)) + } + if len(got) != len(want) { + t.Errorf(label + "number of elements not correct") + } + for i, u := range want { + if got[i] != u { + t.Errorf(label + "order not correct") + } + } + if iter.StartIndex() != start { + t.Errorf(label + "start index not correct") + } + if iter.EndIndex() != end { + t.Errorf(label + "end index not correct") + } +} + +func checkThreads(t *testing.T, factory iterator.Factory, + input []*types.Thread, want []*types.Thread, start, end int, +) { + label := "threads" + got := make([]*types.Thread, 0) + iter := factory.NewIterator(input) + for iter.Next() { + got = append(got, iter.Value().(*types.Thread)) + } + if len(got) != len(want) { + t.Errorf(label + "number of elements not correct") + } + for i, th := range want { + if got[i].Uid != th.Uid { + t.Errorf(label + "order not correct") + } + } + if iter.StartIndex() != start { + t.Errorf(label + "start index not correct") + } + if iter.EndIndex() != end { + t.Errorf(label + "end index not correct") + } +}