Skip to content

Commit

Permalink
fully working code update
Browse files Browse the repository at this point in the history
  • Loading branch information
jesbu1 committed Mar 11, 2024
1 parent e6f378e commit 37ac30f
Show file tree
Hide file tree
Showing 8 changed files with 490 additions and 638 deletions.
160 changes: 160 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,161 @@
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ export SPRINT=[SPRINT_DOWNLOAD_LOCATION]
```

## 2. Downloading the data
### 2.1 Model Training Data
You need to pre-train models to run zero-shot or finetuning experiments.
If you don't want to pre-train a model yourself, you can skip to step 3 as you don't need the pre-training dataset file.

We have two options for obtaining the ALFRED dataset---either download the data from here: [Google Drive Link](https://drive.google.com/file/d/1ZgKDgG9Fv491GVb9rxIVNJpViPNKFWMF) or set it up yourself with the instructions in the [README in the `datasets` folder](datasets/README.md).
Download the ALFRED dataset here: [Google Drive Link](https://drive.google.com/file/d/1ZgKDgG9Fv491GVb9rxIVNJpViPNKFWMF).

You can use [Gdown](https://github.com/wkentaro/gdown) to directly download the dataset to your server/computer at the desired location (18GB download):
```
Expand All @@ -78,6 +79,15 @@ Once the dataset is downloaded (`px_llama_13b.tar.gz`) simply untar it (36GB aft
```
tar -xvzf px_llama_13b.tar.gz
```
### 2.2 ALFRED Evaluation Data
To run evals and fine-tuning experiments, you must extract ALFRED evaluation data we have processed ([Google Drive Link](https://drive.google.com/file/d/1MHDrKSRmyag-DwipyLj-i-BbKU_dxbne/view)):

```
cd [SPRINT_REPO_LOCATION]
cd sprint/alfred/data
gdown 1MHDrKSRmyag-DwipyLj-i-BbKU_dxbne
tar -xvzf json_2.1.0_merge_goto.tar.gz
```

## 3. Setting up WandB
We log using WandB. First create a wandb account if you don't already have one [here](https://wandb.ai).
Expand All @@ -87,7 +97,7 @@ Finally, fill in `WANDB_ENTITY_NAME, WANDB_PROJECT_NAME` in the file `utils/wand


## 4. Pre-training a Model
You can either pre-train a model yourself or download a pre-trained checkpoint. Pre-trained model checkpoints can be found here: [Google Drive Link](https://drive.google.com/file/d/1PDNX7Z1BBoB3pmeBTfOgNxe2I53kUoS0).
You can either pre-train a model yourself or download a pre-trained checkpoint. Pre-trained model checkpoints can be found here: [Google Drive Link](https://drive.google.com/file/d/1PDNX7Z1BBoB3pmeBTfOgNxe2I53kUoS0/view).

Otherwise, run the following command from the base SPRINT repo location to train our model, SPRINT:

Expand Down Expand Up @@ -150,8 +160,10 @@ Checkpoints are saved in `sprint_saved_rl_models/`

To run SayCan zero-shot evals, pre-train the L-BC baseline above and then:
```
TODO: coming soon
python sprint/saycan_eval.py --model_checkpoint_dir [L-BC PATH] --env_type {eval_instruct, eval_length, eval_scene} --run_group [RUN_GROUP] --experiment_name [EXP_NAME] --llm_gpus [GPU]
```
The optional `llm_gpus` flag allows you to input a comma separated list of GPU IDs to put the LLM onto since it might be too big to fit on the same GPU as the model.

The 13b llama model we used is no longer available on huggingface, so to fully reproduce this you should follow the llama instructions to download LLaMA-13B.
Right now this script defaults to LLaMA-7B, but the empirically the performance is very similar.

Expand Down
7 changes: 5 additions & 2 deletions sprint/datasets/large_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from threading import Lock
from sprint.utils.utils import process_skill_strings


Expand Down Expand Up @@ -48,8 +49,8 @@ class LargeLanguageModel:

def __init__(self, config):
assert (
"opt" in config.llm_model or "gpt" in config.llm_model
), "No tokenizer support for non-gpt/opt models"
"opt" in config.llm_model or "gpt" in config.llm_model or "llama" in config.llm_model
), "No tokenizer support for non-gpt/opt/llama models"
self.config = config
self.llm_gpus = config.llm_gpus
self.llm_max_new_tokens = config.llm_max_new_tokens
Expand All @@ -76,6 +77,7 @@ def __init__(self, config):
pad_token_id=self.tokenizer.eos_token_id,
device_map="auto",
)
self.lock = Lock()
self.next_skill_top_p = 0.9
self.next_skill_temp = 0.8
self.ret_tensor_type = "pt"
Expand Down Expand Up @@ -320,6 +322,7 @@ def _get_non_generated_logprobs_hf(
):
second_skill_start_pos = second_skill_attn_mask.sum(-1)
with torch.no_grad():
# so that even if multiple threads are using it at once, our maximium batch size won't be exceeded
with self.lock:
logits = (
self.model(
Expand Down
2 changes: 1 addition & 1 deletion sprint/datasets/llm_aggregate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main():
"--llm_model",
type=str,
default="decapoda-resaerch/llama-13b-hf",
help="which model to use for the large language model. For optimal performance, use GPT-J-6B or bigger. For speed and decent performance, opt-2.7b is fine.",
help="which model to use for the large language model.",
choices=[
"facebook/opt-125m",
"facebook/opt-350m",
Expand Down
1 change: 0 additions & 1 deletion sprint/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def main(config):
)
model_checkpoint_dir = config.model_checkpoint_dir
print(model_checkpoint_dir)
config.use_llm = False # set use llm to false so it doesn't load the LLM
list_of_checkpoints, list_of_epochs = get_list_of_checkpoints(model_checkpoint_dir)
# load one of the checkpoints' configs

Expand Down
11 changes: 0 additions & 11 deletions sprint/rollouts/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
)
from sprint.utils.utils import (
get_action_from_agent,
AttrDict,
load_object_class,
process_skill_strings,
)

# import random
path = "."

import sys
Expand All @@ -34,15 +31,7 @@
DATA_PATH = (
f"{os.environ['SPRINT']}/sprint/alfred/data/json_2.1.0_merge_goto/preprocess"
)
# VISUAL_MODEL = "resnet18"
REWARD_CONFIG_PATH = f"{os.environ['SPRINT']}/sprint/alfred/models/config/rewards.json"
# DEFAULT_NUM_STEPS = 30
# EVAL_STEP_RATIO = 2
# TRAIN_STEP_RATIO = 2

# global_task_args = AttrDict()
# global_task_args.reward_config = REWARD_CONFIG
# global_task_args.visual_model = VISUAL_MODEL


def run_policy(
Expand Down
Loading

0 comments on commit 37ac30f

Please sign in to comment.