pass sessions to processors; fixes memory leak in watched_addresses
authorThomasV <thomasv@gitorious>
Tue, 3 Dec 2013 20:53:39 +0000 (00:53 +0400)
committerThomasV <thomasv@gitorious>
Tue, 3 Dec 2013 20:54:45 +0000 (00:54 +0400)
backends/bitcoind/blockchain_processor.py
backends/irc/__init__.py
processor.py

index 68b8117..d4c56f4 100644 (file)
@@ -24,7 +24,12 @@ class BlockchainProcessor(Processor):
         self.shared = shared
         self.config = config
         self.up_to_date = False
-        self.watched_addresses = []
+
+        self.watch_lock = threading.Lock()
+        self.watch_blocks = []
+        self.watch_headers = []
+        self.watched_addresses = {}
+
         self.history_cache = {}
         self.chunk_cache = {}
         self.cache_lock = threading.Lock()
@@ -664,14 +669,16 @@ class BlockchainProcessor(Processor):
         for addr in self.batch_list.keys():
             self.invalidate_cache(addr)
 
-    def add_request(self, request):
+    def add_request(self, session, request):
         # see if we can get if from cache. if not, add to queue
-        if self.process(request, cache_only=True) == -1:
-            self.queue.put(request)
+        if self.process(session, request, cache_only=True) == -1:
+            self.queue.put((session, request))
+
+
 
-    def process(self, request, cache_only=False):
-        #print "abe process", request
 
+    def process(self, session, request, cache_only=False):
+        
         message_id = request['id']
         method = request['method']
         params = request.get('params', [])
@@ -679,35 +686,33 @@ class BlockchainProcessor(Processor):
         error = None
 
         if method == 'blockchain.numblocks.subscribe':
+            if session not in self.watch_headers:
+                with self.watch_lock:
+                    self.watch_blocks.append(session)
             result = self.height
 
         elif method == 'blockchain.headers.subscribe':
+            if session not in self.watch_headers:
+                with self.watch_lock:
+                    self.watch_headers.append(session)
             result = self.header
 
         elif method == 'blockchain.address.subscribe':
             try:
                 address = params[0]
                 result = self.get_status(address, cache_only)
-                self.watch_address(address)
-            except BaseException, e:
-                error = str(e) + ': ' + address
-                print_log("error:", error)
+                with self.watch_lock:
+                    l = self.watched_addresses.get(address)
+                    if l is None:
+                        self.watched_addresses[address] = [session]
+                    elif session not in l:
+                        l.append(session)
 
-        elif method == 'blockchain.address.unsubscribe':
-            try:
-                password = params[0]
-                address = params[1]
-                if password == self.config.get('server', 'password'):
-                    self.watched_addresses.remove(address)
-                    # print_log('unsubscribed', address)
-                    result = "ok"
-                else:
-                    print_log('incorrect password')
-                    result = "authentication error"
             except BaseException, e:
                 error = str(e) + ': ' + address
                 print_log("error:", error)
 
+
         elif method == 'blockchain.address.get_history':
             try:
                 address = params[0]
@@ -774,13 +779,10 @@ class BlockchainProcessor(Processor):
             return -1
 
         if error:
-            self.push_response({'id': message_id, 'error': error})
+            self.push_response(session, {'id': message_id, 'error': error})
         elif result != '':
-            self.push_response({'id': message_id, 'result': result})
+            self.push_response(session, {'id': message_id, 'result': result})
 
-    def watch_address(self, addr):
-        if addr not in self.watched_addresses:
-            self.watched_addresses.append(addr)
 
     def getfullblock(self, block_hash):
         block = self.bitcoind('getblock', [block_hash])
@@ -941,33 +943,36 @@ class BlockchainProcessor(Processor):
 
         if self.sent_height != self.height:
             self.sent_height = self.height
-            self.push_response({
-                'id': None,
-                'method': 'blockchain.numblocks.subscribe',
-                'params': [self.height],
-            })
+            for session in self.watch_blocks:
+                self.push_response(session, {
+                        'id': None,
+                        'method': 'blockchain.numblocks.subscribe',
+                        'params': [self.height],
+                        })
 
         if self.sent_header != self.header:
             print_log("blockchain: %d (%.3fs)" % (self.height, t2 - t1))
             self.sent_header = self.header
