diff --git a/internal/state/state.go b/internal/state/state.go index 32ec5c1..1030699 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -2,6 +2,8 @@ package state import ( "sync" + + "github.com/romshark/templier/internal/broadcaster" ) type State struct { @@ -30,30 +32,26 @@ func (s State) IsErr() bool { } type Tracker struct { - state State - lock sync.Mutex - listeners map[chan<- struct{}]struct{} + state State + lock sync.Mutex + broadcaster *broadcaster.SignalBroadcaster } func NewTracker() *Tracker { return &Tracker{ - listeners: make(map[chan<- struct{}]struct{}), + broadcaster: broadcaster.NewSignalBroadcaster(), } } // AddListener adds a listener channel. // c will be written struct{}{} to when a state change happens. func (s *Tracker) AddListener(c chan<- struct{}) { - s.lock.Lock() - defer s.lock.Unlock() - s.listeners[c] = struct{}{} + s.broadcaster.AddListener(c) } // RemoveListener removes a listener channel. func (s *Tracker) RemoveListener(c chan<- struct{}) { - s.lock.Lock() - defer s.lock.Unlock() - delete(s.listeners, c) + s.broadcaster.RemoveListener(c) } // Reset resets the state and notifies all listeners. @@ -63,7 +61,7 @@ func (s *Tracker) Reset() { s.state.ErrTempl = "" s.state.ErrGolangCILint = "" s.state.ErrGo = "" - s.notifyListeners() + s.broadcaster.BroadcastNonblock() } // SetErrTempl sets or resets (if "") the current templ error @@ -75,7 +73,7 @@ func (s *Tracker) SetErrTempl(msg string) { return // State didn't change, ignore. } s.state.ErrTempl = msg - s.notifyListeners() + s.broadcaster.BroadcastNonblock() } // SetErrGolangCILint sets or resets (if "") the current golangci-lint error @@ -87,7 +85,7 @@ func (s *Tracker) SetErrGolangCILint(msg string) { return // State didn't change, ignore. } s.state.ErrGolangCILint = msg - s.notifyListeners() + s.broadcaster.BroadcastNonblock() } // SetErrGo sets or resets (if "") the current Go error @@ -99,7 +97,7 @@ func (s *Tracker) SetErrGo(msg string) { return // State didn't change, ignore. } s.state.ErrGo = msg - s.notifyListeners() + s.broadcaster.BroadcastNonblock() } // Get returns the current state. @@ -108,12 +106,3 @@ func (s *Tracker) Get() State { defer s.lock.Unlock() return s.state } - -func (s *Tracker) notifyListeners() { - for ch := range s.listeners { - select { - case ch <- struct{}{}: - default: // Ignore unresponsive listeners. - } - } -} diff --git a/internal/state/state_test.go b/internal/state/state_test.go new file mode 100644 index 0000000..13b134e --- /dev/null +++ b/internal/state/state_test.go @@ -0,0 +1,78 @@ +package state_test + +import ( + "sync" + "testing" + + "github.com/romshark/templier/internal/state" + "github.com/stretchr/testify/require" +) + +func TestStateListener(t *testing.T) { + s := state.NewTracker() + + require.Equal(t, state.State{}, s.Get()) + + var wg sync.WaitGroup + wg.Add(1) + + c1 := make(chan struct{}, 3) + s.AddListener(c1) + + go func() { + defer wg.Done() + <-c1 + <-c1 + <-c1 + }() + + s.SetErrGo("go failed") + s.SetErrGolangCILint("golangcilint failed") + s.SetErrTempl("templ failed") + wg.Wait() // Wait for the listener goroutine to receive an update + require.Equal(t, state.State{ + ErrGo: "go failed", + ErrGolangCILint: "golangcilint failed", + ErrTempl: "templ failed", + }, s.Get()) +} + +func TestStateMsg(t *testing.T) { + s := state.NewTracker() + + require.Equal(t, state.State{}, s.Get()) + + s.SetErrGo("go failed") + s.SetErrGolangCILint("golangcilint failed") + s.SetErrTempl("templ failed") + + require.Equal(t, state.State{ + ErrGo: "go failed", + ErrGolangCILint: "golangcilint failed", + ErrTempl: "templ failed", + }, s.Get()) + + require.Equal(t, "templ failed", s.Get().Msg()) + s.SetErrTempl("") + require.Equal(t, "golangcilint failed", s.Get().Msg()) + s.SetErrGolangCILint("") + require.Equal(t, "go failed", s.Get().Msg()) + s.SetErrGo("") + require.Zero(t, s.Get().Msg()) +} + +func TestStateReset(t *testing.T) { + s := state.NewTracker() + require.False(t, s.Get().IsErr()) + + require.Equal(t, state.State{}, s.Get()) + + s.SetErrGo("go failed") + s.SetErrGolangCILint("golangcilint failed") + s.SetErrTempl("templ failed") + require.True(t, s.Get().IsErr()) + + s.Reset() + require.Equal(t, state.State{}, s.Get()) + require.False(t, s.Get().IsErr()) +}