refactor functions passed to worker_interface into WorkerBridge class
authorForrest Voight <forrest@forre.st>
Sat, 7 Jan 2012 04:55:54 +0000 (23:55 -0500)
committerForrest Voight <forrest@forre.st>
Sat, 7 Jan 2012 04:55:54 +0000 (23:55 -0500)
p2pool/bitcoin/worker_interface.py
p2pool/main.py

index fced903..7da8bd8 100644 (file)
@@ -24,16 +24,28 @@ class _Page(jsonrpc.Server):
         request.content = StringIO.StringIO(json.dumps(dict(id=0, method='getwork')))
         return self.render_POST(request)
 
+class WorkerBridge(object):
+    def __init__(self):
+        self.new_work_event = variable.Event()
+    
+    def preprocess_request(self, request):
+        return request, # *args to self.compute
+    
+    def get_work(self, request):
+        raise NotImplementedError()
+    
+    def got_response(self, block_header):
+        print self.got_response, "called with", block_header
+
 class WorkerInterface(object):
-    def __init__(self, compute, response_callback, new_work_event=variable.Event(), request_process_func=lambda request: (request,)):
-        self.compute = compute
-        self.response_callback = response_callback
-        self.new_work_event = new_work_event
-        self.request_process_func = request_process_func
+    def __init__(self, worker_bridge):
+        self.worker_bridge = worker_bridge
         
         self.worker_views = {}
         
         self.work_cache = {} # request_process_func(request) -> blockattempt
+        
+        new_work_event = self.worker_bridge.new_work_event
         watch_id = new_work_event.watch(lambda *args: self_ref().work_cache.clear())
         self_ref = weakref.ref(self, lambda _: new_work_event.unwatch(watch_id))
     
@@ -47,7 +59,7 @@ class WorkerInterface(object):
         request.setHeader('X-Roll-NTime', 'expire=10')
         
         if data is not None:
-            defer.returnValue(self.response_callback(getwork.decode_data(data), request))
+            defer.returnValue(self.worker_bridge.got_response(getwork.decode_data(data), request))
         
         if p2pool.DEBUG:
             id = random.randrange(1000, 10000)
@@ -55,27 +67,27 @@ class WorkerInterface(object):
         
         if long_poll:
             request_id = request.getClientIP(), request.getHeader('Authorization')
-            if self.worker_views.get(request_id, self.new_work_event.times) != self.new_work_event.times:
+            if self.worker_views.get(request_id, self.worker_bridge.new_work_event.times) != self.worker_bridge.new_work_event.times:
                 if p2pool.DEBUG:
                     print 'POLL %i PUSH' % (id,)
             else:
                 if p2pool.DEBUG:
                     print 'POLL %i WAITING' % (id,)
-                yield self.new_work_event.get_deferred()
-            self.worker_views[request_id] = self.new_work_event.times
+                yield self.worker_bridge.new_work_event.get_deferred()
+            self.worker_views[request_id] = self.worker_bridge.new_work_event.times
         
-        key = self.request_process_func(request)
+        key = self.worker_bridge.preprocess_request(request)
         
         if key in self.work_cache:
             res, orig_timestamp = self.work_cache.pop(key)
         else:
-            res = self.compute(*key)
+            res = self.worker_bridge.get_work(*key)
             orig_timestamp = res.timestamp
         
         if res.timestamp + 12 < orig_timestamp + 600:
             self.work_cache[key] = res.update(timestamp=res.timestamp + 12), orig_timestamp
         
         if p2pool.DEBUG:
-            print 'POLL %i END identifier=%i' % (id, self.new_work_event.times)
+            print 'POLL %i END identifier=%i' % (id, self.worker_bridge.new_work_event.times)
         
-        defer.returnValue(res.getwork(identifier=str(self.new_work_event.times)))
+        defer.returnValue(res.getwork(identifier=str(self.worker_bridge.new_work_event.times)))
index 96712ce..1a2d3a9 100644 (file)
@@ -414,9 +414,16 @@ def main(args, net, datadir_path):
             my_doa_shares_not_in_chain = my_doa_shares - my_doa_shares_in_chain
             
             return (my_shares_not_in_chain - my_doa_shares_not_in_chain, my_doa_shares_not_in_chain), my_shares, (orphans_recorded_in_chain, doas_recorded_in_chain)
+            
+        my_share_hashes = set()
+        my_doa_share_hashes = set()
         
+        class WorkerBridge(worker_interface.WorkerBridge):
+          def __init__(self):
+            worker_interface.WorkerBridge.__init__(self)
+            self.new_work_event = current_work.changed
         
-        def get_payout_script_from_username(user):
+          def _get_payout_script_from_username(self, user):
             if user is None:
                 return None
             try:
@@ -425,13 +432,13 @@ def main(args, net, datadir_path):
                 return None
             return bitcoin_data.pubkey_hash_to_script2(pubkey_hash)
         
-        def precompute(request):
-            payout_script = get_payout_script_from_username(request.getUser())
+          def preprocess_request(self, request):
+            payout_script = self._get_payout_script_from_username(request.getUser())
             if payout_script is None or random.uniform(0, 100) < args.worker_fee:
                 payout_script = my_script
             return payout_script,
 
-        def compute(payout_script):
+          def get_work(self, payout_script):
             if len(p2p_node.peers) == 0 and net.PERSIST:
                 raise jsonrpc.Error(-12345, u'p2pool is not connected to any peers')
             if current_work.value['best_share_hash'] is None and net.PERSIST:
@@ -480,10 +487,7 @@ def main(args, net, datadir_path):
                 share_target=share_info['bits'].target,
             )
         
-        my_share_hashes = set()
-        my_doa_share_hashes = set()
-        
-        def got_response(header, request):
+          def got_response(self, header, request):
             try:
                 # match up with transactions
                 xxx = merkle_root_to_transactions.get(header['merkle_root'], None)
@@ -547,7 +551,7 @@ def main(args, net, datadir_path):
                 return False
         
         web_root = resource.Resource()
-        worker_interface.WorkerInterface(compute, got_response, current_work.changed, precompute).attach_to(web_root)
+        worker_interface.WorkerInterface(WorkerBridge()).attach_to(web_root)
         
         def get_rate():
             if tracker.get_height(current_work.value['best_share_hash']) < 720: