work
authorforrest <forrest@470744a7-cac9-478e-843e-5ec1b25c69e8>
Sat, 11 Jun 2011 13:07:43 +0000 (13:07 +0000)
committerforrest <forrest@470744a7-cac9-478e-843e-5ec1b25c69e8>
Sat, 11 Jun 2011 13:07:43 +0000 (13:07 +0000)
git-svn-id: svn://forre.st/p2pool@1243 470744a7-cac9-478e-843e-5ec1b25c69e8

bitcoin_p2p.py
expiring_dict.py
main.py
p2p.py
util.py

index dfa6631..ec134bc 100644 (file)
@@ -5,7 +5,7 @@ Implementation of Bitcoin's p2p protocol
 import struct
 import socket
 import random
-import StringIO
+import cStringIO as StringIO
 import hashlib
 import time
 import traceback
@@ -14,10 +14,36 @@ from twisted.internet import protocol, reactor, defer
 
 import util
 
-def hex(n):
-    return '0x%x' % n
+class Type(object):
+    def _unpack(self, data, ignore_extra=False):
+        f = StringIO.StringIO(data)
+        obj = self.read(f)
+        
+        if not ignore_extra:
+            if f.tell() != len(data):
+                raise ValueError("underread " + repr((self, data)))
+        
+        return obj
+    
+    def unpack(self, data, ignore_extra=False):
+        obj = self._unpack(data, ignore_extra)
+        assert self._unpack(self._pack(obj)) == obj
+        return obj
+    
+    def _pack(self, obj):
+        f = StringIO.StringIO()
+        self.write(f, obj)
+        
+        data = f.getvalue()
+        
+        return data
+    
+    def pack(self, obj):
+        data = self._pack(obj)
+        assert self._unpack(data) == obj
+        return data
 
-class VarIntType(object):
+class VarIntType(Type):
     def read(self, file):
         first, = struct.unpack("<B", file.read(1))
         if first == 0xff:
@@ -28,96 +54,122 @@ class VarIntType(object):
             return struct.unpack("<H", file.read(2))[0]
         else:
             return first
-    def pack(self, item):
+    
+    def write(self, file, item):
         if item < 0xfd:
-            return struct.pack("<B", item)
+            file.write(struct.pack("<B", item))
         elif item <= 0xffff:
-            return struct.pack("<BH", 0xfd, item)
+            file.write(struct.pack("<BH", 0xfd, item))
         elif item <= 0xffffffff:
-            return struct.pack("<BI", 0xfe, item)
+            file.write(struct.pack("<BI", 0xfe, item))
         elif item <= 0xffffffffffffffff:
-            return struct.pack("<BQ", 0xff, item)
+            file.write(struct.pack("<BQ", 0xff, item))
         else:
             raise ValueError("int too large for varint")
 
-class VarStrType(object):
+class VarStrType(Type):
     def read(self, file):
         length = VarIntType().read(file)
         res = file.read(length)
         if len(res) != length:
             raise ValueError("var str not long enough %r" % ((length, len(res), res),))
         return res
-    def pack(self, item):
-        return VarIntType().pack(len(item)) + item
+    
+    def write(self, file, item):
+        VarIntType().write(file, len(item))
+        file.write(item)
 
-class FixedStrType(object):
+class FixedStrType(Type):
     def __init__(self, length):
         self.length = length
+    
     def read(self, file):
         res = file.read(self.length)
         if len(res) != self.length:
             raise ValueError("early EOF!")
         return res
-    def pack(self, item):
+    
+    def write(self, file, item):
         if len(item) != self.length:
             raise ValueError("incorrect length!")
-        return item
+        file.write(item)
 
-class EnumType(object):
-    def __init__(self, inner, map):
+class EnumType(Type):
+    def __init__(self, inner, values):
         self.inner = inner
-        self.map = map
-        self.revmap = dict((v, k) for k, v in map.iteritems())
+        self.values = values
+        
+        self.keys = {}
+        for k, v in values.iteritems():
+            if v in self.keys:
+                raise ValueError("duplicate value in values")
+            self.keys[v] = k
+    
     def read(self, file):
-        inner = self.inner.read(file)
-        return self.map[inner]
-    def pack(self, item):
-        return self.inner.pack(self.revmap[item])
+        return self.keys[self.inner.read(file)]
+    
+    def write(self, file, item):
+        self.inner.write(file, self.values[item])
 
