Skip to content

Commit

Permalink
run inside tmp dir
Browse files Browse the repository at this point in the history
  • Loading branch information
Disservin committed Sep 10, 2024
1 parent cb459bd commit 16bf7a2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 16 deletions.
31 changes: 23 additions & 8 deletions tests/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ def test_bench_128_threads_3_bench_tmp_epd_depth(self):
)
assert self.stockfish.process.returncode == 0

def test_export_net_verify_nnue(self):
self.stockfish = Stockfish("export_net verify.nnue".split(" "), True)
assert self.stockfish.process.returncode == 0

def test_d(self):
self.stockfish = Stockfish("d".split(" "), True)
assert self.stockfish.process.returncode == 0
Expand All @@ -168,6 +164,13 @@ def test_uci(self):
self.stockfish = Stockfish("uci".split(" "), True)
assert self.stockfish.process.returncode == 0

def test_export_net_verify_nnue(self):
current_path = os.path.abspath(os.getcwd())
self.stockfish = Stockfish(
f"export_net {os.path.join(current_path , 'verify.nnue')}".split(" "), True
)
assert self.stockfish.process.returncode == 0

# verify the generated net equals the base net

def test_network_equals_base(self):
Expand All @@ -184,6 +187,15 @@ def test_network_equals_base(self):
network = line.split(" ")[-1]
break

# find network file in src dir
network = os.path.join(PATH.parent.resolve(), "src", network)

if not os.path.exists(network):
print(
f"Network file {network} not found, please download the network file over the make command."
)
assert False

diff = subprocess.run(["diff", network, f"verify.nnue"])

assert diff.returncode == 0
Expand Down Expand Up @@ -376,6 +388,11 @@ def test_fen_position_with_moves_with_mate_go_depth_and_searchmoves(self):
self.stockfish.starts_with("bestmove e3e2")

def test_verify_nnue_network(self):
current_path = os.path.abspath(os.getcwd())
Stockfish(
f"export_net {os.path.join(current_path , 'verify.nnue')}".split(" "), True
)

self.stockfish.send_command("setoption name EvalFile value verify.nnue")
self.stockfish.send_command("position startpos")
self.stockfish.send_command("go depth 5")
Expand Down Expand Up @@ -482,7 +499,6 @@ def parse_args():
return parser.parse_args()


# To run the tests
if __name__ == "__main__":
args = parse_args()

Expand All @@ -491,10 +507,9 @@ def parse_args():
Syzygy.download_syzygy()

framework = MiniTestFramework()
framework.run([TestCLI, TestInteractive, TestSyzygy])

EPD.delete_bench_epd()
TSAN.unset_tsan_option()
# Each test suite will be ran inside a temporary directory
framework.run([TestCLI, TestInteractive, TestSyzygy])

if framework.has_failed():
sys.exit(1)
Expand Down
22 changes: 14 additions & 8 deletions tests/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import sys
import traceback
import fnmatch
import signal
from functools import wraps
from contextlib import redirect_stdout
import io
import tarfile
import urllib.request
import pathlib
import concurrent.futures
import tempfile

CYAN_COLOR = "\033[36m"
GRAY_COLOR = "\033[2m"
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_syzygy_path():

@staticmethod
def download_syzygy():
if not os.path.isdir("../tests/syzygy"):
if not os.path.isdir(os.path.join(PATH, "syzygy")):
url = "https://api.github.com/repos/niklasf/python-chess/tarball/9b9aa13f9f36d08aadfabff872882f4ab1494e95"
tarball_path = "/tmp/python-chess.tar.gz"

Expand All @@ -97,7 +97,7 @@ def download_syzygy():
with tarfile.open(tarball_path, "r:gz") as tar:
tar.extractall("/tmp")

os.rename("/tmp/niklasf-python-chess-9b9aa13", "../tests/syzygy")
os.rename("/tmp/niklasf-python-chess-9b9aa13", os.path.join(PATH, "syzygy"))


class OrderedClassMembers(type):
Expand Down Expand Up @@ -152,11 +152,17 @@ def run(self, classes: List[type]):
self.start_time = time.time()

for test_class in classes:
ret = self.__run(test_class)
if ret:
self.failed_test_suites += 1
else:
self.passed_test_suites += 1
with tempfile.TemporaryDirectory() as tmpdirname:
original_cwd = os.getcwd()
os.chdir(tmpdirname)
try:
ret = self.__run(test_class)
if ret:
self.failed_test_suites += 1
else:
self.passed_test_suites += 1
finally:
os.chdir(original_cwd)

duration = round(time.time() - self.start_time, 2)

Expand Down

0 comments on commit 16bf7a2

Please sign in to comment.