Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,9 @@ def factory(cls, endpoint, timeout, host_conn = None, *args, **kwargs):
if conn.is_unsupported_proto_version:
raise ProtocolVersionUnsupported(endpoint, conn.protocol_version)
raise conn.last_error
elif conn.is_closed or conn.is_defunct:
raise ConnectionShutdown(
"Connection to %s was closed during setup" % (endpoint,))
elif not conn.connected_event.is_set():
conn.close()
raise OperationTimedOut("Timed out creating connection (%s seconds)" % timeout)
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,41 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange):
second_run = list(itertools.islice(gen.generate(0, 2), 5))

assert first_run == second_run


class FactoryCloseRaceTest(unittest.TestCase):
"""Tests for Connection.factory() handling connections closed during setup."""

def _make_fake_connection_class(self, is_closed=False, is_defunct=False, last_error=None):
"""Create a fake connection class whose __init__ sets up minimal state
needed by factory() without actually connecting to anything."""
from threading import Event

class FakeConnection(Connection):
def __init__(self, endpoint, *args, **kwargs): # noqa - intentionally skips super().__init__
self.connected_event = Event()
self.connected_event.set()
self.is_closed = is_closed
self.is_defunct = is_defunct
self.is_unsupported_proto_version = False
self.last_error = last_error
self.endpoint = endpoint

return FakeConnection

def test_factory_raises_on_closed_during_setup(self):
FakeConn = self._make_fake_connection_class(is_closed=True)
with pytest.raises(ConnectionShutdown, match="closed during setup"):
FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5)

def test_factory_raises_on_defunct_during_setup(self):
FakeConn = self._make_fake_connection_class(is_defunct=True)
with pytest.raises(ConnectionShutdown, match="closed during setup"):
FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5)

def test_factory_returns_conn_when_connected_normally(self):
FakeConn = self._make_fake_connection_class(is_closed=False, is_defunct=False)
result = FakeConn.factory(DefaultEndPoint('1.2.3.4'), timeout=5)
assert result is not None
assert not result.is_closed
assert not result.is_defunct
Loading