Skip to content

Commit

Permalink
Add pipeline to sort packages (#1686)
Browse files Browse the repository at this point in the history
* Add pipeline to sort packages

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Add tests

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Add calculate_version_rank on Package

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Start enumerating from 1

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Fix tests

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Return version rank anyhow

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Fix API tests

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

* Address review comments

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>

---------

Signed-off-by: Tushar Goel <tushar.goel.dav@gmail.com>
  • Loading branch information
TG1999 authored Dec 8, 2024
1 parent cec5d9e commit c923364
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 21 deletions.
2 changes: 2 additions & 0 deletions vulnerabilities/improvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vulnerabilities.improvers import vulnerability_status
from vulnerabilities.pipelines import VulnerableCodePipeline
from vulnerabilities.pipelines import compute_package_risk
from vulnerabilities.pipelines import compute_package_version_rank
from vulnerabilities.pipelines import enhance_with_exploitdb
from vulnerabilities.pipelines import enhance_with_kev
from vulnerabilities.pipelines import enhance_with_metasploit
Expand Down Expand Up @@ -39,6 +40,7 @@
enhance_with_metasploit.MetasploitImproverPipeline,
enhance_with_exploitdb.ExploitDBImproverPipeline,
compute_package_risk.ComputePackageRiskPipeline,
compute_package_version_rank.ComputeVersionRankPipeline,
]

IMPROVERS_REGISTRY = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Generated by Django 4.2.16 on 2024-12-04 11:50

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("vulnerabilities", "0083_alter_packagechangelog_software_version_and_more"),
]

operations = [
migrations.AlterModelOptions(
name="package",
options={
"ordering": [
"type",
"namespace",
"name",
"version_rank",
"version",
"qualifiers",
"subpath",
]
},
),
migrations.AddField(
model_name="package",
name="version_rank",
field=models.IntegerField(
default=0,
help_text="Rank of the version to support ordering by version. Rank zero means the rank has not been defined yet",
),
),
]
62 changes: 44 additions & 18 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,12 @@ class Package(PackageURLMixin):
"indicate greater vulnerability risk for the package.",
)

version_rank = models.IntegerField(
help_text="Rank of the version to support ordering by version. Rank "
"zero means the rank has not been defined yet",
default=0,
)

objects = PackageQuerySet.as_manager()

def save(self, *args, **kwargs):
Expand Down Expand Up @@ -738,11 +744,34 @@ def purl(self):

class Meta:
unique_together = ["type", "namespace", "name", "version", "qualifiers", "subpath"]
ordering = ["type", "namespace", "name", "version", "qualifiers", "subpath"]
ordering = ["type", "namespace", "name", "version_rank", "version", "qualifiers", "subpath"]

def __str__(self):
return self.package_url

@property
def calculate_version_rank(self):
"""
Calculate and return the `version_rank` for a package that does not have one.
If this package already has a `version_rank`, return it.
The calculated rank will be interpolated between two packages that have
`version_rank` values and are closest to this package in terms of version order.
"""

group_packages = Package.objects.filter(
type=self.type,
namespace=self.namespace,
name=self.name,
)

if any(p.version_rank == 0 for p in group_packages):
sorted_packages = sorted(group_packages, key=lambda p: self.version_class(p.version))
for rank, package in enumerate(sorted_packages, start=1):
package.version_rank = rank
Package.objects.bulk_update(sorted_packages, fields=["version_rank"])
return self.version_rank

@property
def affected_by(self):
"""
Expand Down Expand Up @@ -789,14 +818,6 @@ def get_details_url(self, request):

return reverse("package_details", kwargs={"purl": self.purl}, request=request)

def sort_by_version(self, packages):
"""
Return a sequence of `packages` sorted by version.
"""
if not packages:
return []
return sorted(packages, key=lambda x: self.version_class(x.version))

@cached_property
def version_class(self):
range_class = RANGE_CLASS_BY_SCHEMES.get(self.type)
Expand Down Expand Up @@ -831,19 +852,20 @@ def get_non_vulnerable_versions(self):
Return a tuple of the next and latest non-vulnerable versions as Package instance.
Return a tuple of (None, None) if there is no non-vulnerable version.
"""
if self.version_rank == 0:
self.calculate_version_rank
non_vulnerable_versions = Package.objects.get_fixed_by_package_versions(
self, fix=False
).only_non_vulnerable()
sorted_versions = self.sort_by_version(non_vulnerable_versions)

