indentation and imports cleaned up
[p2pool.git] / p2pool / bitcoin / data.py
index 187928a..dc8b8d6 100644 (file)
@@ -1,11 +1,11 @@
 from __future__ import division
 
-import struct
 import hashlib
-import warnings
+import struct
 
-from . import base58
-from p2pool.util import bases, math
+from . import base58, skiplists
+from p2pool.util import bases, math, variable
+import p2pool
 
 class EarlyEnd(Exception):
     pass
@@ -19,6 +19,9 @@ def read((data, pos), length):
         raise EarlyEnd()
     return data2, (data, pos + length)
 
+def size((data, pos)):
+    return len(data) - pos
+
 class Type(object):
     # the same data can have only one unpacked representation, but multiple packed binary representations
     
@@ -54,17 +57,20 @@ class Type(object):
     def unpack(self, data):
         obj = self._unpack(data)
         
-        if __debug__:
+        if p2pool.DEBUG:
             data2 = self._pack(obj)
             if data2 != data:
-                assert self._unpack(data2) == obj
+                if self._unpack(data2) != obj:
+                    raise AssertionError()
         
         return obj
     
     def pack(self, obj):
         data = self._pack(obj)
         
-        assert self._unpack(data) == obj
+        if p2pool.DEBUG:
+            if self._unpack(data) != obj:
+                raise AssertionError()
         
         return data
     
@@ -162,7 +168,7 @@ class HashType(Type):
         if not 0 <= item < 2**256:
             raise ValueError('invalid hash value - %r' % (item,))
         if item != 0 and item < 2**160:
-            warnings.warn('very low hash value - maybe you meant to use ShortHashType? %x' % (item,))
+            print 'Very low hash value - maybe you meant to use ShortHashType? %x' % (item,)
         return file, ('%064x' % (item,)).decode('hex')[::-1]
 
 class ShortHashType(Type):
@@ -227,12 +233,41 @@ class IPV6AddressType(Type):
         assert len(data) == 16, len(data)
         return file, data
 
+_record_types = {}
+
+def get_record(fields):
+    fields = tuple(sorted(fields))
+    if 'keys' in fields:
+        raise ValueError()
+    if fields not in _record_types:
+        class _Record(object):
+            __slots__ = fields
+            def __getitem__(self, key):
+                return getattr(self, key)
+            def __setitem__(self, key, value):
+                setattr(self, key, value)
+            #def __iter__(self):
+            #    for field in self.__slots__:
+            #        yield field, getattr(self, field)
+            def keys(self):
+                return self.__slots__
+            def __eq__(self, other):
+                if isinstance(other, dict):
+                    return dict(self) == other
+                elif isinstance(other, _Record):
+                    return all(self[k] == other[k] for k in self.keys())
+                raise TypeError()
+            def __ne__(self, other):
+                return not (self == other)
+        _record_types[fields] = _Record
+    return _record_types[fields]()
+
 class ComposedType(Type):
     def __init__(self, fields):
         self.fields = fields
     
     def read(self, file):
-        item = {}
+        item = get_record(k for k, v in self.fields)
         for key, type_ in self.fields:
             item[key], file = type_.read(file)
         return item, file
@@ -268,7 +303,7 @@ class FloatingIntegerType(Type):
     def read(self, file):
         bits, file = self._inner.read(file)
         target = self._bits_to_target(bits)
-        if __debug__:
+        if p2pool.DEBUG:
             if self._target_to_bits(target) != bits:
                 raise ValueError('bits in non-canonical form')
         return target, file
@@ -281,8 +316,9 @@ class FloatingIntegerType(Type):
     
     def _bits_to_target(self, bits2):
         target = math.shift_left(bits2 & 0x00ffffff, 8 * ((bits2 >> 24) - 3))
-        assert target == self._bits_to_target1(struct.pack('<I', bits2))
-        assert self._target_to_bits(target, _check=False) == bits2
+        if p2pool.DEBUG:
+            assert target == self._bits_to_target1(struct.pack('<I', bits2))
+            assert self._target_to_bits(target, _check=False) == bits2
         return target
     
     def _bits_to_target1(self, bits):
@@ -370,6 +406,11 @@ def merkle_hash(tx_list):
 def target_to_average_attempts(target):
     return 2**256//(target + 1)
 
