Skip to content

Commit

Permalink
Change locate_working_dir logic to be breadth first.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jun 19, 2024
1 parent a83956c commit 7a957dd
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 58 deletions.
6 changes: 2 additions & 4 deletions runhouse/resources/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from runhouse.constants import CONDA_INSTALL_CMDS, EMPTY_DEFAULT_ENV_NAME
from runhouse.globals import rns_client
from runhouse.resources.resource import Resource
from runhouse.utils import locate_working_dir


def _process_reqs(reqs):
Expand All @@ -26,10 +27,7 @@ def _process_reqs(reqs):
else:
# if package refers to a local path package
path = Path(package.split(":")[-1]).expanduser()
if (
path.is_absolute()
or (rns_client.locate_working_dir() / path).exists()
):
if path.is_absolute() or (locate_working_dir() / path).exists():
package = Package.from_string(package)
elif isinstance(package, dict):
package = Package.from_config(package)
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/folders/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from runhouse.resources.resource import Resource
from runhouse.rns.top_level_rns_fns import exists
from runhouse.rns.utils.api import generate_uuid
from runhouse.utils import locate_working_dir

fsspec.register_implementation("ssh", sshfs.SSHFileSystem)
# SSHFileSystem is not yet builtin.
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(
if system != "file"
else path
if Path(path).expanduser().is_absolute()
else str(Path(rns_client.locate_working_dir()) / path)
else str(Path(locate_working_dir()) / path)
)
self.data_config = data_config or {}

Expand Down Expand Up @@ -619,7 +620,7 @@ def _save_sub_resources(self, folder: str = None):

@staticmethod
def _path_relative_to_rh_workdir(path):
rh_workdir = Path(rns_client.locate_working_dir())
rh_workdir = Path(locate_working_dir())
try:
return str(Path(path).relative_to(rh_workdir))
except ValueError:
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from runhouse.rns.utils.api import ResourceAccess, ResourceVisibility
from runhouse.servers.http.certs import TLSCertConfig
from runhouse.utils import locate_working_dir

# Filter out DeprecationWarnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -1540,7 +1541,7 @@ def notebook(
from runhouse.resources.packages.package import Package

if sync_package_on_close == "./":
sync_package_on_close = rns_client.locate_working_dir()
sync_package_on_close = locate_working_dir()
pkg = Package.from_string("local:" + sync_package_on_close)
self._rsync(source=f"~/{pkg.name}", dest=pkg.local_path, up=False)
if not persist:
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from runhouse.rns.utils.names import _generate_default_name
from runhouse.servers.http import HTTPClient
from runhouse.servers.http.http_utils import CallParams
from runhouse.utils import locate_working_dir

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1035,7 +1036,7 @@ def _extract_pointers(raw_cls_or_fn: Union[Type, Callable], reqs: List[str]):
local_path = (
Path(req).expanduser()
if Path(req).expanduser().is_absolute()
else Path(rns_client.locate_working_dir()) / req
else Path(locate_working_dir()) / req
)

if local_path:
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/packages/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pathlib import Path
from typing import Dict, Optional, Union

from runhouse import globals
from runhouse.resources.envs.utils import install_conda, run_setup_command
from runhouse.resources.folders import Folder, folder
from runhouse.resources.hardware.cluster import Cluster
Expand All @@ -14,6 +13,8 @@
detect_cuda_version_or_cpu,
)
from runhouse.resources.resource import Resource
from runhouse.utils import locate_working_dir


INSTALL_METHODS = {"local", "reqs", "pip", "conda"}

Expand Down Expand Up @@ -407,7 +408,7 @@ def from_string(specifier: str, dryrun=False):
abs_target = (
Path(rel_target).expanduser()
if Path(rel_target).expanduser().is_absolute()
else Path(globals.rns_client.locate_working_dir()) / rel_target
else Path(locate_working_dir()) / rel_target
)
if abs_target.exists():
target = Folder(
Expand Down
45 changes: 2 additions & 43 deletions runhouse/rns/rns_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
remove_null_values_from_dict,
ResourceAccess,
)
from runhouse.utils import locate_working_dir

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, configs) -> None:
self._configs = configs
self._prev_folders = []