later_non_vulnerable_versions = [
non_vuln_ver
for non_vuln_ver in sorted_versions
if self.version_class(non_vuln_ver.version) > self.current_version
]
later_non_vulnerable_versions = non_vulnerable_versions.filter(
version_rank__gt=self.version_rank
)

later_non_vulnerable_versions = list(later_non_vulnerable_versions)

if later_non_vulnerable_versions:
sorted_versions = self.sort_by_version(later_non_vulnerable_versions)
sorted_versions = later_non_vulnerable_versions
next_non_vulnerable = sorted_versions[0]
latest_non_vulnerable = sorted_versions[-1]
return next_non_vulnerable, latest_non_vulnerable
Expand Down Expand Up @@ -872,6 +894,8 @@ def get_affecting_vulnerabilities(self):
Return a list of vulnerabilities that affect this package together with information regarding
the versions that fix the vulnerabilities.
"""
if self.version_rank == 0:
self.calculate_version_rank
package_details_vulns = []

fixed_by_packages = Package.objects.get_fixed_by_package_versions(self, fix=True)
Expand All @@ -895,12 +919,13 @@ def get_affecting_vulnerabilities(self):
if fixed_version > self.current_version:
later_fixed_packages.append(fixed_pkg)

next_fixed_package = None
next_fixed_package_vulns = []

sort_fixed_by_packages_by_version = []
if later_fixed_packages:
sort_fixed_by_packages_by_version = self.sort_by_version(later_fixed_packages)
sort_fixed_by_packages_by_version = sorted(
later_fixed_packages, key=lambda p: p.version_rank
)

fixed_by_pkgs = []

Expand Down Expand Up @@ -930,6 +955,7 @@ def fixing_vulnerabilities(self):
"""
Return a queryset of Vulnerabilities that are fixed by this package.
"""
print("A")
return self.fixed_by_vulnerabilities.all()

@property
Expand Down
93 changes: 93 additions & 0 deletions vulnerabilities/pipelines/compute_package_version_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#
# Copyright (c) nexB Inc. and others. All rights reserved.
# VulnerableCode is a trademark of nexB Inc.
# SPDX-License-Identifier: Apache-2.0
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
# See https://aboutcode.org for more information about nexB OSS projects.
#

from itertools import groupby

from aboutcode.pipeline import LoopProgress
from django.db import transaction
from univers.version_range import RANGE_CLASS_BY_SCHEMES
from univers.versions import Version

from vulnerabilities.models import Package
from vulnerabilities.pipelines import VulnerableCodePipeline


class ComputeVersionRankPipeline(VulnerableCodePipeline):
"""
A pipeline to compute and assign version ranks for all packages.
"""

pipeline_id = "compute_version_rank"
license_expression = None

@classmethod
def steps(cls):
return (cls.compute_and_store_version_rank,)

def compute_and_store_version_rank(self):
"""
Compute and assign version ranks to all packages.
"""
groups = Package.objects.only("type", "namespace", "name").order_by(
"type", "namespace", "name"
)

def key(package):
return package.type, package.namespace, package.name

groups = groupby(groups, key=key)

groups = [(list(x), list(y)) for x, y in groups]

total_groups = len(groups)
self.log(f"Calculating `version_rank` for {total_groups:,d} groups of packages.")

progress = LoopProgress(
total_iterations=total_groups,
logger=self.log,
progress_step=5,
)

for group, packages in progress.iter(groups):
type, namespace, name = group
if type not in RANGE_CLASS_BY_SCHEMES:
continue
self.update_version_rank_for_group(packages)

self.log("Successfully populated `version_rank` for all packages.")

@transaction.atomic
def update_version_rank_for_group(self, packages):
"""
Update the `version_rank` for all packages in a specific group.
"""

# Sort the packages by version
sorted_packages = self.sort_packages_by_version(packages)

# Assign version ranks
updates = []
for rank, package in enumerate(sorted_packages, start=1):
package.version_rank = rank
updates.append(package)

# Bulk update to save the ranks
Package.objects.bulk_update(updates, fields=["version_rank"])

def sort_packages_by_version(self, packages):
"""
Sort packages by version using `version_class`.
"""

if not packages:
return []
version_class = RANGE_CLASS_BY_SCHEMES.get(packages[0].type).version_class
if not version_class:
version_class = Version
return sorted(packages, key=lambda p: version_class(p.version))
4 changes: 3 additions & 1 deletion vulnerabilities/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def setUp(self):
self.pkg_2_14_0_rc1 = from_purl(
"pkg:maven/com.fasterxml.jackson.core/jackson-databind@2.14.0-rc1"
)
self.pkg_2_12_6.calculate_version_rank

set_as_fixing(package=self.pkg_2_12_6, vulnerability=self.vul3)

Expand Down Expand Up @@ -608,6 +609,7 @@ def setUp(self):
self.pkg_2_14_0_rc1 = from_purl(
"pkg:maven/com.fasterxml.jackson.core/jackson-databind@2.14.0-rc1"
)
self.pkg_2_12_6.calculate_version_rank

self.ref = VulnerabilityReference.objects.create(
reference_type="advisory", reference_id="CVE-xxx-xxx", url="https://example.com"
Expand Down Expand Up @@ -806,7 +808,7 @@ def test_api_with_ghost_package_no_fixing_vulnerabilities(self):
"qualifiers": {},
"subpath": "",
"is_vulnerable": True,
"next_non_vulnerable_version": "2.14.0-rc1",
"next_non_vulnerable_version": "2.12.6",
"latest_non_vulnerable_version": "2.14.0-rc1",
"affected_by_vulnerabilities": [
{
Expand Down
59 changes: 59 additions & 0 deletions vulnerabilities/tests/test_compute_package_version_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from unittest.mock import patch

import pytest
from univers.versions import Version

from vulnerabilities.models import Package
from vulnerabilities.pipelines.compute_package_version_rank import ComputeVersionRankPipeline


@pytest.mark.django_db
class TestComputeVersionRankPipeline:
@pytest.fixture
def pipeline(self):
return ComputeVersionRankPipeline()

@pytest.fixture
def packages(self, db):
package_type = "pypi"
namespace = "test_namespace"
name = "test_package"
Package.objects.create(type=package_type, namespace=namespace, name=name, version="1.0.0")
Package.objects.create(type=package_type, namespace=namespace, name=name, version="1.1.0")
Package.objects.create(type=package_type, namespace=namespace, name=name, version="0.9.0")
return Package.objects.filter(type=package_type, namespace=namespace, name=name)

def test_compute_and_store_version_rank(self, pipeline, packages):
with patch.object(pipeline, "log") as mock_log:
pipeline.compute_and_store_version_rank()
assert mock_log.call_count > 0
for package in packages:
assert package.version_rank is not None

def test_update_version_rank_for_group(self, pipeline, packages):
with patch.object(Package.objects, "bulk_update") as mock_bulk_update:
pipeline.update_version_rank_for_group(packages)
mock_bulk_update.assert_called_once()
updated_packages = mock_bulk_update.call_args[0][0]
assert len(updated_packages) == len(packages)
for idx, package in enumerate(sorted(packages, key=lambda p: Version(p.version))):
assert updated_packages[idx].version_rank == idx

def test_sort_packages_by_version(self, pipeline, packages):
sorted_packages = pipeline.sort_packages_by_version(packages)
versions = [p.version for p in sorted_packages]
assert versions == sorted(versions, key=Version)

def test_sort_packages_by_version_empty(self, pipeline):
assert pipeline.sort_packages_by_version([]) == []

def test_sort_packages_by_version_invalid_scheme(self, pipeline, packages):
for package in packages:
package.type = "invalid"
assert pipeline.sort_packages_by_version(packages) == []

def test_compute_and_store_version_rank_invalid_scheme(self, pipeline):
Package.objects.create(type="invalid", namespace="test", name="package", version="1.0.0")
with patch.object(pipeline, "log") as mock_log:
pipeline.compute_and_store_version_rank()
mock_log.assert_any_call("Successfully populated `version_rank` for all packages.")
7 changes: 5 additions & 2 deletions vulnerabilities/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,11 @@ def test_sort_by_version(self):
version="3.0.0",
)

sorted_pkgs = requesting_package.sort_by_version(vuln_pkg_list)
first_sorted_item = sorted_pkgs[0]
requesting_package.calculate_version_rank

sorted_pkgs = Package.objects.filter(package_url__in=list_to_sort)

sorted_pkgs = list(sorted_pkgs)

assert sorted_pkgs[0].purl == "pkg:npm/sequelize@3.9.1"
assert sorted_pkgs[-1].purl == "pkg:npm/sequelize@3.40.1"
Expand Down

0 comments on commit c923364

Please sign in to comment.