From 42fc75cb2ced9a27b8baecb08ec33976096007c0 Mon Sep 17 00:00:00 2001 From: black-06 Date: Sat, 18 Feb 2023 09:19:24 +0800 Subject: [PATCH] fix: association concurrently appending (#6044) * fix: association concurrently appending * fix: fix unit test * fix: fix gofumpt --- association.go | 8 ++++-- tests/associations_many2many_test.go | 40 ++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/association.go b/association.go index 06229caa7..6719a1d04 100644 --- a/association.go +++ b/association.go @@ -353,9 +353,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) } appendToFieldValues := func(ev reflect.Value) { diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 4ba31f902..845c16af5 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -1,9 +1,12 @@ package tests_test import ( + "fmt" + "sync" "testing" "gorm.io/gorm" + "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) @@ -353,3 +356,40 @@ func TestDuplicateMany2ManyAssociation(t *testing.T) { AssertEqual(t, nil, err) AssertEqual(t, user2, findUser2) } + +func TestConcurrentMany2ManyAssociation(t *testing.T) { + db, err := OpenTestConnection() + if err != nil { + t.Fatalf("open test connection failed, err: %+v", err) + } + + count := 3 + + var languages []Language + for i := 0; i < count; i++ { + language := Language{Code: fmt.Sprintf("consurrent %d", i)} + db.Create(&language) + languages = append(languages, language) + } + + user := User{} + db.Create(&user) + db.Preload("Languages").FirstOrCreate(&user) + + var wg sync.WaitGroup + for i := 0; i < count; i++ { + wg.Add(1) + go func(user User, language Language) { + err := db.Model(&user).Association("Languages").Append(&language) + AssertEqual(t, err, nil) + + wg.Done() + }(user, languages[i]) + } + wg.Wait() + + var find User + err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error + AssertEqual(t, err, nil) + AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") +}