Skip to content

Commit

Permalink
SNOW-1902019: Python CVEs january batch (#2154)
Browse files Browse the repository at this point in the history
Co-authored-by: Jamison <jamison.rose@snowflake.com>
Co-authored-by: Adam Ling <adam.ling@snowflake.com>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent ec3002d commit 3769b43
Show file tree
Hide file tree
Showing 9 changed files with 354 additions and 47 deletions.
7 changes: 6 additions & 1 deletion src/snowflake/connector/auth/_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
ProgrammingError,
ServiceUnavailableError,
)
from ..file_util import owner_rw_opener
from ..network import (
ACCEPT_TYPE_APPLICATION_SNOWFLAKE,
CONTENT_TYPE_APPLICATION_JSON,
Expand Down Expand Up @@ -625,7 +626,11 @@ def flush_temporary_credentials() -> None:
)
try:
with open(
TEMPORARY_CREDENTIAL_FILE, "w", encoding="utf-8", errors="ignore"
TEMPORARY_CREDENTIAL_FILE,
"w",
encoding="utf-8",
errors="ignore",
opener=owner_rw_opener,
) as f:
json.dump(TEMPORARY_CREDENTIAL, f)
except Exception as ex:
Expand Down
14 changes: 11 additions & 3 deletions src/snowflake/connector/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def __init__(
file_path: str | dict[str, str],
entry_lifetime: int = constants.DAY_IN_SECONDS,
file_timeout: int = 0,
load_if_file_exists: bool = True,
) -> None:
"""Inits an SFDictFileCache with path, lifetime.
Expand Down Expand Up @@ -445,7 +446,7 @@ def __init__(
self._file_lock_path = f"{self.file_path}.lock"
self._file_lock = FileLock(self._file_lock_path, timeout=self.file_timeout)
self.last_loaded: datetime.datetime | None = None
if os.path.exists(self.file_path):
if os.path.exists(self.file_path) and load_if_file_exists:
with self._lock:
self._load()
# indicate whether the cache is modified or not, this variable is for
Expand Down Expand Up @@ -498,7 +499,7 @@ def _load(self) -> bool:
"""Load cache from disk if possible, returns whether it was able to load."""
try:
with open(self.file_path, "rb") as r_file:
other: SFDictFileCache = pickle.load(r_file)
other: SFDictFileCache = self._deserialize(r_file)
# Since we want to know whether we are dirty after loading
# we have to know whether the file could learn anything from self
# so instead of calling self.update we call other.update and swap
Expand Down Expand Up @@ -529,6 +530,13 @@ def load(self) -> bool:
with self._lock:
return self._load()

def _serialize(self):
return pickle.dumps(self)

@classmethod
def _deserialize(cls, r_file):
return pickle.load(r_file)

def _save(self, load_first: bool = True, force_flush: bool = False) -> bool:
"""Save cache to disk if possible, returns whether it was able to save.
Expand Down Expand Up @@ -559,7 +567,7 @@ def _save(self, load_first: bool = True, force_flush: bool = False) -> bool:
# python program.
# thus we fall back to the approach using the normal open() method to open a file and write.
with open(tmp_file, "wb") as w_file:
w_file.write(pickle.dumps(self))
w_file.write(self._serialize())
# We write to a tmp file and then move it to have atomic write
os.replace(tmp_file_path, self.file_path)
self.last_loaded = datetime.datetime.fromtimestamp(
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/connector/encryption_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD
from .constants import UTF8, EncryptionMetadata, MaterialDescriptor, kilobyte
from .file_util import owner_rw_opener
from .util_text import random_string

block_size = int(algorithms.AES.block_size / 8) # in bytes
Expand Down Expand Up @@ -213,7 +214,7 @@ def decrypt_file(

logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file)
with open(in_filename, "rb") as infile:
with open(temp_output_file, "wb") as outfile:
with open(temp_output_file, "wb", opener=owner_rw_opener) as outfile:
SnowflakeEncryptionUtil.decrypt_stream(
metadata, encryption_material, infile, outfile, chunk_size
)
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/connector/file_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
logger = getLogger(__name__)


def owner_rw_opener(path, flags) -> int:
return os.open(path, flags, mode=0o600)


class SnowflakeFileUtil:
@staticmethod
def get_digest_and_size(src: IO[bytes]) -> tuple[str, int]:
Expand Down
166 changes: 161 additions & 5 deletions src/snowflake/connector/ocsp_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import codecs
import importlib
import json
import os
import platform
Expand All @@ -30,6 +31,7 @@
from asn1crypto.x509 import Certificate
from OpenSSL.SSL import Connection

from snowflake.connector import SNOWFLAKE_CONNECTOR_VERSION
from snowflake.connector.compat import OK, urlsplit, urlunparse
from snowflake.connector.constants import HTTP_HEADER_USER_AGENT
from snowflake.connector.errorcode import (
Expand Down Expand Up @@ -58,9 +60,10 @@

from . import constants
from .backoff_policies import exponential_backoff
from .cache import SFDictCache, SFDictFileCache
from .cache import CacheEntry, SFDictCache, SFDictFileCache
from .telemetry import TelemetryField, generate_telemetry_data_dict
from .url_util import extract_top_level_domain_from_hostname, url_encode_str
from .util_text import _base64_bytes_to_str


class OCSPResponseValidationResult(NamedTuple):
Expand All @@ -72,27 +75,180 @@ class OCSPResponseValidationResult(NamedTuple):
ts: int | None = None
validated: bool = False

def _serialize(self):
def serialize_exception(exc):
# serialization exception is not supported for all exceptions
# in the ocsp_snowflake.py, most exceptions are RevocationCheckError which is easy to serialize.
# however, it would require non-trivial effort to serialize other exceptions especially 3rd part errors
# as there can be un-serializable members and nondeterministic constructor arguments.
# here we do a general best efforts serialization for other exceptions recording only the error message.
if not exc:
return None

exc_type = type(exc)
ret = {"class": exc_type.__name__, "module": exc_type.__module__}
if isinstance(exc, RevocationCheckError):
ret.update({"errno": exc.errno, "msg": exc.raw_msg})
else:
ret.update({"msg": str(exc)})
return ret

return json.dumps(
{
"exception": serialize_exception(self.exception),
"issuer": (
_base64_bytes_to_str(self.issuer.dump()) if self.issuer else None
),
"subject": (
_base64_bytes_to_str(self.subject.dump()) if self.subject else None
),
"cert_id": (
_base64_bytes_to_str(self.cert_id.dump()) if self.cert_id else None
),
"ocsp_response": _base64_bytes_to_str(self.ocsp_response),
"ts": self.ts,
"validated": self.validated,
}
)

@classmethod
def _deserialize(cls, json_str: str) -> OCSPResponseValidationResult:
json_obj = json.loads(json_str)

def deserialize_exception(exception_dict: dict | None) -> Exception | None:
# as pointed out in the serialization method, here we do the best effort deserialization
# for non-RevocationCheckError exceptions. If we can not deserialize the exception, we will
# return a RevocationCheckError with a message indicating the failure.
if not exception_dict:
return
exc_class = exception_dict.get("class")
exc_module = exception_dict.get("module")
try:
if (
exc_class == "RevocationCheckError"
and exc_module == "snowflake.connector.errors"
):
return RevocationCheckError(
msg=exception_dict["msg"],
errno=exception_dict["errno"],
)
else:
module = importlib.import_module(exc_module)
exc_cls = getattr(module, exc_class)
return exc_cls(exception_dict["msg"])
except Exception as deserialize_exc:
logger.debug(
f"hitting error {str(deserialize_exc)} while deserializing exception,"
f" the original error error class and message are {exc_class} and {exception_dict['msg']}"
)
return RevocationCheckError(
f"Got error {str(deserialize_exc)} while deserializing ocsp cache, please try "
f"cleaning up the "
f"OCSP cache under directory {OCSP_RESPONSE_VALIDATION_CACHE.file_path}",
errno=ER_OCSP_RESPONSE_LOAD_FAILURE,
)

return OCSPResponseValidationResult(
exception=deserialize_exception(json_obj.get("exception")),
issuer=(
Certificate.load(b64decode(json_obj.get("issuer")))
if json_obj.get("issuer")
else None
),
subject=(
Certificate.load(b64decode(json_obj.get("subject")))
if json_obj.get("subject")
else None
),
cert_id=(
CertId.load(b64decode(json_obj.get("cert_id")))
if json_obj.get("cert_id")
else None
),
ocsp_response=(
b64decode(json_obj.get("ocsp_response"))
if json_obj.get("ocsp_response")
else None
),
ts=json_obj.get("ts"),
validated=json_obj.get("validated"),
)


class _OCSPResponseValidationResultCache(SFDictFileCache):
def _serialize(self) -> bytes:
entries = {
(
_base64_bytes_to_str(k[0]),
_base64_bytes_to_str(k[1]),
_base64_bytes_to_str(k[2]),
): (v.expiry.isoformat(), v.entry._serialize())
for k, v in self._cache.items()
}

return json.dumps(
{
"cache_keys": list(entries.keys()),
"cache_items": list(entries.values()),
"entry_lifetime": self._entry_lifetime.total_seconds(),
"file_path": str(self.file_path),
"file_timeout": self.file_timeout,
"last_loaded": (
self.last_loaded.isoformat() if self.last_loaded else None
),
"telemetry": self.telemetry,
"connector_version": SNOWFLAKE_CONNECTOR_VERSION, # reserved for schema version control
}
).encode()

@classmethod
def _deserialize(cls, opened_fd) -> _OCSPResponseValidationResultCache:
data = json.loads(opened_fd.read().decode())
cache_instance = cls(
file_path=data["file_path"],
entry_lifetime=int(data["entry_lifetime"]),
file_timeout=data["file_timeout"],
load_if_file_exists=False,
)
cache_instance.file_path = os.path.expanduser(data["file_path"])
cache_instance.telemetry = data["telemetry"]
cache_instance.last_loaded = (
datetime.fromisoformat(data["last_loaded"]) if data["last_loaded"] else None
)
for k, v in zip(data["cache_keys"], data["cache_items"]):
cache_instance._cache[
(b64decode(k[0]), b64decode(k[1]), b64decode(k[2]))
] = CacheEntry(
datetime.fromisoformat(v[0]),
OCSPResponseValidationResult._deserialize(v[1]),
)
return cache_instance


try:
OCSP_RESPONSE_VALIDATION_CACHE: SFDictFileCache[
tuple[bytes, bytes, bytes],
OCSPResponseValidationResult,
] = SFDictFileCache(
] = _OCSPResponseValidationResultCache(
entry_lifetime=constants.DAY_IN_SECONDS,
file_path={
"linux": os.path.join(
"~", ".cache", "snowflake", "ocsp_response_validation_cache"
"~", ".cache", "snowflake", "ocsp_response_validation_cache.json"
),
"darwin": os.path.join(
"~", "Library", "Caches", "Snowflake", "ocsp_response_validation_cache"
"~",
"Library",
"Caches",
"Snowflake",
"ocsp_response_validation_cache.json",
),
"windows": os.path.join(
"~",
"AppData",
"Local",
"Snowflake",
"Caches",
"ocsp_response_validation_cache",
"ocsp_response_validation_cache.json",
),
},
)
Expand Down
9 changes: 7 additions & 2 deletions src/snowflake/connector/storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ def _send_request_with_retry(
f"{verb} with url {url} failed for exceeding maximum retries."
)

def _open_intermediate_dst_path(self, mode):
if not self.intermediate_dst_path.exists():
self.intermediate_dst_path.touch(mode=0o600)
return self.intermediate_dst_path.open(mode)

def prepare_download(self) -> None:
# TODO: add nicer error message for when target directory is not writeable
# but this should be done before we get here
Expand All @@ -352,13 +357,13 @@ def prepare_download(self) -> None:
self.num_of_chunks = ceil(file_header.content_length / self.chunk_size)

# Preallocate encrypted file.
with self.intermediate_dst_path.open("wb+") as fd:
with self._open_intermediate_dst_path("wb+") as fd:
fd.truncate(self.meta.src_file_size)

def write_downloaded_chunk(self, chunk_id: int, data: bytes) -> None:
"""Writes given data to the temp location starting at chunk_id * chunk_size."""
# TODO: should we use chunking and write content in smaller chunks?
with self.intermediate_dst_path.open("rb+") as fd:
with self._open_intermediate_dst_path("rb+") as fd:
fd.seek(self.chunk_size * chunk_id)
fd.write(data)

Expand Down
5 changes: 5 additions & 0 deletions src/snowflake/connector/util_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import base64
import hashlib
import logging
import random
Expand Down Expand Up @@ -292,6 +293,10 @@ def random_string(
return "".join([prefix, random_part, suffix])


def _base64_bytes_to_str(x) -> str | None:
return base64.b64encode(x).decode("utf-8") if x else None


def get_md5(text: str | bytes) -> bytes:
if isinstance(text, str):
text = text.encode("utf-8")
Expand Down
8 changes: 5 additions & 3 deletions test/extras/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,18 @@
assert (
cache_files
== {
"ocsp_response_validation_cache.lock",
"ocsp_response_validation_cache",
"ocsp_response_validation_cache.json.lock",
"ocsp_response_validation_cache.json",
"ocsp_response_cache.json",
}
and not platform.system() == "Windows"
) or (
cache_files
== {
"ocsp_response_validation_cache",
"ocsp_response_validation_cache.json",
"ocsp_response_cache.json",
}
and platform.system() == "Windows"
), str(
cache_files
)
Loading

0 comments on commit 3769b43

Please sign in to comment.