diff --git a/cassandra/connection.py b/cassandra/connection.py index 87f860f32b..32037000ca 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -866,6 +866,9 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs): if conn.is_unsupported_proto_version: raise ProtocolVersionUnsupported(endpoint, conn.protocol_version) raise conn.last_error + elif conn.is_closed or conn.is_defunct: + raise ConnectionShutdown( + "Connection to %s was closed during setup" % (endpoint,)) elif not conn.connected_event.is_set(): conn.close() raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 6ac63ff761..d6f7f5ade9 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -571,3 +571,41 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): second_run = list(itertools.islice(gen.generate(0, 2), 5)) assert first_run == second_run + + +class FactoryCloseRaceTest(unittest.TestCase): + """Tests for Connection.factory() handling connections closed during setup.""" + + def _make_fake_connection_class(self, is_closed=False, is_defunct=False, last_error=None): + """Create a fake connection class whose __init__ sets up minimal state + needed by factory() without actually connecting to anything.""" + from threading import Event + + class FakeConnection(Connection): + def __init__(self, endpoint, *args, **kwargs): # noqa - intentionally skips super().__init__ + self.connected_event = Event() + self.connected_event.set() + self.is_closed = is_closed + self.is_defunct = is_defunct + self.is_unsupported_proto_version = False + self.last_error = last_error + self.endpoint = endpoint + + return FakeConnection + + def test_factory_raises_on_closed_during_setup(self): + FakeConn = self._make_fake_connection_class(is_closed=True) + with pytest.raises(ConnectionShutdown, match="closed during setup"): + FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5) + + def test_factory_raises_on_defunct_during_setup(self): + FakeConn = self._make_fake_connection_class(is_defunct=True) + with pytest.raises(ConnectionShutdown, match="closed during setup"): + FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5) + + def test_factory_returns_conn_when_connected_normally(self): + FakeConn = self._make_fake_connection_class(is_closed=False, is_defunct=False) + result = FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5) + assert result is not None + assert not result.is_closed + assert not result.is_defunct