use a dict for sessions, instead of a list. Deprecates #48
authorThomasV <thomasv@gitorious>
Fri, 6 Dec 2013 10:12:07 +0000 (14:12 +0400)
committerThomasV <thomasv@gitorious>
Fri, 6 Dec 2013 10:12:07 +0000 (14:12 +0400)
backends/bitcoind/blockchain_processor.py
processor.py
server.py
transports/stratum_http.py
transports/stratum_tcp.py

index 6891c41..fb7b957 100644 (file)
@@ -925,9 +925,12 @@ class BlockchainProcessor(Processor):
                 print_log("cache: invalidating", address)
                 self.history_cache.pop(address)
 
-        if address in self.watched_addresses:
+        with self.watch_lock:
+            sessions = self.watched_addresses.get(address)
+
+        if sessions:
             # TODO: update cache here. if new value equals cached value, do not send notification
-            self.address_queue.put(address)
+            self.address_queue.put((address,sessions))
 
     def main_iteration(self):
         if self.shared.stopped():
@@ -962,12 +965,12 @@ class BlockchainProcessor(Processor):
 
         while True:
             try:
-                addr = self.address_queue.get(False)
+                addr, sessions = self.address_queue.get(False)
             except:
                 break
 
             status = self.get_status(addr)
-            for session in self.watched_addresses.get(addr,[]):
+            for session in sessions:
                 self.push_response(session, {
                         'id': None,
                         'method': 'blockchain.address.subscribe',
index 4094580..0bdb994 100644 (file)
@@ -82,7 +82,7 @@ class RequestDispatcher(threading.Thread):
         self.response_queue = queue.Queue()
         self.lock = threading.Lock()
         self.idlock = threading.Lock()
-        self.sessions = []
+        self.sessions = {}
         self.processors = {}
 
     def push_response(self, session, item):
@@ -98,13 +98,16 @@ class RequestDispatcher(threading.Thread):
         return self.request_queue.get()
 
     def get_session_by_address(self, address):
-        for x in self.sessions:
+        for x in self.sessions.values():
             if x.address == address:
                 return x
 
     def run(self):
         if self.shared is None:
             raise TypeError("self.shared not set in Processor")
+
+        lastgc = 0 
+
         while not self.shared.stopped():
             session, request = self.pop_request()
             try:
@@ -112,6 +115,10 @@ class RequestDispatcher(threading.Thread):
             except:
                 traceback.print_exc(file=sys.stdout)
 
+            if time.time() - lastgc > 60.0:
+                self.collect_garbage()
+                lastgc = time.time()
+
         self.stop()
 
     def stop(self):
@@ -147,44 +154,32 @@ class RequestDispatcher(threading.Thread):
 
     def get_sessions(self):
         with self.lock:
-            r = self.sessions[:]
+            r = self.sessions.values()
         return r
 
     def add_session(self, session):
+        key = session.key()
         with self.lock:
-            self.sessions.append(session)
+            self.sessions[key] = session
 
-    def collect_garbage(self):
-        # Deep copy entire sessions list and blank it
-        # This is done to minimize lock contention
+    def remove_session(self, session):
+        key = session.key()
         with self.lock:
-            sessions = self.sessions[:]
-
-        active_sessions = []
+            self.sessions.pop(key)
 
+    def collect_garbage(self):
         now = time.time()
-        for session in sessions:
-            if (now - session.time) > 1000:
+        for session in self.sessions.values():
+            if (now - session.time) > session.timeout:
                 session.stop()
 
-        bp = self.processors['blockchain']
-
-        for session in sessions:
-            if not session.stopped():
-                # If session is still alive then re-add it back
-                # to our internal register
-                active_sessions.append(session)
-            else:
-                session.stop_subscriptions(bp)
-
-        with self.lock:
-            self.sessions = active_sessions[:]
-
 
 
 class Session:
 
-    def __init__(self):
+    def __init__(self, dispatcher):
+        self.dispatcher = dispatcher
+        self.bp = self.dispatcher.processors['blockchain']
         self._stopped = False
         self.lock = threading.Lock()
         self.subscriptions = []
@@ -196,6 +191,10 @@ class Session:
         threading.Timer(2, self.info).start()
 
 
+    def key(self):
+        return self.name + self.address
+
+
     # Debugging method. Doesn't need to be threadsafe.
     def info(self):
         for sub in self.subscriptions:
@@ -214,6 +213,21 @@ class Session:
                       "%3d" % len(self.subscriptions),
                       self.version)
 
+    def stop(self):
+        with self.lock:
+            if self._stopped:
+                return
+            self._stopped = True
+
+        self.shutdown()
+        self.dispatcher.remove_session(self)
+        self.stop_subscriptions()
+
+
+    def shutdown(self):
+        pass
+
+
     def stopped(self):
         with self.lock:
             return self._stopped
@@ -225,7 +239,9 @@ class Session:
                 self.subscriptions.append((method,params))
 
 
-    def stop_subscriptions(self, bp):
+    def stop_subscriptions(self):
+        bp = self.bp
+
         with self.lock:
             s = self.subscriptions[:]
 
@@ -244,6 +260,10 @@ class Session:
                         continue
                     if self in l:
                         l.remove(self)
+                    if self in l:
+                        print "error rc!!"
+                        bp.shared.stop()
+
                     if l == []:
                         bp.watched_addresses.pop(addr)
 
index 9c2bb99..d41bcf3 100755 (executable)
--- a/server.py
+++ b/server.py
@@ -108,9 +108,9 @@ def run_rpc_command(command, stratum_tcp_port):
 
     if command == 'info':
         now = time.time()
-        print 'type           address   sub  version  time'
+        print 'type           address         sub  version  time'
         for item in r:
-            print '%4s   %15s   %3s  %7s  %.2f' % (item.get('name'),
+            print '%4s   %21s   %3s  %7s  %.2f' % (item.get('name'),
                                                    item.get('address'),
                                                    item.get('subscriptions'),
                                                    item.get('version'),
index eace4ab..2b86869 100644 (file)
@@ -126,9 +126,8 @@ class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
         return response
 
     def create_session(self):
-        session_id = random_string(10)
-        session = HttpSession(session_id)
-        self.dispatcher.add_session(session)
+        session_id = random_string(20)
+        session = HttpSession(self.dispatcher, session_id)
         return session_id
 
     def poll_session(self, session):
@@ -335,21 +334,18 @@ class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
 
 class HttpSession(Session):
 
-    def __init__(self, session_id):
-        Session.__init__(self)
+    def __init__(self, dispatcher, session_id):
+        Session.__init__(self, dispatcher)
         self.pending_responses = Queue.Queue()
         self.address = session_id
         self.name = "HTTP"
+        self.timeout = 60
+        self.dispatcher.add_session(self)
 
     def send_response(self, response):
         raw_response = json.dumps(response)
         self.pending_responses.put(response)
 
-    def stopped(self):
-        with self.lock:
-            if time.time() - self.time > 60:
-                self._stopped = True
-            return self._stopped
 
 
 class HttpServer(threading.Thread):
index 6ae2940..aad4c17 100644 (file)
@@ -11,8 +11,8 @@ from utils import print_log
 
 class TcpSession(Session):
 
-    def __init__(self, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
-        Session.__init__(self)
+    def __init__(self, dispatcher, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
+        Session.__init__(self, dispatcher)
         self.use_ssl = use_ssl
         if use_ssl:
             import ssl
@@ -26,9 +26,11 @@ class TcpSession(Session):
         else:
             self._connection = connection
 
-        self.address = address[0]
+        self.address = address[0] + ":%d"%address[1]
         self.name = "TCP " if not use_ssl else "SSL "
+        self.timeout = 1000
         self.response_queue = queue.Queue()
+        self.dispatcher.add_session(self)
 
     def do_handshake(self):
         if self.use_ssl:
@@ -40,10 +42,7 @@ class TcpSession(Session):
         else:
             return self._connection
 
-    def stop(self):
-        if self.stopped():
-            return
-
+    def shutdown(self):
         try:
             self._connection.shutdown(socket.SHUT_RDWR)
         except:
@@ -52,8 +51,6 @@ class TcpSession(Session):
             pass
 
         self._connection.close()
-        with self.lock:
-            self._stopped = True
 
     def send_response(self, response):
         self.response_queue.put(response)
@@ -183,7 +180,7 @@ class TcpServer(threading.Thread):
 
             #if self.use_ssl: print_log("SSL: new session", address)
             try:
-                session = TcpSession(connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
+                session = TcpSession(self.dispatcher, connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
             except BaseException, e:
                 error = str(e)
                 print_log("cannot start TCP session", error, address)
@@ -191,8 +188,6 @@ class TcpServer(threading.Thread):
                 time.sleep(0.1)
                 continue
 
-            self.dispatcher.add_session(session)
-            self.dispatcher.collect_garbage()
             client_req = TcpClientRequestor(self.dispatcher, session)
             client_req.start()
             responder = TcpClientResponder(session)