-            self.push_response({
-                'id': None,
-                'method': 'blockchain.headers.subscribe',
-                'params': [self.header],
-            })
+            for session in self.watch_headers:
+                self.push_response(session, {
+                        'id': None,
+                        'method': 'blockchain.headers.subscribe',
+                        'params': [self.header],
+                        })
 
         while True:
             try:
                 addr = self.address_queue.get(False)
             except:
                 break
-            if addr in self.watched_addresses:
-                status = self.get_status(addr)
-                self.push_response({
-                    'id': None,
-                    'method': 'blockchain.address.subscribe',
-                    'params': [addr, status],
-                })
+
+            status = self.get_status(addr)
+            for session in self.watched_addresses[addr]:
+                self.push_response(session, {
+                        'id': None,
+                        'method': 'blockchain.address.subscribe',
+                        'params': [addr, status],
+                        })
 
         if not self.shared.stopped():
             threading.Timer(10, self.main_iteration).start()
index 0ec2a8b..cc3f363 100644 (file)
@@ -128,7 +128,7 @@ class IrcThread(threading.Thread):
                             self.peers[name] = (ip, host, ports)
 
                     if time.time() - t > 5*60:
-                        self.processor.push_response({'method': 'server.peers', 'params': [self.get_peers()]})
+                        #self.processor.push_response({'method': 'server.peers', 'params': [self.get_peers()]})
                         s.send('NAMES #electrum\n')
                         t = time.time()
                         self.peers = {}
@@ -164,19 +164,20 @@ class ServerProcessor(Processor):
             self.irc.start()
         Processor.run(self)
 
-    def process(self, request):
+    def process(self, session, request):
         method = request['method']
         params = request['params']
         result = None
 
-        if method in ['server.stop', 'server.info', 'server.heapy']:
+        if method in ['server.stop', 'server.info', 'server.debug']:
             try:
                 password = request['params'][0]
             except:
                 password = None
 
             if password != self.password:
-                self.push_response({'id': request['id'],
+                self.push_response(session, 
+                                   {'id': request['id'],
                                     'result': None,
                                     'error': 'incorrect password'})
                 return
@@ -202,15 +203,7 @@ class ServerProcessor(Processor):
                                     "subscriptions": len(s.subscriptions)},
                          self.dispatcher.request_dispatcher.get_sessions())
 
-        elif method == 'server.cache':
-            p = self.dispatcher.request_dispatcher.processors['blockchain']
-            result = len(repr(p.history_cache))
-
-        elif method == 'server.load':
-            p = self.dispatcher.request_dispatcher.processors['blockchain']
-            result = p.queue.qsize()
-
-        elif method == 'server.heapy':
+        elif method == 'server.debug':
             try:
                 s = request['params'][1]
             except:
@@ -219,12 +212,13 @@ class ServerProcessor(Processor):
             if s:
                 from guppy import hpy
                 h = hpy()
+                bp = self.dispatcher.request_dispatcher.processors['blockchain']
                 try:
                     result = str(eval(s))
                 except:
                     result = "error"
         else:
-            print_log("unknown method", request)
+            print_log("unknown method", method)
 
         if result != '':
-            self.push_response({'id': request['id'], 'result': result})
+            self.push_response(session, {'id': request['id'], 'result': result})
index 0ad86ff..9210fd6 100644 (file)
@@ -34,21 +34,21 @@ class Processor(threading.Thread):
         self.dispatcher = None
         self.queue = queue.Queue()
 
-    def process(self, request):
+    def process(self, session, request):
         pass
 
-    def add_request(self, request):
-        self.queue.put(request)
+    def add_request(self, session, request):
+        self.queue.put((session, request))
 
-    def push_response(self, response):
+    def push_response(self, session, response):
         #print "response", response
-        self.dispatcher.request_dispatcher.push_response(response)
+        self.dispatcher.request_dispatcher.push_response(session, response)
 
     def run(self):
         while not self.shared.stopped():
-            request = self.queue.get(10000000000)
+            request, session = self.queue.get(10000000000)
             try:
-                self.process(request)
+                self.process(request, session)
             except:
                 traceback.print_exc(file=sys.stdout)
 
@@ -87,8 +87,8 @@ class RequestDispatcher(threading.Thread):
         self.sessions = []
         self.processors = {}
 
-    def push_response(self, item):
-        self.response_queue.put(item)
+    def push_response(self, session, item):
+        self.response_queue.put((session, item))
 
     def pop_response(self):
         return self.response_queue.get()
@@ -141,9 +141,6 @@ class RequestDispatcher(threading.Thread):
             if suffix == 'subscribe':
                 session.subscribe_to_service(method, params)
 
