Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add live proxy feature #512

Merged
merged 6 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions air_example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ clean_on_exit = true
[screen]
clear_on_rebuild = true
keep_scroll = true

# Enable live-reloading on the browser.
[proxy]
enabled = true
proxy_port = 8090
app_port = 8080
12 changes: 9 additions & 3 deletions runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Config struct {
Log cfgLog `toml:"log"`
Misc cfgMisc `toml:"misc"`
Screen cfgScreen `toml:"screen"`
Proxy cfgProxy `toml:"proxy"`
}

type cfgBuild struct {
Expand Down Expand Up @@ -96,6 +97,12 @@ type cfgScreen struct {
KeepScroll bool `toml:"keep_scroll"`
}

type cfgProxy struct {
Enabled bool `toml:"enabled"`
ProxyPort int `toml:"proxy_port"`
AppPort int `toml:"app_port"`
}

type sliceTransformer struct{}

func (t sliceTransformer) Transformer(typ reflect.Type) func(dst, src reflect.Value) error {
Expand Down Expand Up @@ -350,10 +357,9 @@ func (c *Config) killDelay() time.Duration {
// interpret as milliseconds if less than the value of 1 millisecond
if c.Build.KillDelay < time.Millisecond {
return c.Build.KillDelay * time.Millisecond
} else {
// normalize kill delay to milliseconds
return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond
}
// normalize kill delay to milliseconds
return time.Duration(c.Build.KillDelay.Milliseconds()) * time.Millisecond
}

func (c *Config) binPath() string {
Expand Down
15 changes: 15 additions & 0 deletions runner/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
// Engine ...
type Engine struct {
config *Config
proxy *Proxy
logger *logger
watcher filenotify.FileWatcher
debugMode bool
Expand Down Expand Up @@ -48,6 +49,7 @@ func NewEngineWithConfig(cfg *Config, debugMode bool) (*Engine, error) {
}
e := Engine{
config: cfg,
proxy: NewProxy(&cfg.Proxy),
logger: logger,
watcher: watcher,
debugMode: debugMode,
Expand Down Expand Up @@ -310,6 +312,11 @@ func (e *Engine) isModified(filename string) bool {

// Endless loop and never return
func (e *Engine) start() {
if e.config.Proxy.Enabled {
go e.proxy.Run()
e.mainLog("Proxy server listening on http://localhost%s", e.proxy.server.Addr)
}

e.running = true
firstRunCh := make(chan bool, 1)
firstRunCh <- true
Expand Down Expand Up @@ -535,6 +542,9 @@ func (e *Engine) runBin() error {
cmd, stdout, stderr, _ := e.startCmd(command)
processExit := make(chan struct{})
e.mainDebug("running process pid %v", cmd.Process.Pid)
if e.config.Proxy.Enabled {
e.proxy.Reload()
}

wg.Add(1)
atomic.AddUint64(&e.round, 1)
Expand Down Expand Up @@ -579,6 +589,11 @@ func (e *Engine) cleanup() {
e.mainLog("cleaning...")
defer e.mainLog("see you again~")

if e.config.Proxy.Enabled {
e.mainDebug("powering down the proxy...")
e.proxy.Stop()
}

e.withLock(func() {
close(e.binStopCh)
e.binStopCh = make(chan bool)
Expand Down
3 changes: 3 additions & 0 deletions runner/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,9 @@ func Test(t *testing.T) {
t.Log("testing")
}
`)
if err != nil {
t.Fatal(err)
}
// run sed
// check the file is exist
if _, err := os.Stat(dftTOML); err != nil {
Expand Down
157 changes: 157 additions & 0 deletions runner/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package runner

import (
"bytes"
"errors"
"fmt"
"io"
"log"
"net/http"
"strconv"
"strings"
"syscall"
"time"
)

type Reloader interface {
AddSubscriber() *Subscriber
RemoveSubscriber(id int)
Reload()
Stop()
}

type Proxy struct {
server *http.Server
config *cfgProxy
stream Reloader
}

func NewProxy(cfg *cfgProxy) *Proxy {
p := &Proxy{
config: cfg,
server: &http.Server{
Addr: fmt.Sprintf(":%d", cfg.ProxyPort),
},
stream: NewProxyStream(),
}
return p
}

func (p *Proxy) Run() {
http.HandleFunc("/", p.proxyHandler)
http.HandleFunc("/internal/reload", p.reloadHandler)
if err := p.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatal(err)
}
}

func (p *Proxy) Stop() {
p.server.Close()
p.stream.Stop()
}

func (p *Proxy) Reload() {
p.stream.Reload()
}

func (p *Proxy) injectLiveReload(origURL string, respBody io.ReadCloser) string {
buf := new(bytes.Buffer)
if _, err := buf.ReadFrom(respBody); err != nil {
log.Fatalf("failed to convert request body to bytes buffer, err: %+v\n", err)
}
s := buf.String()

body := strings.LastIndex(s, "</body>")
if body == -1 {
log.Fatal("invalid html page, missing the body tag")
}
script := `
<script>
const parser = new DOMParser();
const proxyURL = "http://localhost:%d";
new EventSource(proxyURL + "/internal/reload").onmessage = () => {
fetch(proxyURL + "%s").then(res => res.text()).then(resStr => {
const newPage = parser.parseFromString(resStr, "text/html");
document.replaceChild(newPage.documentElement, document.documentElement);
});
};
</script>
`
parsedScript := fmt.Sprintf(script, p.config.ProxyPort, origURL)

s = s[:body] + parsedScript + s[body:]
return s
}

func (p *Proxy) proxyHandler(w http.ResponseWriter, r *http.Request) {
url := fmt.Sprintf("http://localhost:%d%s", p.config.AppPort, r.URL.Path)
req, err := http.NewRequest(r.Method, url, r.Body)
if err != nil {
log.Fatalf("proxy could not create request, err: %+v\n", err)
}
req.Header.Set("X-Forwarded-For", r.RemoteAddr)

client := &http.Client{}
var resp *http.Response
for i := 0; i < 10; i++ {
resp, err = client.Do(req)
if err == nil {
break
}
if !errors.Is(err, syscall.ECONNREFUSED) {
log.Fatalf("proxy failed to call %s, err: %+v\n", url, err)
}
time.Sleep(100 * time.Millisecond)
}
defer resp.Body.Close()

// copy all headers except Content-Length
for k, vv := range resp.Header {
for _, v := range vv {
if k == "Content-Length" {
continue
}
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)

if strings.Contains(resp.Header.Get("Content-Type"), "text/html") {
s := p.injectLiveReload(r.URL.String(), resp.Body)
w.Header().Set("Content-Length", strconv.Itoa((len([]byte(s)))))
if _, err := io.WriteString(w, s); err != nil {
log.Fatalf("proxy failed injected live reloading script, err: %+v\n", err)
}
} else {
w.Header().Set("Content-Length", resp.Header.Get("Content-Length"))
if _, err := io.Copy(w, resp.Body); err != nil {
log.Fatalf("proxy failed to forward request, err: %+v\n", err)
}
}
}

func (p *Proxy) reloadHandler(w http.ResponseWriter, r *http.Request) {
flusher, err := w.(http.Flusher)
if !err {
http.Error(w, "Streaming unsupported!", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

sub := p.stream.AddSubscriber()
go func() {
<-r.Context().Done()
p.stream.RemoveSubscriber(sub.id)
}()

w.WriteHeader(http.StatusOK)
flusher.Flush()

for range sub.reloadCh {
fmt.Fprintf(w, "data: reload\n\n")
flusher.Flush()
}
}
50 changes: 50 additions & 0 deletions runner/proxy_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package runner

import (
"sync"
)

type ProxyStream struct {
sync.Mutex
subscribers map[int]*Subscriber
count int
}

type Subscriber struct {
id int
reloadCh chan struct{}
}

func NewProxyStream() *ProxyStream {
return &ProxyStream{subscribers: make(map[int]*Subscriber)}
}

func (stream *ProxyStream) Stop() {
for id := range stream.subscribers {
stream.RemoveSubscriber(id)
}
stream.count = 0
}

func (stream *ProxyStream) AddSubscriber() *Subscriber {
stream.Lock()
defer stream.Unlock()
stream.count++

sub := &Subscriber{id: stream.count, reloadCh: make(chan struct{})}
stream.subscribers[stream.count] = sub
return sub
}

func (stream *ProxyStream) RemoveSubscriber(id int) {
stream.Lock()
defer stream.Unlock()
close(stream.subscribers[id].reloadCh)
delete(stream.subscribers, id)
}

func (stream *ProxyStream) Reload() {
for _, sub := range stream.subscribers {
sub.reloadCh <- struct{}{}
}
}
66 changes: 66 additions & 0 deletions runner/proxy_stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package runner

import (
"sync"
"testing"
)

func find(s map[int]*Subscriber, id int) bool {
for _, sub := range s {
if sub.id == id {
return true
}
}
return false
}

func TestProxyStream(t *testing.T) {
stream := NewProxyStream()

var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_ = stream.AddSubscriber()
}(i)
}
wg.Wait()

if got, exp := len(stream.subscribers), 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}

go func() {
stream.Reload()
}()

reloadCount := 0
for _, sub := range stream.subscribers {
wg.Add(1)
go func(sub *Subscriber) {
defer wg.Done()
<-sub.reloadCh
reloadCount++
}(sub)
}
wg.Wait()

if got, exp := reloadCount, 10; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}

stream.RemoveSubscriber(2)
stream.AddSubscriber()
if got, exp := find(stream.subscribers, 2), false; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)
}
if got, exp := find(stream.subscribers, 11), true; got != exp {
t.Errorf("expected subscriber found to be %t but got %t", exp, got)
}

stream.Stop()
if got, exp := len(stream.subscribers), 0; got != exp {
t.Errorf("expected %d but got %d", exp, got)
}
}
Loading