Skip to content

Commit

Permalink
Backport requote_uri
Browse files Browse the repository at this point in the history
  • Loading branch information
lexiforest committed Dec 3, 2024
1 parent 6f040f0 commit 5fc3be5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 12 deletions.
58 changes: 55 additions & 3 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .. import AsyncCurl, Curl, CurlError, CurlHttpVersion, CurlInfo, CurlOpt, CurlSslVersion
from ..curl import CURL_WRITEFUNC_ERROR, CurlMime
from .cookies import Cookies, CookieTypes, CurlMorsel
from .exceptions import ImpersonateError, RequestException, SessionClosed, code2error
from .exceptions import InvalidURL, ImpersonateError, RequestException, SessionClosed, code2error
from .headers import Headers, HeaderTypes
from .impersonate import BrowserType # noqa: F401
from .impersonate import (
Expand Down Expand Up @@ -148,7 +148,6 @@ def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str:
new_args_counter = Counter(x[0] for x in params)
for key, value in params:
# Bool and Dict values should be converted to json-friendly values
# you may throw this part away if you don't like it :)
if isinstance(value, (bool, dict)):
value = dumps(value)
# 1 to 1 mapping, we have to search and update it.
Expand All @@ -174,6 +173,57 @@ def _update_url_params(url: str, params: Union[Dict, List, Tuple]) -> str:
return new_url


# Adapted from: https://github.com/psf/requests/blob/1ae6fc3137a11e11565ed22436aa1e77277ac98c/src%2Frequests%2Futils.py#L633-L682
# License: Apache 2.0

# The unreserved URI characters (RFC 3986)
UNRESERVED_SET = frozenset(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + "0123456789-._~"
)


def unquote_unreserved(uri: str) -> str:
"""Un-escape any percent-escape sequences in a URI that are unreserved
characters. This leaves all reserved, illegal and non-ASCII bytes encoded.
"""
parts = uri.split("%")
for i in range(1, len(parts)):
h = parts[i][0:2]
if len(h) == 2 and h.isalnum():
try:
c = chr(int(h, 16))
except ValueError:
raise InvalidURL(f"Invalid percent-escape sequence: '{h}'")

if c in UNRESERVED_SET:
parts[i] = c + parts[i][2:]
else:
parts[i] = f"%{parts[i]}"
else:
parts[i] = f"%{parts[i]}"
return "".join(parts)


def requote_uri(uri: str) -> str:
"""Re-quote the given URI.
This function passes the given URI through an unquote/quote cycle to
ensure that it is fully and consistently quoted.
"""
safe_with_percent = "!#$%&'()*+,/:;=?@[]~|"
safe_without_percent = "!#$&'()*+,/:;=?@[]~|"
try:
# Unquote only the unreserved characters
# Then quote only illegal characters (do not quote reserved,
# unreserved, or '%')
return quote(unquote_unreserved(uri), safe=safe_with_percent)
except InvalidURL:
# We couldn't unquote the given URI, so let's try quoting it, but
# there may be unquoted '%'s in the URI. We need to make sure they're
# properly quoted so they do not cause issues elsewhere.
return quote(uri, safe=safe_without_percent)


# TODO: should we move this function to headers.py?
def _update_header_line(header_lines: List[str], key: str, value: str, replace: bool = False):
"""Update header line list by key value pair."""
Expand Down Expand Up @@ -418,8 +468,10 @@ def _set_curl_options(
url = _update_url_params(url, params)
if self.base_url:
url = urljoin(self.base_url, url)
if quote is not False:
if quote:
url = _quote_path_and_params(url, quote_str=quote)
if quote is not False:
url = requote_uri(url)
c.setopt(CurlOpt.URL, url.encode())

# data/body/json
Expand Down
31 changes: 22 additions & 9 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def test_url_encode(server):

# should not change
url = "http://127.0.0.1:8000/%2f%2f%2f"
r = requests.get(str(url))
assert r.url == str(url)
r = requests.get(url)
assert r.url == url

url = "http://127.0.0.1:8000/imaginary-pagination:7"
r = requests.get(str(url))
Expand All @@ -168,15 +168,17 @@ def test_url_encode(server):
r = requests.get(str(url))
assert r.url == url

# Non-ASCII URL should be percent encoded as UTF-8 sequence
non_ascii_url = "http://127.0.0.1:8000/search?q=测试"
encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95"
# NOTE: this seems to be unnecessary

r = requests.get(non_ascii_url)
assert r.url == encoded_non_ascii_url
# Non-ASCII URL should be percent encoded as UTF-8 sequence
# non_ascii_url = "http://127.0.0.1:8000/search?q=测试"
# encoded_non_ascii_url = "http://127.0.0.1:8000/search?q=%E6%B5%8B%E8%AF%95"
#
# r = requests.get(non_ascii_url)
# assert r.url == encoded_non_ascii_url

r = requests.get(encoded_non_ascii_url)
assert r.url == encoded_non_ascii_url
# r = requests.get(encoded_non_ascii_url)
# assert r.url == encoded_non_ascii_url

# should be quoted
url = "http://127.0.0.1:8000/e x a m p l e"
Expand Down Expand Up @@ -209,6 +211,17 @@ def test_url_encode(server):
r = requests.get(url, quote=False)
assert r.url == url

# empty values should be kept
url = "http://127.0.0.1:8000/api?param1=value1&param2=&param3=value3"
r = requests.get(url)
assert r.url == url

# Do not unquote
url = "http://127.0.0.1:8000/path?token=example%7C2024-10-20T10%3A00%3A00Z"
r = requests.get(url)
print(r.url)
assert r.url == url


def test_headers(server):
r = requests.get(str(server.url.copy_with(path="/echo_headers")), headers={"foo": "bar"})
Expand Down

0 comments on commit 5fc3be5

Please sign in to comment.