diff --git a/main.go b/main.go index a9b1fa4..fd44aaf 100644 --- a/main.go +++ b/main.go @@ -30,30 +30,39 @@ func (l LastValues) String() string { return fmt.Sprintf("Confirmed: %d, Deaths: %d, Recovered: %d", l.Confirmed, l.Deaths, l.Recovered) } +// fetch runs on its own goroutine +func fetch(ctx context.Context, req *http.Request, ch chan LastValues) error { + defer close(ch) + var r struct { + Data LastValues `json:"data"` + } + body, err := http.DefaultClient.Do(req) + if err != nil { + log.Printf("fetchCOVID19Data: %v", err) + return err + } + defer body.Body.Close() + err = json.NewDecoder(body.Body).Decode(&r) + if err != nil { + log.Printf("fetchCOVID19Data: %v", err) + return err + } + + select { + case ch <- LastValues{r.Data.Confirmed, r.Data.Deaths, r.Data.Recovered}: + case <-ctx.Done(): + } + return nil +} + // fetchCOVID19Data ... -func fetchCOVID19Data(ctx context.Context, req *http.Request) <-chan LastValues { +func fetchCOVID19Data(ctx context.Context) <-chan LastValues { ch := make(chan LastValues) - go func() { - var r struct { - Data LastValues `json:"data"` - } - body, err := http.DefaultClient.Do(req) - if err != nil { - log.Printf("fetchCOVID19Data: %v", err) - return - } - defer body.Body.Close() - err = json.NewDecoder(body.Body).Decode(&r) - if err != nil { - log.Printf("fetchCOVID19Data: %v", err) - return - } - - select { - case ch <- LastValues{r.Data.Confirmed, r.Data.Deaths, r.Data.Recovered}: - case <-ctx.Done(): - } - }() + req, err := http.NewRequestWithContext(ctx, "GET", URL, nil) + if err != nil { + panic("internal error - misuse of NewRequestWithContext") + } + go fetch(ctx, req, ch) return ch } @@ -62,12 +71,8 @@ func routine(sleep time.Duration) { const timeout = time.Second * 2 for { ctx, cancel := context.WithTimeout(context.Background(), timeout) - req, err := http.NewRequestWithContext(ctx, "GET", URL, nil) - if err != nil { - panic("internal error - misuse of NewRequestWithContext") - } select { - case newVal := <-fetchCOVID19Data(ctx, req): + case newVal := <-fetchCOVID19Data(ctx): if cachedVal != newVal { err := beeep.Alert("COVID-19 Brazil", newVal.String(), IMG) if err != nil { diff --git a/main_test.go b/main_test.go index 4f52a23..4b35bcd 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,11 @@ package main import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" "testing" ) @@ -30,3 +35,35 @@ func TestLastValues_String(t *testing.T) { }) } } + +func TestFetch(t *testing.T) { + tests := []struct { + name string + input io.Reader + want bool + }{ + {"ok json", strings.NewReader(`{"data": {"confirmed": 10, "deaths": 10, "recovered": 10}}`), true}, + {"bad json", strings.NewReader(`{"data": "confirmed": 10, "deaths": 10, "recovered": 10}}`), false}, + } + for _, v := range tests { + v := v + t.Run(v.name, func(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(w, v.input) + })) + defer ts.Close() + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + panic("misuse of NewRequest") + } + ch := make(chan LastValues) + go func() { // drain the fetch output out + _ = <-ch + }() + if err := fetch(context.TODO(), req, ch); (err == nil) != v.want { + t.Errorf("fetch: expected: %v got: %v", v.want, err) + } + }) + } +}