self.rh_directory = str(Path(self.locate_working_dir()) / "rh")
self.rh_directory = str(Path(locate_working_dir()) / "rh")
self.rh_builtins_directory = str(
Path(importlib.util.find_spec("runhouse").origin).parent / "builtins"
)
Expand All @@ -76,48 +77,6 @@ def __init__(self, configs) -> None:

self.session = requests.Session()

@classmethod
def find_parent_with_file(cls, dir_path, file, searched_dirs=None):
if Path(dir_path) == Path.home() or dir_path == Path("/"):
return None
if Path(dir_path, file).exists():
return str(dir_path)
else:
if searched_dirs is None:
searched_dirs = {
dir_path,
}
else:
searched_dirs.add(dir_path)
parent_path = Path(dir_path).parent
if parent_path in searched_dirs:
return None
return cls.find_parent_with_file(
parent_path, file, searched_dirs=searched_dirs
)

@classmethod
def locate_working_dir(cls, cwd=os.getcwd()):
# Search for working_dir by looking up directory tree, in the following order:
# 1. Upward directory with rh/ subdirectory
# 2. Root git directory
# 3. Upward directory with requirements.txt
# 4. User's cwd

for search_target in [
".git",
"setup.py",
"setup.cfg",
"pyproject.toml",
"rh",
"requirements.txt",
]:
dir_with_target = cls.find_parent_with_file(cwd, search_target)
if dir_with_target is not None:
return dir_with_target
else:
return cwd

@property
def default_folder(self):
return self._configs.default_folder
Expand Down
43 changes: 43 additions & 0 deletions runhouse/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,50 @@
import asyncio
import contextvars
import functools
import os
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path


def _find_directory_containing_any_file(dir_path, files, searched_dirs=None):
if Path(dir_path) == Path.home() or dir_path == Path("/"):
return None

if any(Path(dir_path, file).exists() for file in files):
return str(dir_path)

searched_dirs.add(dir_path)
parent_path = Path(dir_path).parent
if parent_path in searched_dirs:
return None
return _find_directory_containing_any_file(
parent_path, files, searched_dirs=searched_dirs
)


def locate_working_dir(start_dir=None):
if start_dir is None:
start_dir = os.getcwd()

# Search first for anything that represents a Python package
target_files = [
".git",
"setup.py",
"setup.cfg",
"pyproject.toml",
"requirements.txt",
]

dir_with_target = _find_directory_containing_any_file(
start_dir, target_files, searched_dirs=set()
)

if dir_with_target is None:
dir_with_target = _find_directory_containing_any_file(
start_dir, ["rh"], searched_dirs=set()
)

return dir_with_target if dir_with_target is not None else start_dir


def _thread_coroutine(coroutine, context):
Expand Down
11 changes: 6 additions & 5 deletions tests/test_den/test_rns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
import pytest

import runhouse as rh
from runhouse.globals import rns_client
from runhouse.utils import locate_working_dir


@pytest.mark.level("unit")
def test_find_working_dir(tmp_path):
starting_dir = Path(tmp_path, "subdir/subdir/subdir/subdir")
d = rns_client.locate_working_dir(cwd=str(starting_dir))
d = locate_working_dir(str(starting_dir))
assert d in str(starting_dir)

Path(tmp_path, "subdir/rh").mkdir(parents=True)
d = rns_client.locate_working_dir(str(starting_dir))
d = locate_working_dir(str(starting_dir))
assert d == str(Path(tmp_path, "subdir"))

Path(tmp_path, "subdir/rh").rmdir()

Path(tmp_path, "subdir/subdir/.git").mkdir(exist_ok=True, parents=True)
d = rns_client.locate_working_dir(str(starting_dir))
d = locate_working_dir(str(starting_dir))
assert d in str(Path(tmp_path, "subdir/subdir"))

Path(tmp_path, "subdir/subdir/.git").rmdir()

Path(tmp_path, "subdir/subdir/requirements.txt").write_text("....")
d = rns_client.locate_working_dir(str(starting_dir))
d = locate_working_dir(str(starting_dir))
assert d in str(Path(tmp_path, "subdir/subdir"))


Expand Down

0 comments on commit 7a957dd

Please sign in to comment.