Skip to content

Commit

Permalink
[CherryPick] [FRONTEND] Add support for using remote cache managers (#…
Browse files Browse the repository at this point in the history
…2934) (#3379)

Adds a redis-based one as an initial implementation, but it should be
straightforward to extend with more impls.

Co-authored-by: andrewjcg <andrewjcg@gmail.com>
  • Loading branch information
shunting314 and andrewjcg authored Mar 14, 2024
1 parent 0e7b97b commit 996b6c0
Showing 1 changed file with 122 additions and 10 deletions.
132 changes: 122 additions & 10 deletions python/triton/runtime/cache.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import importlib
import json
import os
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, Optional
from typing import Dict, List, Optional
import hashlib


def default_cache_dir():
Expand All @@ -27,10 +29,6 @@ def __init__(self, key):
def get_file(self, filename) -> Optional[str]:
pass

@abstractmethod
def has_file(self, filename) -> bool:
pass

@abstractmethod
def put(self, data, filename, binary=True) -> str:
pass
Expand Down Expand Up @@ -70,20 +68,20 @@ def __init__(self, key, override=False, dump=False):
def _make_path(self, filename) -> str:
return os.path.join(self.cache_dir, filename)

def has_file(self, filename) -> bool:
def _has_file(self, filename) -> bool:
if not self.cache_dir:
raise RuntimeError("Could not create or locate cache dir")
return os.path.exists(self._make_path(filename))

def get_file(self, filename) -> Optional[str]:
if self.has_file(filename):
if self._has_file(filename):
return self._make_path(filename)
else:
return None

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
grp_filename = f"__grp__{filename}"
if not self.has_file(grp_filename):
if not self._has_file(grp_filename):
return None
grp_filepath = self._make_path(grp_filename)
with open(grp_filepath) as f:
Expand Down Expand Up @@ -130,6 +128,122 @@ def put(self, data, filename, binary=True) -> str:
return filepath


class RemoteCacheBackend:
"""
A backend implementation for accessing a remote/distributed cache.
"""

def __init__(self, key: str):
pass

@abstractmethod
def get(self, filenames: List[str]) -> Dict[str, bytes]:
pass

@abstractmethod
def put(self, filename: str, data: bytes):
pass


class RedisRemoteCacheBackend(RemoteCacheBackend):

def __init__(self, key):
import redis
self._key = key
self._key_fmt = os.environ.get("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
self._redis = redis.Redis(
host=os.environ.get("TRITON_REDIS_HOST", "localhost"),
port=int(os.environ.get("TRITON_REDIS_PORT", 6379)),
)

def _get_key(self, filename: str) -> str:
return self._key_fmt.format(key=self._key, filename=filename)

def get(self, filenames: List[str]) -> Dict[str, str]:
results = self._redis.mget([self._get_key(f) for f in filenames])
return {filename: result for filename, result in zip(filenames, results) if result is not None}

def put(self, filename: str, data: bytes) -> Dict[str, bytes]:
self._redis.set(self._get_key(filename), data)


class RemoteCacheManager(CacheManager):

def __init__(self, key, override=False, dump=False):
# Setup backend pointed too by `TRITON_REMOTE_CACHE_BACKEND`.
remote_cache_manager = os.environ["TRITON_REMOTE_CACHE_BACKEND"]
module_path, clz_nme = remote_cache_manager.split(":")
module = importlib.import_module(module_path)
remote_cache_cls = getattr(module, clz_nme)
self._backend = remote_cache_cls(key)

self._override = override
self._dump = dump

# Use a `FileCacheManager` to materialize remote cache paths locally.
self._file_cache_manager = FileCacheManager(key, override=override, dump=dump)

def _materialize(self, filename: str, data: bytes):
# We use a backing `FileCacheManager` to provide the materialized data.
return self._file_cache_manager.put(data, filename, binary=True)

def get_file(self, filename: str) -> Optional[str]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_file(filename)

# We always check the remote cache backend -- even if our internal file-
# based cache has the item -- to make sure LRU accounting works as
# expected.
results = self._backend.get([filename])
if len(results) == 0:
return None
(_, data), = results.items()
return self._materialize(filename, data)

def put(self, data, filename: str, binary=True) -> str:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put(data, filename, binary=binary)

if not isinstance(data, bytes):
data = str(data).encode("utf-8")
self._backend.put(filename, data)
return self._materialize(filename, data)

def get_group(self, filename: str) -> Optional[Dict[str, str]]:
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.get_group(filename)

grp_filename = f"__grp__{filename}"
grp_filepath = self.get_file(grp_filename)
if grp_filepath is None:
return None
with open(grp_filepath) as f:
grp_data = json.load(f)
child_paths = grp_data.get("child_paths", None)

result = None

# Found group data.
if child_paths is not None:
result = {}
for child_path, data in self._backend.get(child_paths).items():
result[child_path] = self._materialize(child_path, data)

return result

def put_group(self, filename: str, group: Dict[str, str]):
# We don't handle the dump/override cases.
if self._dump or self._override:
return self._file_cache_manager.put_group(filename, group)

grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))})
grp_filename = f"__grp__{filename}"
return self.put(grp_contents, grp_filename)


__cache_cls = FileCacheManager
__cache_cls_nme = "DEFAULT"

Expand All @@ -142,8 +256,6 @@ def get_cache_manager(key) -> CacheManager:
global __cache_cls_nme

if user_cache_manager is not None and user_cache_manager != __cache_cls_nme:
import importlib

module_path, clz_nme = user_cache_manager.split(":")
module = importlib.import_module(module_path)
__cache_cls = getattr(module, clz_nme)
Expand Down

0 comments on commit 996b6c0

Please sign in to comment.