From d8bfc14b3e6d81c3388ea33930ec9b9e90fe715d Mon Sep 17 00:00:00 2001 From: Daniel Voogsgerd Date: Thu, 14 Dec 2023 14:04:41 +0100 Subject: [PATCH] fix(sw): Fix go implementation --- internal/smithwaterman/bench.go | 8 +- internal/smithwaterman/bench_test.go | 4 +- internal/smithwaterman/smithwaterman.go | 55 +++---- internal/smithwaterman/smithwaterman_test.go | 145 +++++++++++++------ internal/worker/worker_impl.go | 9 +- internal/worker/worker_rust_bridge.go | 13 +- 6 files changed, 149 insertions(+), 85 deletions(-) diff --git a/internal/smithwaterman/bench.go b/internal/smithwaterman/bench.go index ca3046c..49e0442 100644 --- a/internal/smithwaterman/bench.go +++ b/internal/smithwaterman/bench.go @@ -10,11 +10,17 @@ func runOnce(nQ, nT int) (time.Duration, float32) { const GapPenalty = 2 const MismatchPenalty = 1 + scores := AlignmentScore{ + MatchScore: MatchScore, + MismatchPenalty: MismatchPenalty, + GapPenalty: GapPenalty, + } + query := strings.Repeat("A", nQ) target := strings.Repeat("T", nT) start := time.Now() - findStringScore(query, target, MatchScore, GapPenalty, MismatchPenalty) + findStringScore(query, target, scores) end := time.Now() elapsed := end.Sub(start) diff --git a/internal/smithwaterman/bench_test.go b/internal/smithwaterman/bench_test.go index 46de8d2..8e24e70 100644 --- a/internal/smithwaterman/bench_test.go +++ b/internal/smithwaterman/bench_test.go @@ -1,11 +1,11 @@ package smithwaterman import ( + "testing" "time" - "testing" ) // Kinda, annoying to test, I'll just make sure it runs func TestBenchmark(t *testing.T) { - Benchmark(time.Duration(1e7), 4, 2) + Benchmark(time.Duration(1e7), 4, 2) } diff --git a/internal/smithwaterman/smithwaterman.go b/internal/smithwaterman/smithwaterman.go index 2f1c9d4..2d0b1d3 100644 --- a/internal/smithwaterman/smithwaterman.go +++ b/internal/smithwaterman/smithwaterman.go @@ -10,37 +10,37 @@ import ( "strings" ) -//Commented out for now -// var GAP_PENALTY = 1 -// var MATCH_SCORE = 2 -// var MISMATCH_PENALTY = 1 - -func findStringScore(query string, target string, matchScore int, gapPen int, mismatchPen int) []int { - score := make([]int, (len(query)+1)*(len(target)+1)) +type AlignmentScore struct { + MatchScore int + MismatchPenalty int + GapPenalty int +} - width := len(target) + 1 - height := len(query) + 1 +func findStringScore(query string, target string, scores AlignmentScore) []int { + width := len(query) + 1 + height := len(target) + 1 + score := make([]int, width*height) for y := 1; y < height; y++ { for x := 1; x < width; x++ { - subScore := matchScore + subScore := scores.MatchScore - if query[y-1] != target[x-1] { - subScore = -mismatchPen + if query[x-1] != target[y-1] { + subScore = -scores.MismatchPenalty } score[index(x, y, width)] = max(0, score[index(x-1, y-1, width)]+subScore, - score[index(x-1, y, width)]-gapPen, - score[index(x, y-1, width)]-gapPen) + score[index(x-1, y, width)]-scores.GapPenalty, + score[index(x, y-1, width)]-scores.GapPenalty) } } return score } -func FindLocalAlignment(query, target string, matchScore int, gapPen int, mismatchPen int) (string, string, int) { - score := findStringScore(query, target, matchScore, gapPen, mismatchPen) +func FindLocalAlignment(query, target string, scores AlignmentScore) (string, string, int) { + score := findStringScore(query, target, scores) maxIndex := 0 maxScore := 0 @@ -53,38 +53,39 @@ func FindLocalAlignment(query, target string, matchScore int, gapPen int, mismat } } - width := len(target) + 1 + width := len(query) + 1 x, y := index2coord(maxIndex, width) var queryResult, targetResult strings.Builder - traceback(score, query, target, x, y, width, &queryResult, &targetResult, matchScore, gapPen, mismatchPen) + traceback(score, query, target, x, y, width, &queryResult, &targetResult, scores) return queryResult.String(), targetResult.String(), maxScore } -func traceback(matrix []int, query, target string, x, y, width int, queryResult, targetResult *strings.Builder, matchScore int, gapPen int, mismatchPen int) { +func traceback(matrix []int, query, target string, x, y, width int, queryResult, targetResult *strings.Builder, scores AlignmentScore) { if x == 0 || y == 0 { return } + matchScore := scores.MatchScore if query[x-1] != target[y-1] { - matchScore = -mismatchPen + matchScore = -scores.MismatchPenalty } // TODO: Evaluate what is more important in the case of multiple paths - score := matrix[index(y, x, width)] + score := matrix[index(x, y, width)] if score == 0 { return - } else if score == matrix[index(y-1, x-1, width)]+matchScore { - traceback(matrix, query, target, x-1, y-1, width, queryResult, targetResult, matchScore, gapPen, mismatchPen) + } else if score == matrix[index(x-1, y-1, width)]+matchScore { + traceback(matrix, query, target, x-1, y-1, width, queryResult, targetResult, scores) queryResult.WriteByte(query[x-1]) targetResult.WriteByte(target[y-1]) - } else if score == matrix[index(y, x-1, width)]-gapPen { - traceback(matrix, query, target, x-1, y, width, queryResult, targetResult, matchScore, gapPen, mismatchPen) + } else if score == matrix[index(x-1, y, width)]-scores.GapPenalty { + traceback(matrix, query, target, x-1, y, width, queryResult, targetResult, scores) queryResult.WriteByte(query[x-1]) targetResult.WriteRune('-') } else { - traceback(matrix, query, target, x, y-1, width, queryResult, targetResult, matchScore, gapPen, mismatchPen) + traceback(matrix, query, target, x, y-1, width, queryResult, targetResult, scores) queryResult.WriteRune('-') targetResult.WriteByte(target[y-1]) } @@ -95,5 +96,5 @@ func index(x, y, width int) int { } func index2coord(index, width int) (int, int) { - return index / width, index % width + return index % width, index / width } diff --git a/internal/smithwaterman/smithwaterman_test.go b/internal/smithwaterman/smithwaterman_test.go index 2d590f1..5441aea 100644 --- a/internal/smithwaterman/smithwaterman_test.go +++ b/internal/smithwaterman/smithwaterman_test.go @@ -1,72 +1,97 @@ package smithwaterman import ( + "fmt" + "strings" "testing" ) // Global variables for now only for testing -var GAP_PENALTY = 1 -var MATCH_SCORE = 2 -var MISMATCH_PENALTY = 1 func TestBasic(t *testing.T) { - test_substring("A", "A", "A", "A", t) - test_substring("HOI", "HOI", "HOI", "HOI", t) - test_substring("AAAAAAATAAAAAAAA", "CCTCCCCCCCCCCCCC", "T", "T", t) + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, + } + + test_substring("A", "A", "A", "A", scores, t) + test_substring("HOI", "HOI", "HOI", "HOI", scores, t) + test_substring("AAAAAAATAAAAAAAA", "CCTCCCCCCCCCCCCC", "T", "T", scores, t) } func TestNoMatch(t *testing.T) { - test_substring("A", "T", "", "", t) - test_substring("AAAA", "TTTT", "", "", t) - test_substring("ATATTTATTAAATATATTATATATTAA", "CCCCGCGGGGCGCGCGGCGCGCGCGCGCG", "", "", t) + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, + } + + test_substring("A", "T", "", "", scores, t) + test_substring("AAAA", "TTTT", "", "", scores, t) + test_substring("ATATTTATTAAATATATTATATATTAA", "CCCCGCGGGGCGCGCGGCGCGCGCGCGCG", "", "", scores, t) } func TestGap(t *testing.T) { - test_with_scoring(1, 1, 2, "CCAA", "GATA", "A-A", "ATA", t) - test_with_scoring(1, 1, 2, "AA", "ATA", "A-A", "ATA", t) - test_with_scoring(1, 1, 2, "AA", "ATTA", "A", "A", t) - test_with_scoring(1, 1, 3, "AA", "ATTA", "A--A", "ATTA", t) - test_with_scoring(1, 1, 3, "ATA", "ATTA", "A-TA", "ATTA", t) - test_with_scoring(1, 1, 2, "AAAAAAAAA", "AAATTAAATTAAA", "AAA--AAA--AAA", "AAATTAAATTAAA", t) + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, + } + test_substring("CCAA", "GATA", "A-A", "ATA", scores, t) + test_substring("AA", "ATA", "A-A", "ATA", scores, t) + test_substring("AA", "ATTA", "A", "A", scores, t) + test_substring("AAAAAAAAA", "AAATTAAATTAAA", "AAA--AAA--AAA", "AAATTAAATTAAA", scores, t) + + scores = AlignmentScore{ + MatchScore: 3, + MismatchPenalty: 1, + GapPenalty: 1, + } + + test_substring("AA", "ATTA", "A--A", "ATTA", scores, t) + test_substring("ATA", "ATTA", "A-TA", "ATTA", scores, t) } func TestMismatch(t *testing.T) { - test_with_scoring(1, 1, 2, "ATA", "ACA", "ATA", "ACA", t) - test_with_scoring(3, 2, 5, "ACAC", "ACGCTTTTACC", "ACAC", "ACGC", t) - test_with_scoring(3, 2, 5, "ACAC", "AGGCTTTTACC", "ACAC", "AC-C", t) + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, + } + test_substring("ATA", "ACA", "ATA", "ACA", scores, t) + scores = AlignmentScore{ + MatchScore: 5, + MismatchPenalty: 2, + GapPenalty: 3, + } + test_substring("ACAC", "ACGCTTTTACC", "ACAC", "ACGC", scores, t) + test_substring("ACAC", "AGGCTTTTACC", "ACAC", "AC-C", scores, t) } func TestMultipleOptions(t *testing.T) { - test_with_scoring(1, 1, 2, "AA", "AATAA", "AA", "AA", t) - test_with_scoring(1, 1, 2, "ATTA", "ATAA", "ATTA", "AT-A", t) + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, + } + test_substring("AA", "AATAA", "AA", "AA", scores, t) + test_substring("ATTA", "ATAA", "ATTA", "AT-A", scores, t) } func TestAdvanced(t *testing.T) { - test_with_scoring(1, 1, 2, "TACGGGCCCGCTAC", "TAGCCCTATCGGTCA", "TACGGGCCCGCTA-C", "TA---G-CC-CTATC", t) - test_with_scoring(1, 1, 2, "AAGTCGTAAAAGTGCACGT", "TAAGCCGTTAAGTGCGCGTG", "AAGTCGTAAAAGTGCACGT", "AAGCCGT-TAAGTGCGCGT", t) - test_with_scoring(1, 1, 2, "AAGTCGTAAAAGTGCACGT", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzTAAGCCGTTAAGTGCGCGTG", "AAGTCGTAAAAGTGCACGT", "AAGCCGT-TAAGTGCGCGT", t) -} - -func test_with_scoring(gap int, mismatch int, match int, query, target, expected_query, expected_target string, t *testing.T) { - if gap < 0 || mismatch < 0 || match < 0 { - t.Log("Normally these should be all positive") + scores := AlignmentScore{ + MatchScore: 2, + MismatchPenalty: 1, + GapPenalty: 1, } - oldGap := GAP_PENALTY - oldMismatch := MISMATCH_PENALTY - oldMatch := MATCH_SCORE - GAP_PENALTY = gap - MISMATCH_PENALTY = mismatch - MATCH_SCORE = match - - test_substring(query, target, expected_query, expected_target, t) - - GAP_PENALTY = oldGap - MISMATCH_PENALTY = oldMismatch - MATCH_SCORE = oldMatch + test_substring("TACGGGCCCGCTAC", "TAGCCCTATCGGTCA", "TACGGGCCCGCTA-C", "TA---G-CC-CTATC", scores, t) + test_substring("AAGTCGTAAAAGTGCACGT", "TAAGCCGTTAAGTGCGCGTG", "AAGTCGTAAAAGTGCACGT", "AAGCCGT-TAAGTGCGCGT", scores, t) + test_substring("AAGTCGTAAAAGTGCACGT", "zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzTAAGCCGTTAAGTGCGCGTG", "AAGTCGTAAAAGTGCACGT", "AAGCCGT-TAAGTGCGCGT", scores, t) } -func test_substring(query, target, expected_query, expected_target string, t *testing.T) { - found_query, found_target, score := FindLocalAlignment(query, target, MATCH_SCORE, GAP_PENALTY, MISMATCH_PENALTY) +func test_substring(query, target, expected_query, expected_target string, scores AlignmentScore, t *testing.T) { + found_query, found_target, score := FindLocalAlignment(query, target, scores) if score == 0 { t.Logf("Found no substring") return @@ -91,3 +116,39 @@ func test_substring(query, target, expected_query, expected_target string, t *te t.Errorf("Did not find the substring for query: %s and target %s", query, target) } } + +func formatMatrix(matrix []int, query, target string) *string { + width := len(query) + 1 + height := len(target) + 1 + + if len(matrix) != width*height { + ret := "Matrix is not the same size as the provided width * height" + return &ret + } + + var output strings.Builder + + output.WriteString(" ") + + for x := 0; x < width-1; x++ { + output.WriteString(fmt.Sprintf(" %c ", query[x])) + } + + output.WriteRune('\n') + + for y := 0; y < height; y++ { + if y > 0 { + output.WriteString(fmt.Sprintf(" %c", target[y-1])) + } else { + output.WriteString(" ") + } + for x := 0; x < width; x++ { + output.WriteString(fmt.Sprintf("%3d ", matrix[index(x, y, width)])) + } + output.WriteRune('\n') + } + + ret := output.String() + + return &ret +} diff --git a/internal/worker/worker_impl.go b/internal/worker/worker_impl.go index a3f9863..1dc4ce2 100644 --- a/internal/worker/worker_impl.go +++ b/internal/worker/worker_impl.go @@ -1,6 +1,7 @@ package worker import ( + "dlsa/internal/smithwaterman" "errors" "fmt" "log" @@ -79,7 +80,7 @@ func (w *Worker) ExecuteWork(work *WorkPackage, queries []QueryTargetType) { // qRes, _, score := smithwaterman.FindLocalAlignment(string(querySeq), string(targetSeq), work.MatchScore, work.MismatchPenalty, work.GapPenalty) rustRes, err := findAlignmentWithFallback(string(querySeq), string(targetSeq), - AlignmentScore{work.MatchScore, -work.MismatchPenalty, -work.GapPenalty}) + smithwaterman.AlignmentScore{work.MatchScore, -work.MismatchPenalty, -work.GapPenalty}) if err != nil { // TODO: What now? @@ -116,7 +117,7 @@ func (w *Worker) ExecuteWork(work *WorkPackage, queries []QueryTargetType) { w.status = Waiting } -func findAlignmentWithFallback(query, target string, alignmentScore AlignmentScore) (*GoResult, error) { +func findAlignmentWithFallback(query, target string, alignmentScore smithwaterman.AlignmentScore) (*GoResult, error) { var rustRes *GoResult var err error @@ -154,8 +155,8 @@ func (w *Worker) ExecuteWorkInParallel(work *WorkPackage) { // Split the work packages into chunks var chunks = make([][]QueryTargetType, cpuCount) for i := 0; i < cpuCount; i++ { - var start = numWorkPackages * i / cpuCount; - var end = numWorkPackages * (i + 1) / cpuCount; + var start = numWorkPackages * i / cpuCount + var end = numWorkPackages * (i + 1) / cpuCount chunks[i] = workPackages[start:end] } var wg sync.WaitGroup diff --git a/internal/worker/worker_rust_bridge.go b/internal/worker/worker_rust_bridge.go index c243ca4..0ae7384 100644 --- a/internal/worker/worker_rust_bridge.go +++ b/internal/worker/worker_rust_bridge.go @@ -5,6 +5,7 @@ package worker import "C" import ( + "dlsa/internal/smithwaterman" "errors" ) @@ -20,12 +21,6 @@ type GoResult struct { Score uint16 } -type AlignmentScore struct { - MatchScore int - MismatchPenalty int - GapPenalty int -} - func cCharPtrToString(cStr *C.char) string { return C.GoString(cStr) } @@ -57,7 +52,7 @@ func ConvertResultToGoResult(cResult *C.struct_Result) GoResult { // return goResult // } -func FindRustAlignmentSequential(query, target string, alignmentScore AlignmentScore) (*GoResult, error) { +func FindRustAlignmentSequential(query, target string, alignmentScore smithwaterman.AlignmentScore) (*GoResult, error) { queryC := C.CString(query) targetC := C.CString(target) @@ -78,7 +73,7 @@ func FindRustAlignmentSequential(query, target string, alignmentScore AlignmentS return &goResult, nil } -func FindRustAlignmentSimd(query, target string, alignmentScore AlignmentScore) (*GoResult, error) { +func FindRustAlignmentSimd(query, target string, alignmentScore smithwaterman.AlignmentScore) (*GoResult, error) { queryC := C.CString(query) targetC := C.CString(target) alignmentScoreC := C.struct_AlignmentScores{ @@ -99,7 +94,7 @@ func FindRustAlignmentSimd(query, target string, alignmentScore AlignmentScore) return &goResult, nil } -func FindRustAlignmentSimdLowMem(query, target string, alignmentScore AlignmentScore) (*GoResult, error) { +func FindRustAlignmentSimdLowMem(query, target string, alignmentScore smithwaterman.AlignmentScore) (*GoResult, error) { queryC := C.CString(query) targetC := C.CString(target) alignmentScoreC := C.struct_AlignmentScores{