moved share punishment condition checks into Share classes
authorForrest Voight <forrest@forre.st>
Fri, 26 Oct 2012 14:50:07 +0000 (10:50 -0400)
committerForrest Voight <forrest@forre.st>
Sun, 28 Oct 2012 06:08:20 +0000 (02:08 -0400)
p2pool/data.py

index 7238f6e..ba1cb79 100644 (file)
@@ -274,14 +274,14 @@ class Share(object):
     def get_other_tx_hashes(self, tracker):
         return []
     
-    def get_other_txs(self, tracker, known_txs):
-        return []
-    
-    def get_other_txs_size(self, tracker, known_txs):
-        return 0
-    
-    def get_new_txs_size(self, known_txs):
-        return 0
+    def should_punish_reason(self, previous_block, bits, tracker, known_txs):
+        if self.pow_hash <= self.header['bits'].target:
+            return -1, 'block solution'
+        
+        if (self.header['previous_block'], self.header['bits']) != (previous_block, bits) and self.header_hash != previous_block and self.peer is not None:
+            return True, 'Block-stale detected! %x < %x' % (self.header['previous_block'], previous_block)
+        
+        return False, None
     
     def as_block(self, tracker, known_txs):
         if self.other_txs is None:
@@ -537,7 +537,7 @@ class NewShare(object):
     def get_other_tx_hashes(self, tracker):
         return [tracker.items[tracker.get_nth_parent_hash(self.hash, x['share_count'])].share_info['new_transaction_hashes'][x['tx_count']] for x in self.share_info['transaction_hash_refs']]
     
-    def get_other_txs(self, tracker, known_txs):
+    def _get_other_txs(self, tracker, known_txs):
         other_tx_hashes = self.get_other_tx_hashes(tracker)
         
         if not all(tx_hash in known_txs for tx_hash in other_tx_hashes):
@@ -545,19 +545,29 @@ class NewShare(object):
         
         return [known_txs[tx_hash] for tx_hash in other_tx_hashes]
     
-    def get_other_txs_size(self, tracker, known_txs):
-        other_txs = self.get_other_txs(tracker, known_txs)
+    def should_punish_reason(self, previous_block, bits, tracker, known_txs):
+        if self.pow_hash <= self.header['bits'].target:
+            return -1, 'block solution'
+        
+        if (self.header['previous_block'], self.header['bits']) != (previous_block, bits) and self.header_hash != previous_block and self.peer is not None:
+            return True, 'Block-stale detected! %x < %x' % (self.header['previous_block'], previous_block)
+        
+        other_txs = self._get_other_txs(tracker, known_txs)
         if other_txs is None:
-            return None # not all txs present
-        size = sum(len(bitcoin_data.tx_type.pack(tx)) for tx in other_txs)
-    
-    def get_new_txs_size(self, known_txs):
-        if not all(tx_hash in known_txs for tx_hash in self.share_info['new_transaction_hashes']):
-            return None # not all txs present
-        return sum(len(bitcoin_data.tx_type.pack(known_txs[tx_hash])) for tx_hash in self.share_info['new_transaction_hashes'])
+            return True, 'not all txs present'
+        
+        all_txs_size = sum(len(bitcoin_data.tx_type.pack(tx)) for tx in other_txs)
+        if all_txs_size > 1000000:
+            return True, 'txs over block size limit'
+        
+        new_txs_size = sum(len(bitcoin_data.tx_type.pack(known_txs[tx_hash])) for tx_hash in self.share_info['new_transaction_hashes'])
+        if new_txs_size > 50000:
+            return True, 'new txs over limit'
+        
+        return False, None
     
     def as_block(self, tracker, known_txs):
-        other_txs = self.get_other_txs(tracker, known_txs)
+        other_txs = self._get_other_txs(tracker, known_txs)
         if other_txs is None:
             return None # not all txs present
         return dict(header=self.header, txs=[self.check(tracker)] + other_txs)
@@ -691,11 +701,7 @@ class OkayTracker(forest.Tracker):
         decorated_heads = sorted(((
             self.verified.get_work(self.verified.get_nth_parent_hash(h, min(5, self.verified.get_height(h)))),
             #self.items[h].peer is None,
-            self.items[h].pow_hash <= self.items[h].header['bits'].target, # is block solution
-            (self.items[h].header['previous_block'], self.items[h].header['bits']) == (previous_block, bits) or self.items[h].peer is None,
-            self.items[h].get_other_txs(self, known_txs) is not None,
-            self.items[h].get_other_txs_size(self, known_txs) < 1000000,
-            self.items[h].get_new_txs_size(known_txs) < 50000,
+            -self.items[h].should_punish_reason(previous_block, bits, self, known_txs)[0],
             -self.items[h].time_seen,
         ), h) for h in self.verified.tails.get(best_tail, []))
         if p2pool.DEBUG:
@@ -706,18 +712,10 @@ class OkayTracker(forest.Tracker):
         
         if best is not None:
             best_share = self.items[best]
-            if (best_share.header['previous_block'], best_share.header['bits']) != (previous_block, bits) and best_share.header_hash != previous_block and best_share.peer is not None:
+            punish, punish_reason = best_share.should_punish_reason(previous_block, bits, self, known_txs)
+            if punish:
                 if p2pool.DEBUG:
-                    print 'Stale detected! %x < %x' % (best_share.header['previous_block'], previous_block)
-                best = best_share.previous_hash
-            elif best_share.get_other_txs(self, known_txs) is None:
-                print 'Share with incomplete transactions detected! Jumping from %s to %s!' % (format_hash(best), format_hash(best_share.previous_hash))
-                best = best_share.previous_hash
-            elif best_share.get_other_txs_size(self, known_txs) > 1000000:
-                print >>sys.stderr, 'Share with too many transactions detected! Jumping from %s to %s!' % (format_hash(best), format_hash(best_share.previous_hash))
-                best = best_share.previous_hash
-            elif best_share.get_new_txs_size(known_txs) > 50000:
-                print >>sys.stderr, 'Share with too many new transactions detected! Jumping from %s to %s!' % (format_hash(best), format_hash(best_share.previous_hash))
+                    print >>sys.stderr, 'Punishing share for %r! Jumping from %s to %s!' % (punish_reason, format_hash(best), format_hash(best_share.previous_hash))
                 best = best_share.previous_hash
             
             timestamp_cutoff = min(int(time.time()), best_share.timestamp) - 3600