Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconnect database when connection is invalid. #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 82 additions & 10 deletions short/short.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package short

import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log"
Expand All @@ -10,7 +11,7 @@ import (
"github.com/andyxning/shortme/conf"
"github.com/andyxning/shortme/sequence"
_ "github.com/andyxning/shortme/sequence/db"
_ "github.com/go-sql-driver/mysql"
"github.com/go-sql-driver/mysql"
)

type shorter struct {
Expand All @@ -21,6 +22,16 @@ type shorter struct {

// connect will panic when it can not connect to DB server.
func (shorter *shorter) mustConnect() {
shorter.reconnectReadDB()
shorter.reconnectWriteDB()
}

func (shorter *shorter) reconnectReadDB() {
if shorter.readDB != nil {
shorter.readDB.Close()
shorter.readDB = nil
}

db, err := sql.Open("mysql", conf.Conf.ShortDB.ReadDSN)
if err != nil {
log.Panicf("short read db open error. %v", err)
Expand All @@ -35,8 +46,15 @@ func (shorter *shorter) mustConnect() {
db.SetMaxOpenConns(conf.Conf.ShortDB.MaxOpenConns)

shorter.readDB = db
}

func (shorter *shorter) reconnectWriteDB() {
if shorter.writeDB != nil {
shorter.writeDB.Close()
shorter.writeDB = nil
}

db, err = sql.Open("mysql", conf.Conf.ShortDB.WriteDSN)
db, err := sql.Open("mysql", conf.Conf.ShortDB.WriteDSN)
if err != nil {
log.Panicf("short write db open error. %v", err)
}
Expand All @@ -54,6 +72,15 @@ func (shorter *shorter) mustConnect() {

// initSequence will panic when it can not open the sequence successfully.
func (shorter *shorter) mustInitSequence() {
shorter.reconnectSequence()
}

func (shorter *shorter) reconnectSequence() {
if shorter.sequence != nil {
shorter.sequence.Close()
shorter.sequence = nil
}

sequence, err := sequence.GetSequence("db")
if err != nil {
log.Panicf("get sequence instance error. %v", err)
Expand All @@ -79,14 +106,29 @@ func (shorter *shorter) close() {
}
}

func (shorter *shorter) connectionError(err *error) bool {
return *err == driver.ErrBadConn || *err == mysql.ErrInvalidConn
}

func (shorter *shorter) Expand(shortURL string) (longURL string, err error) {
selectSQL := fmt.Sprintf(`SELECT long_url FROM short WHERE short_url=?`)

var rows *sql.Rows
rows, err = shorter.readDB.Query(selectSQL, shortURL)

if err != nil {
log.Printf("short read db query error. %v", err)
return "", errors.New("short read db query error")
if shorter.connectionError(&err) {
shorter.reconnectReadDB()

rows, err = shorter.readDB.Query(selectSQL, shortURL)
if err != nil {
log.Printf("short read db query error. %v", err)
return "", errors.New("short read db query error")
}
} else {
log.Printf("short read db query error. %v", err)
return "", errors.New("short read db query error")
}
}

defer rows.Close()
Expand All @@ -113,8 +155,18 @@ func (shorter *shorter) Short(longURL string) (shortURL string, err error) {
var seq uint64
seq, err = shorter.sequence.NextSequence()
if err != nil {
log.Printf("get next sequence error. %v", err)
return "", errors.New("get next sequence error")
if shorter.connectionError(&err) {
shorter.reconnectSequence()

seq, err = shorter.sequence.NextSequence()
if err != nil {
log.Printf("get next sequence error. %v", err)
return "", errors.New("get next sequence error")
}
} else {
log.Printf("get next sequence error. %v", err)
return "", errors.New("get next sequence error")
}
}

shortURL = base.Int2String(seq)
Expand All @@ -130,15 +182,35 @@ func (shorter *shorter) Short(longURL string) (shortURL string, err error) {
var stmt *sql.Stmt
stmt, err = shorter.writeDB.Prepare(insertSQL)
if err != nil {
log.Printf("short write db prepares error. %v", err)
return "", errors.New("short write db prepares error")
if shorter.connectionError(&err) {
shorter.reconnectWriteDB()

stmt, err = shorter.writeDB.Prepare(insertSQL)
if err != nil {
log.Printf("short write db prepares error. %v", err)
return "", errors.New("short write db prepares error")
}
} else {
log.Printf("short write db prepares error. %v", err)
return "", errors.New("short write db prepares error")
}
}
defer stmt.Close()

_, err = stmt.Exec(longURL, shortURL)
if err != nil {
log.Printf("short write db insert error. %v", err)
return "", errors.New("short write db insert error")
if shorter.connectionError(&err) {
shorter.reconnectWriteDB()

_, err = stmt.Exec(longURL, shortURL)
if err != nil {
log.Printf("short write db insert error. %v", err)
return "", errors.New("short write db insert error")
}
} else {
log.Printf("short write db insert error. %v", err)
return "", errors.New("short write db insert error")
}
}

return shortURL, nil
Expand Down