+# tx
+
+def tx_get_sigop_count(tx):
+    return sum(script.get_sigop_count(txin['script']) for txin in tx['tx_ins']) + sum(script.get_sigop_count(txout['script']) for txout in tx['tx_outs'])
+
 # human addresses
 
 human_address_type = ChecksummedType(ComposedType([
@@ -399,25 +440,68 @@ def pubkey_to_script2(pubkey):
 def pubkey_hash_to_script2(pubkey_hash):
     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
 
+def script2_to_human(script2, net):
+    try:
+        pubkey = script2[1:-1]
+        script2_test = pubkey_to_script2(pubkey)
+    except:
+        pass
+    else:
+        if script2_test == script2:
+            return 'Pubkey. Address: %s' % (pubkey_to_address(pubkey, net),)
+    
+    try:
+        pubkey_hash = ShortHashType().unpack(script2[3:-2])
+        script2_test2 = pubkey_hash_to_script2(pubkey_hash)
+    except:
+        pass
+    else:
+        if script2_test2 == script2:
+            return 'Address. Address: %s' % (pubkey_hash_to_address(pubkey_hash, net),)
+    
+    return 'Unknown. Script: %s'  % (script2.encode('hex'),)
+
 # linked list tracker
 
 class Tracker(object):
     def __init__(self):
         self.shares = {} # hash -> share
+        #self.ids = {} # hash -> (id, height)
         self.reverse_shares = {} # previous_hash -> set of share_hashes
         
         self.heads = {} # head hash -> tail_hash
         self.tails = {} # tail hash -> set of head hashes
+        
         self.heights = {} # share_hash -> height_to, other_share_hash
-        self.skips = {} # share_hash -> skip list
         
+        '''
         self.id_generator = itertools.count()
         self.tails_by_id = {}
+        '''
+        
+        self.get_nth_parent_hash = skiplists.DistanceSkipList(self)
+        
+        self.added = variable.Event()
+        self.removed = variable.Event()
     
     def add(self, share):
         assert not isinstance(share, (int, long, type(None)))
         if share.hash in self.shares:
-            return # XXX raise exception?
+            raise ValueError('share already present')
+        
+        '''
+        parent_id = self.ids.get(share.previous_hash, None)
+        children_ids = set(self.ids.get(share2_hash) for share2_hash in self.reverse_shares.get(share.hash, set()))
+        infos = set()
+        if parent_id is not None:
+            infos.add((parent_id[0], parent_id[1] + 1))
+        for child_id in children_ids:
+            infos.add((child_id[0], child_id[1] - 1))
+        if not infos:
+            infos.add((self.id_generator.next(), 0))
+        chosen = min(infos)
+        self.ids[share.hash] = chosen
+        '''
         
         self.shares[share.hash] = share
         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
@@ -441,6 +525,8 @@ class Tracker(object):
         
         for head in heads:
             self.heads[head] = tail
+        
+        self.added.happened(share)
     
     def test(self):
         t = Tracker()
@@ -457,6 +543,7 @@ class Tracker(object):
         if share_hash not in self.shares:
             raise KeyError()
         share = self.shares[share_hash]
+        children = self.reverse_shares.get(share_hash, set())
         del share_hash
         
         if share.hash in self.heads and share.previous_hash in self.tails:
@@ -473,7 +560,7 @@ class Tracker(object):
                 self.tails[tail].add(share.previous_hash)
                 self.heads[share.previous_hash] = tail
         elif share.previous_hash in self.tails:
-            raise NotImplementedError() # will break other things..
+            #raise NotImplementedError() # will break other things..
             heads = self.tails[share.previous_hash]
             if len(self.reverse_shares[share.previous_hash]) > 1:
                 raise NotImplementedError()
@@ -485,6 +572,24 @@ class Tracker(object):
         else:
             raise NotImplementedError()
         
+        to_remove = set()
+        for share_hash2 in self.heights:
+            height_to, other_share_hash, work_inc = self.heights[share_hash2]
+            if other_share_hash != share.previous_hash:
+                continue
+            assert children
+            if len(children) == 1:
+                height_to -= 1
+                other_share_hash = share.hash
+                work_inc -= target_to_average_attempts(share.target)
+                self.heights[share_hash2] = height_to, other_share_hash, work_inc
+            else:
+                to_remove.add(share_hash2)
+        for share_hash2 in to_remove:
+            del self.heights[share_hash2]
+        if share.hash in self.heights:
+            del self.heights[share.hash]
+        
         '''
         height, tail = self.get_height_and_last(share.hash)
         
@@ -526,36 +631,44 @@ class Tracker(object):
         if not self.reverse_shares[share.previous_hash]:
             self.reverse_shares.pop(share.previous_hash)
         
-        assert self.test() is None
+        #assert self.test() is None
+        self.removed.happened(share)
+    
+    def get_height(self, share_hash):
+        height, work, last = self.get_height_work_and_last(share_hash)
+        return height
+    
+    def get_work(self, share_hash):
+        height, work, last = self.get_height_work_and_last(share_hash)
+        return work
+    
+    def get_last(self, share_hash):
+        height, work, last = self.get_height_work_and_last(share_hash)
+        return last
     
     def get_height_and_last(self, share_hash):
+        height, work, last = self.get_height_work_and_last(share_hash)
+        return height, last
+    
+    def get_height_work_and_last(self, share_hash):
         assert isinstance(share_hash, (int, long, type(None)))
         orig = share_hash
         height = 0
+        work = 0
         updates = []
         while True:
             if share_hash is None or share_hash not in self.shares:
                 break
-            updates.append((share_hash, height))
+            updates.append((share_hash, height, work))
             if share_hash in self.heights:
-                height_inc, share_hash = self.heights[share_hash]
+                height_inc, share_hash, work_inc = self.heights[share_hash]
             else:
-                height_inc, share_hash = 1, self.shares[share_hash].previous_hash
+                height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, target_to_average_attempts(self.shares[share_hash].target)
             height += height_inc
-        for update_hash, height_then in updates:
-            self.heights[update_hash] = height - height_then, share_hash
-        assert (height, share_hash) == self.get_height_and_last2(orig), ((height, share_hash), self.get_height_and_last2(orig))
-        return height, share_hash
-    
-    def get_height_and_last2(self, share_hash):
-        assert isinstance(share_hash, (int, long, type(None)))
-        height = 0
-        while True:
-            if share_hash not in self.shares:
-                break
-            share_hash = self.shares[share_hash].previous_hash
-            height += 1
-        return height, share_hash
+            work += work_inc
+        for update_hash, height_then, work_then in updates:
+            self.heights[update_hash] = height - height_then, share_hash, work - work_then
+        return height, work, share_hash
     
     def get_chain_known(self, start_hash):
         assert isinstance(start_hash, (int, long, type(None)))
@@ -594,63 +707,26 @@ class Tracker(object):
     
     def get_highest_height(self):
         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
-    
-    def get_nth_parent_hash(self, item_hash, n):
-        if n < 0:
-            raise ValueError('n must be >= 0')
-        
-        updates = {}
-        while n:
-            if item_hash not in self.skips:
-                self.skips[item_hash] = math.geometric(.5), [(1, self.shares[item_hash].previous_hash)]
-            skip_length, skip = self.skips[item_hash]
-            
-            for i in xrange(skip_length):
-                if i in updates:
-                    n_then, that_hash = updates.pop(i)
-                    x, y = self.skips[that_hash]
-                    assert len(y) == i
-                    y.append((n_then - n, item_hash))
-            
-            for i in xrange(len(skip), skip_length):
-                updates[i] = n, item_hash
-            
-            for i, (dist, then_hash) in enumerate(reversed(skip)):
-                if dist <= n:
-                    break
-            else:
-                raise AssertionError()
-            
-            n -= dist
-            item_hash = then_hash
-        
-        return item_hash
-    
-    def get_nth_parent2(self, item_hash, n):
-        x = item_hash
-        for i in xrange(n):
-            x = self.shares[item_hash].previous_hash
-        return x
+
+class FakeShare(object):
+    def __init__(self, **kwargs):
+        self.__dict__.update(kwargs)
 
 if __name__ == '__main__':
-    class FakeShare(object):
-        def __init__(self, hash, previous_hash):
-            self.hash = hash
-            self.previous_hash = previous_hash
     
     t = Tracker()
     
-    for i in xrange(100):
-        t.add(FakeShare(i, i - 1 if i > 0 else None))
+    for i in xrange(10000):
+        t.add(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None))
     
-    t.remove(99)
+    #t.remove(99)
     
-    print "HEADS", t.heads
-    print "TAILS", t.tails
+    print 'HEADS', t.heads
+    print 'TAILS', t.tails
     
     import random
     
-    while True:
+    while False:
         print
         print '-'*30
         print
@@ -661,20 +737,20 @@ if __name__ == '__main__':
             t.add(FakeShare(i, x))
         while t.shares:
             x = random.choice(list(t.shares))
-            print "DEL", x, t.__dict__
+            print 'DEL', x, t.__dict__
             try:
                 t.remove(x)
             except NotImplementedError:
-                print "aborted; not implemented"
+                print 'aborted; not implemented'
         import time
         time.sleep(.1)
-        print "HEADS", t.heads
-        print "TAILS", t.tails
+        print 'HEADS', t.heads
+        print 'TAILS', t.tails
     
     #for share_hash, share in sorted(t.shares.iteritems()):
     #    print share_hash, share.previous_hash, t.heads.get(share_hash), t.tails.get(share_hash)
     
-    import sys;sys.exit()
+    #import sys;sys.exit()
     
     print t.get_nth_parent_hash(9000, 5000)
     print t.get_nth_parent_hash(9001, 412)