Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Validate proposed file extensions #32830

Merged
merged 3 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
import io
import itertools
import json
import types
import xml.etree.ElementTree

from youtube_dl.utils import (
_UnsafeExtensionError,
age_restricted,
args_to_str,
base_url,
Expand Down Expand Up @@ -270,6 +272,27 @@ def env(var):
expand_path('~/%s' % env('YOUTUBE_DL_EXPATH_PATH')),
'%s/expanded' % compat_getenv('HOME'))

_uncommon_extensions = [
('exe', 'abc.exe.ext'),
('de', 'abc.de.ext'),
('../.mp4', None),
('..\\.mp4', None),
]

def assertUnsafeExtension(self, ext=None):
assert_raises = self.assertRaises(_UnsafeExtensionError)
assert_raises.ext = ext
orig_exit = assert_raises.__exit__

def my_exit(self_, exc_type, exc_val, exc_tb):
did_raise = orig_exit(exc_type, exc_val, exc_tb)
if did_raise and assert_raises.ext is not None:
self.assertEqual(assert_raises.ext, assert_raises.exception.extension, 'Unsafe extension not as unexpected')
return did_raise

assert_raises.__exit__ = types.MethodType(my_exit, assert_raises)
return assert_raises

def test_prepend_extension(self):
self.assertEqual(prepend_extension('abc.ext', 'temp'), 'abc.temp.ext')
self.assertEqual(prepend_extension('abc.ext', 'temp', 'ext'), 'abc.temp.ext')
Expand All @@ -278,6 +301,19 @@ def test_prepend_extension(self):
self.assertEqual(prepend_extension('.abc', 'temp'), '.abc.temp')
self.assertEqual(prepend_extension('.abc.ext', 'temp'), '.abc.temp.ext')

# Test uncommon extensions
self.assertEqual(prepend_extension('abc.ext', 'bin'), 'abc.bin.ext')
for ext, result in self._uncommon_extensions:
with self.assertUnsafeExtension(ext):
prepend_extension('abc', ext)
if result:
self.assertEqual(prepend_extension('abc.ext', ext, 'ext'), result)
else:
with self.assertUnsafeExtension(ext):
prepend_extension('abc.ext', ext, 'ext')
with self.assertUnsafeExtension(ext):
prepend_extension('abc.unexpected_ext', ext, 'ext')

def test_replace_extension(self):
self.assertEqual(replace_extension('abc.ext', 'temp'), 'abc.temp')
self.assertEqual(replace_extension('abc.ext', 'temp', 'ext'), 'abc.temp')
Expand All @@ -286,6 +322,16 @@ def test_replace_extension(self):
self.assertEqual(replace_extension('.abc', 'temp'), '.abc.temp')
self.assertEqual(replace_extension('.abc.ext', 'temp'), '.abc.temp')

# Test uncommon extensions
self.assertEqual(replace_extension('abc.ext', 'bin'), 'abc.unknown_video')
for ext, _ in self._uncommon_extensions:
with self.assertUnsafeExtension(ext):
replace_extension('abc', ext)
with self.assertUnsafeExtension(ext):
replace_extension('abc.ext', ext, 'ext')
with self.assertUnsafeExtension(ext):
replace_extension('abc.unexpected_ext', ext, 'ext')

def test_subtitles_filename(self):
self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt'), 'abc.en.vtt')
self.assertEqual(subtitles_filename('abc.ext', 'en', 'vtt', 'ext'), 'abc.en.vtt')
Expand Down
17 changes: 17 additions & 0 deletions youtube_dl/YoutubeDL.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import datetime
import errno
import functools
import io
import itertools
import json
Expand Down Expand Up @@ -53,6 +54,7 @@
compat_urllib_request_DataHandler,
)
from .utils import (
_UnsafeExtensionError,
age_restricted,
args_to_str,
bug_reports_message,
Expand Down Expand Up @@ -129,6 +131,20 @@
import ctypes


def _catch_unsafe_file_extension(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except _UnsafeExtensionError as error:
self.report_error(
'{0} found; to avoid damaging your system, this value is disallowed.'
' If you believe this is an error{1}').format(
error.message, bug_reports_message(','))

return wrapper


class YoutubeDL(object):
"""YoutubeDL class.

Expand Down Expand Up @@ -1925,6 +1941,7 @@ def print_optional(field):
if self.params.get('forcejson', False):
self.to_stdout(json.dumps(self.sanitize_info(info_dict)))

@_catch_unsafe_file_extension
def process_info(self, info_dict):
"""Process a single resolved IE result."""

Expand Down
4 changes: 4 additions & 0 deletions youtube_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
workaround_optparse_bug9161,
)
from .utils import (
_UnsafeExtensionError,
DateRange,
decodeOption,
DEFAULT_OUTTMPL,
Expand Down Expand Up @@ -173,6 +174,9 @@ def _real_main(argv=None):
if opts.ap_mso and opts.ap_mso not in MSO_INFO:
parser.error('Unsupported TV Provider, use --ap-list-mso to get a list of supported TV Providers')

if opts.no_check_extensions:
_UnsafeExtensionError.lenient = True

def parse_retries(retries):
if retries in ('inf', 'infinite'):
parsed_retries = float('inf')
Expand Down
4 changes: 4 additions & 0 deletions youtube_dl/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@ def _comma_separated_values_options_callback(option, opt_str, value, parser):
'--no-check-certificate',
action='store_true', dest='no_check_certificate', default=False,
help='Suppress HTTPS certificate validation')
workarounds.add_option(
'--no-check-extensions',
action='store_true', dest='no_check_extensions', default=False,
help='Suppress file extension validation')
workarounds.add_option(
'--prefer-insecure',
'--prefer-unsecure', action='store_true', dest='prefer_insecure',
Expand Down
173 changes: 148 additions & 25 deletions youtube_dl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,21 +1717,6 @@ def random_user_agent():
'PST': -8, 'PDT': -7 # Pacific
}

KNOWN_EXTENSIONS = (
'mp4', 'm4a', 'm4p', 'm4b', 'm4r', 'm4v', 'aac',
'flv', 'f4v', 'f4a', 'f4b',
'webm', 'ogg', 'ogv', 'oga', 'ogx', 'spx', 'opus',
'mkv', 'mka', 'mk3d',
'avi', 'divx',
'mov',
'asf', 'wmv', 'wma',
'3gp', '3g2',
'mp3',
'flac',
'ape',
'wav',
'f4f', 'f4m', 'm3u8', 'smil')

# needed for sanitizing filenames in restricted mode
ACCENT_CHARS = dict(zip('ÂÃÄÀÁÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖŐØŒÙÚÛÜŰÝÞßàáâãäåæçèéêëìíîïðñòóôõöőøœùúûüűýþÿ',
itertools.chain('AAAAAA', ['AE'], 'CEEEEIIIIDNOOOOOOO', ['OE'], 'UUUUUY', ['TH', 'ss'],
Expand Down Expand Up @@ -3959,19 +3944,22 @@ def parse_duration(s):
return duration


def prepend_extension(filename, ext, expected_real_ext=None):
def _change_extension(prepend, filename, ext, expected_real_ext=None):
name, real_ext = os.path.splitext(filename)
return (
'{0}.{1}{2}'.format(name, ext, real_ext)
if not expected_real_ext or real_ext[1:] == expected_real_ext
else '{0}.{1}'.format(filename, ext))
sanitize_extension = _UnsafeExtensionError.sanitize_extension

if not expected_real_ext or real_ext.partition('.')[0::2] == ('', expected_real_ext):
filename = name
if prepend and real_ext:
sanitize_extension(ext, prepend=prepend)
return ''.join((filename, '.', ext, real_ext))

def replace_extension(filename, ext, expected_real_ext=None):
name, real_ext = os.path.splitext(filename)
return '{0}.{1}'.format(
name if not expected_real_ext or real_ext[1:] == expected_real_ext else filename,
ext)
# Mitigate path traversal and file impersonation attacks
return '.'.join((filename, sanitize_extension(ext)))


prepend_extension = functools.partial(_change_extension, True)
replace_extension = functools.partial(_change_extension, False)


def check_executable(exe, args=[]):
Expand Down Expand Up @@ -6561,3 +6549,138 @@ def join_nonempty(*values, **kwargs):
if from_dict is not None:
values = (traverse_obj(from_dict, variadic(v)) for v in values)
return delim.join(map(compat_str, filter(None, values)))


class Namespace(object):
"""Immutable namespace"""

def __init__(self, **kw_attr):
self.__dict__.update(kw_attr)

def __iter__(self):
return iter(self.__dict__.values())

@property
def items_(self):
return self.__dict__.items()


MEDIA_EXTENSIONS = Namespace(
common_video=('avi', 'flv', 'mkv', 'mov', 'mp4', 'webm'),
video=('3g2', '3gp', 'f4v', 'mk3d', 'divx', 'mpg', 'ogv', 'm4v', 'wmv'),
common_audio=('aiff', 'alac', 'flac', 'm4a', 'mka', 'mp3', 'ogg', 'opus', 'wav'),
audio=('aac', 'ape', 'asf', 'f4a', 'f4b', 'm4b', 'm4p', 'm4r', 'oga', 'ogx', 'spx', 'vorbis', 'wma', 'weba'),
thumbnails=('jpg', 'png', 'webp'),
# storyboards=('mhtml', ),
subtitles=('srt', 'vtt', 'ass', 'lrc', 'ttml'),
manifests=('f4f', 'f4m', 'm3u8', 'smil', 'mpd'),
)
MEDIA_EXTENSIONS.video = MEDIA_EXTENSIONS.common_video + MEDIA_EXTENSIONS.video
MEDIA_EXTENSIONS.audio = MEDIA_EXTENSIONS.common_audio + MEDIA_EXTENSIONS.audio

KNOWN_EXTENSIONS = (
MEDIA_EXTENSIONS.video + MEDIA_EXTENSIONS.audio
+ MEDIA_EXTENSIONS.manifests
)


class _UnsafeExtensionError(Exception):
"""
Mitigation exception for unwanted file overwrite/path traversal

Ref: https://github.com/yt-dlp/yt-dlp/security/advisories/GHSA-79w7-vh3h-8g4j
"""
_ALLOWED_EXTENSIONS = frozenset(itertools.chain(
( # internal
'description',
'json',
'meta',
'orig',
'part',
'temp',
'uncut',
'unknown_video',
'ytdl',
),
# video
MEDIA_EXTENSIONS.video, (
'avif',
'ismv',
'm2ts',
'm4s',
'mng',
'mpeg',
'qt',
'swf',
'ts',
'vp9',
'wvm',
),
# audio
MEDIA_EXTENSIONS.audio, (
'isma',
'mid',
'mpga',
'ra',
),
# image
MEDIA_EXTENSIONS.thumbnails, (
'bmp',
'gif',
'ico',
'heic',
'jng',
'jpeg',
'jxl',
'svg',
'tif',
'wbmp',
),
# subtitle
MEDIA_EXTENSIONS.subtitles, (
'dfxp',
'fs',
'ismt',
'sami',
'scc',
'ssa',
'tt',
),
# others
MEDIA_EXTENSIONS.manifests,
(
# not used in yt-dl
# *MEDIA_EXTENSIONS.storyboards,
# 'desktop',
# 'ism',
# 'm3u',
# 'sbv',
# 'swp',
# 'url',
# 'webloc',
# 'xml',
)))

def __init__(self, extension):
super(_UnsafeExtensionError, self).__init__('unsafe file extension: {0!r}'.format(extension))
self.extension = extension

# support --no-check-extensions
lenient = False

@classmethod
def sanitize_extension(cls, extension, **kwargs):
# ... /, *, prepend=False
prepend = kwargs.get('prepend', False)

if '/' in extension or '\\' in extension:
raise cls(extension)

if not prepend:
last = extension.rpartition('.')[-1]
if last == 'bin':
extension = last = 'unknown_video'
if not (cls.lenient or last.lower() in cls._ALLOWED_EXTENSIONS):
raise cls(extension)

return extension
Loading