Skip to content

Commit

Permalink
fix lux s3 to ignore import errors (#317)
Browse files Browse the repository at this point in the history
* Update lux_ai_s3.py

* Update lux_ai_s3.py

* Update lux_ai_s3.py
  • Loading branch information
StoneT2000 authored Nov 21, 2024
1 parent 6fccbc7 commit 51997e0
Showing 1 changed file with 69 additions and 65 deletions.
134 changes: 69 additions & 65 deletions kaggle_environments/envs/lux_ai_s3/lux_ai_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,83 +41,87 @@ def enqueue_output(out, queue):
out.close()

def interpreter(state, env):
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
global luxenv, prev_obs, state_obs, default_env_cfg
player_0 = state[0]
player_1 = state[1]
# filter out actions such as debug annotations so they aren't saved
# filter_actions(state, env)

if env.done:
if "seed" in env.configuration:
seed = int(env.configuration["seed"])
else:
seed = math.floor(random.random() * 1e9);
env.configuration["seed"] = seed

luxenv = LuxAIS3GymEnv(numpy_output=True)
luxenv = RecordEpisode(luxenv, save_on_close=False, save_on_reset=False)
obs, info = luxenv.reset(seed=seed)

env_cfg_json = info["params"]

env.configuration.env_cfg = env_cfg_json
try:
from luxai_s3.wrappers import LuxAIS3GymEnv, RecordEpisode
global luxenv, prev_obs, state_obs, default_env_cfg
player_0 = state[0]
player_1 = state[1]
# filter out actions such as debug annotations so they aren't saved
# filter_actions(state, env)

if env.done:
if "seed" in env.configuration:
seed = int(env.configuration["seed"])
else:
seed = math.floor(random.random() * 1e9);
env.configuration["seed"] = seed

luxenv = LuxAIS3GymEnv(numpy_output=True)
luxenv = RecordEpisode(luxenv, save_on_close=False, save_on_reset=False)
obs, info = luxenv.reset(seed=seed)

env_cfg_json = info["params"]

env.configuration.env_cfg = env_cfg_json

player_0.observation.player = "player_0"
player_1.observation.player = "player_1"
player_0.observation.obs = json.dumps(to_json(obs["player_0"]))
player_1.observation.obs = json.dumps(to_json(obs["player_1"]))

replay_frame = luxenv.serialize_episode_data(dict(
states=[luxenv.episode["states"][-1]],
metadata=luxenv.episode["metadata"],
params=luxenv.episode["params"]
))
# don't need to keep metadata/params beyond first step
player_0.info = dict(replay=replay_frame)
return state

new_state_obs, rewards, terminations, truncations, infos = luxenv.step({
"player_0": np.array(player_0.action["action"]),
"player_1": np.array(player_1.action["action"])
})

# cannot store np arrays in replay jsons so must convert to list
player_0.action = player_0.action["action"]
player_1.action = player_1.action["action"]

dones = dict()
for k in terminations:
dones[k] = terminations[k] | truncations[k]

player_0.observation.player = "player_0"
player_1.observation.player = "player_1"
player_0.observation.obs = json.dumps(to_json(obs["player_0"]))
player_1.observation.obs = json.dumps(to_json(obs["player_1"]))

player_0.observation.obs = json.dumps(to_json(new_state_obs["player_0"]))
player_1.observation.obs = json.dumps(to_json(new_state_obs["player_1"]))


player_0.reward = int(rewards["player_0"])
player_1.reward = int(rewards["player_1"])

player_0.observation.reward = int(player_0.reward)
player_1.observation.reward = int(player_1.reward)
replay_frame = luxenv.serialize_episode_data(dict(
states=[luxenv.episode["states"][-1]],
actions=[luxenv.episode["actions"][-1]],
metadata=luxenv.episode["metadata"],
params=luxenv.episode["params"]
))
# don't need to keep metadata/params beyond first step
del replay_frame["metadata"]
del replay_frame["params"]
player_0.info = dict(replay=replay_frame)
return state

new_state_obs, rewards, terminations, truncations, infos = luxenv.step({
"player_0": np.array(player_0.action["action"]),
"player_1": np.array(player_1.action["action"])
})

# cannot store np arrays in replay jsons so must convert to list
player_0.action = player_0.action["action"]
player_1.action = player_1.action["action"]

dones = dict()
for k in terminations:
dones[k] = terminations[k] | truncations[k]

player_0.observation.player = "player_0"
player_1.observation.player = "player_1"

player_0.observation.obs = json.dumps(to_json(new_state_obs["player_0"]))
player_1.observation.obs = json.dumps(to_json(new_state_obs["player_1"]))


player_0.reward = int(rewards["player_0"])
player_1.reward = int(rewards["player_1"])

player_0.observation.reward = int(player_0.reward)
player_1.observation.reward = int(player_1.reward)
replay_frame = luxenv.serialize_episode_data(dict(
states=[luxenv.episode["states"][-1]],
actions=[luxenv.episode["actions"][-1]],
metadata=luxenv.episode["metadata"],
params=luxenv.episode["params"]
))
# don't need to keep metadata/params beyond first step
del replay_frame["metadata"]
del replay_frame["params"]
player_0.info = dict(replay=replay_frame)

if np.all([dones[k] for k in dones]):
if player_0.status == "ACTIVE":
player_0.status = "DONE"
if player_1.status == "ACTIVE":
player_1.status = "DONE"
if np.all([dones[k] for k in dones]):
if player_0.status == "ACTIVE":
player_0.status = "DONE"
if player_1.status == "ACTIVE":
player_1.status = "DONE"
return state
except ModuleNotFoundError as e:
print("Lux AI S3 Dependencies are missing, interpreter will not work")
return state

def renderer(state, env):
Expand Down

0 comments on commit 51997e0

Please sign in to comment.