From e21d500fa6d66532e30f342cf014b0d97bbb65f3 Mon Sep 17 00:00:00 2001 From: Corentin Giaufer Saubert <43623834+CorentinGS@users.noreply.github.com> Date: Thu, 13 Feb 2025 14:44:51 +0100 Subject: [PATCH] feat: enhance position handling in notation decoding (#27) * feat: enhance position handling in notation decoding * feat: update notation tests to include expected position validation --- notation.go | 2 + notation_test.go | 141 ++++++++++++++++------------------------------- 2 files changed, 51 insertions(+), 92 deletions(-) diff --git a/notation.go b/notation.go index 5a6b03d..0173d42 100644 --- a/notation.go +++ b/notation.go @@ -165,6 +165,8 @@ func (UCINotation) Decode(pos *Position, s string) (*Move, error) { // check for check addTags(&m, pos) + m.position = pos.Update(&m) + return &m, nil } diff --git a/notation_test.go b/notation_test.go index a0220f0..30667c4 100644 --- a/notation_test.go +++ b/notation_test.go @@ -4,63 +4,6 @@ import ( "testing" ) -type _ struct { - Pos1 *Position `json:"pos1"` - Pos2 *Position `json:"pos2"` - AlgText string `json:"alg_text"` - LongAlgText string `json:"long_alg_text"` - UCIText string `json:"uci_text"` - Description string `json:"description"` -} - -/* -TODO: Fix this test for new notation system -func TestValidDecoding(t *testing.T) { - f, err := os.Open("fixtures/valid_notation_tests.json") - if err != nil { - t.Fatal(err) - return - } - - var validTests []validNotationTest - if err := json.NewDecoder(f).Decode(&validTests); err != nil { - t.Fatal(err) - return - } - - for _, test := range validTests { - for i, n := range []Notation{AlgebraicNotation{}, LongAlgebraicNotation{}, UCINotation{}} { - var moveText string - switch i { - case 0: - moveText = test.AlgText - case 1: - moveText = test.LongAlgText - case 2: - moveText = test.UCIText - } - m, err := n.Decode(test.Pos1, moveText) - if err != nil { - movesStrList := []string{} - for _, m := range test.Pos1.ValidMoves() { - s := n.Encode(test.Pos1, &m) - movesStrList = append(movesStrList, s) - } - t.Fatalf("starting from board \n%s\n expected move to be valid error - %s %s\n", test.Pos1.board.Draw(), err.Error(), strings.Join(movesStrList, ",")) - } - postPos := test.Pos1.Update(m) - if test.Pos2.String() != postPos.String() { - t.Fatalf("starting from board \n%s%s\n after move %s\n expected board to be %s\n%s\n but was %s\n%s\n", - test.Pos1.String(), - test.Pos1.board.Draw(), m.String(), test.Pos2.String(), - test.Pos2.board.Draw(), postPos.String(), postPos.board.Draw()) - } - } - } -} - -*/ - type notationDecodeTest struct { N Notation Pos *Position @@ -154,46 +97,52 @@ func TestUCINotationDecode(t *testing.T) { moveWithCheckCapture.AddTag(Capture) tests := []struct { - name string - pos *Position - input string - want *Move - wantErr bool + name string + pos *Position + input string + want *Move + wantErr bool + expectedPos *Position }{ { - name: "valid move without promotion", - pos: unsafeFEN("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"), - input: "e2e4", - want: &Move{s1: E2, s2: E4}, - wantErr: false, + name: "valid move without promotion", + pos: unsafeFEN("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"), + input: "e2e4", + want: &Move{s1: E2, s2: E4}, + expectedPos: unsafeFEN("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"), + wantErr: false, }, { - name: "valid move with promotion", - pos: unsafeFEN("8/P7/8/8/8/8/8/8 w - - 0 1"), - input: "a7a8q", - want: &Move{s1: A7, s2: A8, promo: Queen}, - wantErr: false, + name: "valid move with promotion", + pos: unsafeFEN("8/P7/8/8/8/8/8/8 w - - 0 1"), + input: "a7a8q", + want: &Move{s1: A7, s2: A8, promo: Queen}, + expectedPos: unsafeFEN("Q7/8/8/8/8/8/8/8 b - - 0 1"), + wantErr: false, }, { - name: "valid move with capture", - pos: unsafeFEN("rnbqkb1r/ppp2ppp/3p1n2/4P3/4P3/2N5/PPP2PPP/R1BQKBNR b KQkq - 0 4"), - input: "d6e5", - want: &Move{s1: D6, s2: E5, tags: Capture}, - wantErr: false, + name: "valid move with capture", + pos: unsafeFEN("rnbqkb1r/ppp2ppp/3p1n2/4P3/4P3/2N5/PPP2PPP/R1BQKBNR b KQkq - 0 4"), + input: "d6e5", + want: &Move{s1: D6, s2: E5, tags: Capture}, + wantErr: false, + expectedPos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"), }, { - name: "valid move with check only", - pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"), - input: "f1b5", - want: &Move{s1: F1, s2: B5, tags: Check}, - wantErr: false, + name: "valid move with check only", + pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"), + input: "f1b5", + want: &Move{s1: F1, s2: B5, tags: Check}, + expectedPos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/1B2p3/4P3/2N5/PPP2PPP/R1BQK1NR b KQkq - 1 5"), + wantErr: false, }, { - name: "valid move with check and capture", - pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"), - input: "d1d8", - want: moveWithCheckCapture, - wantErr: false, + name: "valid move with check and capture", + pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"), + input: "d1d8", + want: moveWithCheckCapture, + wantErr: false, + expectedPos: unsafeFEN("rnbQkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1B1KBNR b KQkq - 0 5"), }, { name: "invalid UCI notation length", @@ -217,11 +166,12 @@ func TestUCINotationDecode(t *testing.T) { wantErr: true, }, { - name: "valid en passant move", - pos: unsafeFEN("rnbqkbnr/ppp2ppp/4p3/3pP3/8/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 3"), - input: "e5d6", - want: &Move{s1: E5, s2: D6, tags: EnPassant}, - wantErr: false, + name: "valid en passant move", + pos: unsafeFEN("rnbqkbnr/ppp2ppp/4p3/3pP3/8/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 3"), + input: "e5d6", + want: &Move{s1: E5, s2: D6, tags: EnPassant}, + wantErr: false, + expectedPos: unsafeFEN("rnbqkbnr/ppp2ppp/3Pp3/8/8/8/PPPP1PPP/RNBQKBNR b KQkq - 0 3"), }, } @@ -236,6 +186,13 @@ func TestUCINotationDecode(t *testing.T) { if !tt.wantErr && (got.String() != tt.want.String() || got.promo != tt.want.promo || got.tags != tt.want.tags) { t.Errorf("Decode() = %v (%d), want %v (%d)", got, got.tags, tt.want, tt.want.tags) } + if !tt.wantErr && tt.want.position != nil && got.position != nil && tt.want.position.String() != got.position.String() { + t.Errorf("Decode() = %v, want %v", got.position, tt.want.position) + } + + if !tt.wantErr && tt.expectedPos != nil && got.position.String() != tt.expectedPos.String() { + t.Errorf("Decode() = %v, want %v", got.position.String(), tt.expectedPos) + } }) } }