Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Holistic Evaluation of Text-to-Image Models (HEIM) #1939

Merged
merged 109 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
109 commits
Select commit Hold shift + click to select a range
46b1a22
added image generation scenarios
teetone Oct 22, 2023
77f6716
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 22, 2023
7a33e7b
added deepfloyd offline script
teetone Oct 22, 2023
e986fec
text-to-image models and adapter
teetone Oct 23, 2023
081f6a6
perturbations
teetone Oct 23, 2023
34afa98
translate
teetone Oct 23, 2023
7fe7280
text-to-image metrics
teetone Oct 23, 2023
d1175b9
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 23, 2023
0273b33
heim conf files
teetone Oct 23, 2023
15662b9
window services for text-to-image models
teetone Oct 23, 2023
8df8eb8
file caching
teetone Oct 23, 2023
33a5eab
requests
teetone Oct 23, 2023
fe0dcb3
clients
teetone Oct 24, 2023
9d5eba5
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 24, 2023
c0a3c38
pass over clients
teetone Oct 24, 2023
691c656
resolve merge conflicts
teetone Oct 24, 2023
6beb673
pass on metrics
teetone Oct 24, 2023
bf999c0
pass over metrics
teetone Oct 24, 2023
b7aef97
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 24, 2023
555bcdf
added HEIM run specs
teetone Oct 24, 2023
7435b11
hf tokens
teetone Oct 24, 2023
fc63ae9
resolve merge conflicts
teetone Oct 25, 2023
13fcd79
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 25, 2023
54e2059
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 26, 2023
41b772c
rename image generation params
teetone Oct 26, 2023
96d3f81
style check
teetone Oct 26, 2023
1e94e1d
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 26, 2023
4a53ec9
handle optional dependencies for metrics
teetone Oct 28, 2023
fcfdfcf
resolved merge conflicts
teetone Oct 28, 2023
2cac30c
second pass on import errors
teetone Oct 29, 2023
c4d85b9
new way to pass in output path
teetone Oct 29, 2023
f655fbb
handle new api changes
teetone Oct 29, 2023
c863b92
rename perturbation
teetone Oct 29, 2023
5b7e271
fix server test
teetone Oct 29, 2023
a60493e
image generation parameters
teetone Oct 30, 2023
96f551b
image generation parameters
teetone Oct 30, 2023
d43a156
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Oct 31, 2023
7b28d2f
route openai clip to huggingface tokenizer
teetone Oct 31, 2023
4c7fe32
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 5, 2023
6cd0f1d
run exapanders for max eval instances
teetone Nov 6, 2023
22decfc
check pypi package jaxlib
teetone Nov 6, 2023
9bb8fad
bump jaxlib version
teetone Nov 6, 2023
2425d3e
bump jaxlib version
teetone Nov 6, 2023
0d64a2f
bump jaxlib version
teetone Nov 6, 2023
c95b687
global diffusers
teetone Nov 6, 2023
d4961a3
fix metric pathing and second pass on dependencies
teetone Nov 7, 2023
be4f608
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 7, 2023
b363fcf
handle new perturbation changes
teetone Nov 7, 2023
16ff4f4
handle new perturbation changes
teetone Nov 7, 2023
4b6157c
set default number of instances to 100
teetone Nov 7, 2023
38b7987
fix test
teetone Nov 7, 2023
59a81dd
install script for heim extras
teetone Nov 7, 2023
690fa41
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 7, 2023
66e649f
added HEIM models to test_model_properties
teetone Nov 8, 2023
cd62af1
set heim default
teetone Nov 8, 2023
b81cac5
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 8, 2023
ef4d36f
dependency
teetone Nov 8, 2023
762e4a1
added debug conf file and filter out dependencies
teetone Nov 8, 2023
aac2634
update with heim
teetone Nov 8, 2023
cb24466
update with heim website link in docs
teetone Nov 8, 2023
9043b0f
improve installation pass 2
teetone Nov 8, 2023
061f7a5
bump jax version 0.4
teetone Nov 8, 2023
090cfe0
minor update to instructions
teetone Nov 8, 2023
a6f1085
resolve merge conflict
teetone Nov 9, 2023
ec0edc5
paper
teetone Nov 9, 2023
d2e28c8
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 14, 2023
f7118ad
add original instance
teetone Nov 14, 2023
32bc150
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Nov 14, 2023
fb3b7a9
resolve merge conflicts
teetone Nov 15, 2023
db4755f
resolve merge conflict
teetone Nov 20, 2023
620eeb7
resolve merge conflict
teetone Nov 20, 2023
78d7eb6
resolve merge conflict
teetone Nov 20, 2023
4c707f1
suport one text2image model with new model deployment
teetone Nov 20, 2023
354ba33
suport one text2image model with new model deployment
teetone Nov 20, 2023
440a021
resolve merge conflicts
teetone Nov 23, 2023
ce4f3aa
added remaining HEIM model deployments/metadata
teetone Nov 23, 2023
7d4633a
added remaining HEIM model deployments/metadata
teetone Nov 23, 2023
a82dfea
rename num_trials and clean up doc
teetone Nov 30, 2023
e80ee3c
resolve merge conflicts
teetone Nov 30, 2023
4da21f6
test
teetone Nov 30, 2023
e995330
fix import
teetone Nov 30, 2023
6dd6997
fix run entry test
teetone Nov 30, 2023
ada3584
remove image generation adapter
teetone Nov 30, 2023
e54e7e0
update suffix perturbation
teetone Nov 30, 2023
eb5cd22
fix test
teetone Nov 30, 2023
1467ca0
simplify installation
teetone Nov 30, 2023
5af08e9
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 2, 2023
178ca73
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 3, 2023
1419095
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 5, 2023
4da4f86
resolve merge conflicts
teetone Dec 6, 2023
1c11b9b
resolve merge conflicts
teetone Dec 8, 2023
779f75a
resolve merge conflicts
teetone Dec 12, 2023
bdb76f2
pass through image generation parameters
teetone Dec 12, 2023
d4fb816
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 14, 2023
6a23880
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 16, 2023
f0e4274
DALL-E client cleanup
teetone Dec 17, 2023
08103d3
rename dall-e-2 to match openai's name
teetone Dec 17, 2023
d55e1fb
rename dalle2 window service to be more general (dall-e-3)
teetone Dec 17, 2023
9e27ba2
dall-e 3
teetone Dec 17, 2023
d7cd500
support segmind/SSD-1B
teetone Dec 18, 2023
53c0141
skip
teetone Dec 18, 2023
8cabc53
support stabilityai/stable-diffusion-xl-base-1.0
teetone Dec 18, 2023
974a9e2
support segmind/Segmind-Vega
teetone Dec 18, 2023
5338a88
remove original instance
teetone Dec 19, 2023
67da209
Merge branch 'main' of https://github.com/stanford-crfm/benchmarking …
teetone Dec 20, 2023
66f6d5d
minor improve logging
teetone Dec 20, 2023
b9dc2e7
cleanup
teetone Dec 20, 2023
3a93754
resolved merge conflicts
teetone Dec 20, 2023
b57cb83
fix test
teetone Dec 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,36 @@ The directory structure for this repo is as follows
└── helm-frontend # New React Front-end
```

# Holistic Evaluation of Text-To-Image Models

<img src="https://github.com/stanford-crfm/helm/raw/heim/src/helm/benchmark/static/heim/images/heim-logo.png" alt="" width="800"/>

Significant effort has recently been made in developing text-to-image generation models, which take textual prompts as
input and generate images. As these models are widely used in real-world applications, there is an urgent need to
comprehensively understand their capabilities and risks. However, existing evaluations primarily focus on image-text
alignment and image quality. To address this limitation, we introduce a new benchmark,
**Holistic Evaluation of Text-To-Image Models (HEIM)**.

We identify 12 different aspects that are important in real-world model deployment, including:

- image-text alignment
- image quality
- aesthetics
- originality
- reasoning
- knowledge
- bias
- toxicity
- fairness
- robustness
- multilinguality
- efficiency

By curating scenarios encompassing these aspects, we evaluate state-of-the-art text-to-image models using this benchmark.
Unlike previous evaluations that focused on alignment and quality, HEIM significantly improves coverage by evaluating all
models across all aspects. Our results reveal that no single model excels in all aspects, with different models
demonstrating strengths in different aspects.

This repository contains the code used to produce the [results on the website](https://crfm.stanford.edu/heim/latest/)
and [paper](https://arxiv.org/abs/2311.04287).
8 changes: 8 additions & 0 deletions docs/code.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,11 @@ multiple perturbations and applying it onto a single instance.
4. Add a new class `<Name of tokenizer>WindowService` in file `<Name of tokenizer>_window_service.py`.
Follow what we did for `GPTJWindowService`.
5. Import the new `WindowService` and map the model(s) to it in `WindowServiceFactory`.


## HEIM (text-to-image evaluation)

The overall code structure is the same as HELM's.

When adding new scenarios and metrics for image generation, place the Python files under the `image_generation` package
(e.g., `src/helm/benchmark/scenarios/image_generation`).
16 changes: 16 additions & 0 deletions docs/heim.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# HEIM Quick Start (text-to-image evaluation)

To run HEIM, follow these steps:

1. Create a run specs configuration file. For example, to evaluate
[Stable Diffusion v1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4) against the
[MS-COCO scenario](https://github.com/stanford-crfm/heim/blob/main/src/helm/benchmark/scenarios/image_generation/mscoco_scenario.py), run:
```
echo 'entries: [{description: "mscoco:model=huggingface/stable-diffusion-v1-4", priority: 1}]' > run_specs.conf
```
2. Run the benchmark with certain number of instances (e.g., 10 instances):
`helm-run --conf-paths run_specs.conf --suite heim_v1 --max-eval-instances 10`

Examples of run specs configuration files can be found [here](https://github.com/stanford-crfm/helm/tree/main/src/helm/benchmark/presentation).
We used [this configuration file](https://github.com/stanford-crfm/helm/blob/main/src/helm/benchmark/presentation/run_specs_heim.conf)
to produce results of the paper.
4 changes: 4 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ To add new models and scenarios, refer to the Developer Guide's chapters:

- [Developer Setup](developer_setup.md)
- [Code Structure](code.md)


We also support evaluating text-to-image models as introduced in **Holistic Evaluation of Text-to-Image Models (HEIM)**
([paper](https://arxiv.org/abs/2311.04287), [website](https://crfm.stanford.edu/heim/latest)).
16 changes: 16 additions & 0 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,19 @@ Within this virtual environment, run:
```
pip install crfm-helm
```

### For HEIM (text-to-image evaluation)

To install the additional dependencies to run HEIM, run:

```
pip install "crfm-helm[heim]"
```

Some models (e.g., DALLE-mini/mega) and metrics (`DetectionMetric`) require extra dependencies that are
not available on PyPI. To install these dependencies, download and run the
[extra install script](https://github.com/stanford-crfm/helm/blob/main/install-heim-extras.sh):

```
bash install-heim-extras.sh
```
8 changes: 7 additions & 1 deletion docs/models.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# Models

Please visit the models [page](https://crfm.stanford.edu/helm/latest/?models) of HELM's website for a list of available models and their descriptions.
Please visit the models [page](https://crfm.stanford.edu/helm/latest/?models) of HELM's website
for a list of available models and their descriptions.


## HEIM (text-to-image evaluation)

Please visit the [models page](https://crfm.stanford.edu/heim/latest/?models) of the HEIM results website.
7 changes: 6 additions & 1 deletion docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ helm-server

Then go to http://localhost:8000/ in your browser.

**Next steps:** click [here](get_helm_rank.md) to find out how to to run the full benchmark and get your model's leaderboard rank.

## Next steps

Click [here](get_helm_rank.md) to find out how to run the full benchmark and get your model's leaderboard rank.

For the quick start page for HEIM, visit [here](heim.md).
28 changes: 28 additions & 0 deletions install-heim-extras.sh
teetone marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Extra dependencies for HEIM when evaluating the following:
# Models: craiyon/dalle-mini, craiyon/dalle-mega, thudm/cogview2
# Scenarios: detection with the `DetectionMetric`

# This script fails when any of its commands fail.
set -e

# For DALLE-mini/mega, install the following dependencies.
# On Mac OS, skip installing pytorch with CUDA because CUDA is not supported
if [[ $OSTYPE != 'darwin'* ]]; then
# Manually install pytorch to avoid pip getting killed: https://stackoverflow.com/a/54329850
pip install --no-cache-dir --find-links https://download.pytorch.org/whl/torch_stable.html torch==1.12.1+cu113 torchvision==0.13.1+cu113

# DALLE mini requires jax install
pip install jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn805 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
fi

# For CogView2, manually install apex and Image-Local-Attention. NOTE: need to run this on a GPU machine
echo "Installing CogView2 dependencies..."
pip install localAttention@git+https://github.com/Sleepychord/Image-Local-Attention.git@43fee310cb1c6f64fb0ed77404ba3b01fa586026
pip install --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" apex@git+https://github.com/michiyasunaga/apex.git@9395ba2aab3c05e0e36ef0b7fe48d42de9f10bcf

# For Detectron2. Following https://detectron2.readthedocs.io/en/latest/tutorials/install.html
python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'

echo "Done."
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ nav:
- 'User Guide':
- 'installation.md'
- 'quick_start.md'
- 'heim.md'
- 'get_helm_rank.md'
- 'tutorial.md'
- 'benchmark.md'
Expand Down
Empty file.
220 changes: 220 additions & 0 deletions scripts/offline_eval/deepfloyd/deepfloyd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from collections import Counter
from dacite import from_dict
from tqdm import tqdm
from typing import Dict, List, Tuple
import argparse
import json
import os
import time

from diffusers import DiffusionPipeline
import torch

from helm.common.cache import (
KeyValueStore,
KeyValueStoreCacheConfig,
MongoCacheConfig,
SqliteCacheConfig,
create_key_value_store,
)
from helm.common.request import Request
from helm.common.file_caches.local_file_cache import LocalFileCache
from helm.common.hierarchical_logger import hlog, htrack_block
from helm.proxy.clients.image_generation.deep_floyd_client import DeepFloydClient


"""
Script to run inference for the DeepFloyd-IF models given a dry run benchmark output folder of requests.

From https://huggingface.co/docs/diffusers/main/en/api/pipelines/if#text-to-image-generation

DeepFloyd IF is a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and
language understanding. The model is a modular composed of a frozen text encoder and three cascaded pixel
diffusion modules:

Stage 1: a base model that generates 64x64 px image based on text prompt
Stage 2: a 64x64 px => 256x256 px super-resolution model
Stage 3: a 256x256 px => 1024x1024 px super-resolution model Stage 1 and Stage 2 utilize a frozen text encoder
based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with
cross-attention and attention pooling. Stage 3 is Stability’s x4 Upscaling model. The result is a highly
efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the
COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded
diffusion models and depicts a promising future for text-to-image synthesis.

The following dependencies need to be installed in order to run inference with DeepFloyd models:

accelerate~=0.19.0
dacite~=1.6.0
diffusers[torch]~=0.16.1
pyhocon~=0.3.59
pymongo~=4.2.0
retrying~=1.3.3
safetensors~=0.3.1
sentencepiece~=0.1.97
sqlitedict~=1.7.0
tqdm~=4.64.1
transformers~=4.29.2
zstandard~=0.18.0

Example usage (after a dryrun with run suite deepfloyd):

python3 scripts/offline_eval/deepfloyd/deepfloyd.py IF-I-XL-v1.0 benchmark_output/runs/deepfloyd \
--mongo-uri <MongoDB address>

"""

ORGANIZATION: str = "DeepFloyd"


class DeepFloyd:
MODEL_NAME_TO_MODELS: Dict[str, Tuple[str, str]] = {
"IF-I-XL-v1.0": ("DeepFloyd/IF-I-XL-v1.0", "DeepFloyd/IF-II-L-v1.0"), # XL
"IF-I-L-v1.0": ("DeepFloyd/IF-I-L-v1.0", "DeepFloyd/IF-II-L-v1.0"), # Large
"IF-I-M-v1.0": ("DeepFloyd/IF-I-M-v1.0", "DeepFloyd/IF-II-M-v1.0"), # Medium
}

@staticmethod
def initialize_model(stage1_model_name: str, stage2_model_name: str):
with htrack_block(f"Initializing the three stages of the IF model: {stage1_model_name}"):
# stage 1
stage_1 = DiffusionPipeline.from_pretrained(stage1_model_name, torch_dtype=torch.float16)
stage_1.enable_model_cpu_offload()

# stage 2
stage_2 = DiffusionPipeline.from_pretrained(stage2_model_name, text_encoder=None, torch_dtype=torch.float16)
stage_2.enable_model_cpu_offload()

# stage 3
safety_modules = {
"feature_extractor": stage_1.feature_extractor,
"safety_checker": stage_1.safety_checker,
"watermarker": stage_1.watermarker,
}
stage_3 = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16
)
stage_3.enable_model_cpu_offload()
return stage_1, stage_2, stage_3

def __init__(self, model_name: str, file_cache_path: str, key_value_cache_config: KeyValueStoreCacheConfig):
stage1_model, stage2_model = self.MODEL_NAME_TO_MODELS[model_name]
self._model_engine: str = model_name
self._stage_1, self._stage_2, self._stage_3 = self.initialize_model(stage1_model, stage2_model)

self._file_cache = LocalFileCache(file_cache_path, "png")
self._key_value_cache_config: KeyValueStoreCacheConfig = key_value_cache_config

def _run_inference_single_image(self, prompt: str, file_path: str, seed: int) -> None:
# Generating text embeddings
prompt_embeds, negative_embeds = self._stage_1.encode_prompt(prompt)

generator = torch.manual_seed(seed)
image = self._stage_1(
prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt"
).images

image = self._stage_2(
image=image,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
generator=generator,
output_type="pt",
).images

image = self._stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
image[0].save(file_path)

def _process_request(self, request_state: Dict, store: KeyValueStore) -> bool:
request: Request = from_dict(Request, request_state["request"])
raw_request: Dict = DeepFloydClient.convert_to_raw_request(request)

if store.contains(raw_request):
return True

image_paths: List[str] = []
start_time: float = time.time()
for i in range(request.num_completions):
file_path: str = self._file_cache.generate_unique_new_file_path()
self._run_inference_single_image(request.prompt, file_path, i)
image_paths.append(file_path)
total_inference_time: float = time.time() - start_time

result: Dict = {"images": image_paths, "total_inference_time": total_inference_time}
store.put(raw_request, result)
return False

def run_all(self, run_suite_path: str):
"""
Given a run suite folder, runs inference for all the requests.
"""

counts = Counter(inference_count=0, cached_count=0)

# Go through all the valid run folders, pull requests from the scenario_state.json
# files and run inference for each request.
with create_key_value_store(self._key_value_cache_config) as store:
for run_dir in tqdm(os.listdir(run_suite_path)):
run_path: str = os.path.join(run_suite_path, run_dir)

if not os.path.isdir(run_path):
continue

with htrack_block(f"Processing run directory: {run_dir}"):
scenario_state_path: str = os.path.join(run_path, "scenario_state.json")
if not os.path.isfile(scenario_state_path):
hlog(
f"{run_dir} is missing a scenario_state.json file. Expected at path: {scenario_state_path}."
)
continue

with open(scenario_state_path) as scenario_state_file:
scenario_state = json.load(scenario_state_file)
model_name: str = scenario_state["adapter_spec"]["model"]
current_model_engine: str = model_name.split("/")[-1]

if current_model_engine != self._model_engine:
hlog(f"Not running inference for {current_model_engine}.")
continue

for request_state in tqdm(scenario_state["request_states"]):
cached: bool = self._process_request(request_state, store)
counts["cached_count" if cached else "inference_count"] += 1

hlog(
f"Processed {counts['inference_count']} requests. "
f"{counts['cached_count']} requests already had entries in the cache."
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cache-dir", type=str, default="prod_env/cache", help="Path to the cache directory")
parser.add_argument(
"--mongo-uri",
type=str,
help=(
"For a MongoDB cache, Mongo URI to copy items to. "
"Example format: mongodb://[username:password@]host1[:port1]/dbname"
),
)
parser.add_argument("model_name", type=str, help="Name of the model", choices=DeepFloyd.MODEL_NAME_TO_MODELS.keys())
parser.add_argument("run_suite_path", type=str, help="Path to run path.")
args = parser.parse_args()

cache_config: KeyValueStoreCacheConfig
if args.mongo_uri:
hlog(f"Initialized MongoDB cache with URI: {args.mongo_uri}")
cache_config = MongoCacheConfig(args.mongo_uri, ORGANIZATION)
elif args.cache_dir:
hlog(f"WARNING: Initialized SQLite cache at path: {args.cache_dir}. Are you debugging??")
cache_config = SqliteCacheConfig(os.path.join(args.cache_dir, f"{ORGANIZATION}.sqlite"))
else:
raise ValueError("Either --cache-dir or --mongo-uri should be specified")

deep_floyd = DeepFloyd(
model_name=args.model_name,
file_cache_path=os.path.join(args.cache_dir, "output", ORGANIZATION),
key_value_cache_config=cache_config,
)
deep_floyd.run_all(args.run_suite_path)
hlog("Done.")
Loading
Loading