-class HashType(object):
+class HashType(Type):
     def read(self, file):
         data = file.read(256//8)
         if len(data) != 256//8:
             raise ValueError("incorrect length!")
         return int(data[::-1].encode('hex'), 16)
-    def pack(self, item):
-        return ('%064x' % (item,)).decode('hex')[::-1]
+    
+    def write(self, file, item):
+        file.write(('%064x' % (item,)).decode('hex')[::-1])
 
-class ListType(object):
+class ListType(Type):
     def __init__(self, type):
         self.type = type
+    
     def read(self, file):
         length = VarIntType().read(file)
         return [self.type.read(file) for i in xrange(length)]
-    def pack(self, item):
-        return VarIntType().pack(len(item)) + ''.join(map(self.type.pack, item))
+    
+    def write(self, file, item):
+        VarIntType().write(file, len(item))
+        for subitem in item:
+            self.type.write(file, subitem)
 
-class StructType(object):
+class StructType(Type):
     def __init__(self, desc):
         self.desc = desc
+    
     def read(self, file):
         data = file.read(struct.calcsize(self.desc))
         res, = struct.unpack(self.desc, data)
         return res
-    def pack(self, item):
-        return struct.pack(self.desc, item)
+    
+    def write(self, file, item):
+        data = struct.pack(self.desc, item)
+        if struct.unpack(self.desc, data)[0] != item:
+            # special test because struct doesn't error on some overflows
+            raise ValueError("item didn't survive pack cycle (%r)" % (item,))
+        file.write(data)
 
-class IPV6AddressType(object):
+class IPV6AddressType(Type):
     def read(self, file):
         return socket.inet_ntop(socket.AF_INET6, file.read(16))
-    def pack(self, item):
-        return socket.inet_pton(socket.AF_INET6, item)
+    
+    def write(self, file, item):
+        file.write(socket.inet_pton(socket.AF_INET6, item))
 
-class ComposedType(object):
+class ComposedType(Type):
     def __init__(self, fields):
         self.fields = fields
+    
     def read(self, file):
-        result = {}
+        item = {}
+        for key, type_ in self.fields:
+            item[key] = type_.read(file)
+        return item
+    
+    def write(self, file, item):
         for key, type_ in self.fields:
-            result[key] = type_.read(file)
-        return result
-    def pack(self, item):
-        return ''.join(type_.pack(item[key]) for key, type_ in self.fields)
+            type_.write(file, item[key])
 
 address = ComposedType([
     ('services', StructType('<Q')),
@@ -125,36 +177,20 @@ address = ComposedType([
     ('port', StructType('>H')),
 ])
 
-merkle_record = ComposedType([
-    ('left', HashType()),
-    ('right', HashType()),
-])
-
-inv_vector = ComposedType([
-    ('type', EnumType(StructType('<I'), {1: "tx", 2: "block"})),
-    ('hash', HashType()),
-])
-
-outpoint = ComposedType([
-    ('hash', HashType()),
-    ('index', StructType('<I')),
-])
-
-tx_in = ComposedType([
-    ('previous_output', outpoint),
-    ('script', VarStrType()),
-    ('sequence', StructType('<I')),
-])
-
-tx_out = ComposedType([
-    ('value', StructType('<Q')),
-    ('script', VarStrType()),
-])
-
 tx = ComposedType([
     ('version', StructType('<I')),
-    ('tx_ins', ListType(tx_in)),
-    ('tx_outs', ListType(tx_out)),
+    ('tx_ins', ListType(ComposedType([
+        ('previous_output', ComposedType([
+            ('hash', HashType()),
+            ('index', StructType('<I')),
+        ])),
+        ('script', VarStrType()),
+        ('sequence', StructType('<I')),
+    ]))),
+    ('tx_outs', ListType(ComposedType([
+        ('value', StructType('<Q')),
+        ('script', VarStrType()),
+    ]))),
     ('lock_time', StructType('<I')),
 ])
 
@@ -172,73 +208,15 @@ block = ComposedType([
     ('txns', ListType(tx)),
 ])
 
-message_types = {
-    'version': ComposedType([
-        ('version', StructType('<I')),
-        ('services', StructType('<Q')),
-        ('timestamp', StructType('<Q')),
-        ('addr_me', address),
-        ('addr_you', address),
-        ('nonce', StructType('<Q')),
-        ('sub_version_num', VarStrType()),
-        ('start_height', StructType('<I')),
-    ]),
-    'verack': ComposedType([]),
-    'addr': ListType(ComposedType([
-        ('timestamp', StructType('<I')),
-        ('address', address),
-    ])),
-    'inv': ListType(inv_vector),
-    'getdata': ListType(inv_vector),
-    'getblocks': ComposedType([
-        # XXX has version here?
-        ('have', ListType(HashType())),
-        ('last', HashType()),
-    ]),
-    'getheaders': ComposedType([
-        # XXX has version here?
-        ('have', ListType(HashType())),
-        ('last', HashType()),
-    ]),
-    'tx': tx,
-    'block': block,
-    'headers': ListType(block_headers),
-    'getaddr': ComposedType([]),
-    'checkorder': ComposedType([
-        # XXX
-        ('id', HashType()),
-        ('order', FixedStrType(60)),
-    ]),
-    'submitorder': ComposedType([
-        # XXX
-        ('id', HashType()),
-        ('order', FixedStrType(60)),
-    ]),
-    'reply': ComposedType([
-        ('hash', HashType()),
-        ('reply',  EnumType(StructType('<I'), {0: 'success', 1: 'failure', 2: 'denied'})),
-        ('script', VarStrType()),
-    ]),
-    'ping': ComposedType([]),
-    'alert': ComposedType([
-        ('message', VarStrType()),
-        ('signature', VarStrType()),
-    ]),
-}
-
-def read_type(type_, payload):
-    f = StringIO.StringIO(payload)
-    payload2 = type_.read(f)
-    
-    if f.tell() != len(payload):
-        raise ValueError("underread " + repr((type_, payload)))
-    
-    return payload2
-
 def doublesha(data):
-    return read_type(HashType(), hashlib.sha256(hashlib.sha256(data).digest()).digest())
+    return HashType().unpack(hashlib.sha256(hashlib.sha256(data).digest()).digest())
 
 def merkle_hash(txn_list):
+    merkle_record = ComposedType([
+        ('left', HashType()),
+        ('right', HashType()),
+    ])
+    
     hash_list = [doublesha(tx.pack(txn)) for txn in txn_list]
     while len(hash_list) > 1:
         hash_list = [doublesha(merkle_record.pack(dict(left=left, right=left if right is None else right)))
@@ -248,48 +226,20 @@ def merkle_hash(txn_list):
 def block_hash(headers):
     return doublesha(block_headers.pack(headers))
 
-class Protocol(protocol.Protocol):
-    _prefix = '\xf9\xbe\xb4\xd9'
-    version = 0
-    buf = ""
-    
+class BaseProtocol(protocol.Protocol):
     def connectionMade(self):
         self.dataReceived = util.DataChunker(self.dataReceiver())
-        
-        self.sendPacket("version", dict(
-            version=32200,
-            services=1,
-            timestamp=int(time.time()),
-            addr_me=dict(
-                services=1,
-                address="::ffff:127.0.0.1",
-                port=self.transport.getHost().port,
-            ),
-            addr_you=dict(
-                services=1,
-                address="::ffff:127.0.0.1",
-                port=self.transport.getPeer().port,
-            ),
-            nonce=random.randrange(2**64),
-            sub_version_num="",
-            start_height=0,
-        ))
     
     def dataReceiver(self):
         while True:
-            start = yield 4
-            junk = ""
+            start = ""
             while start != self._prefix:
-                start = start + (yield 1)
-                junk += start[:-4]
-                start = start[-4:]
-            if junk:
-                print "JUNK", repr(junk)
+                start = (start + (yield 1))[-4:]
             
             command = (yield 12).rstrip('\0')
             length, = struct.unpack("<I", (yield 4))
             
-            if self.version >= 209:
+            if self.use_checksum:
                 checksum = yield 4
             else:
                 checksum = None
@@ -302,13 +252,14 @@ class Protocol(protocol.Protocol):
                     print "INVALID HASH"
                     continue
             
-            type_ = message_types.get(command, None)
+            type_ = self.message_types.get(command, None)
             if type_ is None:
-                print "ERROR: NO TYPE FOR", repr(command)
+                print "RECV", command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
+                print "NO TYPE FOR", repr(command)
                 continue
             
             try:
-                payload2 = read_type(type_, payload)
+                payload2 = type_.unpack(payload)
             except:
                 print "RECV", command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
                 traceback.print_exc()
@@ -317,12 +268,119 @@ class Protocol(protocol.Protocol):
             handler = getattr(self, "handle_" + command, None)
             if handler is None:
                 print "RECV", command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
-                print self, "has no handler for", command
-            else:
-                try:
-                    handler(payload2)
-                except:
-                    traceback.print_exc()
+                print "NO HANDLER FOR", command
+                continue
+            
+            
+            #print "RECV", command, payload2
+            
+            try:
+                handler(payload2)
+            except:
+                print "RECV", command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
+                traceback.print_exc()
+                continue
+    
+    def sendPacket(self, command, payload2={}):
+        payload = self.message_types[command].pack(payload2)
+        if len(command) >= 12:
+            raise ValueError("command too long")
+        if self.use_checksum:
+            checksum = hashlib.sha256(hashlib.sha256(payload).digest()).digest()[:4]
+        else:
+            checksum = ""
+        data = self._prefix + struct.pack("<12sI", command, len(payload)) + checksum + payload
+        self.transport.write(data)
+        #print "SEND", command, payload2
+
+class Protocol(BaseProtocol):
+    _prefix = '\xf9\xbe\xb4\xd9'
+    
+    version = 0
+    
+    @property
+    def use_checksum(self):
+        return self.version >= 209
+    
+    message_types = {
+        'version': ComposedType([
+            ('version', StructType('<I')),
+            ('services', StructType('<Q')),
+            ('time', StructType('<Q')),
+            ('addr_to', address),
+            ('addr_from', address),
+            ('nonce', StructType('<Q')),
+            ('sub_version_num', VarStrType()),
+            ('start_height', StructType('<I')),
+        ]),
+        'verack': ComposedType([]),
+        'addr': ListType(ComposedType([
+            ('timestamp', StructType('<I')),
+            ('address', address),
+        ])),
+        'inv': ListType(ComposedType([
+            ('type', EnumType(StructType('<I'), {"tx": 1, "block": 2})),
+            ('hash', HashType()),
+        ])),
+        'getdata': ListType(ComposedType([
+            ('type', EnumType(StructType('<I'), {"tx": 1, "block": 2})),
+            ('hash', HashType()),
+        ])),
+        'getblocks': ComposedType([
+            ('version', StructType('<I')),
+            ('have', ListType(HashType())),
+            ('last', HashType()),
+        ]),
+        'getheaders': ComposedType([
+            ('version', StructType('<I')),
+            ('have', ListType(HashType())),
+            ('last', HashType()),
+        ]),
+        'tx': tx,
+        'block': block,
+        'headers': ListType(block_headers),
+        'getaddr': ComposedType([]),
+        'checkorder': ComposedType([
+            ('id', HashType()),
+            ('order', FixedStrType(60)), # XXX
+        ]),
+        'submitorder': ComposedType([
+            ('id', HashType()),
+            ('order', FixedStrType(60)), # XXX
+        ]),
+        'reply': ComposedType([
+            ('hash', HashType()),
+            ('reply',  EnumType(StructType('<I'), {'success': 0, 'failure': 1, 'denied': 2})),
+            ('script', VarStrType()),
+        ]),
+        'ping': ComposedType([]),
+        'alert': ComposedType([
+            ('message', VarStrType()),
+            ('signature', VarStrType()),
+        ]),
+    }
+    
+    def connectionMade(self):
+        BaseProtocol.connectionMade(self)
+        
+        self.sendPacket("version", dict(
+            version=32200,
+            services=1,
+            time=int(time.time()),
+            addr_to=dict(
+                services=1,
+                address='::ffff:' + self.transport.getPeer().host,
+                port=self.transport.getPeer().port,
+            ),
+            addr_from=dict(
+                services=1,
+                address='::ffff:' + self.transport.getHost().host,
+                port=self.transport.getHost().port,
+            ),
+            nonce=random.randrange(2**64),
+            sub_version_num="",
+            start_height=0,
+        ))
     
     def handle_version(self, payload):
         #print "VERSION", payload
@@ -333,8 +391,10 @@ class Protocol(protocol.Protocol):
         self.version = self.version_after
         
         # connection ready
-        self.checkorder = util.GenericDeferrer(2**256, lambda id, order: self.sendPacket("checkorder", dict(id=id, order=order)))
-        self.submitorder = util.GenericDeferrer(2**256, lambda id, order: self.sendPacket("submitorder", dict(id=id, order=order)))
+        self.check_order = util.GenericDeferrer(2**256, lambda id, order: self.sendPacket("checkorder", dict(id=id, order=order)))
+        self.submit_order = util.GenericDeferrer(2**256, lambda id, order: self.sendPacket("submitorder", dict(id=id, order=order)))
+        self.get_block = util.ReplyMatcher(lambda hash: self.sendPacket("getdata", [dict(type="block", hash=hash)]))
+        self.get_block_headers = util.ReplyMatcher(lambda hash: self.sendPacket("getdata", [dict(type="block", hash=hash)]))
         
         if hasattr(self.factory, "resetDelay"):
             self.factory.resetDelay()
@@ -352,12 +412,14 @@ class Protocol(protocol.Protocol):
     
     def handle_reply(self, payload):
         hash_ = payload.pop('hash')
-        self.checkorder.gotResponse(hash_, payload)
+        self.check_order.got_response(hash_, payload)
+        self.submit_order.got_response(hash_, payload)
     
     def handle_tx(self, payload):
         pass#print "TX", hex(merkle_hash([payload])), payload
     
     def handle_block(self, payload):
+        self.get_block.got_response(block_hash(payload['headers']), payload)
         #print "BLOCK", hex(block_hash(payload['headers']))
         #print payload
         #print merkle_hash(payload['txns'])
@@ -367,18 +429,6 @@ class Protocol(protocol.Protocol):
     def handle_ping(self, payload):
         pass
     
-    def sendPacket(self, command, payload2={}):
-        payload = message_types[command].pack(payload2)
-        if len(command) >= 12:
-            raise ValueError("command too long")
-        if self.version >= 209:
-            checksum = hashlib.sha256(hashlib.sha256(payload).digest()).digest()[:4]
-        else:
-            checksum = ""
-        data = self._prefix + struct.pack("<12sI", command, len(payload)) + checksum + payload
-        self.transport.write(data)
-        #print "SEND", command, repr(payload.encode('hex'))
-    
     def connectionLost(self, reason):
         if hasattr(self.factory, "gotConnection"):
             self.factory.gotConnection(None)
index 6d50714..54450c9 100644 (file)
@@ -39,26 +39,25 @@ class LinkedList(object):
     def __repr__(self):
         return "LinkedList(%r)" % (list(self),)
     
+    def __len__(self):
+        return sum(1 for x in self)
+    
     def __iter__(self):
         cur = self.start.next
-        while True:
-            if cur is self.end:
-                break
-            yield cur.contents
+        while cur is not self.end:
+            cur2 = cur
             cur = cur.next
-    
-    def __len__(self):
-        return sum(1 for x in self)
+            yield cur2 # in case cur is deleted, but items inserted after are ignored
     
     def __reversed__(self):
         cur = self.end.prev
-        while True:
-            if cur is self.start:
-                break
+        while cur is not self.start:
+            cur2 = cur
             cur = cur.prev
-            yield cur.contents
+            yield cur2
     
     def __getitem__(self, index):
+        # odd one out - probably should return Node instance instead of its contents
         if index < 0:
             cur = self.end
             for i in xrange(-index):
@@ -71,7 +70,7 @@ class LinkedList(object):
                 cur = cur.next
                 if cur is self.end:
                     raise IndexError("index out of range")
-        return cur.contents
+        return cur
     
     def appendleft(self, item):
         return self.start.insert_after(item)
@@ -95,53 +94,76 @@ class LinkedList(object):
 
 
 class ExpiringDict(object):
-    def __init__(self, expiry_time=600):
-        self.d = dict()
+    def __init__(self, expiry_time=100, get_touches=True):
         self.expiry_time = expiry_time
+        self.get_touches = get_touches
+        
         self.expiry_deque = LinkedList()
-        self.key_to_node = {}
+        self.d = dict() # key -> node, value
     
     def __repr__(self):
-        self._expire()
+        self.expire()
         return "ExpiringDict" + repr(self.__dict__)
     
-    def _touch(self, key):
-        if key in self.key_to_node:
-            self.key_to_node[key].delete()
-        self.key_to_node[key] = self.expiry_deque.append((time.time(), key))
-    
-    def _expire(self):
-        while self.expiry_deque and self.expiry_deque[0][0] < time.time() - self.expiry_time:
+    def __len__(self):
+        self.expire()
+        return len(self.d)
+    
+    _nothing = object()
+    def touch(self, key, value=_nothing):
+        "Updates expiry node, optionally replacing value, returning new value"
+        if value is self._nothing or key in self.d:
+            node, old_value = self.d[key]
+            node.delete()
+        
+        new_value = old_value if value is self._nothing else value
+        self.d[key] = self.expiry_deque.append((time.time(), key)), new_value
+        return new_value
+    
+    def expire(self):
+        for node in self.expiry_deque:
+            timestamp, key = node.contents
+            if timestamp + self.expiry_time > time.time():
+                break
+            del self.d[key]
+        while self.expiry_deque and self.expiry_deque[0].contents[0] < time.time() - self.expiry_time:
             timestamp, key = self.expiry_deque.popleft()
             del self.d[key]
-            del self.key_to_node[key]
+    
+    def __contains__(self, key):
+        return key in self.d
     
     def __getitem__(self, key):
-        value = self.d[key]
-        self._touch(key)
-        self._expire()
+        if self.get_touches:
+            value = self.touch(key)
+        else:
+            node, value = self.d[key]
+        self.expire()
         return value
     
     def __setitem__(self, key, value):
-        self.d[key] = value
-        self._touch(key)
-        self._expire()
+        self.touch(key, value)
+        self.expire()
     
     def __delitem__(self, key):
-        del self.d[key]
-        self.key_to_node.pop(key).delete()
-        self._expire()
+        node, value = self.d.pop(key)
+        node.delete()
+        self.expire()
     
-    def get(self, key, default_value):
+    def get(self, key, default_value=None):
         if key in self.d:
-            return self[key]
+            res = self[key]
         else:
-            return default_value
+            res = default_value
+            self.expire()
+        return default_value
     
     def setdefault(self, key, default_value):
-        value = self.d.get(key, default_value)
-        self[key] = value
-        return value
+        if key in self.d:
+            return self[key]
+        else:
+            self[key] = default_value
+            return default_value
 
 if __name__ == '__main__':
     x = ExpiringDict(5)
diff --git a/main.py b/main.py
index 481fa2c..42d9616 100644 (file)
--- a/main.py
+++ b/main.py
@@ -6,7 +6,6 @@ import os
 import sys
 import traceback
 import random
-import StringIO
 
 from twisted.internet import reactor, defer
 from twisted.web import server
@@ -63,21 +62,35 @@ bitcoind_group.add_argument(metavar="BITCOIND_RPC_PASSWORD",
 TARGET_MULTIPLIER = 1000000000 # 100
 
 class Node(object):
-    def __init__(self, block):
+    def __init__(self, block, shares):
         self.block = block
-        self.shared = False
+        self.coinbase = coinbase_type.read(self.block['txns'][0]['tx_ins'][0]['script'], ignore_extra=True)
+        self.shares = shares
+    
+    #@classmethod
+    #def accept(
     
     def hash(self):
-        return bitcoin_p2p.block_hash(self.block)
+        return bitcoin_p2p.block_hash(self.block['headers'])
     
     def previous_hash(self):
-        hash_ = bitcoin_p2p.Hash().read(StringIO.StringIO(self.block['transactions'][0]['tx_ins']['script']))
+        hash_ = self.coinbase['previous_block2']
         if hash_ == 2**256 - 1:
             return None
         return hash_
     
-    def check(self, height, previous_node):
-        ah
+    def check(self, chain, height2, previous_node):
+        if self.block['headers']['version'] != chain.version: return False
+        if self.block['headers']['previous_block'] != chain.previous_block: return False
+        if self.block['headers']['merkle_root'] != bitcoin_p2p.merkle_hash(self.block['txns']): return False
+        if self.block['headers']['bits'] != chain.bits: return False
+        
+        if not self.block['txns']: return False
+        if len(self.block['txns'][0]['tx_ins']) != 1: return False
+        
+        okay, self.shares = check_transaction(self.block['txns'][0], {} if previous_node is None else previous_node.shares)
+        
+        return okay
     
     def share(self):
         if self.shared:
@@ -86,10 +99,11 @@ class Node(object):
         a
 
 class Chain(object):
-    def __init__(self):
+    def __init__(self, version, previous_block, bits, height):
+        self.version, self.previous_block1, self.bits, self.height1 = version, previous_block, bits, height
+        
         self.nodes = {} # hash -> (height, node)
-        self.highest = util.Variable(None)
-        self.highest_height = -1
+        self.highest = util.Variable(None) # (height, node)
         self.shared = set()
     
     def accept(self, node, is_current):
@@ -109,17 +123,79 @@ class Chain(object):
         
         height = previous_height + 1
         
-        if not node.check(height, previous_node):
+        if not node.check(self, height, previous_node):
             return
         
         self.nodes[hash_] = (height, node)
         
-        if hieght > self.highest_height:
-            self.highest_height, self.highest.value = height, node
+        if height > self.highest.value[0]:
+            self.highest.set((height, node))
         
         if is_current:
             node.share()
 
+def check_transaction(t, shares):
+    coinbase = coinbase_type.read(t['tx_ins'][0]['script'], ignore_extra=True)
+    t2, new_shares = generate_transaction(shares, t['tx_outs'][coinbase['last_share_index']]['script'], coinbase['subsidy'], coinbase['previous_block2'])
+    return t2 == t, shares
+
+def generate_transaction(shares, add_pubkey, subsidy, previous_block2):
+    shares = shares[1:-1] + [add_pubkey, add_pubkey]
+    total_shares = len(shares)
+    
+    grouped_shares = {}
+    for script in shares:
+        grouped_shares[script]
+    amounts = dict((pubkey, subsidy*shares//total_shares) for (pubkey, shares) in shares.iteritems())
+    amounts = incr_dict(amounts, "XXX", subsidy - sum(amounts.itervalues()))
+    dests = sorted(amounts.iterkeys())
+    
+    return dict(
+        version=1,
+        tx_ins=[dict(
+            previous_output=dict(index=4294967295, hash=0),
+            sequence=4294967295,
+            script=coinbase_type.pack(dict(
+                version=1,
+                subsidy=subsidy,
+                previous_block2=previous_block2,
+                last_share_index=dests.index(add_pubkey),
+                nonce=random.randrange(2**256) if nonce is None else nonce,
+            )),
+        )],
+        tx_outs=[dict(value=amount, script=pubkey) for (pubkey, amount) in dests],
+        lock_time=0,
+    ), shares
+
+class DeferredCacher(object):
+    # XXX should combine requests
+    def __init__(self, func, backing=None):
+        if backing is None:
+            backing = {}
+        
+        self.func = func
+        self.backing = backing
+    
+    @defer.inlineCallbacks
+    def __call__(self, key):
+        if key in self.backing:
+            defer.returnValue(self.backing[key])
+        value = yield self.func(key)
+        self.backing[key] = value
+        defer.returnValue(value)
+
+@defer.inlineCallbacks
+def get_last_p2pool_block(current_block_hash, get_block):
+    block_hash = current_block_hash
+    while True:
+        print hex(block_hash)
+        if block_hash == 0x2c0117ac4e1f784761bc010f5d69c2b107c659a672d0107df64:
+            defer.returnValue(block_hash)
+        block = yield get_block(block_hash)
+        if block == 5:
+            defer.returnValue(block_hash)
+        block_hash = block['headers']['previous_block']
+
 @defer.inlineCallbacks
 def getwork(bitcoind, chains):
     while True:
@@ -130,11 +206,19 @@ def getwork(bitcoind, chains):
             traceback.print_exc()
             yield util.sleep(1)
             continue
+        defer.returnValue((getwork, height))
         defer.returnValue((
-            ((getwork.version, getwork.previous_block, getwork.bits), height, chains.get(getwork.previous_block, Chain()).highest),
+            ((getwork.version, getwork.previous_block, getwork.bits), height, chains.get(getwork.previous_block, Chain()).highest.value),
             (getwork.timestamp,),
         ))
 
+coinbase_type = bitcoin_p2p.ComposedType([
+    ('subsidy', bitcoin_p2p.StructType('<Q')),
+    ('previous_block2', bitcoin_p2p.HashType()),
+    ('last_share_index', bitcoin_p2p.StructType('<I')),
+    ('nonce', bitcoin_p2p.HashType()),
+])
+
 @defer.inlineCallbacks
 def main(args):
     try:
@@ -149,11 +233,20 @@ def main(args):
         print "Testing bitcoind RPC connection..."
         bitcoind = jsonrpc.Proxy('http://%s:%i/' % (args.bitcoind_address, args.bitcoind_rpc_port), (args.bitcoind_rpc_username, args.bitcoind_rpc_password))
         
-        current_work_new, current_work2 = yield getwork(bitcoind, chains)
-        current_work.set(current_work_new)
+        work, height = yield getwork(bitcoind, chains)
+        current_work.set(dict(
+            version=work.version,
+            previous_block=work.previous_block,
+            bits=work.bits,
+            height=height,
+            highest_block2=None,
+        ))
+        current_work2 = dict(
+            timestamp=work.timestamp,
+        )
         
         print "    ...success!"
-        print "    Current block hash: %x height: %i" % (current_work.value[0][1], current_work.value[1])
+        print "    Current block hash: %x height: %i" % (current_work.value['previous_block'], current_work.value['height'])
         print
         
         # connect to bitcoind over bitcoin-p2p and do checkorder to get pubkey to send payouts to
@@ -163,14 +256,14 @@ def main(args):
         
         while True:
             try:
-                res = yield (yield factory.getProtocol()).checkorder(order='\0'*60)
+                res = yield (yield factory.getProtocol()).check_order(order='\0'*60)
                 if res['reply'] != 'success':
                     print "error in checkorder reply:", res
                     continue
+                my_pubkey = res['script']
             except:
                 traceback.print_exc()
             else:
-                my_pubkey = res['script']
                 break
             yield util.sleep(1)
         
@@ -178,30 +271,20 @@ def main(args):
         print "    Payout script:", my_pubkey.encode('hex')
         print
         
+        get_block = DeferredCacher(defer.inlineCallbacks(lambda block_hash: defer.returnValue((yield (yield factory.getProtocol()).get_block(block_hash)))), expiring_dict.ExpiringDict(3600))
+        print (yield get_last_p2pool_block(conv.BlockAttempt.from_getwork((yield bitcoind.rpc_getwork())).previous_block, get_block))
+        
         # setup worker logic
         
-        merkle_root_to_transactions = expiring_dict.ExpiringDict()
+        merkle_root_to_transactions = expiring_dict.ExpiringDict(100)
         
-        def transactions_from_shares(shares):
-            nHeight = 0 # XXX
-            subsidy = (50*100000000) >> (nHeight / 210000)
-            total_shares = sum(shares.itervalues())
-            amounts = dict((pubkey, subsidy*shares//total_shares) for (pubkey, shares) in shares.iteritems())
-            total_amount = sum(amounts.itervalues())
-            amount_left = subsidy - total_amount
-            incr_dict(amounts, "XXX", amount_left)
-            
-            transactions = [{
-                'version': 1,
-                'tx_ins': [{'previous_output': {'index': 4294967295, 'hash': 0}, 'sequence': 4294967295, 'script': bitcoin_p2p.Hash().pack(random.randrange(2**256))}],
-                'tx_outs': [dict(value=amount, script=pubkey) for (pubkey, amount) in sorted(amounts.iteritems())],
-                'lock_time': 0,
-            }]
-            return transactions
-        
-        def compute(((version, previous_block, timestamp, bits), log)):
-            log2 = util.incr_dict(log, my_pubkey)
-            transactions = transactions_from_shares(log2)
+        def compute(state, state2):
+            transactions = [generate_transaction(
+                shares=state['highest'].shares if state['highest'] is not None else {},
+                add_pubkey=my_pubkey,
+                subsidy=50*100000000 >> height//210000,
+                previous_block2=state['highest'].hash() if state['highest'] is not None else {},
+            )]
             merkle_root = bitcoin_p2p.merkle_hash(transactions)
             merkle_root_to_transactions[merkle_root] = transactions
             ba = conv.BlockAttempt(version, previous_block, merkle_root, timestamp, bits)
@@ -211,15 +294,25 @@ def main(args):
             # match up with transactions
             headers = conv.decode_data(data)
             transactions = merkle_root_to_transactions[headers['merkle_root']]
-            block = {'header': headers, 'txns': transactions}
+            block = dict(headers=headers, txns=transactions)
             return p2pCallback(bitcoin_p2p.block.pack(block))
         
         # setup p2p logic and join p2pool network
         
+        seen = set() # grows indefinitely!
+        
         def p2pCallback(block_data):
-            block = bitcoin_p2p.block.read(StringIO.StringIO(block_data))
+            block = bitcoin_p2p.block.unpack(block_data)
             hash_ = bitcoin_p2p.block_hash(block['headers'])
             
+            # early out for worthless
+            if hash_ < 2**256//2**32:
+                return
+            
+            if hash_ in seen:
+                return
+            seen.add(hash_)
+            
             if block['headers']['version'] != 1:
                 return False
             
@@ -243,7 +336,12 @@ def main(args):
         print "Joining p2pool network..."
         
         p2p_node = p2p.Node(p2pCallback, udpPort=random.randrange(49152, 65536) if args.p2pool_port is None else args.p2pool_port)
-        p2p_node.joinNetwork(args.p2pool_nodes)
+        def parse(x):
+            ip, port = x.split(':')
+            return ip, int(port)
+        
+        nodes = [('72.14.191.28', 21519)]*0
+        p2p_node.joinNetwork(map(parse, args.p2pool_nodes) + nodes)
         yield p2p_node._joinDeferred
         
         print "    ...success!"
@@ -264,8 +362,17 @@ def main(args):
         print
         
         while True:
-            current_work_new, current_work2 = yield getwork(bitcoind, chains)
-            current_work.set(current_work_new)
+            work, height = yield getwork(bitcoind, chains)
+            current_work.set(dict(
+                version=work.version,
+                previous_block=work.previous_block,
+                bits=work.bits,
+                height=height,
+                highest_block2=None,
+            ))
+            current_work2 = dict(
+                timestamp=work.timestamp,
+            )
             yield util.sleep(1)
     except:
         traceback.print_exc()
diff --git a/p2p.py b/p2p.py
index c9d39c7..2dcb52e 100644 (file)
--- a/p2p.py
+++ b/p2p.py
@@ -1,23 +1,17 @@
-import random
-import time
-import traceback
-
 from entangled.kademlia import node, encoding, protocol
 from twisted.internet import defer
 
-import util
-
 class CustomBencode(encoding.Bencode):
     def __init__(self, prefix=""):
         self.prefix = prefix
     
     def encode(self, data):
-        return self.prefix + encoding.Bencode.encode(data)
+        return self.prefix + encoding.Bencode.encode(self, data)
     
     def decode(self, data):
         if not data.startswith(self.prefix):
             raise ValueError("invalid prefix")
-        return encoding.Bencode.decode(data[len(self.prefix):])
+        return encoding.Bencode.decode(self, data[len(self.prefix):])
 
 class Node(node.Node):
     @property
@@ -29,45 +23,6 @@ class Node(node.Node):
     def __init__(self, blockCallback, **kwargs):
         node.Node.__init__(self, networkProtocol=protocol.KademliaProtocol(self, msgEncoder=CustomBencode("p2pool")), **kwargs)
         self.blockCallback = blockCallback
-        self.clock_offset = 0
-    
-    # time
-    
-    def joinNetwork(self, *args, **kwargs):
-        node.Node.joinNetwork(self, *args, **kwargs)
-        
-        def go(res):
-            self.joined()
-            return res
-        self._joinDeferred.addBoth(go)
-    
-    def joined(self):
-        self.time_task()
-    
-    def get_my_time(self):
-        return time.time() - self.clock_offset
-    
-    @node.rpcmethod
-    def get_time(self):
-        return time.time()
-    
-    @defer.inlineCallbacks
-    def time_task(self):
-        while True:
-            t_send = time.time()
-            clock_deltas = {None: (t_send, t_send)}
-            for peer, request in [(peer, peer.get_time().addCallback(lambda res: (time.time(), res))) for peer in self.peers]:
-                try:
-                    t_recv, response = yield request
-                    t = (t_send + t_recv)/2
-                    clock_deltas[(peer.id, peer.address, peer.port)] = (t, float(response))
-                except:
-                    traceback.print_exc()
-                    continue
-            
-            self.clock_offset = util.median(mine - theirs for mine, theirs in clock_deltas.itervalues())
-            
-            yield util.sleep(random.expovariate(1/500.))
     
     # disable data storage
     
diff --git a/util.py b/util.py
index 1ab7533..79b7a4e 100644 (file)
--- a/util.py
+++ b/util.py
@@ -1,6 +1,8 @@
 import random
+import collections
 
 from twisted.internet import defer, reactor
+from twisted.python import failure
 from twisted.web import server, resource
 
 class DeferredResource(resource.Resource):
@@ -47,7 +49,6 @@ class Event(object):
 class Variable(object):
     def __init__(self, value):
         self.value = value
-        
         self.changed = Event()
     
     def set(self, value):
@@ -55,7 +56,6 @@ class Variable(object):
             return
         
         self.value = value
-        
         self.changed.happened(value)
 
 def sleep(t):
@@ -70,27 +70,84 @@ def median(x):
     right = len(y)//2
     return (y[left] + y[right])/2
 
+class StringBuffer(object):
+    "Buffer manager with great worst-case behavior"
+    
+    def __init__(self, data=""):
+        self.buf = collections.deque([data])
+        self.buf_len = len(data)
+        self.pos = 0
+    
+    def __len__(self):
+        return self.buf_len - self.pos
+    
+    def add(self, data):
+        self.buf.append(data)
+        self.buf_len += len(data)
+    
+    def get(self, wants):
+        if self.buf_len - self.pos < wants:
+            raise IndexError("not enough data")
+        data = []
+        while wants:
+            seg = self.buf[0][self.pos:self.pos+wants]
+            self.pos += len(seg)
+            while self.buf and self.pos >= len(self.buf[0]):
+                x = self.buf.popleft()
+                self.buf_len -= len(x)
+                self.pos -= len(x)
+            
+            data.append(seg)
+            wants -= len(seg)
+        return ''.join(data)
+
 def _DataChunker(receiver):
     wants = receiver.next()
-    buf = ""
+    buf = StringBuffer()
     
     while True:
-        buf += yield
-        pos = 0
-        
-        while True:
-            if pos + wants > len(buf):
-                break
-            new_wants = receiver.send(buf[pos:pos + wants])
-            pos += wants
-            wants = new_wants
-        
-        buf = buf[pos:]
+        if len(buf) >= wants:
+            wants = receiver.send(buf.get(wants))
+        else:
+            buf.add((yield))
 def DataChunker(receiver):
+    """
+    Produces a function that accepts data that is input into a generator
+    (receiver) in response to the receiver yielding the size of data to wait on
+    """
     x = _DataChunker(receiver)
     x.next()
     return x.send
 
+class ReplyMatcher(object):
+    def __init__(self, func, timeout=5):
+        self.func = func
+        self.timeout = timeout
+        self.map = {}
+    
+    def __call__(self, id):
+      try:
+        self.func(id)
+        uniq = random.randrange(2**256)
+        df = defer.Deferred()
+        def timeout():
+            df, timer = self.map[id].pop(uniq)
+            df.errback(failure.Failure(defer.TimeoutError()))
+            if not self.map[id]:
+                del self.map[id]
+        self.map.setdefault(id, {})[uniq] = (df, reactor.callLater(self.timeout, timeout))
+        return df
+      except:
+        import traceback
+        traceback.print_exc()
+    
+    def got_response(self, id, resp):
+        if id not in self.map:
+            return
+        for df, timer in self.map[id].itervalues():
+            timer.cancel()
+            df.callback(resp)
+
 class GenericDeferrer(object):
     def __init__(self, max_id, func, timeout=5):
         self.max_id = max_id
@@ -106,23 +163,23 @@ class GenericDeferrer(object):
         df = defer.Deferred()
         def timeout():
             self.map.pop(id)
-            df.errback(fail.Failure(defer.TimeoutError()))
+            df.errback(failure.Failure(defer.TimeoutError()))
         timer = reactor.callLater(self.timeout, timeout)
         self.func(id, *args, **kwargs)
         self.map[id] = df, timer
         return df
     
-    def gotResponse(self, id, resp):
+    def got_response(self, id, resp):
         if id not in self.map:
             #print "got id without request", id, resp
             return # XXX
         df, timer = self.map.pop(id)
         timer.cancel()
         df.callback(resp)
-    
+
 def incr_dict(d, key, step=1):
     d = dict(d)
     if key not in d:
         d[key] = 0
     d[key] += 1
-        return d
+    return d