diff --git a/.gitignore b/.gitignore index bf69d70..8a39a0e 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ cacao sqlite.db cookie *.mmdb +*.crt +*.key diff --git a/README.md b/README.md index df27cef..340eab7 100644 --- a/README.md +++ b/README.md @@ -9,5 +9,5 @@ Candy Server with WebUI cacao # loglevel=[debug] listen=[127.0.0.1:8080] storage=[/var/lib/cacao] -cacao --loglevel=debug --listen=127.0.0.1:8080 --stroage=/var/lib/cacao +cacao --loglevel=debug --listen=127.0.0.1:8080 --storage=/var/lib/cacao ``` diff --git a/candy/location.go b/candy/location.go index dafd453..8442a5c 100644 --- a/candy/location.go +++ b/candy/location.go @@ -4,13 +4,13 @@ import ( "crypto/tls" "net" "net/http" - "os" "path" "github.com/ipinfo/go/v2/ipinfo" "github.com/ipinfo/go/v2/ipinfo/cache" "github.com/lanthora/cacao/argp" "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/util" "github.com/oschwald/geoip2-golang" ) @@ -59,22 +59,9 @@ func ipinfoLocation(ip net.IP) (country, region string, ok bool) { return } -func findFileByExtFromDir(dir string, ext string) (string, error) { - files, err := os.ReadDir(dir) - if err != nil { - return "", err - } - for _, file := range files { - if path.Ext(file.Name()) == ext { - return file.Name(), nil - } - } - return "", os.ErrNotExist -} - func mmdbLocation(ip net.IP) (country, region string, ok bool) { storageDir := argp.Get("storage", ".") - filename, err := findFileByExtFromDir(storageDir, ".mmdb") + filename, err := util.FindFileByExtFromDir(storageDir, ".mmdb") if err != nil { return } diff --git a/main.go b/main.go index 962a1ab..3497f99 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,15 @@ package main import ( + "path" + "github.com/gin-gonic/gin" "github.com/lanthora/cacao/api" "github.com/lanthora/cacao/argp" "github.com/lanthora/cacao/candy" "github.com/lanthora/cacao/frontend" "github.com/lanthora/cacao/logger" + "github.com/lanthora/cacao/util" ) func init() { @@ -14,9 +17,6 @@ func init() { } func main() { - addr := argp.Get("listen", ":80") - logger.Info("listen=[%v]", addr) - r := gin.New() r.Use(candy.WebsocketMiddleware(), api.LoginMiddleware(), api.AdminMiddleware()) @@ -60,7 +60,20 @@ func main() { r.NoRoute(frontend.Static) - if err := r.Run(addr); err != nil { - logger.Fatal("service run failed: %v", err) + storageDir := argp.Get("storage", ".") + crtFilename, findCrtErr := util.FindFileByExtFromDir(storageDir, ".crt") + keyFilename, findKeyErr := util.FindFileByExtFromDir(storageDir, ".key") + if findCrtErr == nil && findKeyErr == nil { + addr := argp.Get("listen", ":443") + logger.Info("listen=[%v]", addr) + if err := r.RunTLS(addr, path.Join(storageDir, crtFilename), path.Join(storageDir, keyFilename)); err != nil { + logger.Fatal("tls service run failed: %v", err) + } + } else { + addr := argp.Get("listen", ":80") + logger.Info("listen=[%v]", addr) + if err := r.Run(addr); err != nil { + logger.Fatal("service run failed: %v", err) + } } } diff --git a/util/file.go b/util/file.go new file mode 100644 index 0000000..e436013 --- /dev/null +++ b/util/file.go @@ -0,0 +1,19 @@ +package util + +import ( + "os" + "path" +) + +func FindFileByExtFromDir(dir string, ext string) (string, error) { + files, err := os.ReadDir(dir) + if err != nil { + return "", err + } + for _, file := range files { + if path.Ext(file.Name()) == ext { + return file.Name(), nil + } + } + return "", os.ErrNotExist +}