Skip to content

Commit

Permalink
Fix arbitrary file write during tarfile extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali-Razmjoo committed Sep 4, 2024
1 parent 91ce61f commit 45f272c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 30 deletions.
24 changes: 4 additions & 20 deletions luigi/contrib/lsf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Expand All @@ -28,7 +28,7 @@
except ImportError:
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def do_work_on_compute_node(work_dir):
Expand All @@ -44,21 +44,6 @@ def do_work_on_compute_node(work_dir):
job.work()


def _is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory


def _safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not _is_within_directory(path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)


def extract_packages_archive(work_dir):
package_file = os.path.join(work_dir, "packages.tar")
if not os.path.exists(package_file):
Expand All @@ -67,9 +52,8 @@ def extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
_safe_extract(tar)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
8 changes: 3 additions & 5 deletions luigi/contrib/sge_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import sys
import pickle
import logging
import tarfile
from luigi.safe_extractor import SafeExtractor


def _do_work_on_compute_node(work_dir, tarball=True):
Expand Down Expand Up @@ -64,10 +64,8 @@ def _extract_packages_archive(work_dir):
curdir = os.path.abspath(os.curdir)

os.chdir(work_dir)
tar = tarfile.open(package_file)
for tarinfo in tar:
tar.extract(tarinfo)
tar.close()
extractor = SafeExtractor(work_dir)
extractor.safe_extract(package_file)
if '' not in sys.path:
sys.path.insert(0, '')

Expand Down
96 changes: 96 additions & 0 deletions luigi/safe_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
This module provides a class `SafeExtractor` that offers a secure way to extract tar files while
mitigating path traversal vulnerabilities, which can occur when files inside the archive are
crafted to escape the intended extraction directory.
The `SafeExtractor` ensures that the extracted file paths are validated before extraction to
prevent malicious archives from extracting files outside the intended directory.
Classes:
SafeExtractor: A class to securely extract tar files with protection against path traversal attacks.
Usage Example:
extractor = SafeExtractor("/desired/directory")
extractor.safe_extract("archive.tar")
"""

import os
import tarfile


class SafeExtractor:
"""
A class to safely extract tar files, ensuring that no path traversal
vulnerabilities are exploited.
Attributes:
path (str): The directory to extract files into.
Methods:
_is_within_directory(directory, target):
Checks if a target path is within a given directory.
safe_extract(tar_path, members=None, *, numeric_owner=False):
Safely extracts the contents of a tar file to the specified directory.
"""

def __init__(self, path="."):
"""
Initializes the SafeExtractor with the specified directory path.
Args:
path (str): The directory to extract files into. Defaults to the current directory.
"""
self.path = path

def _is_within_directory(self, directory, target):
"""
Checks if a target path is within a given directory.
Args:
directory (str): The directory to check against.
target (str): The target path to check.
Returns:
bool: True if the target path is within the directory, False otherwise.
"""
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory

def safe_extract(self, tar_path, members=None, *, numeric_owner=False):
"""
Safely extracts the contents of a tar file to the specified directory.
Args:
tar_path (str): The path to the tar file to extract.
members (list, optional): A list of members to extract. Defaults to None.
numeric_owner (bool, optional): If True, only the numeric owner will be used. Defaults to False.
Raises:
ValueError: If a path traversal attempt is detected.
"""
with tarfile.open(tar_path, 'r') as tar:
for member in tar.getmembers():
member_path = os.path.join(self.path, member.name)
if not self._is_within_directory(self.path, member_path):
raise ValueError("Attempted Path Traversal in Tar File")
tar.extractall(self.path, members, numeric_owner=numeric_owner)
11 changes: 6 additions & 5 deletions test/contrib/lsf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import shutil

import pytest
from luigi.safe_extractor import SafeExtractor

DEFAULT_HOME = ''

Expand Down Expand Up @@ -123,8 +124,8 @@ def test_safe_extract(self):
f.write(f'This is test file {i}')
tar.add(file_path, arcname=f'test_file_{i}.txt')

with tarfile.open(tar_path, 'r') as tar:
_safe_extract(tar, self.temp_dir)
extractor = SafeExtractor(self.temp_dir)
extractor.safe_extract(tar_path)

for i in range(3):
file_path = os.path.join(self.temp_dir, f'test_file_{i}.txt')
Expand All @@ -141,9 +142,9 @@ def test_safe_extract_with_traversal(self):
f.write('This is a test file')
tar.add(file_path, arcname='../../test_file.txt')

with tarfile.open(tar_path, 'r') as tar:
with self.assertRaises(ValueError):
_safe_extract(tar, self.temp_dir)
extractor = SafeExtractor(self.temp_dir)
with self.assertRaises(ValueError):
extractor.safe_extract(tar_path)


if __name__ == '__main__':
Expand Down

0 comments on commit 45f272c

Please sign in to comment.