self.node = node
self.max_conns = max_conns
- self.conns = set()
+ self.conns = {}
self.running = False
def buildProtocol(self, addr):
- if len(self.conns) >= self.max_conns:
+ if sum(self.conns.itervalues()) >= self.max_conns or self.conns.get(self._host_to_ident(addr.host), 0) >= 3:
return None
p = Protocol(self.node, True)
p.factory = self
return p
+ def _host_to_ident(self, host):
+ a, b, c, d = host.split('.')
+ return a, b
+
def proto_made_connection(self, proto):
- self.conns.add(proto)
+ ident = self._host_to_ident(proto.transport.getPeer().host)
+ self.conns[ident] = self.conns.get(ident, 0) + 1
def proto_lost_connection(self, proto, reason):
- self.conns.remove(proto)
+ ident = self._host_to_ident(proto.transport.getPeer().host)
+ self.conns[ident] -= 1
+ if not self.conns[ident]:
+ del self.conns[ident]
def proto_connected(self, proto):
self.node.got_conn(proto)