diff --git a/cmd/main.go b/cmd/main.go index c6302af..0a34b74 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -85,7 +85,7 @@ func main() { } // Start the server - webServer := web.New(log, web.WebServerOptions{ + webServer := web.New(log, web.ServerOptions{ Addr: opts.Web.Host + ":" + fmt.Sprintf("%d", opts.Web.Port), Proto: opts.Web.Proto, ReadTimeout: opts.Timeouts.HTTPRead, diff --git a/src/service/service.go b/src/service/service.go index 7fb6646..03ff845 100644 --- a/src/service/service.go +++ b/src/service/service.go @@ -1,232 +1,245 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. - -// Package service provides methods to work with pastes and users. -// Methods of this package do not log or print out anything, they return -// errors instead. It is up to the user of the Service to handle the errors -// and provide useful information to the end user. -package service - -import ( - "errors" - "fmt" - "math/rand" - "strconv" - "time" - - "github.com/iliafrenkel/go-pb/src/store" - "golang.org/x/crypto/bcrypt" -) - -// Service type provides method to work with pastes and users. -type Service struct { - store store.Interface -} - -var ( - ErrPasteNotFound = errors.New("paste not found") - ErrUserNotFound = errors.New("user not found") - ErrPasteIsPrivate = errors.New("paste is private") - ErrPasteHasPassword = errors.New("paste has password") - ErrWrongPassword = errors.New("paste password is incorrect") - ErrStoreFailure = errors.New("store opertation failed") - ErrEmptyBody = errors.New("body is empty") - ErrWrongPrivacy = errors.New("privacy is wrong") - ErrWrongDuration = errors.New("wrong duration format") -) - -// PasteRequest is an input to Create method, normally comes from a web form. -type PasteRequest struct { - Title string `json:"title" form:"title"` - Body string `json:"body" form:"body" binding:"required"` - Expires string `json:"expires" form:"expires" binding:"required"` - DeleteAfterRead bool `json:"delete_after_read" form:"delete_after_read" binding:"-"` - Privacy string `json:"privacy" form:"privacy" binding:"required"` - Password string `json:"password" form:"password"` - Syntax string `json:"syntax" form:"syntax" binding:"required"` - UserID string `json:"user_id"` -} - -// Returns new Service with provided store as a back-end storage. -func New(store store.Interface) *Service { - var s *Service = new(Service) - s.store = store - rand.Seed(time.Now().UnixNano()) - - return s -} - -// Returns new Service with memory as a store. -func NewWithMemDB() *Service { - return New(store.NewMemDB()) -} - -// Returns new Service with postgres db as a store. -func NewWithPostgres(conn string) (*Service, error) { - s, err := store.NewPostgresDB(conn, true) - if err != nil { - return nil, err - } - return New(s), nil -} - -// Create new Paste from the request and save it in the store. -// Paste.Body is mandatory, Paste.Expires is default to never, Paste.Privacy -// must be on of ["private","public","unlisted"]. If password is provided it -// is stored as a hash. -func (s Service) NewPaste(pr PasteRequest) (store.Paste, error) { - created := time.Now() - expires := time.Time{} // zero time means no expiration, this is the default - - // Check that body is not empty - if pr.Body == "" { - return store.Paste{}, ErrEmptyBody - } - - // Privacy can only be "private", "public" or "unlisted" - if pr.Privacy != "private" && pr.Privacy != "public" && pr.Privacy != "unlisted" { - return store.Paste{}, ErrWrongPrivacy - } - - // We expect the expiration to be in the form of "nx" where "n" is a number - // and "x" is a time unit character: m for minute, h for hour, d for day, - // w for week, M for month and y for year. - if pr.Expires != "never" && len(pr.Expires) > 1 { - dur, err := strconv.Atoi(pr.Expires[:len(pr.Expires)-1]) - if err != nil { - return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: %s (%v)", ErrWrongDuration, pr.Expires, err) - } - switch pr.Expires[len(pr.Expires)-1] { - case 'm': //minutes - expires = created.Add(time.Duration(dur) * time.Minute) - case 'h': //hours - expires = created.Add(time.Duration(dur) * time.Hour) - case 'd': //days - expires = created.AddDate(0, 0, dur) - case 'w': //weeks - expires = created.AddDate(0, 0, dur*7) - case 'M': //months - expires = created.AddDate(0, dur, 0) - case 'y': //years - expires = created.AddDate(dur, 0, 0) - default: - return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: %s", ErrWrongDuration, pr.Expires) - } - } - // If password is not empty, hash it before storing - if pr.Password != "" { - hash, err := bcrypt.GenerateFromPassword([]byte(pr.Password), bcrypt.DefaultCost) - if err != nil { - return store.Paste{}, err - } - pr.Password = string(hash) - } - // If the user is known check that it is in our database and add if it's not - var usr store.User - var err error - if pr.UserID != "" { - usr, err = s.store.User(pr.UserID) - if err != nil || usr == (store.User{}) { - return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: user id [%s] (%v)", ErrUserNotFound, pr.UserID, err) - } - } - // Default syntax to "text" - if pr.Syntax == "" { - pr.Syntax = "text" - } - // Create a new paste and store it - paste := store.Paste{ - Title: pr.Title, - Body: pr.Body, - Expires: expires, - DeleteAfterRead: pr.DeleteAfterRead, - Privacy: pr.Privacy, - Password: pr.Password, - CreatedAt: created, - Syntax: pr.Syntax, - User: usr, - } - id, err := s.store.Create(paste) - if err != nil { - return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: (%v)", ErrStoreFailure, err) - } - // Get the paste back and return it - paste, err = s.store.Get(id) - if err != nil { - return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: (%v)", ErrStoreFailure, err) - } - return paste, nil -} - -// GetPaste returns a paste given encoded URL. -// If the paste is private GetPaste will check that it belongs to the user with -// provided uid. If password is given and the paste has password GetPaste will -// check that the password is correct. -func (s Service) GetPaste(url string, uid string, pwd string) (store.Paste, error) { - p := store.Paste{} - id, err := p.URL2ID(url) - if err != nil { - return store.Paste{}, err - } - p, err = s.store.Get(id) - if err != nil { - return p, fmt.Errorf("Service.GetPaste: %w: (%v)", ErrStoreFailure, err) - } - // Check if paste was not found - if p == (store.Paste{}) { - return p, fmt.Errorf("Service.GetPaste: %w: url [%s], id [%v]", ErrPasteNotFound, url, id) - } - // Check privacy - if p.Privacy == "private" && p.User.ID != uid { - return store.Paste{}, ErrPasteIsPrivate - } - // Check if password protected - if p.Password != "" && pwd == "" { - return store.Paste{}, ErrPasteHasPassword - } - // Check if password is correct - if p.Password != "" && bcrypt.CompareHashAndPassword([]byte(p.Password), []byte(pwd)) != nil { - return store.Paste{}, ErrWrongPassword - } - // Update the view count - p.Views += 1 - p, _ = s.store.Update(p) // we ignore the error here because we only update the view count - // Check if paste is a "burner" and delete it if yes - if p.DeleteAfterRead { - err = s.store.Delete(p.ID) - if err != nil { - return p, fmt.Errorf("Service.GetPaste: %w: (%v)", ErrStoreFailure, err) - } - } - return p, nil -} - -// GetOrUpdateUser saves the user in the store and returns it. -func (s Service) GetOrUpdateUser(usr store.User) (store.User, error) { - _, err := s.store.SaveUser(usr) - if err != nil { - return store.User{}, fmt.Errorf("Service.GetOrUpdateUser: %w: (%v)", ErrStoreFailure, err) - } - return usr, nil -} - -// UserPastes returns a list of the last 10 paste for a user. -func (s Service) UserPastes(uid string) ([]store.Paste, error) { - pastes, err := s.store.Find(store.FindRequest{ - UserID: uid, - Sort: "-created", - Since: time.Time{}, - Limit: 10, - Skip: 0, - }) - if err != nil { - return nil, fmt.Errorf("Service.UserPastes: %w: (%v)", ErrStoreFailure, err) - } - return pastes, nil -} - -// GetCount returns total count of pastes and users. -func (s Service) GetCount() (pastes, users int64) { - return s.store.Count() -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +// Package service provides methods to work with pastes and users. +// Methods of this package do not log or print out anything, they return +// errors instead. It is up to the user of the Service to handle the errors +// and provide useful information to the end user. +package service + +import ( + "errors" + "fmt" + "math/rand" + "strconv" + "time" + + "github.com/iliafrenkel/go-pb/src/store" + "golang.org/x/crypto/bcrypt" +) + +// Service type provides method to work with pastes and users. +type Service struct { + store store.Interface +} + +// ErrPasteNotFound and other common errors. +var ( + ErrPasteNotFound = errors.New("paste not found") + ErrUserNotFound = errors.New("user not found") + ErrPasteIsPrivate = errors.New("paste is private") + ErrPasteHasPassword = errors.New("paste has password") + ErrWrongPassword = errors.New("paste password is incorrect") + ErrStoreFailure = errors.New("store opertation failed") + ErrEmptyBody = errors.New("body is empty") + ErrWrongPrivacy = errors.New("privacy is wrong") + ErrWrongDuration = errors.New("wrong duration format") +) + +// PasteRequest is an input to Create method, normally comes from a web form. +type PasteRequest struct { + Title string `json:"title" form:"title"` + Body string `json:"body" form:"body" binding:"required"` + Expires string `json:"expires" form:"expires" binding:"required"` + DeleteAfterRead bool `json:"delete_after_read" form:"delete_after_read" binding:"-"` + Privacy string `json:"privacy" form:"privacy" binding:"required"` + Password string `json:"password" form:"password"` + Syntax string `json:"syntax" form:"syntax" binding:"required"` + UserID string `json:"user_id"` +} + +// New returns new Service with provided store as a back-end storage. +func New(store store.Interface) *Service { + var s *Service = new(Service) + s.store = store + rand.Seed(time.Now().UnixNano()) + + return s +} + +// NewWithMemDB returns new Service with memory as a store. +func NewWithMemDB() *Service { + return New(store.NewMemDB()) +} + +// NewWithPostgres returns new Service with postgres db as a store. +func NewWithPostgres(conn string) (*Service, error) { + s, err := store.NewPostgresDB(conn, true) + if err != nil { + return nil, err + } + return New(s), nil +} + +// parseExpiration tries to parse PasteRequest.Expires string and return +// corresponding time.Time. +// We expect the expiration to be in the form of "nx" where "n" is a number +// and "x" is a time unit character: m for minute, h for hour, d for day, +// w for week, M for month and y for year. +func (s Service) parseExpiration(exp string) (time.Time, error) { + res := time.Time{} + now := time.Now() + + if exp != "never" && len(exp) > 1 { + dur, err := strconv.Atoi(exp[:len(exp)-1]) + if err != nil { + return time.Time{}, fmt.Errorf("Service.parseExpiration: %w: %s (%v)", ErrWrongDuration, exp, err) + } + switch exp[len(exp)-1] { + case 'm': //minutes + res = now.Add(time.Duration(dur) * time.Minute) + case 'h': //hours + res = now.Add(time.Duration(dur) * time.Hour) + case 'd': //days + res = now.AddDate(0, 0, dur) + case 'w': //weeks + res = now.AddDate(0, 0, dur*7) + case 'M': //months + res = now.AddDate(0, dur, 0) + case 'y': //years + res = now.AddDate(dur, 0, 0) + default: + return time.Time{}, fmt.Errorf("Service.NewPaste: %w: %s", ErrWrongDuration, exp) + } + } + return res, nil +} + +// NewPaste creates new Paste from the request and saves it in the store. +// Paste.Body is mandatory, Paste.Expires is default to never, Paste.Privacy +// must be on of ["private","public","unlisted"]. If password is provided it +// is stored as a hash. +func (s Service) NewPaste(pr PasteRequest) (store.Paste, error) { + var err error + created := time.Now() + expires, err := s.parseExpiration(pr.Expires) + if err != nil { + return store.Paste{}, fmt.Errorf("Service.NewPaste: %w", err) + } + + // Check that body is not empty + if pr.Body == "" { + return store.Paste{}, ErrEmptyBody + } + + // Privacy can only be "private", "public" or "unlisted" + if pr.Privacy != "private" && pr.Privacy != "public" && pr.Privacy != "unlisted" { + return store.Paste{}, ErrWrongPrivacy + } + + // If password is not empty, hash it before storing + if pr.Password != "" { + hash, err := bcrypt.GenerateFromPassword([]byte(pr.Password), bcrypt.DefaultCost) + if err != nil { + return store.Paste{}, err + } + pr.Password = string(hash) + } + // If the user is known check that it is in our database and add if it's not + var usr store.User + if pr.UserID != "" { + usr, err = s.store.User(pr.UserID) + if err != nil || usr == (store.User{}) { + return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: user id [%s] (%v)", ErrUserNotFound, pr.UserID, err) + } + } + // Default syntax to "text" + if pr.Syntax == "" { + pr.Syntax = "text" + } + // Create a new paste and store it + paste := store.Paste{ + Title: pr.Title, + Body: pr.Body, + Expires: expires, + DeleteAfterRead: pr.DeleteAfterRead, + Privacy: pr.Privacy, + Password: pr.Password, + CreatedAt: created, + Syntax: pr.Syntax, + User: usr, + } + id, err := s.store.Create(paste) + if err != nil { + return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: (%v)", ErrStoreFailure, err) + } + // Get the paste back and return it + paste, err = s.store.Get(id) + if err != nil { + return store.Paste{}, fmt.Errorf("Service.NewPaste: %w: (%v)", ErrStoreFailure, err) + } + return paste, nil +} + +// GetPaste returns a paste given encoded URL. +// If the paste is private GetPaste will check that it belongs to the user with +// provided uid. If password is given and the paste has password GetPaste will +// check that the password is correct. +func (s Service) GetPaste(url string, uid string, pwd string) (store.Paste, error) { + p := store.Paste{} + id, err := p.URL2ID(url) + if err != nil { + return store.Paste{}, err + } + p, err = s.store.Get(id) + if err != nil { + return p, fmt.Errorf("Service.GetPaste: %w: (%v)", ErrStoreFailure, err) + } + // Check if paste was not found + if p == (store.Paste{}) { + return p, fmt.Errorf("Service.GetPaste: %w: url [%s], id [%v]", ErrPasteNotFound, url, id) + } + // Check privacy + if p.Privacy == "private" && p.User.ID != uid { + return store.Paste{}, ErrPasteIsPrivate + } + // Check if password protected + if p.Password != "" && pwd == "" { + return store.Paste{}, ErrPasteHasPassword + } + // Check if password is correct + if p.Password != "" && bcrypt.CompareHashAndPassword([]byte(p.Password), []byte(pwd)) != nil { + return store.Paste{}, ErrWrongPassword + } + // Update the view count + p.Views++ + p, _ = s.store.Update(p) // we ignore the error here because we only update the view count + // Check if paste is a "burner" and delete it if yes + if p.DeleteAfterRead { + err = s.store.Delete(p.ID) + if err != nil { + return p, fmt.Errorf("Service.GetPaste: %w: (%v)", ErrStoreFailure, err) + } + } + return p, nil +} + +// GetOrUpdateUser saves the user in the store and returns it. +func (s Service) GetOrUpdateUser(usr store.User) (store.User, error) { + _, err := s.store.SaveUser(usr) + if err != nil { + return store.User{}, fmt.Errorf("Service.GetOrUpdateUser: %w: (%v)", ErrStoreFailure, err) + } + return usr, nil +} + +// UserPastes returns a list of the last 10 paste for a user. +func (s Service) UserPastes(uid string) ([]store.Paste, error) { + pastes, err := s.store.Find(store.FindRequest{ + UserID: uid, + Sort: "-created", + Since: time.Time{}, + Limit: 10, + Skip: 0, + }) + if err != nil { + return nil, fmt.Errorf("Service.UserPastes: %w: (%v)", ErrStoreFailure, err) + } + return pastes, nil +} + +// GetCount returns total count of pastes and users. +func (s Service) GetCount() (pastes, users int64) { + return s.store.Count() +} diff --git a/src/service/service_test.go b/src/service/service_test.go index 2cd11ce..f7df99b 100644 --- a/src/service/service_test.go +++ b/src/service/service_test.go @@ -1,424 +1,424 @@ -package service - -import ( - "errors" - "os" - "testing" - "time" - - "github.com/iliafrenkel/go-pb/src/store" -) - -var svc *Service - -func TestMain(m *testing.M) { - svc = NewWithMemDB() - os.Exit(m.Run()) -} - -// Test new paste -func TestNewPaste(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if p.ID == 0 { - t.Error("expect paste to have an id") - } -} - -func TestNewPasteEmptyBody(t *testing.T) { - t.Parallel() - - _, err := svc.NewPaste(PasteRequest{}) - if err == nil { - t.Fatal("expected paste creation to fail") - } - if !errors.Is(err, ErrEmptyBody) { - t.Errorf("expected error to be [%v], got [%v]", ErrEmptyBody, err) - } -} - -func TestNewPasteEmptyPrivacy(t *testing.T) { - t.Parallel() - - _, err := svc.NewPaste(PasteRequest{Body: "Test body"}) - if err == nil { - t.Fatal("expected paste creation to fail") - } - if !errors.Is(err, ErrWrongPrivacy) { - t.Errorf("expected error to be [%v], got [%v]", ErrWrongPrivacy, err) - } -} - -func TestNewPasteWithPassword(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Body: "Test body", - Privacy: "public", - Password: "password", - }) - if err != nil { - t.Fatal("failed to create new paste") - } - if p.Password == "" || p.Password == "password" { - t.Errorf("expected password to be hashed, got %s", p.Password) - } -} - -func TestNewPasteWithUser(t *testing.T) { - t.Parallel() - - u := store.User{ - ID: "test_user", - Name: "Test User", - } - u, err := svc.GetOrUpdateUser(u) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - UserID: u.ID, - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if p.User.ID != u.ID { - t.Errorf("expect paste to have user id [%v], got [%v]", u.ID, p.User.ID) - } -} - -func TestNewPasteWithFakeUser(t *testing.T) { - t.Parallel() - - _, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - UserID: "non_existing_user", - }) - if err == nil { - t.Fatal("expected paste creation to fail") - } - if !errors.Is(err, ErrUserNotFound) { - t.Errorf("expected error to be [%v], got [%v]", ErrUserNotFound, err) - } -} - -func TestNewPasteWithExpirationMinutes(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "10m", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires) > 10*time.Minute { - t.Errorf("expected paste expiration to be less than 10 minutes, got %v", p.Expires) - } -} - -func TestNewPasteWithExpirationHours(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "3h", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires) > 3*time.Hour { - t.Errorf("expected paste expiration to be less than 3 hours, got %v", p.Expires) - } -} - -func TestNewPasteWithExpirationDays(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "5d", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires) > 5*24*time.Hour { - t.Errorf("expected paste expiration to be less than 5 days, got %v", p.Expires) - } -} - -func TestNewPasteWithExpirationWeeks(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "2w", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires) > 14*24*time.Hour { - t.Errorf("expected paste expiration to be less than 2 weeks, got %v", p.Expires) - } -} - -func TestNewPasteWithExpirationMonths(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "6M", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires.AddDate(0, -6, 0)) > time.Second { - t.Errorf("expected paste expiration to be less than 6 months, got %v", p.Expires) - } -} - -func TestNewPasteWithExpirationYears(t *testing.T) { - t.Parallel() - - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "2y", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - if time.Until(p.Expires.AddDate(-2, 0, 0)) > time.Second { - t.Errorf("expected paste expiration to be less than 2 years, got %v", p.Expires) - } -} - -func TestNewPasteWrongExpiration(t *testing.T) { - t.Parallel() - - _, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "abcdefg", - }) - if err == nil { - t.Fatal("expecte paste creation to fail") - } - if !errors.Is(err, ErrWrongDuration) { - t.Errorf("expected error to be [%v], got [%v]", ErrWrongDuration, err) - } - - _, err = svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Expires: "12g", - }) - if err == nil { - t.Fatal("expecte paste creation to fail") - } - if !errors.Is(err, ErrWrongDuration) { - t.Errorf("expected error to be [%v], got [%v]", ErrWrongDuration, err) - } -} - -// Test get paste -func TestGetPaste(t *testing.T) { - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - - paste, err := svc.GetPaste(p.URL(), "", "") - if err != nil { - t.Fatalf("failed to get the paste: %v", err) - } - if p.Title != paste.Title { - t.Errorf("expected paste titles to be equal, want [%s], got [%s]", p.Title, paste.Title) - } -} - -func TestGetPasteWrongURL(t *testing.T) { - _, err := svc.GetPaste("QwE-AsD", "", "") - if err == nil { - t.Fatalf("expected GetPaste to fail") - } -} - -func TestGetPasteDontExist(t *testing.T) { - t.Parallel() - - _, err := svc.GetPaste("QwEAsD12", "", "") - if err == nil { - t.Fatal("expected GetPaste to fail") - } - if !errors.Is(err, ErrPasteNotFound) { - t.Errorf("expected error to be [%v], got [%v]", ErrPasteNotFound, err) - } -} - -func TestGetPastePrivate(t *testing.T) { - t.Parallel() - - u := store.User{ - ID: "test_user_2", - Name: "Test User", - } - u, err := svc.GetOrUpdateUser(u) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "private", - UserID: u.ID, - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - _, err = svc.GetPaste(p.URL(), u.ID, "") - if err != nil { - t.Errorf("expected to get private paste, got [%v]", err) - } - - _, err = svc.GetPaste(p.URL(), "", "") - if err == nil { - t.Error("expected GetPaste to fail") - } - if !errors.Is(err, ErrPasteIsPrivate) { - t.Errorf("expected error to be [%v], got [%v]", ErrPasteIsPrivate, err) - } -} - -func TestGetPasteWithPassword(t *testing.T) { - t.Parallel() - - u := store.User{ - ID: "test_user_3", - Name: "Test User", - } - u, err := svc.GetOrUpdateUser(u) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - Password: "password", - UserID: u.ID, - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - _, err = svc.GetPaste(p.URL(), "", "password") - if err != nil { - t.Errorf("expected to get paste with password, got [%v]", err) - } - - _, err = svc.GetPaste(p.URL(), "", "") - if err == nil { - t.Error("expected GetPaste with password to fail") - } - if !errors.Is(err, ErrPasteHasPassword) { - t.Errorf("expected error to be [%v], got [%v]", ErrPasteHasPassword, err) - } - - _, err = svc.GetPaste(p.URL(), "", "12345") - if err == nil { - t.Error("expected GetPaste with password to fail") - } - if !errors.Is(err, ErrWrongPassword) { - t.Errorf("expected error to be [%v], got [%v]", ErrWrongPassword, err) - } -} - -func TestGetPasteDeleteAfterRead(t *testing.T) { - p, err := svc.NewPaste(PasteRequest{ - Title: "Test title", - Body: "Test body", - Privacy: "public", - DeleteAfterRead: true, - }) - if err != nil { - t.Fatalf("failed to create new paste: %v", err) - } - // Get the paste for the first time - paste, err := svc.GetPaste(p.URL(), "", "") - if err != nil { - t.Fatalf("failed to get the paste: %v", err) - } - if p.Title != paste.Title { - t.Errorf("expected paste titles to be equal, want [%s], got [%s]", p.Title, paste.Title) - } - // Try to get the paste again and check that it doesn't exist anymore - paste, err = svc.GetPaste(p.URL(), "", "") - if err == nil { - t.Fatalf("expected paste to be deleted, got [%+v]", paste) - } - if !errors.Is(err, ErrPasteNotFound) { - t.Errorf("expected error to be [%v], got [%v]", ErrPasteNotFound, err) - } -} - -// Test user pastes -func TestGetUserPastes(t *testing.T) { - t.Parallel() - - u := store.User{ - ID: "test_user_4", - Name: "Test User", - } - u, err := svc.GetOrUpdateUser(u) - if err != nil { - t.Fatalf("failed to create user: %v", err) - } - for i := 0; i < 12; i++ { - _, err := svc.NewPaste(PasteRequest{ - Body: "Test body", - Privacy: "public", - UserID: u.ID, - }) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - } - - pastes, err := svc.UserPastes(u.ID) - if err != nil { - t.Errorf("failed to get user pastes: %v", err) - } - if len(pastes) != 10 { - t.Errorf("expected to get 10 pastes, got %d", len(pastes)) - } -} +package service + +import ( + "errors" + "os" + "testing" + "time" + + "github.com/iliafrenkel/go-pb/src/store" +) + +var svc *Service + +func TestMain(m *testing.M) { + svc = NewWithMemDB() + os.Exit(m.Run()) +} + +// Test new paste +func TestNewPaste(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if p.ID == 0 { + t.Error("expect paste to have an id") + } +} + +func TestNewPasteEmptyBody(t *testing.T) { + t.Parallel() + + _, err := svc.NewPaste(PasteRequest{}) + if err == nil { + t.Fatal("expected paste creation to fail") + } + if !errors.Is(err, ErrEmptyBody) { + t.Errorf("expected error to be [%v], got [%v]", ErrEmptyBody, err) + } +} + +func TestNewPasteEmptyPrivacy(t *testing.T) { + t.Parallel() + + _, err := svc.NewPaste(PasteRequest{Body: "Test body"}) + if err == nil { + t.Fatal("expected paste creation to fail") + } + if !errors.Is(err, ErrWrongPrivacy) { + t.Errorf("expected error to be [%v], got [%v]", ErrWrongPrivacy, err) + } +} + +func TestNewPasteWithPassword(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Body: "Test body", + Privacy: "public", + Password: "password", + }) + if err != nil { + t.Fatal("failed to create new paste") + } + if p.Password == "" || p.Password == "password" { + t.Errorf("expected password to be hashed, got %s", p.Password) + } +} + +func TestNewPasteWithUser(t *testing.T) { + t.Parallel() + + u := store.User{ + ID: "test_user", + Name: "Test User", + } + u, err := svc.GetOrUpdateUser(u) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + UserID: u.ID, + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if p.User.ID != u.ID { + t.Errorf("expect paste to have user id [%v], got [%v]", u.ID, p.User.ID) + } +} + +func TestNewPasteWithFakeUser(t *testing.T) { + t.Parallel() + + _, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + UserID: "non_existing_user", + }) + if err == nil { + t.Fatal("expected paste creation to fail") + } + if !errors.Is(err, ErrUserNotFound) { + t.Errorf("expected error to be [%v], got [%v]", ErrUserNotFound, err) + } +} + +func TestNewPasteWithExpirationMinutes(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "10m", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires) > 10*time.Minute { + t.Errorf("expected paste expiration to be less than 10 minutes, got %v", p.Expires) + } +} + +func TestNewPasteWithExpirationHours(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "3h", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires) > 3*time.Hour { + t.Errorf("expected paste expiration to be less than 3 hours, got %v", p.Expires) + } +} + +func TestNewPasteWithExpirationDays(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "5d", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires) > 5*24*time.Hour { + t.Errorf("expected paste expiration to be less than 5 days, got %v", p.Expires) + } +} + +func TestNewPasteWithExpirationWeeks(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "2w", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires) > 14*24*time.Hour { + t.Errorf("expected paste expiration to be less than 2 weeks, got %v", p.Expires) + } +} + +func TestNewPasteWithExpirationMonths(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "6M", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires.AddDate(0, -6, 0)) > time.Second { + t.Errorf("expected paste expiration to be less than 6 months, got %v", p.Expires) + } +} + +func TestNewPasteWithExpirationYears(t *testing.T) { + t.Parallel() + + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "2y", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + if time.Until(p.Expires.AddDate(-2, 0, 0)) > time.Second { + t.Errorf("expected paste expiration to be less than 2 years, got %v", p.Expires) + } +} + +func TestNewPasteWrongExpiration(t *testing.T) { + t.Parallel() + + _, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "abcdefg", + }) + if err == nil { + t.Fatal("expecte paste creation to fail") + } + if !errors.Is(err, ErrWrongDuration) { + t.Errorf("expected error to be [%v], got [%v]", ErrWrongDuration, err) + } + + _, err = svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Expires: "12g", + }) + if err == nil { + t.Fatal("expecte paste creation to fail") + } + if !errors.Is(err, ErrWrongDuration) { + t.Errorf("expected error to be [%v], got [%v]", ErrWrongDuration, err) + } +} + +// Test get paste +func TestGetPaste(t *testing.T) { + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + + paste, err := svc.GetPaste(p.URL(), "", "") + if err != nil { + t.Fatalf("failed to get the paste: %v", err) + } + if p.Title != paste.Title { + t.Errorf("expected paste titles to be equal, want [%s], got [%s]", p.Title, paste.Title) + } +} + +func TestGetPasteWrongURL(t *testing.T) { + _, err := svc.GetPaste("QwE-AsD", "", "") + if err == nil { + t.Fatalf("expected GetPaste to fail") + } +} + +func TestGetPasteDontExist(t *testing.T) { + t.Parallel() + + _, err := svc.GetPaste("QwEAsD12", "", "") + if err == nil { + t.Fatal("expected GetPaste to fail") + } + if !errors.Is(err, ErrPasteNotFound) { + t.Errorf("expected error to be [%v], got [%v]", ErrPasteNotFound, err) + } +} + +func TestGetPastePrivate(t *testing.T) { + t.Parallel() + + u := store.User{ + ID: "test_user_2", + Name: "Test User", + } + u, err := svc.GetOrUpdateUser(u) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "private", + UserID: u.ID, + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + _, err = svc.GetPaste(p.URL(), u.ID, "") + if err != nil { + t.Errorf("expected to get private paste, got [%v]", err) + } + + _, err = svc.GetPaste(p.URL(), "", "") + if err == nil { + t.Error("expected GetPaste to fail") + } + if !errors.Is(err, ErrPasteIsPrivate) { + t.Errorf("expected error to be [%v], got [%v]", ErrPasteIsPrivate, err) + } +} + +func TestGetPasteWithPassword(t *testing.T) { + t.Parallel() + + u := store.User{ + ID: "test_user_3", + Name: "Test User", + } + u, err := svc.GetOrUpdateUser(u) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + Password: "password", + UserID: u.ID, + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + _, err = svc.GetPaste(p.URL(), "", "password") + if err != nil { + t.Errorf("expected to get paste with password, got [%v]", err) + } + + _, err = svc.GetPaste(p.URL(), "", "") + if err == nil { + t.Error("expected GetPaste with password to fail") + } + if !errors.Is(err, ErrPasteHasPassword) { + t.Errorf("expected error to be [%v], got [%v]", ErrPasteHasPassword, err) + } + + _, err = svc.GetPaste(p.URL(), "", "12345") + if err == nil { + t.Error("expected GetPaste with password to fail") + } + if !errors.Is(err, ErrWrongPassword) { + t.Errorf("expected error to be [%v], got [%v]", ErrWrongPassword, err) + } +} + +func TestGetPasteDeleteAfterRead(t *testing.T) { + p, err := svc.NewPaste(PasteRequest{ + Title: "Test title", + Body: "Test body", + Privacy: "public", + DeleteAfterRead: true, + }) + if err != nil { + t.Fatalf("failed to create new paste: %v", err) + } + // Get the paste for the first time + paste, err := svc.GetPaste(p.URL(), "", "") + if err != nil { + t.Fatalf("failed to get the paste: %v", err) + } + if p.Title != paste.Title { + t.Errorf("expected paste titles to be equal, want [%s], got [%s]", p.Title, paste.Title) + } + // Try to get the paste again and check that it doesn't exist anymore + paste, err = svc.GetPaste(p.URL(), "", "") + if err == nil { + t.Fatalf("expected paste to be deleted, got [%+v]", paste) + } + if !errors.Is(err, ErrPasteNotFound) { + t.Errorf("expected error to be [%v], got [%v]", ErrPasteNotFound, err) + } +} + +// Test user pastes +func TestGetUserPastes(t *testing.T) { + t.Parallel() + + u := store.User{ + ID: "test_user_4", + Name: "Test User", + } + u, err := svc.GetOrUpdateUser(u) + if err != nil { + t.Fatalf("failed to create user: %v", err) + } + for i := 0; i < 12; i++ { + _, err := svc.NewPaste(PasteRequest{ + Body: "Test body", + Privacy: "public", + UserID: u.ID, + }) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + } + + pastes, err := svc.UserPastes(u.ID) + if err != nil { + t.Errorf("failed to get user pastes: %v", err) + } + if len(pastes) != 10 { + t.Errorf("expected to get 10 pastes, got %d", len(pastes)) + } +} diff --git a/src/store/memory.go b/src/store/memory.go index d29778f..7ea36ea 100644 --- a/src/store/memory.go +++ b/src/store/memory.go @@ -1,150 +1,154 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. -package store - -import ( - "fmt" - "math/rand" - "sort" - "strings" - "sync" -) - -// MemDB is a memory storage that implements the store.Interface. -// Because it's a transient storage you will loose all the data once the -// process exits. It's not completely useless though. You can use it when a -// temporary sharing is needed or as a cache for another storage. -type MemDB struct { - pastes map[int64]Paste - users map[string]User - sync.RWMutex -} - -// NewMemDB initialises and returns an instance of MemDB. -func NewMemDB() *MemDB { - var s MemDB - s.pastes = make(map[int64]Paste) - s.users = make(map[string]User) - - return &s -} - -// Count returns total count of pastes and users. -func (m *MemDB) Count() (pastes, users int64) { - m.RLock() - defer m.RUnlock() - - return int64(len(m.pastes)), int64(len(m.users)) -} - -// Create creates and stores a new paste returning its ID. -func (m *MemDB) Create(p Paste) (id int64, err error) { - m.Lock() - defer m.Unlock() - - p.ID = rand.Int63() // #nosec - m.pastes[p.ID] = p - - return p.ID, nil -} - -// Delete deletes a paste by ID. -func (m *MemDB) Delete(id int64) error { - m.Lock() - defer m.Unlock() - - delete(m.pastes, id) - - return nil -} - -// Find return a sorted list of pastes for a given request. -func (m *MemDB) Find(req FindRequest) (pastes []Paste, err error) { - pastes = []Paste{} - - m.RLock() - // Find all the pastes for a user - for _, p := range m.pastes { - if p.User.ID == req.UserID && p.CreatedAt.After(req.Since) { - pastes = append(pastes, p) - } - } - m.RUnlock() - // Sort - sort.Slice(pastes, func(i, j int) bool { - switch req.Sort { - case "+created", "-created": - if strings.HasPrefix(req.Sort, "-") { - return pastes[i].CreatedAt.After(pastes[j].CreatedAt) - } - return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) - case "+expires", "-expires": - if strings.HasPrefix(req.Sort, "-") { - return pastes[i].Expires.After(pastes[j].Expires) - } - return pastes[i].Expires.Before(pastes[j].Expires) - case "+views", "-views": - if strings.HasPrefix(req.Sort, "-") { - return pastes[i].Views > pastes[j].Views - } - return pastes[i].Views <= pastes[j].Views - default: - return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) - } - }) - // Slice with skip and limit - skip := req.Skip - if skip > len(pastes) { - skip = len(pastes) - } - end := skip + req.Limit - if end > len(pastes) { - end = len(pastes) - } - - return pastes[skip:end], nil -} - -// Get returns a paste by ID. -func (m *MemDB) Get(id int64) (Paste, error) { - m.RLock() - defer m.RUnlock() - - return m.pastes[id], nil -} - -// SaveUser creates a new or updates an existing user. -func (m *MemDB) SaveUser(usr User) (id string, err error) { - m.Lock() - defer m.Unlock() - - m.users[usr.ID] = usr - - return usr.ID, nil -} - -// User returns a user by ID. -func (m *MemDB) User(id string) (User, error) { - m.RLock() - defer m.RUnlock() - - if usr, ok := m.users[id]; !ok { - return User{}, fmt.Errorf("MemDB.User: user not found") - } else { - return usr, nil - } -} - -func (m *MemDB) Update(p Paste) (Paste, error) { - m.RLock() - if _, ok := m.pastes[p.ID]; !ok { - m.RUnlock() - return Paste{}, nil - } - m.RUnlock() - m.Lock() - defer m.Unlock() - m.pastes[p.ID] = p - return p, nil -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +package store + +import ( + "fmt" + "math/rand" + "sort" + "strings" + "sync" +) + +// MemDB is a memory storage that implements the store.Interface. +// Because it's a transient storage you will loose all the data once the +// process exits. It's not completely useless though. You can use it when a +// temporary sharing is needed or as a cache for another storage. +type MemDB struct { + pastes map[int64]Paste + users map[string]User + sync.RWMutex +} + +// NewMemDB initialises and returns an instance of MemDB. +func NewMemDB() *MemDB { + var s MemDB + s.pastes = make(map[int64]Paste) + s.users = make(map[string]User) + + return &s +} + +// Count returns total count of pastes and users. +func (m *MemDB) Count() (pastes, users int64) { + m.RLock() + defer m.RUnlock() + + return int64(len(m.pastes)), int64(len(m.users)) +} + +// Create creates and stores a new paste returning its ID. +func (m *MemDB) Create(p Paste) (id int64, err error) { + m.Lock() + defer m.Unlock() + + p.ID = rand.Int63() // #nosec + m.pastes[p.ID] = p + + return p.ID, nil +} + +// Delete deletes a paste by ID. +func (m *MemDB) Delete(id int64) error { + m.Lock() + defer m.Unlock() + + delete(m.pastes, id) + + return nil +} + +// Find return a sorted list of pastes for a given request. +func (m *MemDB) Find(req FindRequest) (pastes []Paste, err error) { + pastes = []Paste{} + + m.RLock() + // Find all the pastes for a user + for _, p := range m.pastes { + if p.User.ID == req.UserID && p.CreatedAt.After(req.Since) { + pastes = append(pastes, p) + } + } + m.RUnlock() + // Sort + sort.Slice(pastes, func(i, j int) bool { + switch req.Sort { + case "+created", "-created": + if strings.HasPrefix(req.Sort, "-") { + return pastes[i].CreatedAt.After(pastes[j].CreatedAt) + } + return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) + case "+expires", "-expires": + if strings.HasPrefix(req.Sort, "-") { + return pastes[i].Expires.After(pastes[j].Expires) + } + return pastes[i].Expires.Before(pastes[j].Expires) + case "+views", "-views": + if strings.HasPrefix(req.Sort, "-") { + return pastes[i].Views > pastes[j].Views + } + return pastes[i].Views <= pastes[j].Views + default: + return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) + } + }) + // Slice with skip and limit + skip := req.Skip + if skip > len(pastes) { + skip = len(pastes) + } + end := skip + req.Limit + if end > len(pastes) { + end = len(pastes) + } + + return pastes[skip:end], nil +} + +// Get returns a paste by ID. +func (m *MemDB) Get(id int64) (Paste, error) { + m.RLock() + defer m.RUnlock() + + return m.pastes[id], nil +} + +// SaveUser creates a new or updates an existing user. +func (m *MemDB) SaveUser(usr User) (id string, err error) { + m.Lock() + defer m.Unlock() + + m.users[usr.ID] = usr + + return usr.ID, nil +} + +// User returns a user by ID. +func (m *MemDB) User(id string) (User, error) { + m.RLock() + defer m.RUnlock() + + var usr User + var ok bool + + if usr, ok = m.users[id]; !ok { + return User{}, fmt.Errorf("MemDB.User: user not found") + } + return usr, nil +} + +// Update updates existing paste. +func (m *MemDB) Update(p Paste) (Paste, error) { + m.RLock() + if _, ok := m.pastes[p.ID]; !ok { + m.RUnlock() + return Paste{}, nil + } + m.RUnlock() + m.Lock() + defer m.Unlock() + m.pastes[p.ID] = p + return p, nil +} diff --git a/src/store/memory_test.go b/src/store/memory_test.go index 0cd8ca3..6ae76ed 100644 --- a/src/store/memory_test.go +++ b/src/store/memory_test.go @@ -1,303 +1,288 @@ -package store - -import ( - "math/rand" - "sort" - "testing" - "time" -) - -// TestCount tests that we can count pastes and users correctly. -func TestCount(t *testing.T) { - t.Parallel() - - // We need a dedicated store because other test running in parallel - // will affect the counts. - m := NewMemDB() - - var usr User - var paste Paste - - // Generate a bunch of users and pastes - uCnt := rand.Int63n(10) - pCnt := rand.Int63n(20) - for i := int64(0); i < uCnt; i++ { - usr = randomUser() - _, err := m.SaveUser(usr) - if err != nil { - t.Fatalf("failed to save user: %v", err) - } - for j := int64(0); j < pCnt; j++ { - u, err := m.User(usr.ID) - if err != nil { - t.Fatalf("failed to get user: %v", err) - } - paste = randomPaste(u) - _, err = m.Create(paste) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - } - } - - // Check the counts - wantUsers := uCnt - wantPastes := uCnt * pCnt - gotPastes, gotUsers := m.Count() - - if wantUsers != gotUsers { - t.Errorf("users count is incorrect, want %d, got %d", wantUsers, gotUsers) - } - if wantPastes != gotPastes { - t.Errorf("pastes count is incorrect, want %d, got %d", wantPastes, gotPastes) - } -} - -// TestDelete tests that we can delete a paste. -func TestDelete(t *testing.T) { - t.Parallel() - - // Create random paste - paste := randomPaste(User{}) - _, err := mdb.Create(paste) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - // Delete the paste and check that it was indeed deleted. - err = mdb.Delete(paste.ID) - if err != nil { - t.Fatalf("failed to delete paste: %v", err) - } - p, err := mdb.Get(paste.ID) - if err != nil { - t.Fatalf("failed to get paste: %v", err) - } - if p != (Paste{}) { - t.Errorf("expected paste to be deleted but found %+v", p) - } -} - -// TestFind tests that we can find a paste using various parameters. -func TestFind(t *testing.T) { - // Create 2 users with 10 pastes each and 10 anonymous pastes - usr1 := randomUser() - usr2 := randomUser() - var pastes1 []Paste - for i := 0; i < 10; i++ { - p1 := randomPaste(usr1) - p1.CreatedAt = time.Now().AddDate(0, 0, -1*i) - p1.Expires = time.Now().AddDate(0, 1*i, 0) - p1.Views = int64(10 * i) - mdb.Create(p1) - pastes1 = append(pastes1, p1) - p2 := randomPaste(usr2) - p2.CreatedAt = time.Now().AddDate(0, 0, -1*i) - mdb.Create(p2) - p3 := randomPaste(User{}) - p3.CreatedAt = time.Now().AddDate(0, 0, -1*i) - mdb.Create(p3) - } - - // Check all pastes for a user - pastes, err := mdb.Find(FindRequest{ - UserID: usr1.ID, - Limit: 11, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != len(pastes1) { - t.Errorf("expected to find %d pastes, got %d", len(pastes1), len(pastes)) - } - // Check limit - pastes, err = mdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 5 { - t.Errorf("expected to find %d pastes, got %d", 5, len(pastes)) - } - // Check skip - pastes, err = mdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 6, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 4 { - t.Errorf("expected to find %d pastes, got %d", 4, len(pastes)) - } - // Check skip over limit - pastes, err = mdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 12, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 0 { - t.Errorf("expected to find %d pastes, got %d", 0, len(pastes)) - } - // Check sort by -created - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-created", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].CreatedAt.After(pastes[j].CreatedAt) - }) { - t.Errorf("expected pastes to be sorted by -created, got %+v", pastes) - } - // Check sort by +created - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+created", - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) - }) { - t.Errorf("expected pastes to be sorted by +created, got %+v", pastes) - } - // Check sort by -expires - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-expires", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Expires.After(pastes[j].Expires) - }) { - t.Errorf("expected pastes to be sorted by -expires, got %+v", pastes) - } - // Check sort by +expires - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+expires", - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Expires.Before(pastes[j].Expires) - }) { - t.Errorf("expected pastes to be sorted by +expires, got %+v", pastes) - } - // Check sort by -views - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-views", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Views > pastes[j].Views - }) { - t.Errorf("expected pastes to be sorted by -views, got %+v", pastes) - } - // Check sort by +views - pastes, err = mdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+views", - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Views < pastes[j].Views - }) { - t.Errorf("expected pastes to be sorted by +views, got %+v", pastes) - } -} - -func TestUpdate(t *testing.T) { - t.Parallel() - - // Create random paste - paste := randomPaste(User{}) - id, err := mdb.Create(paste) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - // Update the paste - paste, _ = mdb.Get(id) - paste.Views = 42 - p, _ := mdb.Update(paste) - - if p.ID != id { - t.Errorf("expected paste to have the same id [%d], got [%d]", id, p.ID) - } - - if p.Views != paste.Views { - t.Errorf("expected paste views to be updated to [%d], got [%d]", paste.Views, p.Views) - } -} - -func TestUpdateNonExisting(t *testing.T) { - t.Parallel() - - // Create random paste - paste := randomPaste(User{}) - p, _ := mdb.Update(paste) - - if p != (Paste{}) { - t.Errorf("expected paste to be empty, got [%+v]", p) - } -} - -func TestGetUser(t *testing.T) { - t.Parallel() - usr := randomUser() - id, err := mdb.SaveUser(usr) - if err != nil { - t.Errorf("failed to create user: %v", err) - } - u, err := mdb.User(id) - if err != nil { - t.Errorf("user not found: %v", err) - } - if usr.ID != u.ID || usr.Name != u.Name { - t.Errorf("expected user to be saved as [%+v], got [%+v]", usr, u) - } -} - -func TestGetUserNotExisting(t *testing.T) { - t.Parallel() - usr := randomUser() - u, err := mdb.User(usr.ID) - if err == nil { - t.Errorf("expected user to be not found") - } - if u != (User{}) { - t.Errorf("expected user to be empty, got %+v", u) - } - -} +package store + +import ( + "math/rand" + "sort" + "testing" + "time" +) + +// testCaseForFind used by TestFind test +type testCaseForFind struct { + name string // test case name + uid string // user id + sort string // sort direction + limit int // max records + skip int // records to skip + exp int // expected result length +} + +var findTestCases = []testCaseForFind{ + { + name: "All user pastes", + uid: "find_user_1", + sort: "", + limit: 11, + skip: 0, + exp: 10, + }, { + name: "Check limit", + uid: "find_user_2", + sort: "", + limit: 5, + skip: 0, + exp: 5, + }, { + name: "Check skip", + uid: "find_user_2", + sort: "", + limit: 5, + skip: 6, + exp: 4, + }, { + name: "Check skip over limit", + uid: "find_user_2", + sort: "", + limit: 5, + skip: 12, + exp: 0, + }, { + name: "Check sort by -created", + uid: "find_user_1", + sort: "-created", + limit: 10, + skip: 0, + exp: 10, + }, { + name: "Check sort by +created", + uid: "find_user_1", + sort: "+created", + limit: 10, + skip: 0, + exp: 10, + }, { + name: "Check sort by -expires", + uid: "find_user_1", + sort: "-expires", + limit: 10, + skip: 0, + exp: 10, + }, { + name: "Check sort by +expires", + uid: "find_user_1", + sort: "+expires", + limit: 10, + skip: 0, + exp: 10, + }, { + name: "Check sort by -views", + uid: "find_user_1", + sort: "-views", + limit: 10, + skip: 0, + exp: 10, + }, { + name: "Check sort by +views", + uid: "find_user_1", + sort: "+views", + limit: 10, + skip: 0, + exp: 10, + }, +} + +// TestCount tests that we can count pastes and users correctly. +func TestCount(t *testing.T) { + t.Parallel() + + // We need a dedicated store because other test running in parallel + // will affect the counts. + m := NewMemDB() + + var usr User + var paste Paste + + // Generate a bunch of users and pastes + uCnt := rand.Int63n(10) + pCnt := rand.Int63n(20) + for i := int64(0); i < uCnt; i++ { + usr = randomUser() + _, err := m.SaveUser(usr) + if err != nil { + t.Fatalf("failed to save user: %v", err) + } + for j := int64(0); j < pCnt; j++ { + u, err := m.User(usr.ID) + if err != nil { + t.Fatalf("failed to get user: %v", err) + } + paste = randomPaste(u) + _, err = m.Create(paste) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + } + } + + // Check the counts + wantUsers := uCnt + wantPastes := uCnt * pCnt + gotPastes, gotUsers := m.Count() + + if wantUsers != gotUsers { + t.Errorf("users count is incorrect, want %d, got %d", wantUsers, gotUsers) + } + if wantPastes != gotPastes { + t.Errorf("pastes count is incorrect, want %d, got %d", wantPastes, gotPastes) + } +} + +// TestDelete tests that we can delete a paste. +func TestDelete(t *testing.T) { + t.Parallel() + + // Create random paste + paste := randomPaste(User{}) + _, err := mdb.Create(paste) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + // Delete the paste and check that it was indeed deleted. + err = mdb.Delete(paste.ID) + if err != nil { + t.Fatalf("failed to delete paste: %v", err) + } + p, err := mdb.Get(paste.ID) + if err != nil { + t.Fatalf("failed to get paste: %v", err) + } + if p != (Paste{}) { + t.Errorf("expected paste to be deleted but found %+v", p) + } +} + +// TestFind tests that we can find a paste using various parameters. +func TestFind(t *testing.T) { + // Create 2 users with 10 pastes each and 10 anonymous pastes + usr1 := randomUser() + usr1.ID = "find_user_1" + usr2 := randomUser() + usr2.ID = "find_user_2" + for i := 0; i < 10; i++ { + p1 := randomPaste(usr1) + p1.CreatedAt = time.Now().AddDate(0, 0, -1*i) + p1.Expires = time.Now().AddDate(0, 1*i, 0) + p1.Views = int64(10 * i) + mdb.Create(p1) + p2 := randomPaste(usr2) + p2.CreatedAt = time.Now().AddDate(0, 0, -1*i) + mdb.Create(p2) + p3 := randomPaste(User{}) + p3.CreatedAt = time.Now().AddDate(0, 0, -1*i) + mdb.Create(p3) + } + + for _, tc := range findTestCases { + t.Run(tc.name, func(t *testing.T) { + pastes, err := mdb.Find(FindRequest{ + UserID: tc.uid, + Sort: tc.sort, + Limit: tc.limit, + Skip: tc.skip, + }) + if err != nil { + t.Fatalf("failed to find pastes: %v", err) + } + if len(pastes) != tc.exp { + t.Errorf("expected to find %d pastes, got %d", tc.exp, len(pastes)) + } + if tc.sort == "" { + return + } + if !sort.SliceIsSorted(pastes, func(i, j int) bool { + switch tc.sort { + case "-created": + return pastes[i].CreatedAt.After(pastes[j].CreatedAt) + case "+created": + return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) + case "-expires": + return pastes[i].Expires.After(pastes[j].Expires) + case "+expires": + return pastes[i].Expires.Before(pastes[j].Expires) + case "-views": + return pastes[i].Views > pastes[j].Views + case "+views": + return pastes[i].Views < pastes[j].Views + default: + return false + } + }) { + t.Errorf("expected pastes to be sorted by %s, got %+v", tc.sort, pastes) + } + }) + } +} + +func TestUpdate(t *testing.T) { + t.Parallel() + + // Create random paste + paste := randomPaste(User{}) + id, err := mdb.Create(paste) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + // Update the paste + paste, _ = mdb.Get(id) + paste.Views = 42 + p, _ := mdb.Update(paste) + + if p.ID != id { + t.Errorf("expected paste to have the same id [%d], got [%d]", id, p.ID) + } + + if p.Views != paste.Views { + t.Errorf("expected paste views to be updated to [%d], got [%d]", paste.Views, p.Views) + } +} + +func TestUpdateNonExisting(t *testing.T) { + t.Parallel() + + // Create random paste + paste := randomPaste(User{}) + p, _ := mdb.Update(paste) + + if p != (Paste{}) { + t.Errorf("expected paste to be empty, got [%+v]", p) + } +} + +func TestGetUser(t *testing.T) { + t.Parallel() + usr := randomUser() + id, err := mdb.SaveUser(usr) + if err != nil { + t.Errorf("failed to create user: %v", err) + } + u, err := mdb.User(id) + if err != nil { + t.Errorf("user not found: %v", err) + } + if usr.ID != u.ID || usr.Name != u.Name { + t.Errorf("expected user to be saved as [%+v], got [%+v]", usr, u) + } +} + +func TestGetUserNotExisting(t *testing.T) { + t.Parallel() + usr := randomUser() + u, err := mdb.User(usr.ID) + if err == nil { + t.Errorf("expected user to be not found") + } + if u != (User{}) { + t.Errorf("expected user to be empty, got %+v", u) + } + +} diff --git a/src/store/postgres.go b/src/store/postgres.go index e016072..e148109 100644 --- a/src/store/postgres.go +++ b/src/store/postgres.go @@ -1,167 +1,168 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. -package store - -import ( - "fmt" - "math/rand" - "strings" - - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/clause" -) - -// PostgresDB is a Postgres SQL databasse storage that implements the -// store.Interface. -type PostgresDB struct { - db *gorm.DB -} - -// NewPostgresDB initialises a new instance of PostgresDB and returns. -// It tries to establish a database connection specified by conn and if -// autoMigrate is true it will try and create/alter all the tables. -func NewPostgresDB(conn string, autoMigrate bool) (*PostgresDB, error) { - var pg PostgresDB - db, err := gorm.Open(postgres.Open(conn), &gorm.Config{}) - if err != nil { - return nil, fmt.Errorf("NewPostgresDB: failed to establish database connection: %w", err) - } - if autoMigrate { - err = db.AutoMigrate(&Paste{}) - } else { - if d, e := db.DB(); e == nil { - err = d.Ping() - } else { - err = e - } - } - - if err != nil { - return nil, fmt.Errorf("NewPostgresDB: %w", err) - } - - pg.db = db - - return &pg, nil -} - -// Count returns total count of pastes and users. -func (pg *PostgresDB) Count() (pastes, users int64) { - pg.db.Model(&Paste{}).Count(&pastes) - pg.db.Model(&User{}).Count(&users) - return -} - -// Create creates and stores a new paste returning its ID. -func (pg *PostgresDB) Create(p Paste) (id int64, err error) { - p.ID = rand.Int63() // #nosec - if p.User.ID == "" { - err = pg.db.Omit("user_id").Create(&p).Error - } else { - err = pg.db.Create(&p).Error - } - if err != nil { - return 0, fmt.Errorf("PostgresDB.Create: %w", err) - } - return p.ID, nil -} - -// Delete deletes a paste by ID. -func (pg *PostgresDB) Delete(id int64) error { - if id == 0 { - return fmt.Errorf("PostgresDB.Delete: id cannot be null") - } - tx := pg.db.Delete(&Paste{}, id) - err := tx.Error - if err != nil { - return fmt.Errorf("PostgresDB.Delete: %w", err) - } - if tx.RowsAffected == 0 { - return fmt.Errorf("PostgresDB.Delete: no rows deleted") - } - - return nil -} - -// Find return a sorted list of pastes for a given request. -func (pg *PostgresDB) Find(req FindRequest) (pastes []Paste, err error) { - sort := "created_at desc" - switch req.Sort { - case "+created", "-created": - sort = "created_at" - if strings.HasPrefix(req.Sort, "-") { - sort = "created_at desc" - } - case "+expires", "-expires": - sort = "expires" - if strings.HasPrefix(req.Sort, "-") { - sort = "expires desc" - } - case "+views", "-views": - sort = "views" - if strings.HasPrefix(req.Sort, "-") { - sort = "views desc" - } - } - - // err = pg.db.Limit(req.Limit).Offset(req.Skip).Find(&pastes, Paste{User: User{ID: req.UserID}}).Order(sort).Error - err = pg.db.Limit(req.Limit).Offset(req.Skip).Order(sort).Find(&pastes, "user_id = ?", req.UserID).Error - if err != nil { - return pastes, fmt.Errorf("PostgresDB.Find: %w", err) - } - return pastes, nil -} - -// Get returns a paste by ID. -func (pg *PostgresDB) Get(id int64) (Paste, error) { - var paste Paste - err := pg.db.Preload("User").Limit(1).Find(&paste, id).Error - if err != nil { - return paste, fmt.Errorf("PostgresDB.Get: %w", err) - } - - return paste, nil -} - -// SaveUser creates a new or updates an existing user. -func (pg *PostgresDB) SaveUser(usr User) (id string, err error) { - err = pg.db.Clauses(clause.OnConflict{ - UpdateAll: true, - }).Save(&usr).Error - if err != nil { - return "", fmt.Errorf("PostgresDB.SaveUser: %w", err) - } - id = usr.ID - return id, nil -} - -// User returns a user by ID. -func (pg *PostgresDB) User(id string) (User, error) { - var usr User - tx := pg.db.Limit(1).Find(&usr, User{ID: id}) - err := tx.Error - if err != nil { - return usr, fmt.Errorf("PostgresDB.User: %w", err) - } - if tx.RowsAffected == 0 { - return usr, fmt.Errorf("PostgresDB.User: user not found") - } - - return usr, err -} - -// Update saves the paste into database and returns it -func (pg *PostgresDB) Update(p Paste) (Paste, error) { - err := pg.db.First(&Paste{}, p.ID).Error - if err != nil { - return Paste{}, fmt.Errorf("PostgresDB.Update: %w", err) - } - err = pg.db.Save(&p).Error - if err != nil { - return Paste{}, fmt.Errorf("PostgresDB.Update: %w", err) - } - - return p, nil -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +package store + +import ( + "fmt" + "math/rand" + "strings" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// PostgresDB is a Postgres SQL databasse storage that implements the +// store.Interface. +type PostgresDB struct { + db *gorm.DB +} + +// NewPostgresDB initialises a new instance of PostgresDB and returns. +// It tries to establish a database connection specified by conn and if +// autoMigrate is true it will try and create/alter all the tables. +func NewPostgresDB(conn string, autoMigrate bool) (*PostgresDB, error) { + var pg PostgresDB + db, err := gorm.Open(postgres.Open(conn), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("NewPostgresDB: failed to establish database connection: %w", err) + } + if autoMigrate { + err = db.AutoMigrate(&Paste{}) + } else { + if d, e := db.DB(); e == nil { + err = d.Ping() + } else { + err = e + } + } + + if err != nil { + return nil, fmt.Errorf("NewPostgresDB: %w", err) + } + + pg.db = db + + return &pg, nil +} + +// Count returns total count of pastes and users. +func (pg *PostgresDB) Count() (pastes, users int64) { + pg.db.Model(&Paste{}).Count(&pastes) + pg.db.Model(&User{}).Count(&users) + return +} + +// Create creates and stores a new paste returning its ID. +func (pg *PostgresDB) Create(p Paste) (id int64, err error) { + p.ID = rand.Int63() // #nosec + if p.User.ID == "" { + err = pg.db.Omit("user_id").Create(&p).Error + } else { + err = pg.db.Create(&p).Error + } + if err != nil { + return 0, fmt.Errorf("PostgresDB.Create: %w", err) + } + return p.ID, nil +} + +// Delete deletes a paste by ID. +func (pg *PostgresDB) Delete(id int64) error { + if id == 0 { + return fmt.Errorf("PostgresDB.Delete: id cannot be null") + } + tx := pg.db.Delete(&Paste{}, id) + err := tx.Error + if err != nil { + return fmt.Errorf("PostgresDB.Delete: %w", err) + } + if tx.RowsAffected == 0 { + return fmt.Errorf("PostgresDB.Delete: no rows deleted") + } + + return nil +} + +// Find return a sorted list of pastes for a given request. +func (pg *PostgresDB) Find(req FindRequest) (pastes []Paste, err error) { + sort := "created_at desc" + switch req.Sort { + case "+created", "-created": + sort = "created_at" + if strings.HasPrefix(req.Sort, "-") { + sort = "created_at desc" + } + case "+expires", "-expires": + sort = "expires" + if strings.HasPrefix(req.Sort, "-") { + sort = "expires desc" + } + case "+views", "-views": + sort = "views" + if strings.HasPrefix(req.Sort, "-") { + sort = "views desc" + } + } + + // err = pg.db.Limit(req.Limit).Offset(req.Skip).Find(&pastes, Paste{User: User{ID: req.UserID}}).Order(sort).Error + err = pg.db.Limit(req.Limit).Offset(req.Skip).Order(sort).Find(&pastes, "user_id = ?", req.UserID).Error + if err != nil { + return pastes, fmt.Errorf("PostgresDB.Find: %w", err) + } + return pastes, nil +} + +// Get returns a paste by ID. +func (pg *PostgresDB) Get(id int64) (Paste, error) { + var paste Paste + err := pg.db.Preload("User").Limit(1).Find(&paste, id).Error + if err != nil { + return paste, fmt.Errorf("PostgresDB.Get: %w", err) + } + + return paste, nil +} + +// SaveUser creates a new or updates an existing user. +func (pg *PostgresDB) SaveUser(usr User) (id string, err error) { + err = pg.db.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Save(&usr).Error + if err != nil { + return "", fmt.Errorf("PostgresDB.SaveUser: %w", err) + } + id = usr.ID + return id, nil +} + +// User returns a user by ID. +func (pg *PostgresDB) User(id string) (User, error) { + var usr User + tx := pg.db.Limit(1).Find(&usr, User{ID: id}) + err := tx.Error + if err != nil { + return usr, fmt.Errorf("PostgresDB.User: %w", err) + } + if tx.RowsAffected == 0 { + return usr, fmt.Errorf("PostgresDB.User: user not found") + } + + return usr, err +} + +// Update saves the paste into database and returns it +func (pg *PostgresDB) Update(p Paste) (Paste, error) { + err := pg.db.First(&Paste{}, p.ID).Error + if err != nil { + return Paste{}, fmt.Errorf("PostgresDB.Update: %w", err) + } + err = pg.db.Save(&p).Error + if err != nil { + return Paste{}, fmt.Errorf("PostgresDB.Update: %w", err) + } + + return p, nil +} diff --git a/src/store/postgres_test.go b/src/store/postgres_test.go index f901cc2..a78de30 100644 --- a/src/store/postgres_test.go +++ b/src/store/postgres_test.go @@ -1,271 +1,172 @@ -package store - -import ( - "sort" - "testing" - "time" -) - -// TestDelete tests that we can delete a paste. -func TestDeletePDB(t *testing.T) { - t.Parallel() - // Create random paste - paste := randomPaste(randomUser()) - id, err := pdb.Create(paste) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - // Delete the paste and check that it was indeed deleted. - err = pdb.Delete(id) - if err != nil { - t.Fatalf("failed to delete paste: %v", err) - } - p, err := pdb.Get(id) - if err != nil { - t.Fatalf("failed to get paste: %v", err) - } - if p != (Paste{}) { - t.Errorf("expected paste to be deleted but found %+v", p) - } -} - -// TestDelete tests that we can't delete a paste if it doesn't exist. -func TestDeleteNonExistingPDB(t *testing.T) { - t.Parallel() - // Create random paste - paste := randomPaste(randomUser()) - // Delete the paste and check that it was indeed deleted. - err := pdb.Delete(paste.ID) - if err == nil { - t.Fatalf("expected delete to fail") - } -} - -// TestFind tests that we can find a paste using various parameters. -func TestFindPDB(t *testing.T) { - t.Parallel() - // Create 2 users with 10 pastes each and 10 anonymous pastes - usr1 := randomUser() - usr2 := randomUser() - var pastes1 []Paste - for i := 0; i < 10; i++ { - p1 := randomPaste(usr1) - p1.CreatedAt = time.Now().AddDate(0, 0, -1*i) - p1.Expires = time.Now().AddDate(0, 1*i, 0) - p1.Views = int64(10*i + 1) - pdb.Create(p1) - pastes1 = append(pastes1, p1) - p2 := randomPaste(usr2) - p2.CreatedAt = time.Now().AddDate(0, 0, -1*i) - pdb.Create(p2) - p3 := randomPaste(User{}) - p3.CreatedAt = time.Now().AddDate(0, 0, -1*i) - pdb.Create(p3) - } - - // Check all pastes for a user - pastes, err := pdb.Find(FindRequest{ - UserID: usr1.ID, - Limit: 11, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != len(pastes1) { - t.Errorf("expected to find %d pastes, got %d", len(pastes1), len(pastes)) - } - // Check limit - pastes, err = pdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 5 { - t.Errorf("expected to find %d pastes, got %d", 5, len(pastes)) - } - // Check skip - pastes, err = pdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 6, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 4 { - t.Errorf("expected to find %d pastes, got %d", 4, len(pastes)) - } - // Check skip over limit - pastes, err = pdb.Find(FindRequest{ - UserID: usr2.ID, - Limit: 5, - Skip: 12, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if len(pastes) != 0 { - t.Errorf("expected to find %d pastes, got %d", 0, len(pastes)) - } - // Check sort by -created - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-created", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].CreatedAt.After(pastes[j].CreatedAt) - }) { - t.Errorf("expected pastes to be sorted by -created, got %+v", pastes) - } - // Check sort by +created - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+created", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) - }) { - t.Errorf("expected pastes to be sorted by +created, got %+v", pastes) - } - // Check sort by -expires - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-expires", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Expires.After(pastes[j].Expires) - }) { - t.Errorf("expected pastes to be sorted by -expires, got %+v", pastes) - } - // Check sort by +expires - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+expires", - Limit: 5, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Expires.Before(pastes[j].Expires) - }) { - t.Errorf("expected pastes to be sorted by +expires, got %+v", pastes) - } - // Check sort by -views - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "-views", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Views > pastes[j].Views - }) { - t.Errorf("expected pastes to be sorted by -views, got %+v", pastes) - } - // Check sort by +views - pastes, err = pdb.Find(FindRequest{ - UserID: usr1.ID, - Sort: "+views", - Limit: 10, - Skip: 0, - }) - if err != nil { - t.Fatalf("failed to find pastes: %v", err) - } - if !sort.SliceIsSorted(pastes, func(i, j int) bool { - return pastes[i].Views < pastes[j].Views - }) { - t.Errorf("expected pastes to be sorted by +views, got %+v", pastes) - } -} - -func TestUpdatePDB(t *testing.T) { - t.Parallel() - usr := randomUser() - // Create random paste - paste := randomPaste(usr) - id, err := pdb.Create(paste) - if err != nil { - t.Fatalf("failed to create paste: %v", err) - } - // Update the paste - paste, _ = pdb.Get(id) - paste.Views = 42 - p, _ := pdb.Update(paste) - - if p.ID != id { - t.Errorf("expected paste to have the same id [%d], got [%d]", id, p.ID) - } - - if p.Views != paste.Views { - t.Errorf("expected paste views to be updated to [%d], got [%d]", paste.Views, p.Views) - } -} - -func TestUpdateNonExistingPDB(t *testing.T) { - t.Parallel() - usr := randomUser() - // Create random paste - paste := randomPaste(usr) - p, err := pdb.Update(paste) - - if err == nil { - t.Error("expected paste update to fail") - } - if p != (Paste{}) { - t.Errorf("expected paste to be empty, got [%+v]", p) - } -} - -func TestGetUserPDB(t *testing.T) { - t.Parallel() - usr := randomUser() - id, err := pdb.SaveUser(usr) - if err != nil { - t.Errorf("failed to create user: %v", err) - } - u, err := pdb.User(id) - if err != nil { - t.Errorf("user not found: %v", err) - } - if usr.ID != u.ID || usr.Name != u.Name { - t.Errorf("expected user to be saved as [%+v], got [%+v]", usr, u) - } -} - -func TestGetUserNotExistingPDB(t *testing.T) { - t.Parallel() - usr := randomUser() - u, err := pdb.User(usr.ID) - if err == nil { - t.Errorf("expected user to be not found") - } - if u != (User{}) { - t.Errorf("expected user to be empty, got %+v", u) - } - -} +package store + +import ( + "sort" + "testing" + "time" +) + +// TestDelete tests that we can delete a paste. +func TestDeletePDB(t *testing.T) { + t.Parallel() + // Create random paste + paste := randomPaste(randomUser()) + id, err := pdb.Create(paste) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + // Delete the paste and check that it was indeed deleted. + err = pdb.Delete(id) + if err != nil { + t.Fatalf("failed to delete paste: %v", err) + } + p, err := pdb.Get(id) + if err != nil { + t.Fatalf("failed to get paste: %v", err) + } + if p != (Paste{}) { + t.Errorf("expected paste to be deleted but found %+v", p) + } +} + +// TestDelete tests that we can't delete a paste if it doesn't exist. +func TestDeleteNonExistingPDB(t *testing.T) { + t.Parallel() + // Create random paste + paste := randomPaste(randomUser()) + // Delete the paste and check that it was indeed deleted. + err := pdb.Delete(paste.ID) + if err == nil { + t.Fatalf("expected delete to fail") + } +} + +// TestFind tests that we can find a paste using various parameters. +func TestFindPDB(t *testing.T) { + t.Parallel() + // Create 2 users with 10 pastes each and 10 anonymous pastes + usr1 := randomUser() + usr1.ID = "find_user_1" + usr2 := randomUser() + usr2.ID = "find_user_2" + for i := 0; i < 10; i++ { + p1 := randomPaste(usr1) + p1.CreatedAt = time.Now().AddDate(0, 0, -1*i) + p1.Expires = time.Now().AddDate(0, 1*i, 0) + p1.Views = int64(10*i + 1) + pdb.Create(p1) + p2 := randomPaste(usr2) + p2.CreatedAt = time.Now().AddDate(0, 0, -1*i) + pdb.Create(p2) + p3 := randomPaste(User{}) + p3.CreatedAt = time.Now().AddDate(0, 0, -1*i) + pdb.Create(p3) + } + + for _, tc := range findTestCases { + t.Run(tc.name, func(t *testing.T) { + pastes, err := pdb.Find(FindRequest{ + UserID: tc.uid, + Sort: tc.sort, + Limit: tc.limit, + Skip: tc.skip, + }) + if err != nil { + t.Fatalf("failed to find pastes: %v", err) + } + if len(pastes) != tc.exp { + t.Errorf("expected to find %d pastes, got %d", tc.exp, len(pastes)) + } + if tc.sort == "" { + return + } + if !sort.SliceIsSorted(pastes, func(i, j int) bool { + switch tc.sort { + case "-created": + return pastes[i].CreatedAt.After(pastes[j].CreatedAt) + case "+created": + return pastes[i].CreatedAt.Before(pastes[j].CreatedAt) + case "-expires": + return pastes[i].Expires.After(pastes[j].Expires) + case "+expires": + return pastes[i].Expires.Before(pastes[j].Expires) + case "-views": + return pastes[i].Views > pastes[j].Views + case "+views": + return pastes[i].Views < pastes[j].Views + default: + return false + } + }) { + t.Errorf("expected pastes to be sorted by %s, got %+v", tc.sort, pastes) + } + }) + } +} + +func TestUpdatePDB(t *testing.T) { + t.Parallel() + usr := randomUser() + // Create random paste + paste := randomPaste(usr) + id, err := pdb.Create(paste) + if err != nil { + t.Fatalf("failed to create paste: %v", err) + } + // Update the paste + paste, _ = pdb.Get(id) + paste.Views = 42 + p, _ := pdb.Update(paste) + + if p.ID != id { + t.Errorf("expected paste to have the same id [%d], got [%d]", id, p.ID) + } + + if p.Views != paste.Views { + t.Errorf("expected paste views to be updated to [%d], got [%d]", paste.Views, p.Views) + } +} + +func TestUpdateNonExistingPDB(t *testing.T) { + t.Parallel() + usr := randomUser() + // Create random paste + paste := randomPaste(usr) + p, err := pdb.Update(paste) + + if err == nil { + t.Error("expected paste update to fail") + } + if p != (Paste{}) { + t.Errorf("expected paste to be empty, got [%+v]", p) + } +} + +func TestGetUserPDB(t *testing.T) { + t.Parallel() + usr := randomUser() + id, err := pdb.SaveUser(usr) + if err != nil { + t.Errorf("failed to create user: %v", err) + } + u, err := pdb.User(id) + if err != nil { + t.Errorf("user not found: %v", err) + } + if usr.ID != u.ID || usr.Name != u.Name { + t.Errorf("expected user to be saved as [%+v], got [%+v]", usr, u) + } +} + +func TestGetUserNotExistingPDB(t *testing.T) { + t.Parallel() + usr := randomUser() + u, err := pdb.User(usr.ID) + if err == nil { + t.Errorf("expected user to be not found") + } + if u != (User{}) { + t.Errorf("expected user to be empty, got %+v", u) + } + +} diff --git a/src/store/store.go b/src/store/store.go index d1e4040..ba45442 100644 --- a/src/store/store.go +++ b/src/store/store.go @@ -1,131 +1,131 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. - -// Package store defines a common interface that any concrete storage -// implementation must implement. Along wiht some supporting types. -// It provides two implementations of store.Interface - MemDB and PostgresDB. -package store - -import ( - "errors" - "fmt" - "math" - "strings" - "time" -) - -// Interface defines methods that an implementation of a concrete storage -// must provide. -type Interface interface { - Count() (pastes, users int64) // return total counts for pastes and users - Create(paste Paste) (id int64, err error) // create new paste and return its id - Delete(id int64) error // delete paste by id - Find(req FindRequest) ([]Paste, error) // find pastes - Get(id int64) (Paste, error) // get paste by id - Update(paste Paste) (Paste, error) // update paste information and return updated paste - SaveUser(usr User) (id string, err error) // creates or updates a user - User(id string) (User, error) // get user by id -} - -// User represents a single user. -type User struct { - ID string `json:"id" gorm:"primaryKey"` - Name string `json:"name" gorm:"index"` - Email string `json:"email" gorm:"index"` - IP string `json:"ip,omitempty"` - Admin bool `json:"admin"` -} - -// Paste represents a single paste with an optional reference to its user. -type Paste struct { - ID int64 `json:"id" gorm:"primaryKey"` - Title string `json:"title"` - Body string `json:"body"` - Expires time.Time `json:"expires" gorm:"index"` - DeleteAfterRead bool `json:"delete_after_read"` - Privacy string `json:"privacy"` - Password string `json:"password"` - CreatedAt time.Time `json:"created"` - Syntax string `json:"syntax"` - UserID string `json:"user_id" gorm:"index default:null"` - User User `json:"user"` - Views int64 `json:"views"` -} - -// URL generates a base62 encoded string from the paste ID. This string is -// used as a unique URL for the paste, hence the name. -func (p Paste) URL() string { - const ( - alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - length = int64(len(alphabet)) - ) - var encodedBuilder strings.Builder - encodedBuilder.Grow(11) - number := p.ID - - for ; number > 0; number = number / length { - encodedBuilder.WriteByte(alphabet[(number % length)]) - } - - return encodedBuilder.String() -} - -// URL2ID decodes the previously generated URL string into a paste ID. -func (p Paste) URL2ID(url string) (int64, error) { - const ( - alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - length = int64(len(alphabet)) - ) - - var number int64 - - for i, symbol := range url { - alphabeticPosition := strings.IndexRune(alphabet, symbol) - - if alphabeticPosition == -1 { - return int64(alphabeticPosition), errors.New("invalid character: " + string(symbol)) - } - number += int64(alphabeticPosition) * int64(math.Pow(float64(length), float64(i))) - } - - return number, nil -} - -// Expiration returns a "humanized" duration between now and the expiry date -// stored in `Expires`. For example: "25 minutes" or "2 months" or "Never". -func (p Paste) Expiration() string { - if p.Expires.IsZero() { - return "Never" - } - - diff := time.Time{}.Add(time.Until(p.Expires)) - years, months, days := diff.Date() - hours, minutes, seconds := diff.Clock() - - switch { - case years >= 2: - return fmt.Sprintf("%d years", years-1) - case months >= 2: - return fmt.Sprintf("%d months", months-1) - case days >= 2: - return fmt.Sprintf("%d days", days-1+hours/12) - case hours >= 1: - return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) - case minutes >= 2: - return fmt.Sprintf("%d min", minutes) - case seconds >= 1: - return fmt.Sprintf("%d sec", seconds) - } - - return p.Expires.Sub(p.CreatedAt).String() -} - -// FindRequest is an input to the Find method -type FindRequest struct { - UserID string - Sort string - Since time.Time - Limit int - Skip int -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +// Package store defines a common interface that any concrete storage +// implementation must implement. Along with some supporting types. +// It provides two implementations of store.Interface - MemDB and PostgresDB. +package store + +import ( + "errors" + "fmt" + "math" + "strings" + "time" +) + +// Interface defines methods that an implementation of a concrete storage +// must provide. +type Interface interface { + Count() (pastes, users int64) // return total counts for pastes and users + Create(paste Paste) (id int64, err error) // create new paste and return its id + Delete(id int64) error // delete paste by id + Find(req FindRequest) ([]Paste, error) // find pastes + Get(id int64) (Paste, error) // get paste by id + Update(paste Paste) (Paste, error) // update paste information and return updated paste + SaveUser(usr User) (id string, err error) // creates or updates a user + User(id string) (User, error) // get user by id +} + +// User represents a single user. +type User struct { + ID string `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"index"` + Email string `json:"email" gorm:"index"` + IP string `json:"ip,omitempty"` + Admin bool `json:"admin"` +} + +// Paste represents a single paste with an optional reference to its user. +type Paste struct { + ID int64 `json:"id" gorm:"primaryKey"` + Title string `json:"title"` + Body string `json:"body"` + Expires time.Time `json:"expires" gorm:"index"` + DeleteAfterRead bool `json:"delete_after_read"` + Privacy string `json:"privacy"` + Password string `json:"password"` + CreatedAt time.Time `json:"created"` + Syntax string `json:"syntax"` + UserID string `json:"user_id" gorm:"index default:null"` + User User `json:"user"` + Views int64 `json:"views"` +} + +// URL generates a base62 encoded string from the paste ID. This string is +// used as a unique URL for the paste, hence the name. +func (p Paste) URL() string { + const ( + alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + length = int64(len(alphabet)) + ) + var encodedBuilder strings.Builder + encodedBuilder.Grow(11) + number := p.ID + + for ; number > 0; number = number / length { + encodedBuilder.WriteByte(alphabet[(number % length)]) + } + + return encodedBuilder.String() +} + +// URL2ID decodes the previously generated URL string into a paste ID. +func (p Paste) URL2ID(url string) (int64, error) { + const ( + alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + length = int64(len(alphabet)) + ) + + var number int64 + + for i, symbol := range url { + alphabeticPosition := strings.IndexRune(alphabet, symbol) + + if alphabeticPosition == -1 { + return int64(alphabeticPosition), errors.New("invalid character: " + string(symbol)) + } + number += int64(alphabeticPosition) * int64(math.Pow(float64(length), float64(i))) + } + + return number, nil +} + +// Expiration returns a "humanized" duration between now and the expiry date +// stored in `Expires`. For example: "25 minutes" or "2 months" or "Never". +func (p Paste) Expiration() string { + if p.Expires.IsZero() { + return "Never" + } + + diff := time.Time{}.Add(time.Until(p.Expires)) + years, months, days := diff.Date() + hours, minutes, seconds := diff.Clock() + + switch { + case years >= 2: + return fmt.Sprintf("%d years", years-1) + case months >= 2: + return fmt.Sprintf("%d months", months-1) + case days >= 2: + return fmt.Sprintf("%d days", days-1+hours/12) + case hours >= 1: + return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + case minutes >= 2: + return fmt.Sprintf("%d min", minutes) + case seconds >= 1: + return fmt.Sprintf("%d sec", seconds) + } + + return p.Expires.Sub(p.CreatedAt).String() +} + +// FindRequest is an input to the Find method +type FindRequest struct { + UserID string + Sort string + Since time.Time + Limit int + Skip int +} diff --git a/src/store/store_test.go b/src/store/store_test.go index 563baf5..5044287 100644 --- a/src/store/store_test.go +++ b/src/store/store_test.go @@ -1,160 +1,160 @@ -package store - -import ( - "fmt" - "math/rand" - "os" - "strings" - "testing" - "time" -) - -var mdb *MemDB -var pdb *PostgresDB -var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") - -// TestMain is a setup function for the test suite. It creates a new MemDB -// instance and seeds random generator. -func TestMain(m *testing.M) { - var err error - mdb = NewMemDB() - pdb, err = NewPostgresDB("host=localhost user=test password=test dbname=test port=5432 sslmode=disable", true) - pdb.db.AllowGlobalUpdate = true - pdb.db.Delete(Paste{}) - pdb.db.Delete(User{}) - if err != nil { - fmt.Printf("Failed to create a PostgresDB store: %v\n", err) - os.Exit(1) - } - rand.Seed(time.Now().UnixNano()) - - c := m.Run() - - pdb.db.Delete(Paste{}) - pdb.db.Delete(User{}) - - os.Exit(c) -} - -// randSeq generates random string of a given size. -func randSeq(n int) string { - b := make([]rune, n) - for i := range b { - b[i] = letters[rand.Intn(len(letters))] - } - return string(b) -} - -// randomUser creates a User with random values. -func randomUser() User { - u := User{ - ID: randSeq(10), - Name: randSeq(10), - Email: "", - IP: "", - Admin: false, - } - return u -} - -// randomPaste creates a Paste with random values. -func randomPaste(usr User) Paste { - p := Paste{ - ID: rand.Int63(), - Title: randSeq(10), - Body: randSeq(10), - Expires: time.Time{}, - DeleteAfterRead: false, - Privacy: "public", - Password: "", - CreatedAt: time.Now(), - Syntax: "none", - UserID: usr.ID, - User: usr, - Views: 0, - } - return p -} -func TestPasteURL(t *testing.T) { - t.Parallel() - - p := Paste{ - ID: 123, - Title: "", - Body: "qwe", - Expires: time.Time{}, - DeleteAfterRead: false, - Privacy: "public", - Password: "", - CreatedAt: time.Now(), - Syntax: "text", - User: User{}, - Views: 0, - } - - id, _ := p.URL2ID(p.URL()) - if p.ID != id { - t.Errorf("expected paste id to be %d, got %d", p.ID, id) - } - - _, err := p.URL2ID("@#%$#") - if err == nil { - t.Error("expected decoding to fail") - } -} - -func TestPasteExpiration(t *testing.T) { - t.Parallel() - - p := Paste{ - ID: 123, - Title: "", - Body: "qwe", - Expires: time.Now().Add(30 * time.Second), - DeleteAfterRead: false, - Privacy: "public", - Password: "", - CreatedAt: time.Now(), - Syntax: "text", - User: User{}, - Views: 0, - } - if !strings.HasSuffix(p.Expiration(), "sec") { - t.Errorf("expected expiration to have [sec], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().Add(11 * time.Minute) - if !strings.HasSuffix(p.Expiration(), "min") { - t.Errorf("expected expiration to have [min], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().Add(13 * time.Hour) - if p.Expiration()[2:3] != ":" && p.Expiration()[5:6] != ":" { - t.Errorf("expected expiration to be [13:00:00], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().Add(96 * time.Hour) - if !strings.HasSuffix(p.Expiration(), "days") { - t.Errorf("expected expiration to have [days], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().AddDate(0, 5, 0) - if !strings.HasSuffix(p.Expiration(), "months") { - t.Errorf("expected expiration to have [months], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().AddDate(2, 0, 0) - if !strings.HasSuffix(p.Expiration(), "years") { - t.Errorf("expected expiration to have [years], got [%s]", p.Expiration()) - } - - p.Expires = time.Time{} - if p.Expiration() != "Never" { - t.Errorf("expected expiration to be [Never], got [%s]", p.Expiration()) - } - - p.Expires = time.Now().Add(1*time.Second - 1*time.Millisecond) - if !strings.HasPrefix(p.Expiration(), "999") || !strings.HasSuffix(p.Expiration(), "ms") { - t.Errorf("expected expiration to be [999ms], got [%s]", p.Expiration()) - } -} +package store + +import ( + "fmt" + "math/rand" + "os" + "strings" + "testing" + "time" +) + +var mdb *MemDB +var pdb *PostgresDB +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +// TestMain is a setup function for the test suite. It creates a new MemDB +// instance and seeds random generator. +func TestMain(m *testing.M) { + var err error + mdb = NewMemDB() + pdb, err = NewPostgresDB("host=localhost user=test password=test dbname=test port=5432 sslmode=disable", true) + pdb.db.AllowGlobalUpdate = true + pdb.db.Delete(Paste{}) + pdb.db.Delete(User{}) + if err != nil { + fmt.Printf("Failed to create a PostgresDB store: %v\n", err) + os.Exit(1) + } + rand.Seed(time.Now().UnixNano()) + + c := m.Run() + + pdb.db.Delete(Paste{}) + pdb.db.Delete(User{}) + + os.Exit(c) +} + +// randSeq generates random string of a given size. +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +// randomUser creates a User with random values. +func randomUser() User { + u := User{ + ID: randSeq(10), + Name: randSeq(10), + Email: "", + IP: "", + Admin: false, + } + return u +} + +// randomPaste creates a Paste with random values. +func randomPaste(usr User) Paste { + p := Paste{ + ID: rand.Int63(), + Title: randSeq(10), + Body: randSeq(10), + Expires: time.Time{}, + DeleteAfterRead: false, + Privacy: "public", + Password: "", + CreatedAt: time.Now(), + Syntax: "none", + UserID: usr.ID, + User: usr, + Views: 0, + } + return p +} +func TestPasteURL(t *testing.T) { + t.Parallel() + + p := Paste{ + ID: 123, + Title: "", + Body: "qwe", + Expires: time.Time{}, + DeleteAfterRead: false, + Privacy: "public", + Password: "", + CreatedAt: time.Now(), + Syntax: "text", + User: User{}, + Views: 0, + } + + id, _ := p.URL2ID(p.URL()) + if p.ID != id { + t.Errorf("expected paste id to be %d, got %d", p.ID, id) + } + + _, err := p.URL2ID("@#%$#") + if err == nil { + t.Error("expected decoding to fail") + } +} + +func TestPasteExpiration(t *testing.T) { + t.Parallel() + + p := Paste{ + ID: 123, + Title: "", + Body: "qwe", + Expires: time.Now().Add(30 * time.Second), + DeleteAfterRead: false, + Privacy: "public", + Password: "", + CreatedAt: time.Now(), + Syntax: "text", + User: User{}, + Views: 0, + } + if !strings.HasSuffix(p.Expiration(), "sec") { + t.Errorf("expected expiration to have [sec], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().Add(11 * time.Minute) + if !strings.HasSuffix(p.Expiration(), "min") { + t.Errorf("expected expiration to have [min], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().Add(13 * time.Hour) + if p.Expiration()[2:3] != ":" && p.Expiration()[5:6] != ":" { + t.Errorf("expected expiration to be [13:00:00], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().Add(96 * time.Hour) + if !strings.HasSuffix(p.Expiration(), "days") { + t.Errorf("expected expiration to have [days], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().AddDate(0, 5, 0) + if !strings.HasSuffix(p.Expiration(), "months") { + t.Errorf("expected expiration to have [months], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().AddDate(2, 0, 0) + if !strings.HasSuffix(p.Expiration(), "years") { + t.Errorf("expected expiration to have [years], got [%s]", p.Expiration()) + } + + p.Expires = time.Time{} + if p.Expiration() != "Never" { + t.Errorf("expected expiration to be [Never], got [%s]", p.Expiration()) + } + + p.Expires = time.Now().Add(1*time.Second - 1*time.Millisecond) + if !strings.HasPrefix(p.Expiration(), "999") || !strings.HasSuffix(p.Expiration(), "ms") { + t.Errorf("expected expiration to be [999ms], got [%s]", p.Expiration()) + } +} diff --git a/src/web/routes.go b/src/web/routes.go index 20ee466..ac4e907 100644 --- a/src/web/routes.go +++ b/src/web/routes.go @@ -1,335 +1,336 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. -package web - -import ( - "bytes" - "errors" - "net/http" - - "github.com/go-pkgz/auth/token" - "github.com/gorilla/mux" - "github.com/iliafrenkel/go-pb/src/service" - "github.com/iliafrenkel/go-pb/src/store" -) - -// All the data that any page template may need. -type PageData struct { - Title string - Brand string - Tagline string - Logo string - Theme string - ID string - User token.User - Pastes []store.Paste - Paste store.Paste - Server string - Version string - ErrorCode int - ErrorText string - ErrorMessage string - PastesCount int64 - UsersCount int64 -} - -// Generate HTML from a template with PageData. -func (h *WebServer) generateHTML(tpl string, p PageData) []byte { - var html bytes.Buffer - pcnt, ucnt := h.service.GetCount() - var pd = PageData{ - Title: h.options.BrandName + " - " + p.Title, - Brand: h.options.BrandName, - Tagline: h.options.BrandTagline, - Logo: h.options.Logo, - Theme: h.options.BootstrapTheme, - ID: p.ID, - User: p.User, - Pastes: p.Pastes, - Paste: p.Paste, - Server: h.options.Proto + "://" + h.options.Addr, - Version: h.options.Version, - ErrorCode: p.ErrorCode, - ErrorText: p.ErrorText, - ErrorMessage: p.ErrorMessage, - PastesCount: pcnt, - UsersCount: ucnt, - } - - err := h.templates.ExecuteTemplate(&html, tpl, pd) - if err != nil { - h.log.Logf("ERROR error executing template: %v", err) - } - - return html.Bytes() -} - -func (h *WebServer) showInternalError(w http.ResponseWriter, err error) { - h.log.Logf("ERROR : %v", err) - w.WriteHeader(http.StatusInternalServerError) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusInternalServerError, - ErrorText: http.StatusText(http.StatusInternalServerError), - ErrorMessage: "", - })) - if err != nil { - h.log.Logf("ERROR showInternalError: failed to write: %v", e) - } -} - -// handleGetHomePage shows the homepage in reponse to a GET / request. -func (h *WebServer) handleGetHomePage(w http.ResponseWriter, r *http.Request) { - usr, _ := token.GetUserInfo(r) - - pastes, err := h.service.UserPastes(usr.ID) - if err != nil { - h.showInternalError(w, err) - return - } - - _, e := w.Write(h.generateHTML("index.html", PageData{Title: "Home", Pastes: pastes, User: usr})) - if err != nil { - h.log.Logf("ERROR handleGetHomePage: failed to write: %v", e) - } -} - -// handlePostPaste creates new paste from the form data -func (h *WebServer) handlePostPaste(w http.ResponseWriter, r *http.Request) { - usr, _ := token.GetUserInfo(r) - // Read the form data - r.Body = http.MaxBytesReader(w, r.Body, h.options.MaxBodySize) - if err := r.ParseForm(); err != nil { - h.log.Logf("WARN parsing form failed: %v", err) - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "", - })) - if e != nil { - h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) - } - return - } - // Update the user - _, err := h.service.GetOrUpdateUser(store.User{ - ID: usr.ID, - Name: usr.Name, - Email: usr.Email, - IP: usr.IP, - Admin: usr.IsAdmin(), - }) - if err != nil { - h.log.Logf("ERROR can't update the user: %v", err) - } - // Create a new paste - var p = service.PasteRequest{ - Title: r.PostFormValue("title"), - Body: r.PostFormValue("body"), - Expires: r.PostFormValue("expires"), - DeleteAfterRead: r.PostFormValue("delete_after_read") == "yes", - Privacy: r.PostFormValue("privacy"), - Password: r.PostFormValue("password"), - Syntax: r.PostFormValue("syntax"), - UserID: usr.ID, - } - paste, err := h.service.NewPaste(p) - if err != nil { - if errors.Is(err, service.ErrEmptyBody) { - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "Body must not be empty.", - })) - if e != nil { - h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) - } - return - } - if errors.Is(err, service.ErrWrongPrivacy) { - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "Privacy can be one of 'private', 'public' or 'unlisted'.", - })) - if e != nil { - h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) - } - return - } - if errors.Is(err, service.ErrWrongDuration) { - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "Duration format is incorrect.", - })) - if e != nil { - h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) - } - return - } - // Some bad thing happened and we don't know what to do - h.showInternalError(w, err) - return - } - // Get a list of user pastes - pastes, err := h.service.UserPastes(usr.ID) - if err != nil { - h.showInternalError(w, err) - return - } - - _, e := w.Write(h.generateHTML("view.html", PageData{ - Title: "Paste", - Pastes: pastes, - Paste: paste, - User: usr, - })) - if e != nil { - h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) - } -} - -// handleGetPastePage generates a page to view a single paste. -func (h *WebServer) handleGetPastePage(w http.ResponseWriter, r *http.Request) { - usr, _ := token.GetUserInfo(r) - // Get paste encoded ID - vars := mux.Vars(r) - id, ok := vars["id"] - if !ok { - h.log.Logf("WARN handleGetPastePage: paste id not found") - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "", - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } - return - } - // If the request comes from a password form, get the password - if err := r.ParseForm(); err != nil { - h.log.Logf("WARN parsing form failed: %v", err) - w.WriteHeader(http.StatusBadRequest) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusBadRequest, - ErrorText: http.StatusText(http.StatusBadRequest), - ErrorMessage: "", - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } - return - } - pwd := r.PostFormValue("password") - - // Get the paste from the storage - paste, err := h.service.GetPaste(id, usr.ID, pwd) - if err != nil { - // Check if paste was not found - if errors.Is(err, service.ErrPasteNotFound) { - w.WriteHeader(http.StatusNotFound) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusNotFound, - ErrorText: http.StatusText(http.StatusNotFound), - ErrorMessage: "There is no such paste", - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } - return - } - // Check if paste is private an belongs to another user - if errors.Is(err, service.ErrPasteIsPrivate) { - w.WriteHeader(http.StatusForbidden) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusForbidden, - ErrorText: http.StatusText(http.StatusForbidden), - ErrorMessage: "This paste is private", - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } - return - } - // Check if paste is password-protected - if errors.Is(err, service.ErrPasteHasPassword) || errors.Is(err, service.ErrWrongPassword) { - w.WriteHeader(http.StatusUnauthorized) - _, e := w.Write(h.generateHTML("password.html", PageData{ - ID: id, - User: usr, - Title: "Password", - ErrorMessage: "This paste is protected by a password", - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } - return - } - // Some other error that we didn't expect - h.showInternalError(w, err) - return - } - - // Get user pastes - pastes, err := h.service.UserPastes(usr.ID) - if err != nil { - h.showInternalError(w, err) - return - } - - _, e := w.Write(h.generateHTML("view.html", PageData{ - Title: "Paste", - Pastes: pastes, - Paste: paste, - User: usr, - })) - if e != nil { - h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) - } -} - -// handleGetPastesList generates a page to view a list of pastes. -func (h *WebServer) handleGetPastesList(w http.ResponseWriter, r *http.Request) { - usr, _ := token.GetUserInfo(r) - - pastes, err := h.service.UserPastes(usr.ID) - if err != nil { - h.showInternalError(w, err) - return - } - - _, e := w.Write(h.generateHTML("list.html", PageData{Title: "Pastes", Pastes: pastes, User: usr})) - if e != nil { - h.log.Logf("ERROR handleGetPastesList: failed to write: %v", e) - } -} - -// Show 404 Not Found error page -func (h *WebServer) notFound(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - _, e := w.Write(h.generateHTML("error.html", PageData{ - Title: "Error", - ErrorCode: http.StatusNotFound, - ErrorText: http.StatusText(http.StatusNotFound), - ErrorMessage: "Unfortunately the page you are looking for is not there 🙁", - })) - if e != nil { - h.log.Logf("ERROR notFound: failed to write: %v", e) - } -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +package web + +import ( + "bytes" + "errors" + "net/http" + + "github.com/go-pkgz/auth/token" + "github.com/gorilla/mux" + "github.com/iliafrenkel/go-pb/src/service" + "github.com/iliafrenkel/go-pb/src/store" +) + +// PageData contains the data that any page template may need. +type PageData struct { + Title string + Brand string + Tagline string + Logo string + Theme string + ID string + User token.User + Pastes []store.Paste + Paste store.Paste + Server string + Version string + ErrorCode int + ErrorText string + ErrorMessage string + PastesCount int64 + UsersCount int64 +} + +// Generate HTML from a template with PageData. +func (h *Server) generateHTML(tpl string, p PageData) []byte { + var html bytes.Buffer + pcnt, ucnt := h.service.GetCount() + var pd = PageData{ + Title: h.options.BrandName + " - " + p.Title, + Brand: h.options.BrandName, + Tagline: h.options.BrandTagline, + Logo: h.options.Logo, + Theme: h.options.BootstrapTheme, + ID: p.ID, + User: p.User, + Pastes: p.Pastes, + Paste: p.Paste, + Server: h.options.Proto + "://" + h.options.Addr, + Version: h.options.Version, + ErrorCode: p.ErrorCode, + ErrorText: p.ErrorText, + ErrorMessage: p.ErrorMessage, + PastesCount: pcnt, + UsersCount: ucnt, + } + + err := h.templates.ExecuteTemplate(&html, tpl, pd) + if err != nil { + h.log.Logf("ERROR error executing template: %v", err) + } + + return html.Bytes() +} + +func (h *Server) showInternalError(w http.ResponseWriter, err error) { + h.log.Logf("ERROR : %v", err) + w.WriteHeader(http.StatusInternalServerError) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusInternalServerError, + ErrorText: http.StatusText(http.StatusInternalServerError), + ErrorMessage: "", + })) + if err != nil { + h.log.Logf("ERROR showInternalError: failed to write: %v", e) + } +} + +// handleGetHomePage shows the homepage in response to a GET / request. +func (h *Server) handleGetHomePage(w http.ResponseWriter, r *http.Request) { + usr, _ := token.GetUserInfo(r) + + pastes, err := h.service.UserPastes(usr.ID) + if err != nil { + h.showInternalError(w, err) + return + } + + _, e := w.Write(h.generateHTML("index.html", PageData{Title: "Home", Pastes: pastes, User: usr})) + if err != nil { + h.log.Logf("ERROR handleGetHomePage: failed to write: %v", e) + } +} + +// handlePostPaste creates new paste from the form data +func (h *Server) handlePostPaste(w http.ResponseWriter, r *http.Request) { + usr, _ := token.GetUserInfo(r) + // Read the form data + r.Body = http.MaxBytesReader(w, r.Body, h.options.MaxBodySize) + if err := r.ParseForm(); err != nil { + h.log.Logf("WARN parsing form failed: %v", err) + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "", + })) + if e != nil { + h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) + } + return + } + // Update the user + _, err := h.service.GetOrUpdateUser(store.User{ + ID: usr.ID, + Name: usr.Name, + Email: usr.Email, + IP: usr.IP, + Admin: usr.IsAdmin(), + }) + if err != nil { + h.log.Logf("ERROR can't update the user: %v", err) + } + // Create a new paste + var p = service.PasteRequest{ + Title: r.PostFormValue("title"), + Body: r.PostFormValue("body"), + Expires: r.PostFormValue("expires"), + DeleteAfterRead: r.PostFormValue("delete_after_read") == "yes", + Privacy: r.PostFormValue("privacy"), + Password: r.PostFormValue("password"), + Syntax: r.PostFormValue("syntax"), + UserID: usr.ID, + } + paste, err := h.service.NewPaste(p) + if err != nil { + if errors.Is(err, service.ErrEmptyBody) { + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "Body must not be empty.", + })) + if e != nil { + h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) + } + return + } + if errors.Is(err, service.ErrWrongPrivacy) { + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "Privacy can be one of 'private', 'public' or 'unlisted'.", + })) + if e != nil { + h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) + } + return + } + if errors.Is(err, service.ErrWrongDuration) { + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "Duration format is incorrect.", + })) + if e != nil { + h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) + } + return + } + // Some bad thing happened and we don't know what to do + h.showInternalError(w, err) + return + } + // Get a list of user pastes + pastes, err := h.service.UserPastes(usr.ID) + if err != nil { + h.showInternalError(w, err) + return + } + + _, e := w.Write(h.generateHTML("view.html", PageData{ + Title: "Paste", + Pastes: pastes, + Paste: paste, + User: usr, + })) + if e != nil { + h.log.Logf("ERROR handlePostPaste: failed to write: %v", e) + } +} + +// handleGetPastePage generates a page to view a single paste. +func (h *Server) handleGetPastePage(w http.ResponseWriter, r *http.Request) { + usr, _ := token.GetUserInfo(r) + // Get paste encoded ID + vars := mux.Vars(r) + id, ok := vars["id"] + if !ok { + h.log.Logf("WARN handleGetPastePage: paste id not found") + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "", + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } + return + } + // If the request comes from a password form, get the password + if err := r.ParseForm(); err != nil { + h.log.Logf("WARN parsing form failed: %v", err) + w.WriteHeader(http.StatusBadRequest) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusBadRequest, + ErrorText: http.StatusText(http.StatusBadRequest), + ErrorMessage: "", + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } + return + } + pwd := r.PostFormValue("password") + + // Get the paste from the storage + paste, err := h.service.GetPaste(id, usr.ID, pwd) + if err != nil { + // Check if paste was not found + if errors.Is(err, service.ErrPasteNotFound) { + w.WriteHeader(http.StatusNotFound) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusNotFound, + ErrorText: http.StatusText(http.StatusNotFound), + ErrorMessage: "There is no such paste", + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } + return + } + // Check if paste is private an belongs to another user + if errors.Is(err, service.ErrPasteIsPrivate) { + w.WriteHeader(http.StatusForbidden) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusForbidden, + ErrorText: http.StatusText(http.StatusForbidden), + ErrorMessage: "This paste is private", + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } + return + } + // Check if paste is password-protected + if errors.Is(err, service.ErrPasteHasPassword) || errors.Is(err, service.ErrWrongPassword) { + w.WriteHeader(http.StatusUnauthorized) + _, e := w.Write(h.generateHTML("password.html", PageData{ + ID: id, + User: usr, + Title: "Password", + ErrorMessage: "This paste is protected by a password", + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } + return + } + // Some other error that we didn't expect + h.showInternalError(w, err) + return + } + + // Get user pastes + pastes, err := h.service.UserPastes(usr.ID) + if err != nil { + h.showInternalError(w, err) + return + } + + _, e := w.Write(h.generateHTML("view.html", PageData{ + Title: "Paste", + Pastes: pastes, + Paste: paste, + User: usr, + })) + if e != nil { + h.log.Logf("ERROR handleGetPastePage: failed to write: %v", e) + } +} + +// handleGetPastesList generates a page to view a list of pastes. +func (h *Server) handleGetPastesList(w http.ResponseWriter, r *http.Request) { + usr, _ := token.GetUserInfo(r) + + pastes, err := h.service.UserPastes(usr.ID) + if err != nil { + h.showInternalError(w, err) + return + } + + _, e := w.Write(h.generateHTML("list.html", PageData{Title: "Pastes", Pastes: pastes, User: usr})) + if e != nil { + h.log.Logf("ERROR handleGetPastesList: failed to write: %v", e) + } +} + +// Show 404 Not Found error page +func (h *Server) notFound(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, e := w.Write(h.generateHTML("error.html", PageData{ + Title: "Error", + ErrorCode: http.StatusNotFound, + ErrorText: http.StatusText(http.StatusNotFound), + ErrorMessage: "Unfortunately the page you are looking for is not there 🙁", + })) + if e != nil { + h.log.Logf("ERROR notFound: failed to write: %v", e) + } +} diff --git a/src/web/routes_test.go b/src/web/routes_test.go index 5043a01..37f1a12 100644 --- a/src/web/routes_test.go +++ b/src/web/routes_test.go @@ -1,473 +1,473 @@ -package web - -import ( - "net/http" - "net/http/httptest" - "net/url" - "os" - "strings" - "testing" - "time" - - "github.com/go-pkgz/auth/token" - "github.com/go-pkgz/lgr" - "github.com/iliafrenkel/go-pb/src/service" - "github.com/iliafrenkel/go-pb/src/store" -) - -var webSrv *WebServer - -// TestMain is a setup function for the test suite. It creates a new WebServer -// with options suitable for testing. -func TestMain(m *testing.M) { - log := lgr.New(lgr.Debug, lgr.CallerFile, lgr.CallerFunc, lgr.Msec, lgr.LevelBraces) - - webSrv = New(log, WebServerOptions{ - Addr: "localhost:8080", - Proto: "http", - ReadTimeout: 2, - WriteTimeout: 2, - IdleTimeout: 5, - LogFile: "", - LogMode: "debug", - MaxBodySize: 1024, - BrandName: "Go PB", - BrandTagline: "Testing is good!", - Assets: "../../assets", - Templates: "../../templates", - Version: "test", - AuthSecret: "ki7GZphH7bRNhKN8476jUTJn2QaMRxhX", - AuthTokenDuration: 60 * time.Second, - AuthCookieDuration: 60 * time.Second, - AuthIssuer: "go-pb test", - AuthURL: "http://localhost:8080", - DBType: "memory", - }) - - os.Exit(m.Run()) -} - -// TestGetHomePage verifies the GET / route handler. It checks that the home -// page is generated with correct title and that the New Paste form is there. -func TestGetHomePage(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/", nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) - } - - want := webSrv.options.BrandName + " - Home" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = `
` - if !strings.Contains(got, want) { - t.Errorf("Response should have form [%s], got [%s]", want, got) - } -} - -// TestPostPasteDefaults create a paste with just the required fields. -func TestPostPasteDefaults(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - form := url.Values{} - form.Add("body", "Test body") - form.Add("privacy", "public") - req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - webSrv.router.ServeHTTP(w, req) - - if w.Code != http.StatusOK { - t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) - } - - want := webSrv.options.BrandName + " - Paste" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = `Test body` - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// TestPostPasteEmptyForm try to POST an empty form -func TestPostPasteEmptyForm(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - form := url.Values{} - req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - webSrv.router.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "Body must not be empty" - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Wrong value for privacy -func TestPostPasteWrongPrivacy(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - form := url.Values{} - form.Add("body", "Test body") - form.Add("privacy", "absolutely public") - req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - webSrv.router.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "Privacy can be one of 'private', 'public' or 'unlisted'." - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Wrong value for expiration -func TestPostPasteWrongExpiration(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - form := url.Values{} - form.Add("body", "Test body") - form.Add("privacy", "public") - form.Add("expires", "1,3z") - req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - webSrv.router.ServeHTTP(w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "Duration format is incorrect." - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// TestNotFoundPage verifies the NotFound handler. It checks that the error -// page has the correct title and error message and that there is a link to -// the home page. -func TestNotFoundPage(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/NotFoundPage", nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusNotFound { - t.Errorf("Status should be %d, got %d", http.StatusNotFound, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "Unfortunately the page you are looking for is not there 🙁" - if !strings.Contains(got, want) { - t.Errorf("Response should have error message [%s], got [%s]", want, got) - } - - want = `Test paste` - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Get non-existing paste -func TestGetNonExistingPaste(t *testing.T) { - t.Parallel() - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/p/IYCE8rJj8Qg", nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusNotFound { - t.Errorf("Status should be %d, got %d", http.StatusNotFound, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "There is no such paste" - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Get private paste of another user -func TestGetPrivatePasteOfAnotherUser(t *testing.T) { - t.Parallel() - - u, _ := webSrv.service.GetOrUpdateUser(store.User{ - ID: "test_user", - Name: "Test User", - }) - p, _ := webSrv.service.NewPaste(service.PasteRequest{ - Title: "Test", - Body: "Test paste", - Expires: "", - DeleteAfterRead: false, - Privacy: "private", - Password: "", - Syntax: "text", - UserID: u.ID, - }) - - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusForbidden { - t.Errorf("Status should be %d, got %d", http.StatusForbidden, w.Code) - } - - want := webSrv.options.BrandName + " - Error" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = "This paste is private" - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Get private paste of the user who created it -func TestGetPrivatePaste(t *testing.T) { - t.Parallel() - - u, _ := webSrv.service.GetOrUpdateUser(store.User{ - ID: "test_user_1", - Name: "Test User 1", - }) - p, _ := webSrv.service.NewPaste(service.PasteRequest{ - Title: "Test", - Body: "Test paste", - Expires: "", - DeleteAfterRead: false, - Privacy: "private", - Password: "", - Syntax: "text", - UserID: u.ID, - }) - - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) - // Add user to request context - r = token.SetUserInfo(r, token.User{ - Name: u.Name, - ID: u.ID, - }) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) - } - - want := webSrv.options.BrandName + " - Paste" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = `Test paste` - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Get password protected paste without password -func TestGetPasswordProtectedPasteNoPassword(t *testing.T) { - t.Parallel() - - p, _ := webSrv.service.NewPaste(service.PasteRequest{ - Title: "Test", - Body: "Test paste", - Expires: "", - DeleteAfterRead: false, - Privacy: "public", - Password: "pa$$w0rd", - Syntax: "text", - UserID: "", - }) - - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusUnauthorized { - t.Errorf("Status should be %d, got %d", http.StatusUnauthorized, w.Code) - } - - want := webSrv.options.BrandName + " - Password" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = `Test paste` - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} - -// Get a list of pastes for a user -func TestGetUserPaste(t *testing.T) { - t.Parallel() - - p1, _ := webSrv.service.NewPaste(service.PasteRequest{ - Title: "Test 1", - Body: "Test paste 1", - Expires: "", - DeleteAfterRead: false, - Privacy: "public", - Password: "", - Syntax: "text", - UserID: "", - }) - p2, _ := webSrv.service.NewPaste(service.PasteRequest{ - Title: "Test 2", - Body: "Test paste 2", - Expires: "", - DeleteAfterRead: false, - Privacy: "private", - Password: "", - Syntax: "text", - UserID: "", - }) - - w := httptest.NewRecorder() - r, _ := http.NewRequest("GET", "/l/", nil) - webSrv.router.ServeHTTP(w, r) - - if w.Code != http.StatusOK { - t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) - } - - want := webSrv.options.BrandName + " - Pastes" - got := w.Body.String() - if !strings.Contains(got, want) { - t.Errorf("Response should have title [%s], got [%s]", want, got) - } - - want = p1.Title - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } - - want = p2.Title - if !strings.Contains(got, want) { - t.Errorf("Response should have body [%s], got [%s]", want, got) - } -} +package web + +import ( + "net/http" + "net/http/httptest" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/lgr" + "github.com/iliafrenkel/go-pb/src/service" + "github.com/iliafrenkel/go-pb/src/store" +) + +var webSrv *Server + +// TestMain is a setup function for the test suite. It creates a new WebServer +// with options suitable for testing. +func TestMain(m *testing.M) { + log := lgr.New(lgr.Debug, lgr.CallerFile, lgr.CallerFunc, lgr.Msec, lgr.LevelBraces) + + webSrv = New(log, ServerOptions{ + Addr: "localhost:8080", + Proto: "http", + ReadTimeout: 2, + WriteTimeout: 2, + IdleTimeout: 5, + LogFile: "", + LogMode: "debug", + MaxBodySize: 1024, + BrandName: "Go PB", + BrandTagline: "Testing is good!", + Assets: "../../assets", + Templates: "../../templates", + Version: "test", + AuthSecret: "ki7GZphH7bRNhKN8476jUTJn2QaMRxhX", + AuthTokenDuration: 60 * time.Second, + AuthCookieDuration: 60 * time.Second, + AuthIssuer: "go-pb test", + AuthURL: "http://localhost:8080", + DBType: "memory", + }) + + os.Exit(m.Run()) +} + +// TestGetHomePage verifies the GET / route handler. It checks that the home +// page is generated with correct title and that the New Paste form is there. +func TestGetHomePage(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/", nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) + } + + want := webSrv.options.BrandName + " - Home" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = `` + if !strings.Contains(got, want) { + t.Errorf("Response should have form [%s], got [%s]", want, got) + } +} + +// TestPostPasteDefaults create a paste with just the required fields. +func TestPostPasteDefaults(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + form := url.Values{} + form.Add("body", "Test body") + form.Add("privacy", "public") + req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + webSrv.router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) + } + + want := webSrv.options.BrandName + " - Paste" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = `Test body` + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// TestPostPasteEmptyForm try to POST an empty form +func TestPostPasteEmptyForm(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + form := url.Values{} + req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + webSrv.router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "Body must not be empty" + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Wrong value for privacy +func TestPostPasteWrongPrivacy(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + form := url.Values{} + form.Add("body", "Test body") + form.Add("privacy", "absolutely public") + req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + webSrv.router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "Privacy can be one of 'private', 'public' or 'unlisted'." + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Wrong value for expiration +func TestPostPasteWrongExpiration(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + form := url.Values{} + form.Add("body", "Test body") + form.Add("privacy", "public") + form.Add("expires", "1,3z") + req, _ := http.NewRequest("POST", "/p/", strings.NewReader(form.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + webSrv.router.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("Status should be %d, got %d", http.StatusBadRequest, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "Duration format is incorrect." + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// TestNotFoundPage verifies the NotFound handler. It checks that the error +// page has the correct title and error message and that there is a link to +// the home page. +func TestNotFoundPage(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/NotFoundPage", nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Status should be %d, got %d", http.StatusNotFound, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "Unfortunately the page you are looking for is not there 🙁" + if !strings.Contains(got, want) { + t.Errorf("Response should have error message [%s], got [%s]", want, got) + } + + want = `Test paste` + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Get non-existing paste +func TestGetNonExistingPaste(t *testing.T) { + t.Parallel() + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/p/IYCE8rJj8Qg", nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("Status should be %d, got %d", http.StatusNotFound, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "There is no such paste" + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Get private paste of another user +func TestGetPrivatePasteOfAnotherUser(t *testing.T) { + t.Parallel() + + u, _ := webSrv.service.GetOrUpdateUser(store.User{ + ID: "test_user", + Name: "Test User", + }) + p, _ := webSrv.service.NewPaste(service.PasteRequest{ + Title: "Test", + Body: "Test paste", + Expires: "", + DeleteAfterRead: false, + Privacy: "private", + Password: "", + Syntax: "text", + UserID: u.ID, + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusForbidden { + t.Errorf("Status should be %d, got %d", http.StatusForbidden, w.Code) + } + + want := webSrv.options.BrandName + " - Error" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = "This paste is private" + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Get private paste of the user who created it +func TestGetPrivatePaste(t *testing.T) { + t.Parallel() + + u, _ := webSrv.service.GetOrUpdateUser(store.User{ + ID: "test_user_1", + Name: "Test User 1", + }) + p, _ := webSrv.service.NewPaste(service.PasteRequest{ + Title: "Test", + Body: "Test paste", + Expires: "", + DeleteAfterRead: false, + Privacy: "private", + Password: "", + Syntax: "text", + UserID: u.ID, + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) + // Add user to request context + r = token.SetUserInfo(r, token.User{ + Name: u.Name, + ID: u.ID, + }) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) + } + + want := webSrv.options.BrandName + " - Paste" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = `Test paste` + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Get password protected paste without password +func TestGetPasswordProtectedPasteNoPassword(t *testing.T) { + t.Parallel() + + p, _ := webSrv.service.NewPaste(service.PasteRequest{ + Title: "Test", + Body: "Test paste", + Expires: "", + DeleteAfterRead: false, + Privacy: "public", + Password: "pa$$w0rd", + Syntax: "text", + UserID: "", + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/p/"+p.URL(), nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusUnauthorized { + t.Errorf("Status should be %d, got %d", http.StatusUnauthorized, w.Code) + } + + want := webSrv.options.BrandName + " - Password" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = `Test paste` + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} + +// Get a list of pastes for a user +func TestGetUserPaste(t *testing.T) { + t.Parallel() + + p1, _ := webSrv.service.NewPaste(service.PasteRequest{ + Title: "Test 1", + Body: "Test paste 1", + Expires: "", + DeleteAfterRead: false, + Privacy: "public", + Password: "", + Syntax: "text", + UserID: "", + }) + p2, _ := webSrv.service.NewPaste(service.PasteRequest{ + Title: "Test 2", + Body: "Test paste 2", + Expires: "", + DeleteAfterRead: false, + Privacy: "private", + Password: "", + Syntax: "text", + UserID: "", + }) + + w := httptest.NewRecorder() + r, _ := http.NewRequest("GET", "/l/", nil) + webSrv.router.ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Errorf("Status should be %d, got %d", http.StatusOK, w.Code) + } + + want := webSrv.options.BrandName + " - Pastes" + got := w.Body.String() + if !strings.Contains(got, want) { + t.Errorf("Response should have title [%s], got [%s]", want, got) + } + + want = p1.Title + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } + + want = p2.Title + if !strings.Contains(got, want) { + t.Errorf("Response should have body [%s], got [%s]", want, got) + } +} diff --git a/src/web/web.go b/src/web/web.go index d2da44b..f555f6d 100644 --- a/src/web/web.go +++ b/src/web/web.go @@ -1,245 +1,246 @@ -// Copyright 2021 Ilia Frenkel. All rights reserved. -// Use of this source code is governed by a MIT-style -// license that can be found in the LICENSE.txt file. - -// Package web implements a web server that provides a front-end for the -// go-pb application. -package web - -import ( - "context" - "fmt" - "html/template" - "io" - "net" - "net/http" - "os" - "time" - - "github.com/go-pkgz/auth" - "github.com/go-pkgz/auth/avatar" - "github.com/go-pkgz/auth/token" - "github.com/go-pkgz/lgr" - "github.com/gorilla/handlers" - "github.com/gorilla/mux" - "github.com/iliafrenkel/go-pb/src/service" -) - -// WebServerOptions defines various parameters needed to run the WebServer -type WebServerOptions struct { - Addr string // address to listen on, see http.Server docs for details - Proto string // protocol, either "http" or "https" - ReadTimeout time.Duration // maximum duration for reading the entire request. - WriteTimeout time.Duration // maximum duration before timing out writes of the response - IdleTimeout time.Duration // maximum amount of time to wait for the next request - LogFile string // if not empty, will write logs to the file - LogMode string // can be either "debug" or "production" - BrandName string // displayed at the top of each page, default is "Go PB" - BrandTagline string // displayed below the BrandName - Assets string // location of the assets folder (css, js, images) - Templates string // location of the templates folder - Logo string // name of the logo image within the assets folder - MaxBodySize int64 // maximum size for request's body - BootstrapTheme string // one of the themes, see css files in the assets folder - Version string // app version, comes from build - AuthSecret string // secret for JWT token generation and validation - AuthTokenDuration time.Duration // JWT token expiration duration - AuthCookieDuration time.Duration // cookie expiration time - AuthIssuer string // application name used as an issuer in oauth requests - AuthURL string // callback URL for oauth requests - DBType string // type of the store to use - DBConn string // database connection string - GitHubCID string // github client id for oauth - GitHubCSEC string // github client secret for oauth - GoogleCID string // google client id for oauth - GoogleCSEC string // google client secret for oauth - TwitterCID string // twitter client id for oauth - TwitterCSEC string // twitter client secret for oauth -} - -// WebServer encapsulates a router and a server. -// Normally, you'd create a new instance by calling New which configures the -// rotuer and then call ListenAndServe to start serving incoming requests. -type WebServer struct { - router *mux.Router - server *http.Server - options WebServerOptions - templates *template.Template - log *lgr.Logger - service *service.Service -} - -var dbgLogFormatter handlers.LogFormatter = func(writer io.Writer, params handlers.LogFormatterParams) { - const ( - green = "\033[97;42m" - white = "\033[90;47m" - yellow = "\033[90;43m" - red = "\033[97;41m" - blue = "\033[97;44m" - magenta = "\033[97;45m" - cyan = "\033[97;46m" - reset = "\033[0m" - ) - - code := params.StatusCode - cclr := "" - switch { - case code >= http.StatusOK && code < http.StatusMultipleChoices: - cclr = green - case code >= http.StatusMultipleChoices && code < http.StatusBadRequest: - cclr = white - case code >= http.StatusBadRequest && code < http.StatusInternalServerError: - cclr = yellow - default: - cclr = red - } - - method := params.Request.Method - mclr := "" - switch method { - case http.MethodGet: - mclr = blue - case http.MethodPost: - mclr = cyan - case http.MethodPut: - mclr = yellow - case http.MethodDelete: - mclr = red - case http.MethodPatch: - mclr = green - case http.MethodHead: - mclr = magenta - case http.MethodOptions: - mclr = white - default: - mclr = reset - } - - host, _, err := net.SplitHostPort(params.Request.RemoteAddr) - if err != nil { - host = params.Request.RemoteAddr - } - - fmt.Fprintf(writer, "|%s %3d %s| %15s |%s %-7s %s| %8d | %s \n", - cclr, code, reset, - host, - mclr, method, reset, - params.Size, - params.URL.RequestURI(), - ) -} - -// ListenAndServe starts an HTTP server and binds it to the provided address. -// You have to call New() first to initialise the WebServer. -func (h *WebServer) ListenAndServe() error { - var hdlr http.Handler - var w io.Writer - var err error - if h.options.LogFile == "" { - w = lgr.ToWriter(h.log, "") - } else { - w, err = os.OpenFile(h.options.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return fmt.Errorf("WebServer.ListenAndServer: cannot open log file: [%s]: %w", h.options.LogFile, err) - } - } - if h.options.LogMode == "debug" { - hdlr = handlers.CustomLoggingHandler(w, h.router, dbgLogFormatter) - } else { - hdlr = handlers.CombinedLoggingHandler(w, h.router) - } - h.server = &http.Server{ - Addr: h.options.Addr, - WriteTimeout: h.options.WriteTimeout, - ReadTimeout: h.options.ReadTimeout, - IdleTimeout: h.options.IdleTimeout, - Handler: hdlr, - } - - return h.server.ListenAndServe() -} - -func (h *WebServer) Shutdown(ctx context.Context) error { - return h.server.Shutdown(ctx) -} - -// New returns an instance of the WebServer with initialised middleware, -// loaded templates and routes. You can call ListenAndServe on a newly -// created instance to initialise the HTTP server and start handling incoming -// requests. -func New(l *lgr.Logger, opts WebServerOptions) *WebServer { - var handler WebServer - handler.log = l - handler.options = opts - - // Load template - tpl, err := template.ParseGlob(handler.options.Templates + "/*.html") - if err != nil { - handler.log.Logf("FATAL error loading templates: %v", err) - } - handler.log.Logf("INFO loaded %d templates", len(tpl.Templates())) - handler.templates = tpl - - // Initialise the service - switch opts.DBType { - case "memory": - handler.service = service.NewWithMemDB() - case "postgres": - handler.service, err = service.NewWithPostgres(opts.DBConn) - if err != nil { - handler.log.Logf("FATAL error creating Postgres service: %v", err) - } - default: - handler.log.Logf("FATAL unknown store type: %v", opts.DBType) - } - - // Initialise the router - handler.router = mux.NewRouter() - - // Templates and static files - handler.router.PathPrefix("/assets/").Handler(http.StripPrefix("/assets/", http.FileServer(http.Dir(handler.options.Assets)))) - - // Auth middleware - authSvc := auth.NewService(auth.Opts{ - SecretReader: token.SecretFunc(func(id string) (string, error) { // secret key for JWT - return handler.options.AuthSecret, nil - }), - TokenDuration: handler.options.AuthTokenDuration, - CookieDuration: handler.options.AuthCookieDuration, - Issuer: handler.options.AuthIssuer, - URL: handler.options.AuthURL, - DisableXSRF: true, - AvatarStore: avatar.NewLocalFS(".tmp"), - Logger: handler.log, // optional logger for auth library - }) - authSvc.AddProvider("github", handler.options.GitHubCID, handler.options.GitHubCSEC) - authSvc.AddProvider("google", handler.options.GoogleCID, handler.options.GoogleCSEC) - authSvc.AddProvider("twitter", handler.options.TwitterCID, handler.options.TwitterCSEC) - authSvc.AddProvider("dev", "", "") // dev auth, runs dev oauth2 server on :8084 - - go func() { - devAuthServer, err := authSvc.DevAuth() - if err != nil { - handler.log.Logf("FATAL %v", err) - } - devAuthServer.Run(context.Background()) - }() - m := authSvc.Middleware() - handler.router.Use(m.Trace) - authRoutes, avaRoutes := authSvc.Handlers() - handler.router.PathPrefix("/auth").Handler(authRoutes) - handler.router.PathPrefix("/avatar").Handler(avaRoutes) - - // Define routes - handler.router.HandleFunc("/", handler.handleGetHomePage).Methods("GET") - handler.router.HandleFunc("/p/", handler.handlePostPaste).Methods("POST") - handler.router.HandleFunc("/p/", handler.handleGetHomePage).Methods("GET") - handler.router.HandleFunc("/p/{id}", handler.handleGetPastePage).Methods("GET") - handler.router.HandleFunc("/p/{id}", handler.handleGetPastePage).Methods("POST") - handler.router.HandleFunc("/l/", handler.handleGetPastesList).Methods("GET") - - // Common error routes - handler.router.NotFoundHandler = handler.router.NewRoute().BuildOnly().HandlerFunc(handler.notFound).GetHandler() - - return &handler -} +// Copyright 2021 Ilia Frenkel. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE.txt file. + +// Package web implements a web server that provides a front-end for the +// go-pb application. +package web + +import ( + "context" + "fmt" + "html/template" + "io" + "net" + "net/http" + "os" + "time" + + "github.com/go-pkgz/auth" + "github.com/go-pkgz/auth/avatar" + "github.com/go-pkgz/auth/token" + "github.com/go-pkgz/lgr" + "github.com/gorilla/handlers" + "github.com/gorilla/mux" + "github.com/iliafrenkel/go-pb/src/service" +) + +// ServerOptions defines various parameters needed to run the WebServer +type ServerOptions struct { + Addr string // address to listen on, see http.Server docs for details + Proto string // protocol, either "http" or "https" + ReadTimeout time.Duration // maximum duration for reading the entire request. + WriteTimeout time.Duration // maximum duration before timing out writes of the response + IdleTimeout time.Duration // maximum amount of time to wait for the next request + LogFile string // if not empty, will write logs to the file + LogMode string // can be either "debug" or "production" + BrandName string // displayed at the top of each page, default is "Go PB" + BrandTagline string // displayed below the BrandName + Assets string // location of the assets folder (css, js, images) + Templates string // location of the templates folder + Logo string // name of the logo image within the assets folder + MaxBodySize int64 // maximum size for request's body + BootstrapTheme string // one of the themes, see css files in the assets folder + Version string // app version, comes from build + AuthSecret string // secret for JWT token generation and validation + AuthTokenDuration time.Duration // JWT token expiration duration + AuthCookieDuration time.Duration // cookie expiration time + AuthIssuer string // application name used as an issuer in oauth requests + AuthURL string // callback URL for oauth requests + DBType string // type of the store to use + DBConn string // database connection string + GitHubCID string // github client id for oauth + GitHubCSEC string // github client secret for oauth + GoogleCID string // google client id for oauth + GoogleCSEC string // google client secret for oauth + TwitterCID string // twitter client id for oauth + TwitterCSEC string // twitter client secret for oauth +} + +// Server encapsulates a router and a server. +// Normally, you'd create a new instance by calling New which configures the +// rotuer and then call ListenAndServe to start serving incoming requests. +type Server struct { + router *mux.Router + server *http.Server + options ServerOptions + templates *template.Template + log *lgr.Logger + service *service.Service +} + +var dbgLogFormatter handlers.LogFormatter = func(writer io.Writer, params handlers.LogFormatterParams) { + const ( + green = "\033[97;42m" + white = "\033[90;47m" + yellow = "\033[90;43m" + red = "\033[97;41m" + blue = "\033[97;44m" + magenta = "\033[97;45m" + cyan = "\033[97;46m" + reset = "\033[0m" + ) + + code := params.StatusCode + cclr := "" + switch { + case code >= http.StatusOK && code < http.StatusMultipleChoices: + cclr = green + case code >= http.StatusMultipleChoices && code < http.StatusBadRequest: + cclr = white + case code >= http.StatusBadRequest && code < http.StatusInternalServerError: + cclr = yellow + default: + cclr = red + } + + method := params.Request.Method + mclr := "" + switch method { + case http.MethodGet: + mclr = blue + case http.MethodPost: + mclr = cyan + case http.MethodPut: + mclr = yellow + case http.MethodDelete: + mclr = red + case http.MethodPatch: + mclr = green + case http.MethodHead: + mclr = magenta + case http.MethodOptions: + mclr = white + default: + mclr = reset + } + + host, _, err := net.SplitHostPort(params.Request.RemoteAddr) + if err != nil { + host = params.Request.RemoteAddr + } + + fmt.Fprintf(writer, "|%s %3d %s| %15s |%s %-7s %s| %8d | %s \n", + cclr, code, reset, + host, + mclr, method, reset, + params.Size, + params.URL.RequestURI(), + ) +} + +// ListenAndServe starts an HTTP server and binds it to the provided address. +// You have to call New() first to initialise the WebServer. +func (h *Server) ListenAndServe() error { + var hdlr http.Handler + var w io.Writer + var err error + if h.options.LogFile == "" { + w = lgr.ToWriter(h.log, "") + } else { + w, err = os.OpenFile(h.options.LogFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("WebServer.ListenAndServer: cannot open log file: [%s]: %w", h.options.LogFile, err) + } + } + if h.options.LogMode == "debug" { + hdlr = handlers.CustomLoggingHandler(w, h.router, dbgLogFormatter) + } else { + hdlr = handlers.CombinedLoggingHandler(w, h.router) + } + h.server = &http.Server{ + Addr: h.options.Addr, + WriteTimeout: h.options.WriteTimeout, + ReadTimeout: h.options.ReadTimeout, + IdleTimeout: h.options.IdleTimeout, + Handler: hdlr, + } + + return h.server.ListenAndServe() +} + +// Shutdown gracefully shutdown the server with the givem context. +func (h *Server) Shutdown(ctx context.Context) error { + return h.server.Shutdown(ctx) +} + +// New returns an instance of the WebServer with initialised middleware, +// loaded templates and routes. You can call ListenAndServe on a newly +// created instance to initialise the HTTP server and start handling incoming +// requests. +func New(l *lgr.Logger, opts ServerOptions) *Server { + var handler Server + handler.log = l + handler.options = opts + + // Load template + tpl, err := template.ParseGlob(handler.options.Templates + "/*.html") + if err != nil { + handler.log.Logf("FATAL error loading templates: %v", err) + } + handler.log.Logf("INFO loaded %d templates", len(tpl.Templates())) + handler.templates = tpl + + // Initialise the service + switch opts.DBType { + case "memory": + handler.service = service.NewWithMemDB() + case "postgres": + handler.service, err = service.NewWithPostgres(opts.DBConn) + if err != nil { + handler.log.Logf("FATAL error creating Postgres service: %v", err) + } + default: + handler.log.Logf("FATAL unknown store type: %v", opts.DBType) + } + + // Initialise the router + handler.router = mux.NewRouter() + + // Templates and static files + handler.router.PathPrefix("/assets/").Handler(http.StripPrefix("/assets/", http.FileServer(http.Dir(handler.options.Assets)))) + + // Auth middleware + authSvc := auth.NewService(auth.Opts{ + SecretReader: token.SecretFunc(func(id string) (string, error) { // secret key for JWT + return handler.options.AuthSecret, nil + }), + TokenDuration: handler.options.AuthTokenDuration, + CookieDuration: handler.options.AuthCookieDuration, + Issuer: handler.options.AuthIssuer, + URL: handler.options.AuthURL, + DisableXSRF: true, + AvatarStore: avatar.NewLocalFS(".tmp"), + Logger: handler.log, // optional logger for auth library + }) + authSvc.AddProvider("github", handler.options.GitHubCID, handler.options.GitHubCSEC) + authSvc.AddProvider("google", handler.options.GoogleCID, handler.options.GoogleCSEC) + authSvc.AddProvider("twitter", handler.options.TwitterCID, handler.options.TwitterCSEC) + authSvc.AddProvider("dev", "", "") // dev auth, runs dev oauth2 server on :8084 + + go func() { + devAuthServer, err := authSvc.DevAuth() + if err != nil { + handler.log.Logf("FATAL %v", err) + } + devAuthServer.Run(context.Background()) + }() + m := authSvc.Middleware() + handler.router.Use(m.Trace) + authRoutes, avaRoutes := authSvc.Handlers() + handler.router.PathPrefix("/auth").Handler(authRoutes) + handler.router.PathPrefix("/avatar").Handler(avaRoutes) + + // Define routes + handler.router.HandleFunc("/", handler.handleGetHomePage).Methods("GET") + handler.router.HandleFunc("/p/", handler.handlePostPaste).Methods("POST") + handler.router.HandleFunc("/p/", handler.handleGetHomePage).Methods("GET") + handler.router.HandleFunc("/p/{id}", handler.handleGetPastePage).Methods("GET") + handler.router.HandleFunc("/p/{id}", handler.handleGetPastePage).Methods("POST") + handler.router.HandleFunc("/l/", handler.handleGetPastesList).Methods("GET") + + // Common error routes + handler.router.NotFoundHandler = handler.router.NewRoute().BuildOnly().HandlerFunc(handler.notFound).GetHandler() + + return &handler +}