Skip to content

Commit

Permalink
feat: enhance position handling in notation decoding (#27)
Browse files Browse the repository at this point in the history
* feat: enhance position handling in notation decoding

* feat: update notation tests to include expected position validation
  • Loading branch information
CorentinGS authored Feb 13, 2025
1 parent 7b215f0 commit e21d500
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 92 deletions.
2 changes: 2 additions & 0 deletions notation.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ func (UCINotation) Decode(pos *Position, s string) (*Move, error) {
// check for check
addTags(&m, pos)

m.position = pos.Update(&m)

return &m, nil
}

Expand Down
141 changes: 49 additions & 92 deletions notation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,6 @@ import (
"testing"
)

type _ struct {
Pos1 *Position `json:"pos1"`
Pos2 *Position `json:"pos2"`
AlgText string `json:"alg_text"`
LongAlgText string `json:"long_alg_text"`
UCIText string `json:"uci_text"`
Description string `json:"description"`
}

/*
TODO: Fix this test for new notation system
func TestValidDecoding(t *testing.T) {
f, err := os.Open("fixtures/valid_notation_tests.json")
if err != nil {
t.Fatal(err)
return
}
var validTests []validNotationTest
if err := json.NewDecoder(f).Decode(&validTests); err != nil {
t.Fatal(err)
return
}
for _, test := range validTests {
for i, n := range []Notation{AlgebraicNotation{}, LongAlgebraicNotation{}, UCINotation{}} {
var moveText string
switch i {
case 0:
moveText = test.AlgText
case 1:
moveText = test.LongAlgText
case 2:
moveText = test.UCIText
}
m, err := n.Decode(test.Pos1, moveText)
if err != nil {
movesStrList := []string{}
for _, m := range test.Pos1.ValidMoves() {
s := n.Encode(test.Pos1, &m)
movesStrList = append(movesStrList, s)
}
t.Fatalf("starting from board \n%s\n expected move to be valid error - %s %s\n", test.Pos1.board.Draw(), err.Error(), strings.Join(movesStrList, ","))
}
postPos := test.Pos1.Update(m)
if test.Pos2.String() != postPos.String() {
t.Fatalf("starting from board \n%s%s\n after move %s\n expected board to be %s\n%s\n but was %s\n%s\n",
test.Pos1.String(),
test.Pos1.board.Draw(), m.String(), test.Pos2.String(),
test.Pos2.board.Draw(), postPos.String(), postPos.board.Draw())
}
}
}
}
*/

type notationDecodeTest struct {
N Notation
Pos *Position
Expand Down Expand Up @@ -154,46 +97,52 @@ func TestUCINotationDecode(t *testing.T) {
moveWithCheckCapture.AddTag(Capture)

tests := []struct {
name string
pos *Position
input string
want *Move
wantErr bool
name string
pos *Position
input string
want *Move
wantErr bool
expectedPos *Position
}{
{
name: "valid move without promotion",
pos: unsafeFEN("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
input: "e2e4",
want: &Move{s1: E2, s2: E4},
wantErr: false,
name: "valid move without promotion",
pos: unsafeFEN("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"),
input: "e2e4",
want: &Move{s1: E2, s2: E4},
expectedPos: unsafeFEN("rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR b KQkq e3 0 1"),
wantErr: false,
},
{
name: "valid move with promotion",
pos: unsafeFEN("8/P7/8/8/8/8/8/8 w - - 0 1"),
input: "a7a8q",
want: &Move{s1: A7, s2: A8, promo: Queen},
wantErr: false,
name: "valid move with promotion",
pos: unsafeFEN("8/P7/8/8/8/8/8/8 w - - 0 1"),
input: "a7a8q",
want: &Move{s1: A7, s2: A8, promo: Queen},
expectedPos: unsafeFEN("Q7/8/8/8/8/8/8/8 b - - 0 1"),
wantErr: false,
},
{
name: "valid move with capture",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/3p1n2/4P3/4P3/2N5/PPP2PPP/R1BQKBNR b KQkq - 0 4"),
input: "d6e5",
want: &Move{s1: D6, s2: E5, tags: Capture},
wantErr: false,
name: "valid move with capture",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/3p1n2/4P3/4P3/2N5/PPP2PPP/R1BQKBNR b KQkq - 0 4"),
input: "d6e5",
want: &Move{s1: D6, s2: E5, tags: Capture},
wantErr: false,
expectedPos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"),
},
{
name: "valid move with check only",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"),
input: "f1b5",
want: &Move{s1: F1, s2: B5, tags: Check},
wantErr: false,
name: "valid move with check only",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"),
input: "f1b5",
want: &Move{s1: F1, s2: B5, tags: Check},
expectedPos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/1B2p3/4P3/2N5/PPP2PPP/R1BQK1NR b KQkq - 1 5"),
wantErr: false,
},
{
name: "valid move with check and capture",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"),
input: "d1d8",
want: moveWithCheckCapture,
wantErr: false,
name: "valid move with check and capture",
pos: unsafeFEN("rnbqkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1BQKBNR w KQkq - 0 5"),
input: "d1d8",
want: moveWithCheckCapture,
wantErr: false,
expectedPos: unsafeFEN("rnbQkb1r/ppp2ppp/5n2/4p3/4P3/2N5/PPP2PPP/R1B1KBNR b KQkq - 0 5"),
},
{
name: "invalid UCI notation length",
Expand All @@ -217,11 +166,12 @@ func TestUCINotationDecode(t *testing.T) {
wantErr: true,
},
{
name: "valid en passant move",
pos: unsafeFEN("rnbqkbnr/ppp2ppp/4p3/3pP3/8/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 3"),
input: "e5d6",
want: &Move{s1: E5, s2: D6, tags: EnPassant},
wantErr: false,
name: "valid en passant move",
pos: unsafeFEN("rnbqkbnr/ppp2ppp/4p3/3pP3/8/8/PPPP1PPP/RNBQKBNR w KQkq d6 0 3"),
input: "e5d6",
want: &Move{s1: E5, s2: D6, tags: EnPassant},
wantErr: false,
expectedPos: unsafeFEN("rnbqkbnr/ppp2ppp/3Pp3/8/8/8/PPPP1PPP/RNBQKBNR b KQkq - 0 3"),
},
}

Expand All @@ -236,6 +186,13 @@ func TestUCINotationDecode(t *testing.T) {
if !tt.wantErr && (got.String() != tt.want.String() || got.promo != tt.want.promo || got.tags != tt.want.tags) {
t.Errorf("Decode() = %v (%d), want %v (%d)", got, got.tags, tt.want, tt.want.tags)
}
if !tt.wantErr && tt.want.position != nil && got.position != nil && tt.want.position.String() != got.position.String() {
t.Errorf("Decode() = %v, want %v", got.position, tt.want.position)
}

if !tt.wantErr && tt.expectedPos != nil && got.position.String() != tt.expectedPos.String() {
t.Errorf("Decode() = %v, want %v", got.position.String(), tt.expectedPos)
}
})
}
}
Expand Down

0 comments on commit e21d500

Please sign in to comment.