Skip to content

Commit

Permalink
Make EncodePositionForNN accept span<Position> (#2097)
Browse files Browse the repository at this point in the history
  • Loading branch information
mooskagh authored Dec 24, 2024
1 parent 4692791 commit c3a160c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#pragma once

#include <span>
#include <string>

#include "chess/board.h"
Expand Down Expand Up @@ -139,6 +140,8 @@ class PositionHistory {
// Checks for any repetitions since the last time 50 move rule was reset.
bool DidRepeatSinceLastZeroingMove() const;

std::span<const Position> GetPositions() const { return positions_; }

private:
int ComputeLastMoveRepetitions(int* cycle_length) const;

Expand Down
23 changes: 15 additions & 8 deletions src/neural/encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ int TransformForPosition(pblczero::NetworkFormat::InputFormat input_format,

InputPlanes EncodePositionForNN(
pblczero::NetworkFormat::InputFormat input_format,
const PositionHistory& history, int history_planes,
std::span<const Position> history, int history_planes,
FillEmptyHistory fill_empty_history, int* transform_out) {
InputPlanes result(kAuxPlaneBase + 8);

Expand All @@ -146,7 +146,7 @@ InputPlanes EncodePositionForNN(
// it for the first board.
ChessBoard::Castlings castlings;
{
const ChessBoard& board = history.Last().GetBoard();
const ChessBoard& board = history.back().GetBoard();
const bool we_are_black = board.flipped();
if (IsCanonicalFormat(input_format)) {
transform = ChooseTransform(board);
Expand Down Expand Up @@ -211,9 +211,9 @@ InputPlanes EncodePositionForNN(
if (we_are_black) result[kAuxPlaneBase + 4].SetAll();
}
if (IsHectopliesFormat(input_format)) {
result[kAuxPlaneBase + 5].Fill(history.Last().GetRule50Ply() / 100.0f);
result[kAuxPlaneBase + 5].Fill(history.back().GetRule50Ply() / 100.0f);
} else {
result[kAuxPlaneBase + 5].Fill(history.Last().GetRule50Ply());
result[kAuxPlaneBase + 5].Fill(history.back().GetRule50Ply());
}
// Plane kAuxPlaneBase + 6 used to be movecount plane, now it's all zeros
// unless we need it for canonical armageddon side to move.
Expand All @@ -232,18 +232,17 @@ InputPlanes EncodePositionForNN(
input_format == pblczero::NetworkFormat::
INPUT_112_WITH_CANONICALIZATION_V2_ARMAGEDDON;
bool flip = false;
int history_idx = history.GetLength() - 1;
int history_idx = history.size() - 1;
for (int i = 0; i < std::min(history_planes, kMoveHistory);
++i, --history_idx) {
const Position& position =
history.GetPositionAt(history_idx < 0 ? 0 : history_idx);
const Position& position = history[history_idx < 0 ? 0 : history_idx];
ChessBoard board = position.GetBoard();
if (flip) board.Mirror();
// Castling changes can't be repeated, so we can stop early.
if (stop_early && board.castlings().as_int() != castlings.as_int()) break;
// Enpassants can't be repeated, but we do need to always send the current
// position.
if (stop_early && history_idx != history.GetLength() - 1 &&
if (stop_early && history_idx != static_cast<int>(history.size()) - 1 &&
!board.en_passant().empty()) {
break;
}
Expand Down Expand Up @@ -327,4 +326,12 @@ InputPlanes EncodePositionForNN(
return result;
}

InputPlanes EncodePositionForNN(
pblczero::NetworkFormat::InputFormat input_format,
const PositionHistory& history, int history_planes,
FillEmptyHistory fill_empty_history, int* transform_out) {
return EncodePositionForNN(input_format, history.GetPositions(),
history_planes, fill_empty_history, transform_out);
}

} // namespace lczero
7 changes: 7 additions & 0 deletions src/neural/encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#pragma once

#include <span>

#include "chess/position.h"
#include "neural/network.h"
#include "proto/net.pb.h"
Expand All @@ -49,6 +51,11 @@ InputPlanes EncodePositionForNN(
const PositionHistory& history, int history_planes,
FillEmptyHistory fill_empty_history, int* transform_out);

InputPlanes EncodePositionForNN(
pblczero::NetworkFormat::InputFormat input_format,
std::span<const Position> positions, int history_planes,
FillEmptyHistory fill_empty_history, int* transform_out);

bool IsCanonicalFormat(pblczero::NetworkFormat::InputFormat input_format);
bool IsCanonicalArmageddonFormat(
pblczero::NetworkFormat::InputFormat input_format);
Expand Down

0 comments on commit c3a160c

Please sign in to comment.