Skip to content

Commit

Permalink
Feat/add_tests (#27)
Browse files Browse the repository at this point in the history
* feat(tests): Add tests for models

* ruff auto-fix

* test(terraform_resource): Add to_dict() test.

* doc(checkout): Add note and update action example regarding fetch-depth.

* feat(action): Add log message when updating pr.

* fix(provider_cache): Fix provider cache not working.

* fix(markdown_tables): Skip table generate for providers without changes.

* fix(pr_update): Fix handling of existing prs

* fix(action): Fix searching existing pr.

* fix(action): Filter pr in python instead of search over github api.

* feat(action): Update pr even when no upgrades where performed.

* fix(provider_handler): Switch reference of object.

* fix(provider_handler): Fix resource update.

* fix(provider_handler): Fix resource replace in cache.

* fix(provider_handler): REmove indent.
  • Loading branch information
Noahnc authored Nov 29, 2023
1 parent e3bb166 commit 229c644
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 38 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Run in update mode
uses: Noahnc/infrapatch@main
Expand All @@ -54,6 +56,8 @@ jobs:

```

> **_NOTE:_** It's important to set the `fetch-depth: 0` in the Checkout step, otherwise rebases performed by InfraPatch will not work correctly.
### Example PR

InfraPatch will create a new branch with the changes and open a PR to the branch for which the Action was triggered.
Expand Down
53 changes: 41 additions & 12 deletions infrapatch/action/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def main(debug: bool):
upgradable_resources_head_branch = None
pr = None
if github_target_branch is not None and config.report_only is False:
pr = get_pr(github_repo, config.head_branch, config.target_branch)
pr = get_pr(github_repo, head=config.target_branch, base=config.head_branch)
if pr is not None:
upgradable_resources_head_branch = provider_handler.get_upgradable_resources()
log.info(f"Branch {config.target_branch} already exists. Checking out...")
Expand All @@ -68,6 +68,10 @@ def main(debug: bool):

if provider_handler.check_if_upgrades_available() is False:
log.info("No resources with pending upgrade found.")
if pr is not None and upgradable_resources_head_branch is not None:
log.info("Updating PR Body...")
provider_handler.set_resources_patched_based_on_existing_resources(upgradable_resources_head_branch)
update_pr_body(pr, provider_handler)
return

if github_target_branch is None:
Expand All @@ -79,22 +83,30 @@ def main(debug: bool):
if upgradable_resources_head_branch is not None:
log.info("Updating status of resources from previous branch...")
provider_handler.set_resources_patched_based_on_existing_resources(upgradable_resources_head_branch)

provider_handler.print_statistics_table()
provider_handler.dump_statistics()

git.push(["-f", "-u", "origin", config.target_branch])

body = get_pr_body(provider_handler)
if pr is not None:
update_pr_body(pr, provider_handler)
return
create_pr(github_repo, config.head_branch, config.target_branch, provider_handler)


def update_pr_body(pr, provider_handler):
if pr is not None:
log.info("Updating existing pull request with new body.")
body = get_pr_body(provider_handler)
log.debug(f"Pull request body:\n{body}")
pr.edit(body=body)
return
create_pr(github_repo, config.head_branch, config.target_branch, body)


def get_pr_body(provider_handler: ProviderHandler) -> str:
body = ""
markdown_tables = provider_handler.get_markdown_tables()
markdown_tables = provider_handler.get_markdown_table_for_changed_resources()
for table in markdown_tables:
body += table.dumps()
body += "\n"
Expand All @@ -104,17 +116,34 @@ def get_pr_body(provider_handler: ProviderHandler) -> str:
return body


def get_pr(repo: Repository, head_branch, target_branch) -> Union[PullRequest, None]:
pull = repo.get_pulls(state="open", sort="created", base=head_branch, head=target_branch)
if pull.totalCount != 0:
log.info(f"Pull request found from '{target_branch}' to '{head_branch}'")
return pull[0]
log.debug(f"No pull request found from '{target_branch}' to '{head_branch}'.")
return None
def get_pr(repo: Repository, base: str, head: str) -> Union[PullRequest, None]:
base_ref = base
head_ref = head
if base_ref.startswith("origin/"):
base_ref = base_ref[len("origin/") :]
if head_ref.startswith("origin/"):
head_ref = head_ref[len("origin/") :]
pulls = repo.get_pulls(state="open", sort="created", direction="desc")

if pulls.totalCount == 0:
log.debug("No pull request found")
return None

pr = [pr for pr in pulls if pr.base.ref == base_ref and pr.head.ref == head_ref]
if len(pr) == 0:
log.debug(f"No pull request found from '{head}' to '{base}'.")
return None
elif len(pr) == 1:
log.debug(f"Pull request found from '{head}' to '{base}'.")
return pr[0]
if len(pr) > 1:
raise Exception(f"Multiple pull requests found from '{head}' to '{base}'.")

def create_pr(repo: Repository, head_branch: str, target_branch: str, body: str) -> PullRequest:

def create_pr(repo: Repository, head_branch: str, target_branch: str, provider_handler: ProviderHandler) -> PullRequest:
body = get_pr_body(provider_handler)
log.info(f"Creating new pull request from '{target_branch}' to '{head_branch}'.")
log.debug(f"Pull request body:\n{body}")
return repo.create_pull(title="InfraPatch Module and Provider Update", body=body, base=head_branch, head=target_branch)


Expand Down
6 changes: 3 additions & 3 deletions infrapatch/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Union

import click
from infrapatch.core.credentials_helper import get_registry_credentials
from infrapatch.core.provider_handler import ProviderHandler
from infrapatch.core.provider_handler_builder import ProviderHandlerBuilder

from infrapatch.cli.__init__ import __version__
from infrapatch.core.credentials_helper import get_registry_credentials
from infrapatch.core.log_helper import catch_exception, setup_logging
from infrapatch.core.provider_handler import ProviderHandler
from infrapatch.core.provider_handler_builder import ProviderHandlerBuilder
from infrapatch.core.utils.terraform.hcl_edit_cli import HclEditCli
from infrapatch.core.utils.terraform.hcl_handler import HclHandler

Expand Down
70 changes: 70 additions & 0 deletions infrapatch/core/models/tests/test_versioned_resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from pathlib import Path

from infrapatch.core.models.versioned_resource import ResourceStatus, VersionedResource


def test_version_management():
# Create new resource with newer version
resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="test_file.py")
resource.newest_version = "2.0.0"

assert resource.status == ResourceStatus.UNPATCHED
assert resource.installed_version_equal_or_newer_than_new_version() is False

resource.set_patched()
assert resource.status == ResourceStatus.PATCHED

resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="test_file.py")
resource.newest_version = "1.0.0"

assert resource.status == ResourceStatus.UNPATCHED
assert resource.installed_version_equal_or_newer_than_new_version() is True


def test_tile_constraint():
resource = VersionedResource(name="test_resource", current_version="~>1.0.0", _source_file="test_file.py")
resource.newest_version = "~>1.0.1"
assert resource.has_tile_constraint() is True
assert resource.installed_version_equal_or_newer_than_new_version() is True

resource.newest_version = "~>1.1.0"
assert resource.installed_version_equal_or_newer_than_new_version() is False

resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="test_file.py")
assert resource.has_tile_constraint() is False

resource = VersionedResource(name="test_resource", current_version="~>1.0.0", _source_file="test_file.py")
resource.newest_version = "1.1.0"
assert resource.newest_version == "~>1.1.0"


def test_patch_error():
resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="test_file.py")
resource.set_patch_error()
assert resource.status == ResourceStatus.PATCH_ERROR


def test_path():
resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="/var/testdir/test_file.py")
assert resource.source_file == Path("/var/testdir/test_file.py")


def test_find():
findably_resource = VersionedResource(name="test_resource3", current_version="1.0.0", _source_file="test_file3.py")
unfindably_resource = VersionedResource(name="test_resource6", current_version="1.0.0", _source_file="test_file8.py")
resources = [
VersionedResource(name="test_resource1", current_version="1.0.0", _source_file="test_file1.py"),
VersionedResource(name="test_resource2", current_version="1.0.0", _source_file="test_file2.py"),
VersionedResource(name="test_resource3", current_version="1.0.0", _source_file="test_file3.py"),
VersionedResource(name="test_resource4", current_version="1.0.0", _source_file="test_file4.py"),
VersionedResource(name="test_resource5", current_version="1.0.0", _source_file="test_file5.py"),
]
assert len(findably_resource.find(resources)) == 1
assert findably_resource.find(resources) == [resources[2]]
assert len(unfindably_resource.find(resources)) == 0


def test_versioned_resource_to_dict():
resource = VersionedResource(name="test_resource", current_version="1.0.0", _source_file="test_file.py")
expected_dict = {"name": "test_resource", "current_version": "1.0.0", "_source_file": "test_file.py", "_newest_version": None, "_status": ResourceStatus.UNPATCHED}
assert resource.to_dict() == expected_dict
82 changes: 82 additions & 0 deletions infrapatch/core/models/tests/test_versioned_terraform_resource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest

from infrapatch.core.models.versioned_terraform_resources import TerraformModule, TerraformProvider


def test_attributes():
# test with default registry
module = TerraformModule(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="test/test_module/test_provider")
provider = TerraformProvider(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="test_provider/test_provider")

assert module.source == "test/test_module/test_provider"
assert module.base_domain is None
assert module.identifier == "test/test_module/test_provider"

assert provider.source == "test_provider/test_provider"
assert provider.base_domain is None
assert provider.identifier == "test_provider/test_provider"

# test with custom registry
module = TerraformModule(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="testregistry.ch/test/test_module/test_provider")
provider = TerraformProvider(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="testregistry.ch/test_provider/test_provider")

assert module.source == "testregistry.ch/test/test_module/test_provider"
assert module.base_domain == "testregistry.ch"
assert module.identifier == "test/test_module/test_provider"

assert provider.source == "testregistry.ch/test_provider/test_provider"
assert provider.base_domain == "testregistry.ch"
assert provider.identifier == "test_provider/test_provider"

# test invalid sources
with pytest.raises(Exception):
TerraformModule(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="test/test_module/test_provider/test")
TerraformModule(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="/test_module")

with pytest.raises(Exception):
TerraformProvider(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="/test_module")
TerraformProvider(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="kfdsjflksdj/kldfsjflsdkj/dkljflsk/test_module")


def test_find():
findably_resource = TerraformModule(name="test_resource3", current_version="1.0.0", _source_file="test_file3.py", _source="test/test_module3/test_provider")
unfindably_resource = TerraformModule(name="test_resource6", current_version="1.0.0", _source_file="test_file8.py", _source="test/test_module3/test_provider")
resources = [
TerraformModule(name="test_resource1", current_version="1.0.0", _source_file="test_file1.py", _source="test/test_module1/test_provider"),
TerraformModule(name="test_resource2", current_version="1.0.0", _source_file="test_file2.py", _source="test/test_module2/test_provider"),
TerraformModule(name="test_resource3", current_version="1.0.0", _source_file="test_file3.py", _source="test/test_module3/test_provider"),
TerraformModule(name="test_resource4", current_version="1.0.0", _source_file="test_file4.py", _source="test/test_module4/test_provider"),
TerraformModule(name="test_resource5", current_version="1.0.0", _source_file="test_file5.py", _source="test/test_module5/test_provider"),
]
assert len(findably_resource.find(resources)) == 1
assert findably_resource.find(resources) == [resources[2]]
assert len(unfindably_resource.find(resources)) == 0


def test_to_dict():
module = TerraformModule(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="test/test_module/test_provider")
provider = TerraformProvider(name="test_resource", current_version="1.0.0", _source_file="test_file.py", _source="test_provider/test_provider")

module_dict = module.to_dict()
provider_dict = provider.to_dict()

assert module_dict == {
"name": "test_resource",
"current_version": "1.0.0",
"_newest_version": None,
"_status": "unpatched",
"_source_file": "test_file.py",
"_source": "test/test_module/test_provider",
"_base_domain": None,
"_identifier": "test/test_module/test_provider",
}
assert provider_dict == {
"name": "test_resource",
"current_version": "1.0.0",
"_newest_version": None,
"_status": "unpatched",
"_source_file": "test_file.py",
"_source": "test_provider/test_provider",
"_base_domain": None,
"_identifier": "test_provider/test_provider",
}
10 changes: 7 additions & 3 deletions infrapatch/core/models/versioned_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ def newest_version(self, version: str):
def set_patched(self):
self._status = ResourceStatus.PATCHED

def has_tile_constraint(self):
return re.match(r"^~>[0-9]+\.[0-9]+\.[0-9]+$", self.current_version)
def has_tile_constraint(self) -> bool:
result = re.match(r"^~>[0-9]+\.[0-9]+\.[0-9]+$", self.current_version)
if result is None:
return False
return True

def set_patch_error(self):
self._status = ResourceStatus.PATCH_ERROR

def find(self, resources):
return [resource for resource in resources if resource.name == self.name and resource._source_file == self._source_file]
result = [resource for resource in resources if resource.name == self.name and resource._source_file == self._source_file]
return result

def installed_version_equal_or_newer_than_new_version(self):
if self.newest_version is None:
Expand Down
10 changes: 1 addition & 9 deletions infrapatch/core/models/versioned_terraform_resources.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging as log
import re
from dataclasses import dataclass
from typing import Optional, Sequence, Union
from typing import Optional, Union

from infrapatch.core.models.versioned_resource import VersionedResource

Expand Down Expand Up @@ -93,11 +93,3 @@ def source(self, source: str) -> None:
self._identifier = source_lower_case
else:
raise Exception(f"Source '{source_lower_case}' is not a valid terraform resource source.")


def get_upgradable_resources(resources: Sequence[VersionedTerraformResource]) -> Sequence[VersionedTerraformResource]:
return [resource for resource in resources if not resource.check_if_up_to_date()]


def from_terraform_resources_to_dict_list(terraform_resources: Sequence[VersionedTerraformResource]) -> Sequence[dict]:
return [terraform_resource.to_dict() for terraform_resource in terraform_resources]
37 changes: 26 additions & 11 deletions infrapatch/core/provider_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import logging as log
from pathlib import Path
from typing import Sequence, Union
from rich import progress

from git import Repo
from pytablewriter import MarkdownTableWriter
from rich import progress
from rich.console import Console

from infrapatch.core.models.statistics import ProviderStatistics, Statistics
Expand All @@ -25,8 +26,16 @@ def __init__(self, providers: Sequence[BaseProviderInterface], console: Console,

def get_resources(self, disable_cache: bool = False) -> dict[str, Sequence[VersionedResource]]:
for provider_name, provider in self.providers.items():
if not disable_cache and provider_name not in self._resource_cache:
if provider_name not in self._resource_cache:
log.debug(f"Fetching resources for provider {provider.get_provider_name()} since cache is empty.")
self._resource_cache[provider.get_provider_name()] = provider.get_resources()
continue
elif disable_cache:
log.debug(f"Fetching resources for provider {provider.get_provider_name()} since cache is disabled.")
self._resource_cache[provider.get_provider_name()] = provider.get_resources()
continue
else:
log.debug(f"Using cached resources for provider {provider.get_provider_name()}.")
return self._resource_cache

def get_upgradable_resources(self, disable_cache: bool = False) -> dict[str, Sequence[VersionedResource]]:
Expand Down Expand Up @@ -119,7 +128,7 @@ def print_statistics_table(self, disable_cache: bool = False):
table = self._get_statistics(disable_cache).get_rich_table()
self.console.print(table)

def get_markdown_tables(self) -> list[MarkdownTableWriter]:
def get_markdown_table_for_changed_resources(self) -> list[MarkdownTableWriter]:
if self._resource_cache is None:
raise Exception("No resources found. Run get_resources() first.")

Expand All @@ -128,17 +137,23 @@ def get_markdown_tables(self) -> list[MarkdownTableWriter]:
changed_resources = [
resource for resource in self._resource_cache[provider_name] if resource.status == ResourceStatus.PATCHED or resource.status == ResourceStatus.PATCH_ERROR
]
if len(changed_resources) == 0:
log.debug(f"No changed resources found for provider {provider_name}. Skipping.")
continue
markdown_tables.append(provider.get_markdown_table(changed_resources))
return markdown_tables

def set_resources_patched_based_on_existing_resources(self, resources: dict[str, Sequence[VersionedResource]]) -> None:
def set_resources_patched_based_on_existing_resources(self, original_resources: dict[str, Sequence[VersionedResource]]) -> None:
for provider_name, provider in self.providers.items():
current_resources = resources[provider_name]
for resource in resources[provider_name]:
current_resource = resource.find(current_resources)
if len(current_resource) == 0:
log.info(f"Resource '{resource.name}' not found in current resources. Skipping.")
original_resources_provider = original_resources[provider_name]
for i, resource in enumerate(self._resource_cache[provider_name]):
found_resources = resource.find(original_resources_provider)
if len(found_resources) == 0:
log.debug(f"Resource '{resource.name}' not found in original resources. Skipping update.")
continue
if len(current_resource) > 1:
if len(found_resources) > 1:
raise Exception(f"Found multiple resources with the same name: {resource.name}")
current_resource[0].set_patched()
log.debug(f"Updating resource '{resource.name}' from provider {provider_name} with original resource.")
found_resource = found_resources[0]
found_resource.set_patched()
self._resource_cache[provider_name][i] = found_resource # type: ignore

0 comments on commit 229c644

Please sign in to comment.