Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race errors with memory tables #2510

Merged
merged 2 commits into from
May 22, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions memory/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/fulltext"
"github.com/dolthub/go-mysql-server/sql/types"
"sync"
)

// Database is an in-memory database.
Expand All @@ -34,6 +35,7 @@ type Database struct {
type MemoryDatabase interface {
sql.Database
AddTable(name string, t MemTable)
DeleteTable(name string)
Database() *BaseDatabase
}

Expand All @@ -60,6 +62,7 @@ type BaseDatabase struct {
events []sql.EventDefinition
primaryKeyIndexes bool
collation sql.CollationID
tablesMu *sync.RWMutex
}

var _ MemoryDatabase = (*Database)(nil)
Expand All @@ -76,9 +79,10 @@ func NewDatabase(name string) *Database {
// NewViewlessDatabase creates a new database that doesn't persist views. Used only for testing. Use NewDatabase.
func NewViewlessDatabase(name string) *BaseDatabase {
return &BaseDatabase{
name: name,
tables: map[string]MemTable{},
fkColl: newForeignKeyCollection(),
name: name,
tables: map[string]MemTable{},
fkColl: newForeignKeyCollection(),
tablesMu: &sync.RWMutex{},
}
}

Expand All @@ -103,6 +107,8 @@ func (d *BaseDatabase) Name() string {

// Tables returns all tables in the database.
func (d *BaseDatabase) Tables() map[string]sql.Table {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()
tables := make(map[string]sql.Table, len(d.tables))
for name, table := range d.tables {
tables[name] = table
Expand Down Expand Up @@ -133,17 +139,23 @@ func (d *BaseDatabase) GetTableInsensitive(ctx *sql.Context, tblName string) (sq
// putTable writes the table given into database storage. A table with this name must already be present.
func (d *BaseDatabase) putTable(t *Table) {
lowerName := strings.ToLower(t.name)
d.tablesMu.RLock()
for name, table := range d.tables {
if strings.ToLower(name) == lowerName {
t.name = table.Name()
d.tables[name] = t
d.tablesMu.RUnlock()
d.AddTable(name, t)
return
}
}
d.tablesMu.RUnlock()
panic(fmt.Sprintf("table %s not found", t.name))
}

func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()

tblNames := make([]string, 0, len(d.tables))
for k := range d.tables {
tblNames = append(tblNames, k)
Expand All @@ -153,6 +165,9 @@ func (d *BaseDatabase) GetTableNames(ctx *sql.Context) ([]string, error) {
}

func (d *BaseDatabase) CreateFulltextTableNames(ctx *sql.Context, parentTableName string, parentIndexName string) (fulltext.IndexTableNames, error) {
d.tablesMu.RLock()
defer d.tablesMu.RUnlock()

var tablePrefix string
OuterLoop:
for i := uint64(0); true; i++ {
Expand Down Expand Up @@ -225,17 +240,28 @@ func (db *HistoryDatabase) AddTableAsOf(name string, t sql.Table, asOf interface
}

db.Revisions[strings.ToLower(name)][asOf] = t
db.tables[name] = t.(MemTable)
db.AddTable(name, t.(MemTable))
}

// AddTable adds a new table to the database.
func (d *BaseDatabase) AddTable(name string, t MemTable) {
d.tablesMu.Lock()
defer d.tablesMu.Unlock()
d.tables[name] = t
}

// DeleteTable deletes a table from the database.
func (d *BaseDatabase) DeleteTable(name string) {
d.tablesMu.Lock()
defer d.tablesMu.Unlock()
delete(d.tables, name)
}

// CreateTable creates a table with the given name and schema
func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.PrimaryKeySchema, collation sql.CollationID, comment string) error {
d.tablesMu.RLock()
_, ok := d.tables[name]
d.tablesMu.RUnlock()
if ok {
return sql.ErrTableAlreadyExists.New(name)
}
Expand All @@ -247,7 +273,7 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri
table.EnablePrimaryKeyIndexes()
}

d.tables[name] = table
d.AddTable(name, table)
sess := SessionFromContext(ctx)
sess.putTable(table.data)

Expand All @@ -256,7 +282,9 @@ func (d *BaseDatabase) CreateTable(ctx *sql.Context, name string, schema sql.Pri

// CreateIndexedTable creates a table with the given name and schema
func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql.PrimaryKeySchema, idxDef sql.IndexDef, collation sql.CollationID) error {
d.tablesMu.RLock()
_, ok := d.tables[name]
d.tablesMu.RUnlock()
if ok {
return sql.ErrTableAlreadyExists.New(name)
}
Expand All @@ -278,31 +306,40 @@ func (d *BaseDatabase) CreateIndexedTable(ctx *sql.Context, name string, sch sql
}
}

d.tables[name] = table
d.AddTable(name, table)
return nil
}

// DropTable drops the table with the given name
func (d *BaseDatabase) DropTable(ctx *sql.Context, name string) error {
d.tablesMu.RLock()
t, ok := d.tables[name]
d.tablesMu.RUnlock()

if !ok {
return sql.ErrTableNotFound.New(name)
}

SessionFromContext(ctx).dropTable(t.(*Table).data)

delete(d.tables, name)
d.DeleteTable(name)
return nil
}

func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) error {
d.tablesMu.RLock()
tbl, ok := d.tables[oldName]
d.tablesMu.RUnlock()

if !ok {
// Should be impossible (engine already checks this condition)
return sql.ErrTableNotFound.New(oldName)
}

d.tablesMu.RLock()
_, ok = d.tables[newName]
d.tablesMu.RUnlock()

if ok {
return sql.ErrTableAlreadyExists.New(newName)
}
Expand All @@ -327,8 +364,8 @@ func (d *BaseDatabase) RenameTable(ctx *sql.Context, oldName, newName string) er
}
memTbl.data.tableName = newName

d.tables[newName] = memTbl
delete(d.tables, oldName)
d.AddTable(newName, memTbl)
d.DeleteTable(oldName)
sess.putTable(memTbl.data)

return nil
Expand Down
Loading