-        # store session and id locally
-        request['id'] = self.store_session_id(session, request['id'])
-
         prefix = request['method'].split('.')[0]
         try:
             p = self.processors[prefix]
@@ -151,7 +148,7 @@ class RequestDispatcher(threading.Thread):
             print_log("error: no processor for", prefix)
             return
 
-        p.add_request(request)
+        p.add_request(session, request)
 
         if method in ['server.version']:
             session.version = params[0]
@@ -183,11 +180,15 @@ class RequestDispatcher(threading.Thread):
             if (now - session.time) > 1000:
                 session.stop()
 
+        bp = self.processors['blockchain']
+
         for session in sessions:
             if not session.stopped():
                 # If session is still alive then re-add it back
                 # to our internal register
                 active_sessions.append(session)
+            else:
+                session.stop_subscriptions(bp)
 
         with self.lock:
             self.sessions = active_sessions[:]
@@ -207,6 +208,7 @@ class Session:
         self.time = time.time()
         threading.Timer(2, self.info).start()
 
+
     # Debugging method. Doesn't need to be threadsafe.
     def info(self):
         for sub in self.subscriptions:
@@ -229,30 +231,37 @@ class Session:
         with self.lock:
             return self._stopped
 
+
     def subscribe_to_service(self, method, params):
-        subdesc = self.build_subdesc(method, params)
         with self.lock:
-            if subdesc is not None:
-                self.subscriptions.append(subdesc)
-
-    # subdesc = A subscription description
-    @staticmethod
-    def build_subdesc(method, params):
-        if method == "blockchain.numblocks.subscribe":
-            return method,
-        elif method == "blockchain.headers.subscribe":
-            return method,
-        elif method in ["blockchain.address.subscribe"]:
-            if not params:
-                return None
-            else:
-                return method, params[0]
-        else:
-            return None
+            if (method, params) not in self.subscriptions:
+                self.subscriptions.append((method,params))
+
 
-    def contains_subscription(self, subdesc):
+    def stop_subscriptions(self, bp):
         with self.lock:
-            return subdesc in self.subscriptions
+            s = self.subscriptions[:]
+
+        for method, params in s:
+            with bp.watch_lock:
+                if method == 'blockchain.numblocks.subscribe':
+                    if self in bp.watch_blocks:
+                        bp.watch_blocks.remove(self)
+                elif method == 'blockchain.headers.subscribe':
+                    if self in bp.watch_headers:
+                        bp.watch_headers.remove(self)
+                elif method == "blockchain.address.subscribe":
+                    addr = params[0]
+                    l = bp.watched_addresses.get(addr)
+                    if not l:
+                        continue
+                    if self in l:
+                        l.remove(self)
+                    if l == []:
+                        bp.watched_addresses.pop(addr)
+
+        with self.lock:
+            self.subscriptions = []
 
 
 class ResponseDispatcher(threading.Thread):
@@ -265,48 +274,5 @@ class ResponseDispatcher(threading.Thread):
 
     def run(self):
         while not self.shared.stopped():
-            self.update()
-
-    def update(self):
-        response = self.request_dispatcher.pop_response()
-        #print "pop response", response
-        internal_id = response.get('id')
-        method = response.get('method')
-        params = response.get('params')
-
-        # A notification
-        if internal_id is None:  # and method is not None and params is not None:
-            found = self.notification(method, params, response)
-            if not found and method == 'blockchain.address.subscribe':
-                request = {
-                    'id': None,
-                    'method': method.replace('.subscribe', '.unsubscribe'),
-                    'params': [self.shared.config.get('server', 'password')] + params,
-                }
-
-                self.request_dispatcher.push_request(None, request)
-        # A response
-        elif internal_id is not None:
-            self.send_response(internal_id, response)
-        else:
-            print_log("no method", response)
-
-    def notification(self, method, params, response):
-        subdesc = Session.build_subdesc(method, params)
-        found = False
-        for session in self.request_dispatcher.sessions:
-            if session.stopped():
-                continue
-            if session.contains_subscription(subdesc):
-                session.send_response(response)
-                found = True
-        # if not found: print_log("no subscriber for", subdesc)
-        return found
-
-    def send_response(self, internal_id, response):
-        session, message_id = self.request_dispatcher.get_session_id(internal_id)
-        if session:
-            response['id'] = message_id
+            session, response = self.request_dispatcher.pop_response()
             session.send_response(response)
-        #else:
-        #    print_log("send_response: no session", message_id, internal_id, response )