diff --git a/src/snowflake/connector/auth/_auth.py b/src/snowflake/connector/auth/_auth.py index e0cc71499..1881ab2fc 100644 --- a/src/snowflake/connector/auth/_auth.py +++ b/src/snowflake/connector/auth/_auth.py @@ -52,6 +52,7 @@ ProgrammingError, ServiceUnavailableError, ) +from ..file_util import owner_rw_opener from ..network import ( ACCEPT_TYPE_APPLICATION_SNOWFLAKE, CONTENT_TYPE_APPLICATION_JSON, @@ -625,7 +626,11 @@ def flush_temporary_credentials() -> None: ) try: with open( - TEMPORARY_CREDENTIAL_FILE, "w", encoding="utf-8", errors="ignore" + TEMPORARY_CREDENTIAL_FILE, + "w", + encoding="utf-8", + errors="ignore", + opener=owner_rw_opener, ) as f: json.dump(TEMPORARY_CREDENTIAL, f) except Exception as ex: diff --git a/src/snowflake/connector/cache.py b/src/snowflake/connector/cache.py index 739f7643a..68885fefa 100644 --- a/src/snowflake/connector/cache.py +++ b/src/snowflake/connector/cache.py @@ -388,6 +388,7 @@ def __init__( file_path: str | dict[str, str], entry_lifetime: int = constants.DAY_IN_SECONDS, file_timeout: int = 0, + load_if_file_exists: bool = True, ) -> None: """Inits an SFDictFileCache with path, lifetime. @@ -445,7 +446,7 @@ def __init__( self._file_lock_path = f"{self.file_path}.lock" self._file_lock = FileLock(self._file_lock_path, timeout=self.file_timeout) self.last_loaded: datetime.datetime | None = None - if os.path.exists(self.file_path): + if os.path.exists(self.file_path) and load_if_file_exists: with self._lock: self._load() # indicate whether the cache is modified or not, this variable is for @@ -498,7 +499,7 @@ def _load(self) -> bool: """Load cache from disk if possible, returns whether it was able to load.""" try: with open(self.file_path, "rb") as r_file: - other: SFDictFileCache = pickle.load(r_file) + other: SFDictFileCache = self._deserialize(r_file) # Since we want to know whether we are dirty after loading # we have to know whether the file could learn anything from self # so instead of calling self.update we call other.update and swap @@ -529,6 +530,13 @@ def load(self) -> bool: with self._lock: return self._load() + def _serialize(self): + return pickle.dumps(self) + + @classmethod + def _deserialize(cls, r_file): + return pickle.load(r_file) + def _save(self, load_first: bool = True, force_flush: bool = False) -> bool: """Save cache to disk if possible, returns whether it was able to save. @@ -559,7 +567,7 @@ def _save(self, load_first: bool = True, force_flush: bool = False) -> bool: # python program. # thus we fall back to the approach using the normal open() method to open a file and write. with open(tmp_file, "wb") as w_file: - w_file.write(pickle.dumps(self)) + w_file.write(self._serialize()) # We write to a tmp file and then move it to have atomic write os.replace(tmp_file_path, self.file_path) self.last_loaded = datetime.datetime.fromtimestamp( diff --git a/src/snowflake/connector/encryption_util.py b/src/snowflake/connector/encryption_util.py index c1c34079e..add7e885e 100644 --- a/src/snowflake/connector/encryption_util.py +++ b/src/snowflake/connector/encryption_util.py @@ -17,6 +17,7 @@ from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD from .constants import UTF8, EncryptionMetadata, MaterialDescriptor, kilobyte +from .file_util import owner_rw_opener from .util_text import random_string block_size = int(algorithms.AES.block_size / 8) # in bytes @@ -213,7 +214,7 @@ def decrypt_file( logger.debug("encrypted file: %s, tmp file: %s", in_filename, temp_output_file) with open(in_filename, "rb") as infile: - with open(temp_output_file, "wb") as outfile: + with open(temp_output_file, "wb", opener=owner_rw_opener) as outfile: SnowflakeEncryptionUtil.decrypt_stream( metadata, encryption_material, infile, outfile, chunk_size ) diff --git a/src/snowflake/connector/file_util.py b/src/snowflake/connector/file_util.py index d89e72185..04744f76e 100644 --- a/src/snowflake/connector/file_util.py +++ b/src/snowflake/connector/file_util.py @@ -21,6 +21,10 @@ logger = getLogger(__name__) +def owner_rw_opener(path, flags) -> int: + return os.open(path, flags, mode=0o600) + + class SnowflakeFileUtil: @staticmethod def get_digest_and_size(src: IO[bytes]) -> tuple[str, int]: diff --git a/src/snowflake/connector/ocsp_snowflake.py b/src/snowflake/connector/ocsp_snowflake.py index fe9c44225..c8a3d611a 100644 --- a/src/snowflake/connector/ocsp_snowflake.py +++ b/src/snowflake/connector/ocsp_snowflake.py @@ -6,6 +6,7 @@ from __future__ import annotations import codecs +import importlib import json import os import platform @@ -30,6 +31,7 @@ from asn1crypto.x509 import Certificate from OpenSSL.SSL import Connection +from snowflake.connector import SNOWFLAKE_CONNECTOR_VERSION from snowflake.connector.compat import OK, urlsplit, urlunparse from snowflake.connector.constants import HTTP_HEADER_USER_AGENT from snowflake.connector.errorcode import ( @@ -58,9 +60,10 @@ from . import constants from .backoff_policies import exponential_backoff -from .cache import SFDictCache, SFDictFileCache +from .cache import CacheEntry, SFDictCache, SFDictFileCache from .telemetry import TelemetryField, generate_telemetry_data_dict from .url_util import extract_top_level_domain_from_hostname, url_encode_str +from .util_text import _base64_bytes_to_str class OCSPResponseValidationResult(NamedTuple): @@ -72,19 +75,172 @@ class OCSPResponseValidationResult(NamedTuple): ts: int | None = None validated: bool = False + def _serialize(self): + def serialize_exception(exc): + # serialization exception is not supported for all exceptions + # in the ocsp_snowflake.py, most exceptions are RevocationCheckError which is easy to serialize. + # however, it would require non-trivial effort to serialize other exceptions especially 3rd part errors + # as there can be un-serializable members and nondeterministic constructor arguments. + # here we do a general best efforts serialization for other exceptions recording only the error message. + if not exc: + return None + + exc_type = type(exc) + ret = {"class": exc_type.__name__, "module": exc_type.__module__} + if isinstance(exc, RevocationCheckError): + ret.update({"errno": exc.errno, "msg": exc.raw_msg}) + else: + ret.update({"msg": str(exc)}) + return ret + + return json.dumps( + { + "exception": serialize_exception(self.exception), + "issuer": ( + _base64_bytes_to_str(self.issuer.dump()) if self.issuer else None + ), + "subject": ( + _base64_bytes_to_str(self.subject.dump()) if self.subject else None + ), + "cert_id": ( + _base64_bytes_to_str(self.cert_id.dump()) if self.cert_id else None + ), + "ocsp_response": _base64_bytes_to_str(self.ocsp_response), + "ts": self.ts, + "validated": self.validated, + } + ) + + @classmethod + def _deserialize(cls, json_str: str) -> OCSPResponseValidationResult: + json_obj = json.loads(json_str) + + def deserialize_exception(exception_dict: dict | None) -> Exception | None: + # as pointed out in the serialization method, here we do the best effort deserialization + # for non-RevocationCheckError exceptions. If we can not deserialize the exception, we will + # return a RevocationCheckError with a message indicating the failure. + if not exception_dict: + return + exc_class = exception_dict.get("class") + exc_module = exception_dict.get("module") + try: + if ( + exc_class == "RevocationCheckError" + and exc_module == "snowflake.connector.errors" + ): + return RevocationCheckError( + msg=exception_dict["msg"], + errno=exception_dict["errno"], + ) + else: + module = importlib.import_module(exc_module) + exc_cls = getattr(module, exc_class) + return exc_cls(exception_dict["msg"]) + except Exception as deserialize_exc: + logger.debug( + f"hitting error {str(deserialize_exc)} while deserializing exception," + f" the original error error class and message are {exc_class} and {exception_dict['msg']}" + ) + return RevocationCheckError( + f"Got error {str(deserialize_exc)} while deserializing ocsp cache, please try " + f"cleaning up the " + f"OCSP cache under directory {OCSP_RESPONSE_VALIDATION_CACHE.file_path}", + errno=ER_OCSP_RESPONSE_LOAD_FAILURE, + ) + + return OCSPResponseValidationResult( + exception=deserialize_exception(json_obj.get("exception")), + issuer=( + Certificate.load(b64decode(json_obj.get("issuer"))) + if json_obj.get("issuer") + else None + ), + subject=( + Certificate.load(b64decode(json_obj.get("subject"))) + if json_obj.get("subject") + else None + ), + cert_id=( + CertId.load(b64decode(json_obj.get("cert_id"))) + if json_obj.get("cert_id") + else None + ), + ocsp_response=( + b64decode(json_obj.get("ocsp_response")) + if json_obj.get("ocsp_response") + else None + ), + ts=json_obj.get("ts"), + validated=json_obj.get("validated"), + ) + + +class _OCSPResponseValidationResultCache(SFDictFileCache): + def _serialize(self) -> bytes: + entries = { + ( + _base64_bytes_to_str(k[0]), + _base64_bytes_to_str(k[1]), + _base64_bytes_to_str(k[2]), + ): (v.expiry.isoformat(), v.entry._serialize()) + for k, v in self._cache.items() + } + + return json.dumps( + { + "cache_keys": list(entries.keys()), + "cache_items": list(entries.values()), + "entry_lifetime": self._entry_lifetime.total_seconds(), + "file_path": str(self.file_path), + "file_timeout": self.file_timeout, + "last_loaded": ( + self.last_loaded.isoformat() if self.last_loaded else None + ), + "telemetry": self.telemetry, + "connector_version": SNOWFLAKE_CONNECTOR_VERSION, # reserved for schema version control + } + ).encode() + + @classmethod + def _deserialize(cls, opened_fd) -> _OCSPResponseValidationResultCache: + data = json.loads(opened_fd.read().decode()) + cache_instance = cls( + file_path=data["file_path"], + entry_lifetime=int(data["entry_lifetime"]), + file_timeout=data["file_timeout"], + load_if_file_exists=False, + ) + cache_instance.file_path = os.path.expanduser(data["file_path"]) + cache_instance.telemetry = data["telemetry"] + cache_instance.last_loaded = ( + datetime.fromisoformat(data["last_loaded"]) if data["last_loaded"] else None + ) + for k, v in zip(data["cache_keys"], data["cache_items"]): + cache_instance._cache[ + (b64decode(k[0]), b64decode(k[1]), b64decode(k[2])) + ] = CacheEntry( + datetime.fromisoformat(v[0]), + OCSPResponseValidationResult._deserialize(v[1]), + ) + return cache_instance + try: OCSP_RESPONSE_VALIDATION_CACHE: SFDictFileCache[ tuple[bytes, bytes, bytes], OCSPResponseValidationResult, - ] = SFDictFileCache( + ] = _OCSPResponseValidationResultCache( entry_lifetime=constants.DAY_IN_SECONDS, file_path={ "linux": os.path.join( - "~", ".cache", "snowflake", "ocsp_response_validation_cache" + "~", ".cache", "snowflake", "ocsp_response_validation_cache.json" ), "darwin": os.path.join( - "~", "Library", "Caches", "Snowflake", "ocsp_response_validation_cache" + "~", + "Library", + "Caches", + "Snowflake", + "ocsp_response_validation_cache.json", ), "windows": os.path.join( "~", @@ -92,7 +248,7 @@ class OCSPResponseValidationResult(NamedTuple): "Local", "Snowflake", "Caches", - "ocsp_response_validation_cache", + "ocsp_response_validation_cache.json", ), }, ) diff --git a/src/snowflake/connector/storage_client.py b/src/snowflake/connector/storage_client.py index ba74f511b..966860f38 100644 --- a/src/snowflake/connector/storage_client.py +++ b/src/snowflake/connector/storage_client.py @@ -329,6 +329,11 @@ def _send_request_with_retry( f"{verb} with url {url} failed for exceeding maximum retries." ) + def _open_intermediate_dst_path(self, mode): + if not self.intermediate_dst_path.exists(): + self.intermediate_dst_path.touch(mode=0o600) + return self.intermediate_dst_path.open(mode) + def prepare_download(self) -> None: # TODO: add nicer error message for when target directory is not writeable # but this should be done before we get here @@ -352,13 +357,13 @@ def prepare_download(self) -> None: self.num_of_chunks = ceil(file_header.content_length / self.chunk_size) # Preallocate encrypted file. - with self.intermediate_dst_path.open("wb+") as fd: + with self._open_intermediate_dst_path("wb+") as fd: fd.truncate(self.meta.src_file_size) def write_downloaded_chunk(self, chunk_id: int, data: bytes) -> None: """Writes given data to the temp location starting at chunk_id * chunk_size.""" # TODO: should we use chunking and write content in smaller chunks? - with self.intermediate_dst_path.open("rb+") as fd: + with self._open_intermediate_dst_path("rb+") as fd: fd.seek(self.chunk_size * chunk_id) fd.write(data) diff --git a/src/snowflake/connector/util_text.py b/src/snowflake/connector/util_text.py index 52b06f328..2c24ae577 100644 --- a/src/snowflake/connector/util_text.py +++ b/src/snowflake/connector/util_text.py @@ -5,6 +5,7 @@ from __future__ import annotations +import base64 import hashlib import logging import random @@ -292,6 +293,10 @@ def random_string( return "".join([prefix, random_part, suffix]) +def _base64_bytes_to_str(x) -> str | None: + return base64.b64encode(x).decode("utf-8") if x else None + + def get_md5(text: str | bytes) -> bytes: if isinstance(text, str): text = text.encode("utf-8") diff --git a/test/extras/run.py b/test/extras/run.py index 856677552..1dab55162 100644 --- a/test/extras/run.py +++ b/test/extras/run.py @@ -35,16 +35,18 @@ assert ( cache_files == { - "ocsp_response_validation_cache.lock", - "ocsp_response_validation_cache", + "ocsp_response_validation_cache.json.lock", + "ocsp_response_validation_cache.json", "ocsp_response_cache.json", } and not platform.system() == "Windows" ) or ( cache_files == { - "ocsp_response_validation_cache", + "ocsp_response_validation_cache.json", "ocsp_response_cache.json", } and platform.system() == "Windows" + ), str( + cache_files ) diff --git a/test/unit/test_ocsp.py b/test/unit/test_ocsp.py index 700f918fe..02d83c8b3 100644 --- a/test/unit/test_ocsp.py +++ b/test/unit/test_ocsp.py @@ -5,7 +5,10 @@ from __future__ import annotations +import copy import datetime +import io +import json import logging import os import platform @@ -14,6 +17,8 @@ from os import environ, path from unittest import mock +import asn1crypto.x509 +from asn1crypto import ocsp from asn1crypto import x509 as asn1crypto509 from cryptography import x509 from cryptography.hazmat.backends import default_backend @@ -76,6 +81,40 @@ def overwrite_ocsp_cache(tmpdir): THIS_DIR = path.dirname(path.realpath(__file__)) +def create_x509_cert(hash_algorithm): + # Generate a private key + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=1024, backend=default_backend() + ) + + # Generate a public key + public_key = private_key.public_key() + + # Create a certificate + subject = x509.Name( + [ + x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), + ] + ) + + issuer = subject + + return ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(public_key) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now()) + .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=365)) + .add_extension( + x509.SubjectAlternativeName([x509.DNSName("example.com")]), + critical=False, + ) + .sign(private_key, hash_algorithm, default_backend()) + ) + + @pytest.fixture(autouse=True) def random_ocsp_response_validation_cache(): file_path = { @@ -576,38 +615,7 @@ def test_building_new_retry(): ], ) def test_signature_verification(hash_algorithm): - # Generate a private key - private_key = rsa.generate_private_key( - public_exponent=65537, key_size=1024, backend=default_backend() - ) - - # Generate a public key - public_key = private_key.public_key() - - # Create a certificate - subject = x509.Name( - [ - x509.NameAttribute(x509.NameOID.COUNTRY_NAME, "US"), - ] - ) - - issuer = subject - - cert = ( - x509.CertificateBuilder() - .subject_name(subject) - .issuer_name(issuer) - .public_key(public_key) - .serial_number(x509.random_serial_number()) - .not_valid_before(datetime.datetime.now()) - .not_valid_after(datetime.datetime.now() + datetime.timedelta(days=365)) - .add_extension( - x509.SubjectAlternativeName([x509.DNSName("example.com")]), - critical=False, - ) - .sign(private_key, hash_algorithm, default_backend()) - ) - + cert = create_x509_cert(hash_algorithm) # in snowflake, we use lib asn1crypto to load certificate, not using lib cryptography asy1_509_cert = asn1crypto509.Certificate.load(cert.public_bytes(Encoding.DER)) @@ -702,3 +710,116 @@ def test_ocsp_server_domain_name(): and SnowflakeOCSP.OCSP_WHITELIST.match("s3.amazonaws.com.cn") and not SnowflakeOCSP.OCSP_WHITELIST.match("s3.amazonaws.com.cn.com") ) + + +@pytest.mark.skipolddriver +def test_json_cache_serialization_and_deserialization(tmpdir): + from snowflake.connector.ocsp_snowflake import ( + OCSPResponseValidationResult, + _OCSPResponseValidationResultCache, + ) + + cache_path = os.path.join(tmpdir, "cache.json") + cert = asn1crypto509.Certificate.load( + create_x509_cert(hashes.SHA256()).public_bytes(Encoding.DER) + ) + cert_id = ocsp.CertId( + { + "hash_algorithm": {"algorithm": "sha1"}, # Minimal hash algorithm + "issuer_name_hash": b"\0" * 20, # Placeholder hash + "issuer_key_hash": b"\0" * 20, # Placeholder hash + "serial_number": 1, # Minimal serial number + } + ) + test_cache = _OCSPResponseValidationResultCache(file_path=cache_path) + test_cache[(b"key1", b"key2", b"key3")] = OCSPResponseValidationResult( + exception=None, + issuer=cert, + subject=cert, + cert_id=cert_id, + ocsp_response=b"response", + ts=0, + validated=True, + ) + + def verify(verify_method, write_cache): + with io.BytesIO() as byte_stream: + byte_stream.write(write_cache._serialize()) + byte_stream.seek(0) + read_cache = _OCSPResponseValidationResultCache._deserialize(byte_stream) + assert len(write_cache) == len(read_cache) + verify_method(write_cache, read_cache) + + def verify_happy_path(origin_cache, loaded_cache): + for (key1, value1), (key2, value2) in zip( + origin_cache.items(), loaded_cache.items() + ): + assert key1 == key2 + for sub_field1, sub_field2 in zip(value1, value2): + assert isinstance(sub_field1, type(sub_field2)) + if isinstance(sub_field1, asn1crypto.x509.Certificate): + for attr in [ + "issuer", + "subject", + "serial_number", + "not_valid_before", + "not_valid_after", + "hash_algo", + ]: + assert getattr(sub_field1, attr) == getattr(sub_field2, attr) + elif isinstance(sub_field1, asn1crypto.ocsp.CertId): + for attr in [ + "hash_algorithm", + "issuer_name_hash", + "issuer_key_hash", + "serial_number", + ]: + assert sub_field1.native[attr] == sub_field2.native[attr] + else: + assert sub_field1 == sub_field2 + + def verify_none(origin_cache, loaded_cache): + for (key1, value1), (key2, value2) in zip( + origin_cache.items(), loaded_cache.items() + ): + assert key1 == key2 and value1 == value2 + + def verify_exception(_, loaded_cache): + exc_1 = loaded_cache[(b"key1", b"key2", b"key3")].exception + exc_2 = loaded_cache[(b"key4", b"key5", b"key6")].exception + exc_3 = loaded_cache[(b"key7", b"key8", b"key9")].exception + assert ( + isinstance(exc_1, RevocationCheckError) + and exc_1.raw_msg == "error" + and exc_1.errno == 1 + ) + assert isinstance(exc_2, ValueError) and str(exc_2) == "value error" + assert ( + isinstance(exc_3, RevocationCheckError) + and "while deserializing ocsp cache, please try cleaning up the OCSP cache under directory" + in exc_3.msg + ) + + verify(verify_happy_path, copy.deepcopy(test_cache)) + + origin_cache = copy.deepcopy(test_cache) + origin_cache[(b"key1", b"key2", b"key3")] = OCSPResponseValidationResult( + None, None, None, None, None, None, False + ) + verify(verify_none, origin_cache) + + origin_cache = copy.deepcopy(test_cache) + origin_cache.update( + { + (b"key1", b"key2", b"key3"): OCSPResponseValidationResult( + exception=RevocationCheckError(msg="error", errno=1), + ), + (b"key4", b"key5", b"key6"): OCSPResponseValidationResult( + exception=ValueError("value error"), + ), + (b"key7", b"key8", b"key9"): OCSPResponseValidationResult( + exception=json.JSONDecodeError("json error", "doc", 0) + ), + } + ) + verify(verify_exception, origin_cache)