diff --git a/tests/advanced/test_replicaset.py b/tests/advanced/test_replicaset.py index 801c01af..516fdea7 100644 --- a/tests/advanced/test_replicaset.py +++ b/tests/advanced/test_replicaset.py @@ -17,7 +17,7 @@ from time import time from bson import SON -from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.errors import AutoReconnect, ConfigurationError, NotPrimaryError from twisted.internet import defer, reactor from twisted.trial import unittest @@ -346,3 +346,26 @@ def test_StaleConnection(self): finally: self.__mongod[0].kill(signal.SIGCONT) yield conn.disconnect() + + @defer.inlineCallbacks + def test_CloseConnectionAfterPrimaryStepDown(self): + conn = ConnectionPool(self.master_with_guaranteed_write) + try: + yield conn.db.coll.insert_one({"x": 42}) + + got_not_primary_error = False + + while True: + try: + yield conn.db.coll.find_one() + if got_not_primary_error: + # We got error and then restored — OK + break + yield self.__sleep(1) + yield conn.admin.command({"replSetStepDown": 86400, "force": 1}) + except (NotPrimaryError, AutoReconnect): + got_not_primary_error = True + + finally: + yield conn.disconnect() + self.flushLoggedErrors(NotPrimaryError) diff --git a/tests/basic/test_bulk.py b/tests/basic/test_bulk.py index 05db828a..d38e6071 100644 --- a/tests/basic/test_bulk.py +++ b/tests/basic/test_bulk.py @@ -269,20 +269,18 @@ def test_OperationFailure(self): def fake_send_query(*args): return defer.succeed( - Msg( - body=bson.encode( - { - "ok": 0.0, - "errmsg": "operation was interrupted", - "code": 11602, - "codeName": "InterruptedDueToReplStateChange", - } - ) + Msg.create( + { + "ok": 0.0, + "errmsg": "operation was interrupted", + "code": 11602, + "codeName": "InterruptedDueToReplStateChange", + } ) ) with patch( - "txmongo.protocol.MongoProtocol.send_msg", side_effect=fake_send_query + "txmongo.protocol.MongoProtocol._send_raw_msg", side_effect=fake_send_query ): yield self.assertFailure( self.coll.bulk_write( diff --git a/tests/basic/test_protocol.py b/tests/basic/test_protocol.py index 0a5d0180..9947f771 100644 --- a/tests/basic/test_protocol.py +++ b/tests/basic/test_protocol.py @@ -87,24 +87,24 @@ def test_EncodeDecodeReply(self): self.assertEqual(decoded.documents, request.documents) def test_EncodeDecodeMsg(self): - request = Msg( - response_to=123, - flag_bits=OP_MSG_MORE_TO_COME, - body=bson.encode({"a": 1, "$db": "dbname"}), + request = Msg.create( + body={"a": 1, "$db": "dbname"}, payload={ "documents": [ - bson.encode({"a": 1}), - bson.encode({"a": 2}), + {"a": 1}, + {"a": 2}, ], "updates": [ - bson.encode({"$set": {"z": 1}}), - bson.encode({"$set": {"z": 2}}), + {"$set": {"z": 1}}, + {"$set": {"z": 2}}, ], "deletes": [ - bson.encode({"_id": ObjectId()}), - bson.encode({"_id": ObjectId()}), + {"_id": ObjectId()}, + {"_id": ObjectId()}, ], }, + acknowledged=False, + response_to=123, ) decoded = self._encode_decode(request) diff --git a/tests/basic/test_queries.py b/tests/basic/test_queries.py index f48083c5..52ba4655 100644 --- a/tests/basic/test_queries.py +++ b/tests/basic/test_queries.py @@ -203,10 +203,7 @@ def test_CursorClosingWithTimeout(self): {"$where": "sleep(100); true"}, batch_size=5, timeout=0.8 ) with patch.object( - MongoProtocol, - "send_msg", - side_effect=MongoProtocol.send_msg, - autospec=True, + MongoProtocol, "send_msg", side_effect=MongoProtocol.send_msg, autospec=True ) as mock: with self.assertRaises(TimeExceeded): yield dfr diff --git a/tests/mongod.py b/tests/mongod.py index b455c88a..f60d03ac 100644 --- a/tests/mongod.py +++ b/tests/mongod.py @@ -96,7 +96,7 @@ def stop(self): if self._proc and self._proc.pid: d = defer.Deferred() self._notify_stop.append(d) - self._proc.signalProcess("INT") + self.kill("INT") return d else: return defer.fail("Not started yet") diff --git a/tests/utils.py b/tests/utils.py index 94f9222a..2240af41 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,4 @@ +from pymongo.errors import AutoReconnect from twisted.internet import defer from twisted.trial import unittest @@ -15,5 +16,11 @@ def setUp(self): @defer.inlineCallbacks def tearDown(self): - yield self.coll.drop() + while True: + try: + yield self.coll.drop() + break + except AutoReconnect: + pass + yield self.conn.disconnect() diff --git a/txmongo/_bulk.py b/txmongo/_bulk.py index c0e1a975..a63b84d9 100644 --- a/txmongo/_bulk.py +++ b/txmongo/_bulk.py @@ -18,27 +18,18 @@ validate_ok_for_update, ) +from txmongo._bulk_constants import ( + _DELETE, + _INSERT, + _UPDATE, + COMMAND_NAME, + PAYLOAD_ARG_NAME, +) from txmongo.protocol import MongoProtocol, Msg from txmongo.types import Document _WriteOp = Union[InsertOne, UpdateOne, UpdateMany, ReplaceOne, DeleteOne, DeleteMany] -_INSERT = 0 -_UPDATE = 1 -_DELETE = 2 - -COMMAND_NAME = { - _INSERT: "insert", - _UPDATE: "update", - _DELETE: "delete", -} - -PAYLOAD_ARG_NAME = { - _INSERT: "documents", - _UPDATE: "updates", - _DELETE: "deletes", -} - class _Run: op_type: int diff --git a/txmongo/_bulk_constants.py b/txmongo/_bulk_constants.py new file mode 100644 index 00000000..1d95483d --- /dev/null +++ b/txmongo/_bulk_constants.py @@ -0,0 +1,13 @@ +_INSERT = 0 +_UPDATE = 1 +_DELETE = 2 +COMMAND_NAME = { + _INSERT: "insert", + _UPDATE: "update", + _DELETE: "delete", +} +PAYLOAD_ARG_NAME = { + _INSERT: "documents", + _UPDATE: "updates", + _DELETE: "deletes", +} diff --git a/txmongo/collection.py b/txmongo/collection.py index 772e9270..7d1aa7f8 100644 --- a/txmongo/collection.py +++ b/txmongo/collection.py @@ -6,7 +6,6 @@ from operator import itemgetter from typing import Iterable, List, Optional -import bson from bson import ObjectId from bson.codec_options import CodecOptions from bson.son import SON @@ -32,19 +31,10 @@ from twisted.python.compat import comparable from txmongo import filter as qf -from txmongo._bulk import _INSERT, _Bulk, _Run -from txmongo.protocol import ( - OP_MSG_MORE_TO_COME, - QUERY_PARTIAL, - QUERY_SLAVE_OK, - MongoProtocol, - Msg, -) -from txmongo.pymongo_internals import ( - _check_command_response, - _check_write_command_response, - _merge_command, -) +from txmongo._bulk import _Bulk, _Run +from txmongo._bulk_constants import _INSERT +from txmongo.protocol import QUERY_PARTIAL, QUERY_SLAVE_OK, MongoProtocol, Msg +from txmongo.pymongo_internals import _check_write_command_response, _merge_command from txmongo.types import Document from txmongo.utils import check_deadline, timeout @@ -413,9 +403,8 @@ def query(): "$showDiskLoc": "showRecordId", # <= MongoDB 3.0 } - @classmethod def _gen_find_command( - cls, + self, db_name: str, coll_name: str, filter_with_modifiers, @@ -425,12 +414,16 @@ def _gen_find_command( batch_size, allow_partial_results, flags: int, - ): + ) -> Msg: cmd = {"find": coll_name} if "$query" in filter_with_modifiers: cmd.update( [ - (cls._MODIFIERS[key], val) if key in cls._MODIFIERS else (key, val) + ( + (self._MODIFIERS[key], val) + if key in self._MODIFIERS + else (key, val) + ) for key, val in filter_with_modifiers.items() ] ) @@ -459,19 +452,17 @@ def _gen_find_command( cmd = {"explain": cmd} cmd["$db"] = db_name - return cmd + return Msg.create(cmd, codec_options=self.codec_options) def __close_cursor_without_response(self, proto: MongoProtocol, cursor_id: int): proto.send_msg( - Msg( - flag_bits=OP_MSG_MORE_TO_COME, - body=bson.encode( - { - "killCursors": self.name, - "$db": self._database.name, - "cursors": [cursor_id], - }, - ), + Msg.create( + { + "killCursors": self.name, + "$db": self._database.name, + "cursors": [cursor_id], + }, + acknowledged=False, ) ) @@ -524,7 +515,7 @@ def after_connection(proto): flags, ) - return proto.send_simple_msg(cmd, codec_options).addCallback( + return proto.send_msg(cmd, codec_options).addCallback( after_reply, after_reply, proto ) @@ -532,7 +523,7 @@ def after_connection(proto): # after_reply can reference to itself directly but this will create a circular # reference between closure and function object which will add unnecessary # work for GC. - def after_reply(reply, this_func, proto, fetched=0): + def after_reply(reply: dict, this_func, proto, fetched=0): try: check_deadline(_deadline) except Exception: @@ -541,8 +532,6 @@ def after_reply(reply, this_func, proto, fetched=0): self.__close_cursor_without_response(proto, cursor_id) raise - _check_command_response(reply) - if "cursor" not in reply: # For example, when we run `explain` command return [reply], defer.succeed(([], None)) @@ -586,7 +575,9 @@ def after_reply(reply, this_func, proto, fetched=0): if batch_size: get_more["batchSize"] = batch_size - next_reply = proto.send_simple_msg(get_more, codec_options) + next_reply = proto.send_msg( + Msg.create(get_more, codec_options=codec_options), codec_options + ) next_reply.addCallback(this_func, this_func, proto, fetched) return out, next_reply @@ -697,26 +688,23 @@ def _insert_one( document["_id"] = ObjectId() inserted_id = document["_id"] - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode( - { - "insert": self.name, - "$db": self.database.name, - "writeConcern": self.write_concern.document, - } - ), - payload={ - "documents": [bson.encode(document, codec_options=self.codec_options)], + msg = Msg.create( + { + "insert": self.name, + "$db": self.database.name, + "writeConcern": self.write_concern.document, }, + { + "documents": [document], + }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response: Optional[Msg] = yield proto.send_msg(msg) - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply: Optional[dict] = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) return InsertOneResult(inserted_id, self.write_concern.acknowledged) @@ -761,37 +749,30 @@ def gen(): @defer.inlineCallbacks def _update(self, filter, update, upsert, multi, _deadline): - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode( - { - "update": self.name, - "$db": self.database.name, - "writeConcern": self.write_concern.document, - } - ), - payload={ + msg = Msg.create( + { + "update": self.name, + "$db": self.database.name, + "writeConcern": self.write_concern.document, + }, + { "updates": [ - bson.encode( - { - "q": filter, - "u": update, - "upsert": bool(upsert), - "multi": bool(multi), - }, - codec_options=self.codec_options, - ) + { + "q": filter, + "u": update, + "upsert": bool(upsert), + "multi": bool(multi), + } ], }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response = yield proto.send_msg(msg) - reply = None - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) if reply.get("n") and "upserted" in reply: # MongoDB >= 2.6.0 returns the upsert _id in an array @@ -916,29 +897,24 @@ def _delete( if let: body["let"] = let - msg = Msg( - flag_bits=Msg.create_flag_bits(self.write_concern.acknowledged), - body=bson.encode(body), - payload={ + msg = Msg.create( + body, + { "deletes": [ - bson.encode( - { - "q": filter, - "limit": 0 if multi else 1, - }, - codec_options=self.codec_options, - ) + { + "q": filter, + "limit": 0 if multi else 1, + }, ], }, + codec_options=self.codec_options, + acknowledged=self.write_concern.acknowledged, ) proto = yield self._database.connection.getprotocol() check_deadline(_deadline) - response = yield proto.send_msg(msg) - reply = None - if response: - reply = bson.decode(response.body, codec_options=self.codec_options) - _check_command_response(reply) + reply = yield proto.send_msg(msg, self.codec_options) + if reply: _check_write_command_response(reply) return DeleteResult(reply, self.write_concern.acknowledged) @@ -1256,13 +1232,9 @@ def _execute_bulk(self, bulk: _Bulk, _deadline: Optional[float]): } def accumulate_response(response: dict, run: _Run, idx_offset: int) -> dict: - _check_command_response(response) _merge_command(run, full_result, idx_offset, response) return response - def decode_response(response: Msg, codec_options: CodecOptions) -> dict: - return bson.decode(response.body, codec_options=codec_options) - got_error = False for run in bulk.gen_runs(): for doc_offset, msg in run.gen_messages( @@ -1273,9 +1245,8 @@ def decode_response(response: Msg, codec_options: CodecOptions) -> dict: self.codec_options, ): check_deadline(_deadline) - deferred = proto.send_msg(msg) + deferred = proto.send_msg(msg, self.codec_options) if effective_write_concern.acknowledged: - deferred.addCallback(decode_response, self.codec_options) if bulk.ordered: reply = yield deferred accumulate_response(reply, run, doc_offset) diff --git a/txmongo/database.py b/txmongo/database.py index 7a4910b4..be1147bb 100644 --- a/txmongo/database.py +++ b/txmongo/database.py @@ -78,12 +78,16 @@ def command( proto = yield self.connection.getprotocol() check_deadline(_deadline) - reply = yield proto.send_simple_msg(command, codec_options) - if check: - msg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( - repr(command), self - ) - _check_command_response(reply, msg, allowable_errors) + errmsg = "TxMongo: command {0} on namespace {1} failed with '%s'".format( + repr(command), self + ) + reply = yield proto.send_msg( + Msg.create(command, codec_options=codec_options), + codec_options, + check=check, + errmsg=errmsg, + allowable_errors=allowable_errors, + ) return reply @timeout diff --git a/txmongo/protocol.py b/txmongo/protocol.py index 75f99343..251be06a 100644 --- a/txmongo/protocol.py +++ b/txmongo/protocol.py @@ -23,10 +23,10 @@ from dataclasses import dataclass, field from hashlib import sha1 from random import SystemRandom -from typing import Dict, List +from typing import Dict, List, Optional import bson -from bson import SON, Binary, CodecOptions +from bson import DEFAULT_CODEC_OPTIONS, SON, Binary, CodecOptions from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -38,6 +38,10 @@ from twisted.internet import defer, error, protocol from twisted.python import failure, log +from txmongo.pymongo_errors import _NOT_MASTER_CODES +from txmongo.pymongo_internals import _check_command_response +from txmongo.types import Document + try: from pymongo.synchronous import auth except ImportError: @@ -254,6 +258,35 @@ def opcode(cls): def create_flag_bits(cls, not_more_to_come: bool) -> int: return 0 if not_more_to_come else OP_MSG_MORE_TO_COME + @classmethod + def create( + cls, + body: Document, + payload: Dict[str, List[Document]] = None, + *, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + acknowledged: bool = True, + request_id: int = 0, + response_to: int = 0, + ) -> "Msg": + encoded_payload = {} + if payload: + encoded_payload = { + key: [bson.encode(doc, codec_options=codec_options) for doc in docs] + for key, docs in payload.items() + } + return Msg( + request_id=request_id, + response_to=response_to, + body=bson.encode(body, codec_options=codec_options), + flag_bits=0 if acknowledged else OP_MSG_MORE_TO_COME, + payload=encoded_payload, + ) + + @property + def acknowledged(self) -> bool: + return (self.flag_bits & OP_MSG_MORE_TO_COME) == 0 + def size_in_bytes(self) -> int: """return estimated overall message length including messageLength and requestID""" # checksum is not added since we don't support it for now @@ -517,29 +550,45 @@ def send_query(self, request): request_id = self._send(request) return self.__wait_for_reply_to(request_id) - def send_msg(self, msg: Msg) -> defer.Deferred[Msg]: - """Send Msg (OP_MSG) and return deferred. - - If OP_MSG has OP_MSG_MORE_TO_COME flag set, returns already fired deferred with None as a result. - """ + def _send_raw_msg(self, msg: Msg) -> defer.Deferred[Optional[Msg]]: + """Send OP_MSG and return result as Msg object (or None if not acknowledged)""" request_id = self._send(msg) - if msg.flag_bits & OP_MSG_MORE_TO_COME: - return defer.succeed(None) - return self.__wait_for_reply_to(request_id) - - def send_simple_msg( - self, body: dict, codec_options: CodecOptions - ) -> defer.Deferred[dict]: - """Send simple OP_MSG without extracted payload and return parsed response.""" + if msg.acknowledged: + return self.__wait_for_reply_to(request_id) + return defer.succeed(None) - def on_response(response: Msg): - reply = bson.decode(response.body, codec_options) - for key, bin_docs in msg.payload.items(): - reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] - return reply + @defer.inlineCallbacks + def send_msg( + self, + msg: Msg, + codec_options: CodecOptions = DEFAULT_CODEC_OPTIONS, + *, + check: bool = True, + errmsg: str = None, + allowable_errors=None, + ) -> defer.Deferred[Optional[dict]]: + """Send OP_MSG and return parsed response as dict.""" + + response = yield self._send_raw_msg(msg) + if response is None: + return + + reply = bson.decode(response.body, codec_options) + for key, bin_docs in msg.payload.items(): + reply[key] = [bson.decode(doc, codec_options) for doc in bin_docs] + + if reply.get("ok") == 0: + if reply.get("code") in _NOT_MASTER_CODES: + self.transport.loseConnection() + raise NotPrimaryError( + "TxMongo: " + reply.get("errmsg", "Unknown error") + ) - msg = Msg(body=bson.encode(body, codec_options=codec_options)) - return self.send_msg(msg).addCallback(on_response) + if check: + _check_command_response( + reply, msg=errmsg, allowable_errors=allowable_errors + ) + return reply def handle(self, request: BaseMessage): if isinstance(request, Reply): @@ -552,7 +601,7 @@ def handle(self, request: BaseMessage): logLevel=logging.WARNING, ) - def handle_reply(self, request): + def handle_reply(self, request: Reply): if request.response_to in self.__deferreds: df = self.__deferreds.pop(request.response_to) if request.response_flags & REPLY_QUERY_FAILURE: @@ -560,7 +609,8 @@ def handle_reply(self, request): code = doc.get("code") msg = "TxMongo: " + doc.get("$err", "Unknown error") fail_conn = False - if code == 13435: + + if code in _NOT_MASTER_CODES: err = NotPrimaryError(msg) fail_conn = True else: diff --git a/txmongo/pymongo_errors.py b/txmongo/pymongo_errors.py new file mode 100644 index 00000000..119634d6 --- /dev/null +++ b/txmongo/pymongo_errors.py @@ -0,0 +1,34 @@ +# Copied from pymongo/helpers.py:32 at commit d7d94b2776098dba32686ddf3ada1f201172daaf + +# From the SDAM spec, the "node is shutting down" codes. +_SHUTDOWN_CODES = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) +# From the SDAM spec, the "not master" error codes are combined with the +# "node is recovering" error codes (of which the "node is shutting down" +# errors are a subset). +_NOT_MASTER_CODES = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotMaster + 13435, # NotMasterNoSlaveOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotMasterOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) +# From the retryable writes spec. +_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + ] +) diff --git a/txmongo/pymongo_internals.py b/txmongo/pymongo_internals.py index 4fa3c646..503df823 100644 --- a/txmongo/pymongo_internals.py +++ b/txmongo/pymongo_internals.py @@ -1,4 +1,6 @@ -from typing import Any, Mapping, MutableMapping, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional from pymongo.errors import ( CursorNotFound, @@ -11,41 +13,11 @@ WTimeoutError, ) -from txmongo._bulk import _DELETE, _INSERT, _UPDATE, _Run - -# Copied from pymongo/helpers.py:32 at commit d7d94b2776098dba32686ddf3ada1f201172daaf +from txmongo._bulk_constants import _DELETE, _INSERT, _UPDATE +from txmongo.pymongo_errors import _NOT_MASTER_CODES -# From the SDAM spec, the "node is shutting down" codes. -_SHUTDOWN_CODES = frozenset( - [ - 11600, # InterruptedAtShutdown - 91, # ShutdownInProgress - ] -) -# From the SDAM spec, the "not master" error codes are combined with the -# "node is recovering" error codes (of which the "node is shutting down" -# errors are a subset). -_NOT_MASTER_CODES = ( - frozenset( - [ - 10107, # NotMaster - 13435, # NotMasterNoSlaveOk - 11602, # InterruptedDueToReplStateChange - 13436, # NotMasterOrSecondary - 189, # PrimarySteppedDown - ] - ) - | _SHUTDOWN_CODES -) -# From the retryable writes spec. -_RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset( - [ - 7, # HostNotFound - 6, # HostUnreachable - 89, # NetworkTimeout - 9001, # SocketException - ] -) +if TYPE_CHECKING: + from txmongo._bulk import _Run # Copied from pymongo/helpers.py:193 at commit 47b0d8ebfd6cefca80c1e4521b47aec7cf8f529d