From 67b04c90c505c3bae1709a2443888fcdf93d3dda Mon Sep 17 00:00:00 2001 From: kiootic Date: Thu, 30 Nov 2023 14:40:43 +0800 Subject: [PATCH] Ignore ports for custom domain --- internal/config/app_domain.go | 2 +- internal/handler/site/handler.go | 27 ++++++--- internal/handler/site/handler_test.go | 86 +++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 8 deletions(-) create mode 100644 internal/handler/site/handler_test.go diff --git a/internal/config/app_domain.go b/internal/config/app_domain.go index 19eeb7f..7c47e21 100644 --- a/internal/config/app_domain.go +++ b/internal/config/app_domain.go @@ -1,6 +1,6 @@ package config type AppDomainConfig struct { - Domain string `json:"domain" pageship:"required,max=200,hostname_port,lowercase"` + Domain string `json:"domain" pageship:"required,max=200,hostname_rfc1123,lowercase"` Site string `json:"site" pageship:"required,dnsLabel"` } diff --git a/internal/handler/site/handler.go b/internal/handler/site/handler.go index dfc8794..693cef6 100644 --- a/internal/handler/site/handler.go +++ b/internal/handler/site/handler.go @@ -48,7 +48,7 @@ func NewHandler(ctx context.Context, logger *zap.Logger, domainResolver domain.R middlewares: conf.Middlewares, } - cache, err := cache.NewCache(cacheSize, cacheTTL, h.doResolve) + cache, err := cache.NewCache(cacheSize, cacheTTL, h.doResolveHandler) if err != nil { return nil, fmt.Errorf("setup cache: %w", err) } @@ -57,13 +57,17 @@ func NewHandler(ctx context.Context, logger *zap.Logger, domainResolver domain.R return h, nil } -func (h *Handler) resolveSite(hostname string) (*SiteHandler, error) { - return h.cache.Load(hostname) +func (h *Handler) resolveHandler(host string) (*SiteHandler, error) { + return h.cache.Load(host) } -func (h *Handler) doResolve(hostname string) (*SiteHandler, error) { - matchedID, ok := h.hostPattern.MatchString(hostname) +func (h *Handler) ResolveSite(host string) (*site.Descriptor, error) { + matchedID, ok := h.hostPattern.MatchString(host) if !ok { + hostname, _, err := net.SplitHostPort(host) + if err != nil { + hostname = host + } id, err := h.domainResolver.Resolve(h.ctx, hostname) if errors.Is(err, domain.ErrDomainNotFound) { return nil, site.ErrSiteNotFound @@ -78,6 +82,15 @@ func (h *Handler) doResolve(hostname string) (*SiteHandler, error) { return nil, err } + return desc, nil +} + +func (h *Handler) doResolveHandler(host string) (*SiteHandler, error) { + desc, err := h.ResolveSite(host) + if err != nil { + return nil, err + } + return NewSiteHandler(desc, h.middlewares), nil } @@ -89,7 +102,7 @@ func (h *Handler) CheckValidDomain(hostname string) error { if h.siteResolver.IsWildcard() { return nil } - _, err := h.resolveSite(hostname) + _, err := h.ResolveSite(hostname) return err } @@ -112,7 +125,7 @@ func (h *Handler) checkAuthz(r *http.Request, handler *SiteHandler) error { } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - handler, err := h.resolveSite(r.Host) + handler, err := h.resolveHandler(r.Host) if errors.Is(err, site.ErrSiteNotFound) { http.NotFound(w, r) return diff --git a/internal/handler/site/handler_test.go b/internal/handler/site/handler_test.go new file mode 100644 index 0000000..d3a0603 --- /dev/null +++ b/internal/handler/site/handler_test.go @@ -0,0 +1,86 @@ +package site_test + +import ( + "context" + "errors" + "testing" + + "github.com/oursky/pageship/internal/domain" + sitehandler "github.com/oursky/pageship/internal/handler/site" + "github.com/oursky/pageship/internal/site" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +type mockDomainResolver struct { + domains map[string]string +} + +func (*mockDomainResolver) Kind() string { return "mock" } + +func (r *mockDomainResolver) Resolve(ctx context.Context, hostname string) (string, error) { + if id, ok := r.domains[hostname]; ok { + return id, nil + } + return "", domain.ErrDomainNotFound +} + +type mockSiteResolver struct { + wildcard bool + sites map[string]string +} + +func (r *mockSiteResolver) IsWildcard() bool { return r.wildcard } + +func (*mockSiteResolver) Kind() string { return "mock" } + +func (r *mockSiteResolver) Resolve(ctx context.Context, matchedID string) (*site.Descriptor, error) { + if id, ok := r.sites[matchedID]; ok { + return &site.Descriptor{ID: id}, nil + } + return nil, site.ErrSiteNotFound +} + +func TestHandleResolution(t *testing.T) { + hostPattern := "http://*.pageship.local" + domainResolver := &mockDomainResolver{ + domains: map[string]string{ + "example.com": "example", + "dev.example.com": "dev.example", + }, + } + siteResolver := &mockSiteResolver{ + sites: map[string]string{ + "example": "example/main", + "dev.example": "example/dev", + "test": "test/main", + "dev.test": "test/dev", + }, + } + + handler, err := sitehandler.NewHandler(context.Background(), zap.NewNop(), + domainResolver, siteResolver, sitehandler.HandlerConfig{HostPattern: hostPattern}) + assert.NoError(t, err) + + resolve := func(host string) any { + desc, err := handler.ResolveSite(host) + if errors.Is(err, site.ErrSiteNotFound) { + return nil + } else if err != nil { + panic(err) + } + return desc + } + + assert.Equal(t, resolve("example.com"), &site.Descriptor{ID: "example/main"}) + assert.Equal(t, resolve("example.com:8001"), &site.Descriptor{ID: "example/main"}) + assert.Equal(t, resolve("dev.example.com"), &site.Descriptor{ID: "example/dev"}) + assert.Equal(t, resolve("example.local"), nil) + + assert.Equal(t, resolve("test.pageship.local"), &site.Descriptor{ID: "test/main"}) + assert.Equal(t, resolve("dev.test.pageship.local"), &site.Descriptor{ID: "test/dev"}) + assert.Equal(t, resolve("dev.test.pageship.local:8001"), &site.Descriptor{ID: "test/dev"}) + assert.Equal(t, resolve("staging.test.pageship.local"), nil) + assert.Equal(t, resolve("pageship.local"), nil) + assert.Equal(t, resolve("main.pageship.local"), nil) +}