diff --git a/internal/database/db.go b/internal/database/db.go index 2fed3d9..2043f40 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -3,6 +3,7 @@ package database import ( "errors" "strings" + "sync" "gorm.io/gorm" ) @@ -10,6 +11,7 @@ import ( var ( DB *gorm.DB Driver DriverType + Locker sync.Mutex ) type DriverType = string @@ -17,7 +19,7 @@ type DriverType = string const Sqlite = "sqlite" const Mysql = "mysql" const Postgres = "postgres" -const UnknowDatabase = "unknown database" +const UnknownDatabase = "unknown database" func InitDB(dsn string) (err error) { dbName, dbDsn := dbType(dsn) @@ -25,10 +27,13 @@ func InitDB(dsn string) (err error) { switch dbName { case Sqlite: DB, err = NewSqlite3(dbDsn) + Driver = Sqlite case Mysql: DB, err = NewMysql(dbDsn) + Driver = Mysql case Postgres: DB, err = NewPostgres(dbDsn) + Driver = Postgres default: err = errors.New("unsupported database") } @@ -45,5 +50,5 @@ func dbType(dsn string) (DriverType, string) { } else if strings.Contains(dsn, "@tcp") { //兼容上个版本的写法 return Mysql, dsn } - return UnknowDatabase, dsn + return UnknownDatabase, dsn } diff --git a/pkg/dns/record.go b/pkg/dns/record.go index b968a9f..203e6b3 100644 --- a/pkg/dns/record.go +++ b/pkg/dns/record.go @@ -41,6 +41,13 @@ func newRecord(rule *Rule, flag, domain, remoteIp, ipArea string) (r *Record, er Domain: domain, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error } diff --git a/pkg/ftp/record.go b/pkg/ftp/record.go index bd5dd9b..08eed92 100644 --- a/pkg/ftp/record.go +++ b/pkg/ftp/record.go @@ -52,6 +52,13 @@ func NewRecord(rule *Rule, flag, user, password, method, path, ip, area string, File: file, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error } diff --git a/pkg/ldap/record.go b/pkg/ldap/record.go index 88690fc..4ab218f 100644 --- a/pkg/ldap/record.go +++ b/pkg/ldap/record.go @@ -40,6 +40,13 @@ func NewRecord(rule *Rule, flag, path, ip, area string) (r *Record, err error) { Path: path, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error } diff --git a/pkg/mysql/record.go b/pkg/mysql/record.go index 20f9181..7784068 100644 --- a/pkg/mysql/record.go +++ b/pkg/mysql/record.go @@ -51,6 +51,13 @@ func newRecord(rule *Rule, flag, username, schema, clientName, clientOS, remoteI Files: files, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error } diff --git a/pkg/rhttp/record.go b/pkg/rhttp/record.go index 9455af7..eca7db1 100644 --- a/pkg/rhttp/record.go +++ b/pkg/rhttp/record.go @@ -44,6 +44,13 @@ func NewRecord(rule *Rule, flag, method, url, ip, area, raw string) (r *Record, RawRequest: raw, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error } diff --git a/pkg/rmi/record.go b/pkg/rmi/record.go index 22575e5..4bc77f9 100644 --- a/pkg/rmi/record.go +++ b/pkg/rmi/record.go @@ -40,6 +40,13 @@ func NewRecord(rule *Rule, flag, path, ip, area string) (r *Record, err error) { Path: path, Rule: *rule, } + + // sqlite db-level lock to prevent too much write operation lead to error of `database is locked` #54 + if database.Driver == database.Sqlite { + database.Locker.Lock() + defer database.Locker.Unlock() + } + return r, database.DB.Create(r).Error }