I've implemented an async connection pool (ConnectionPool) and transport layer (Transport) in Python for managing persistent connections to a backend service. The goal is to handle multiple clients efficiently with sticky sessions, timeouts, and connection reuse. However, I'm unsure about:
TLS/SSL Integration:
Currently, the code doesn't implement TLS. How should I properly add TLS encryption to the Transport class (e.g., using asyncio.open_connection with SSL contexts)?
Are there security pitfalls (e.g., certificate verification, protocol versions) to avoid?
Connection Pool Optimization:
The pool uses a Semaphore for max connections and sticky sessions via sticky_map. Is this approach thread-safe and scalable for high concurrency?
How can I improve connection reuse (e.g., health checks, idle timeout)?
Dynamic Client Handling:
The current design assumes a single backend address. How can I extend it to support dynamic endpoints (e.g., load balancing across multiple hosts)?
Should I pre-warm connections or implement lazy initialization?
Error Handling:
Are there critical edge cases (e.g., partial writes, zombie connections) that aren’t handled robustly?
Code Reference: ConnectionPool: Manages connections with sticky sessions and semaphore-based limits.
Transport: Handles low-level socket communication with retries and timeouts. Connection_pool.py
import asyncio
import logging
from transport import Transport
logger = logging.getLogger(__name__)
class ConnectionPool:
def __init__(self, address, max_connections=100, timeout=5):
self.address = address
self.max_connections = max_connections
self.timeout = timeout
self.connections = set() # All active connections (alive or dead)
self.available_connections = asyncio.Queue() # Connections ready for reuse
self.sticky_map = {} # request_id -> Transport (lazily cleaned)
self.semaphore = asyncio.Semaphore(max_connections) # Limits concurrency
async def _create_connection(self):
"""Create and return a new connection (internal)."""
host, port = self.address.split(":")
conn = Transport(host, int(port), self.timeout)
try:
if await conn.connect():
self.connections.add(conn)
return conn
except Exception as e:
logger.error(f"Connection failed: {e}")
return None
async def get_connection(self, request_id=None, sticky=False):
"""Get a connection, reusing sticky/available ones or creating new."""
# Fast path: Reuse sticky connection if valid
if sticky and request_id:
async with asyncio.Lock():
if request_id in self.sticky_map:
conn = self.sticky_map[request_id]
if conn.is_alive():
return conn
del self.sticky_map[request_id] # Cleanup dead sticky
# Reuse available connections
while not self.available_connections.empty():
conn = await self.available_connections.get()
if conn.is_alive():
if sticky and request_id:
async with asyncio.Lock():
self.sticky_map[request_id] = conn
return conn
else:
async with asyncio.Lock():
self.connections.discard(conn) # Remove dead
# Create a new connection if under limit
async with self.semaphore:
conn = await self._create_connection()
if conn and sticky and request_id:
async with asyncio.Lock():
self.sticky_map[request_id] = conn
return conn
async def release_connection(self, conn, request_id=None):
"""Release a connection back to the pool."""
if request_id:
async with asyncio.Lock():
self.sticky_map.pop(request_id, None) # Unstick if needed
if conn.is_alive():
await self.available_connections.put(conn) # Reuse alive connections
else:
async with asyncio.Lock():
self.connections.discard(conn) # Remove dead
self.semaphore.release() # Important: Release after cleanup!
async def send_request(self, request, sticky=False):
"""Send a request using the pool (handles timeouts and errors)."""
await self.semaphore.acquire()
conn = None
try:
conn = await self.get_connection(request.id, sticky)
if not conn:
return {"error": "No connection available."}
response = await asyncio.wait_for(
conn.send_request(request),
timeout=self.timeout
)
return response
except asyncio.TimeoutError:
logger.error("Request timed out.")
return {"error": "Timeout"}
except Exception as e:
logger.error(f"Request failed: {e}")
return {"error": str(e)}
finally:
if conn:
await self.release_connection(conn, request.id if sticky else None)
async def close_all(self):
"""Close all connections and reset the pool."""
async with asyncio.Lock():
for conn in self.connections:
await conn.close()
self.connections.clear()
self.sticky_map.clear()
while not self.available_connections.empty():
self.available_connections.get_nowait()
Transport.py
import asyncio
import struct
import logging
import response_pb2 as pb2_response
logger = logging.getLogger(__name__)
class Transport:
def __init__(self, host, port, timeout=5):
self.host = host
self.port = port
self.timeout = timeout
self.reader = None
self.writer = None
self.lock = asyncio.Lock() # Per-connection lock
self.connected = False
async def connect(self):
"""Establish a connection with retries and exponential backoff."""
retry_delay = 1
for _ in range(3):
try:
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(self.host, self.port),
timeout=self.timeout
)
self.connected = True
logger.info(f"Connected to {self.host}:{self.port}")
return True
except (asyncio.TimeoutError, ConnectionRefusedError) as e:
logger.error(f"Connection failed: {e}")
break
except Exception as e:
logger.warning(f"Retrying connection: {e}")
await asyncio.sleep(retry_delay)
retry_delay *= 2
self.connected = False
return False
def is_alive(self):
"""Check if the connection is active."""
return (
self.connected
and self.writer
and not self.writer.is_closing()
)
async def send_request(self, request):
"""Thread-safe request sending with timeout."""
if not self.is_alive() and not await self.connect():
return {"error": "Connection failed."}
async with self.lock: # Ensure only one coroutine uses this connection
try:
# Serialize and send request
req_data = request.SerializeToString()
req_len = struct.pack(">I", len(req_data))
self.writer.write(req_len + req_data)
await self.writer.drain()
# Await response
return await asyncio.wait_for(
self.receive_response(),
timeout=self.timeout
)
except Exception as e:
logger.error(f"Request failed: {e}")
self.connected = False
return {"error": str(e)}
async def receive_response(self):
"""Receive and parse a protobuf response."""
try:
len_buf = await self.reader.readexactly(4)
resp_len = struct.unpack(">I", len_buf)[0]
response_data = await self.reader.readexactly(resp_len)
response = pb2_response.Response()
response.ParseFromString(response_data)
return response
except Exception as e:
self.connected = False
return {"error": f"Receiver error: {e}"}
async def close(self):
"""Gracefully close the connection."""
if self.writer:
self.writer.close()
await self.writer.wait_closed()
self.connected = False
Observe the codes, implement the changes and give me the updated code files