Skip to content

Commit

Permalink
Merge pull request #605 from bghira/main
Browse files Browse the repository at this point in the history
CSV data backend
  • Loading branch information
bghira authored Jul 28, 2024
2 parents 679c6c0 + 11074cf commit e8f707f
Show file tree
Hide file tree
Showing 16 changed files with 556 additions and 70 deletions.
78 changes: 76 additions & 2 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ Here is the most basic example of a dataloader configuration file, as `multidata

### `type`

- **Values:** `aws` | `local`
- **Description:** Determines the storage backend (local or cloud) used for this dataset.
- **Values:** `aws` | `local` | `csv`
- **Description:** Determines the storage backend (local, csv or cloud) used for this dataset.

### `instance_data_dir` / `aws_data_prefix`

Expand Down Expand Up @@ -232,6 +232,80 @@ In order, the lines behave as follows:
]
```

## Train directly from CSV URL list

**Note: Your CSV must contain the captions for your images.**

> ⚠️ This is an advanced **and** experimental feature, and you may run into problems. If you do, please open an [issue](https://github.com/bghira/simpletuner/issues)!
Instead of manually downloading your data from a URL list, you might wish to plug them in directly to the trainer.

**Note:** It's always better to manually download the image data. Another strategy to save local disk space might be to try [using cloud storage with local encoder caches](#local-cache-with-cloud-dataset) instead.

### Advantages

- No need to directly download the data
- Can make use of SimpleTuner's caption toolkit to directly caption the URL list
- Saves on disk space, since only the image embeds (if applicable) and text embeds are stored

### Disadvantages

- Requires a costly and potentially slow aspect bucket scan where each image is downloaded and its metadata collected
- The downloaded images are cached on-disk, which can grow to be very large. This is an area of improvement to work on, as the cache management in this version is very basic, write-only/delete-never
- If your dataset has a large number of invalid URLs, these might waste time on resumption as, currently, bad samples are **never** removed from the URL list
- **Suggestion:** Run a URL validation task beforehand and remove any bad samples.

### Configuration

Required keys:

- `type: "csv"`
- `csv_caption_column`
- `csv_cache_dir`
- `caption_strategy: "csv"`

```json
[
{
"id": "csvtest",
"type": "csv",
"csv_caption_column": "caption",
"csv_file": "/Volumes/ml/dataset/test_list.csv",
"csv_cache_dir": "/Volumes/ml/cache/csv/test",
"cache_dir_vae": "/Volumes/ml/cache/vae/sdxl",
"caption_strategy": "csv",
"image_embeds": "image-embeds",
"crop": true,
"crop_aspect": "square",
"crop_style": "center",
"resolution": 1024,
"maximum_image_size": 1024,
"target_downsample_size": 1024,
"resolution_type": "pixel",
"minimum_image_size": 0.5,
"disabled": false,
"skip_file_discovery": "",
"preserve_data_backend_cache": false,
"hash_filenames": true
},
{
"id": "image-embeds",
"type": "local"
},
{
"id": "text-embeds",
"type": "local",
"dataset_type": "text_embeds",
"default": true,
"cache_dir": "/Volumes/ml/cache/text/sdxl",
"disabled": false,
"preserve_data_backend_cache": false,
"skip_file_discovery": "",
"write_batch_size": 128
}
]
```

## Parquet caption strategy / JSON Lines datasets

> ⚠️ This is an advanced feature, and will not be necessary for most users.
Expand Down
2 changes: 1 addition & 1 deletion helpers/caching/text_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def discover_all_files(self):
StateTracker.get_text_cache_files(data_backend_id=self.id)
or StateTracker.set_text_cache_files(
self.data_backend.list_files(
instance_data_root=self.cache_dir,
instance_data_dir=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
Expand Down
18 changes: 10 additions & 8 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
vae,
accelerator,
metadata_backend: MetadataBackend,
instance_data_root: str,
instance_data_dir: str,
image_data_backend: BaseDataBackend,
cache_data_backend: BaseDataBackend = None,
cache_dir="vae_cache",
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(
self.read_batch_size = read_batch_size
self.process_queue_size = process_queue_size
self.vae_batch_size = vae_batch_size
self.instance_data_root = instance_data_root
self.instance_data_dir = instance_data_dir
self.transform = MultiaspectImage.get_image_transforms()
self.rank_info = rank_info()
self.metadata_backend = metadata_backend
Expand Down Expand Up @@ -136,8 +136,10 @@ def generate_vae_cache_filename(self, filepath: str) -> tuple:
if self.hash_filenames:
base_filename = str(sha256(str(base_filename).encode()).hexdigest())
base_filename = str(base_filename) + ".pt"
# Find the subfolders the sample was in, and replace the instance_data_root with the cache_dir
subfolders = os.path.dirname(filepath).replace(self.instance_data_root, "")
# Find the subfolders the sample was in, and replace the instance_data_dir with the cache_dir
subfolders = ""
if self.instance_data_dir is not None:
subfolders = os.path.dirname(filepath).replace(self.instance_data_dir, "")
if len(subfolders) > 0 and subfolders[0] == "/" and self.cache_dir[0] != "/":
subfolders = subfolders[1:]
full_filename = os.path.join(self.cache_dir, subfolders, base_filename)
Expand Down Expand Up @@ -229,7 +231,7 @@ def discover_all_files(self):
data_backend_id=self.id
) or StateTracker.set_image_files(
self.image_data_backend.list_files(
instance_data_root=self.instance_data_root,
instance_data_dir=self.instance_data_dir,
str_pattern="*.[tTwWjJpP][iIeEpPnN][fFbBgG][fFpP]?",
),
data_backend_id=self.id,
Expand All @@ -239,7 +241,7 @@ def discover_all_files(self):
StateTracker.get_vae_cache_files(data_backend_id=self.id)
or StateTracker.set_vae_cache_files(
self.cache_data_backend.list_files(
instance_data_root=self.cache_dir,
instance_data_dir=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
Expand Down Expand Up @@ -277,7 +279,7 @@ def rebuild_cache(self):
self.debug_log("Updating StateTracker with new VAE cache entry list.")
StateTracker.set_vae_cache_files(
self.cache_data_backend.list_files(
instance_data_root=self.cache_dir,
instance_data_dir=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
Expand All @@ -296,7 +298,7 @@ def rebuild_cache(self):
self.debug_log("Updating StateTracker with new VAE cache entry list.")
StateTracker.set_vae_cache_files(
self.cache_data_backend.list_files(
instance_data_root=self.cache_dir,
instance_data_dir=self.cache_dir,
str_pattern="*.pt",
),
data_backend_id=self.id,
Expand Down
8 changes: 4 additions & 4 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def list_by_prefix(self, prefix=""):
for item in response.get("Contents", [])
]

def list_files(self, str_pattern: str, instance_data_root: str = None):
def list_files(self, str_pattern: str, instance_data_dir: str = None):
# Initialize the results list
results = []

Expand All @@ -206,16 +206,16 @@ def list_files(self, str_pattern: str, instance_data_root: str = None):
paginator = self.client.get_paginator("list_objects_v2")

# We'll use fnmatch to filter based on the provided pattern.
if instance_data_root:
pattern = os.path.join(instance_data_root or None, str_pattern)
if instance_data_dir:
pattern = os.path.join(instance_data_dir or None, str_pattern)
else:
pattern = str_pattern

# Using a dictionary to hold files based on their prefixes (subdirectories)
prefix_dict = {}
# Log the first few items, alphabetically sorted:
logger.debug(
f"Listing files in S3 bucket {self.bucket_name} in prefix {instance_data_root} with search pattern: {pattern}"
f"Listing files in S3 bucket {self.bucket_name} in prefix {instance_data_dir} with search pattern: {pattern}"
)

# Paginating over the entire bucket objects
Expand Down
2 changes: 1 addition & 1 deletion helpers/data_backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def open_file(self, identifier, mode):
pass

@abstractmethod
def list_files(self, pattern: str, instance_data_root: str = None) -> tuple:
def list_files(self, pattern: str, instance_data_dir: str = None) -> tuple:
"""
List all files matching the pattern.
"""
Expand Down
Loading

0 comments on commit e8f707f

Please sign in to comment.