Skip to content

Commit

Permalink
refactor: Use broadcaster in state tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
romshark committed Jul 6, 2024
1 parent 7e52099 commit 1561fd8
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 23 deletions.
35 changes: 12 additions & 23 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package state

import (
"sync"

"github.com/romshark/templier/internal/broadcaster"
)

type State struct {
Expand Down Expand Up @@ -30,30 +32,26 @@ func (s State) IsErr() bool {
}

type Tracker struct {
state State
lock sync.Mutex
listeners map[chan<- struct{}]struct{}
state State
lock sync.Mutex
broadcaster *broadcaster.SignalBroadcaster
}

func NewTracker() *Tracker {
return &Tracker{
listeners: make(map[chan<- struct{}]struct{}),
broadcaster: broadcaster.NewSignalBroadcaster(),
}
}

// AddListener adds a listener channel.
// c will be written struct{}{} to when a state change happens.
func (s *Tracker) AddListener(c chan<- struct{}) {
s.lock.Lock()
defer s.lock.Unlock()
s.listeners[c] = struct{}{}
s.broadcaster.AddListener(c)
}

// RemoveListener removes a listener channel.
func (s *Tracker) RemoveListener(c chan<- struct{}) {
s.lock.Lock()
defer s.lock.Unlock()
delete(s.listeners, c)
s.broadcaster.RemoveListener(c)
}

// Reset resets the state and notifies all listeners.
Expand All @@ -63,7 +61,7 @@ func (s *Tracker) Reset() {
s.state.ErrTempl = ""
s.state.ErrGolangCILint = ""
s.state.ErrGo = ""
s.notifyListeners()
s.broadcaster.BroadcastNonblock()
}

// SetErrTempl sets or resets (if "") the current templ error
Expand All @@ -75,7 +73,7 @@ func (s *Tracker) SetErrTempl(msg string) {
return // State didn't change, ignore.
}
s.state.ErrTempl = msg
s.notifyListeners()
s.broadcaster.BroadcastNonblock()
}

// SetErrGolangCILint sets or resets (if "") the current golangci-lint error
Expand All @@ -87,7 +85,7 @@ func (s *Tracker) SetErrGolangCILint(msg string) {
return // State didn't change, ignore.
}
s.state.ErrGolangCILint = msg
s.notifyListeners()
s.broadcaster.BroadcastNonblock()
}

// SetErrGo sets or resets (if "") the current Go error
Expand All @@ -99,7 +97,7 @@ func (s *Tracker) SetErrGo(msg string) {
return // State didn't change, ignore.
}
s.state.ErrGo = msg
s.notifyListeners()
s.broadcaster.BroadcastNonblock()
}

// Get returns the current state.
Expand All @@ -108,12 +106,3 @@ func (s *Tracker) Get() State {
defer s.lock.Unlock()
return s.state
}

func (s *Tracker) notifyListeners() {
for ch := range s.listeners {
select {
case ch <- struct{}{}:
default: // Ignore unresponsive listeners.
}
}
}
78 changes: 78 additions & 0 deletions internal/state/state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package state_test

import (
"sync"
"testing"

"github.com/romshark/templier/internal/state"
"github.com/stretchr/testify/require"
)

func TestStateListener(t *testing.T) {
s := state.NewTracker()

require.Equal(t, state.State{}, s.Get())

var wg sync.WaitGroup
wg.Add(1)

c1 := make(chan struct{}, 3)
s.AddListener(c1)

go func() {
defer wg.Done()
<-c1
<-c1
<-c1
}()

s.SetErrGo("go failed")
s.SetErrGolangCILint("golangcilint failed")
s.SetErrTempl("templ failed")
wg.Wait() // Wait for the listener goroutine to receive an update
require.Equal(t, state.State{
ErrGo: "go failed",
ErrGolangCILint: "golangcilint failed",
ErrTempl: "templ failed",
}, s.Get())
}

func TestStateMsg(t *testing.T) {
s := state.NewTracker()

require.Equal(t, state.State{}, s.Get())

s.SetErrGo("go failed")
s.SetErrGolangCILint("golangcilint failed")
s.SetErrTempl("templ failed")

require.Equal(t, state.State{
ErrGo: "go failed",
ErrGolangCILint: "golangcilint failed",
ErrTempl: "templ failed",
}, s.Get())

require.Equal(t, "templ failed", s.Get().Msg())
s.SetErrTempl("")
require.Equal(t, "golangcilint failed", s.Get().Msg())
s.SetErrGolangCILint("")
require.Equal(t, "go failed", s.Get().Msg())
s.SetErrGo("")
require.Zero(t, s.Get().Msg())
}

func TestStateReset(t *testing.T) {
s := state.NewTracker()
require.False(t, s.Get().IsErr())

require.Equal(t, state.State{}, s.Get())

s.SetErrGo("go failed")
s.SetErrGolangCILint("golangcilint failed")
s.SetErrTempl("templ failed")
require.True(t, s.Get().IsErr())

s.Reset()
require.Equal(t, state.State{}, s.Get())
require.False(t, s.Get().IsErr())
}

0 comments on commit 1561fd8

Please sign in to comment.