Typo fix
[stratum-mining.git] / lib / halfnode.py
1 #!/usr/bin/python
2 # Public Domain
3 # Original author: ArtForz
4 # Twisted integration: slush
5
6 import struct
7 import socket
8 import binascii
9 import time
10 import sys
11 import random
12 import cStringIO
13 from Crypto.Hash import SHA256
14
15 from twisted.internet.protocol import Protocol
16 from util import *
17
18 MY_VERSION = 31402
19 MY_SUBVERSION = ".4"
20
21 class CAddress(object):
22     def __init__(self):
23         self.nTime = 0
24         self.nServices = 1
25         self.pchReserved = "\x00" * 10 + "\xff" * 2
26         self.ip = "0.0.0.0"
27         self.port = 0
28     def deserialize(self, f):
29         #self.nTime = struct.unpack("<I", f.read(4))[0]
30         self.nServices = struct.unpack("<Q", f.read(8))[0]
31         self.pchReserved = f.read(12)
32         self.ip = socket.inet_ntoa(f.read(4))
33         self.port = struct.unpack(">H", f.read(2))[0]
34     def serialize(self):
35         r = ""
36         #r += struct.pack("<I", self.nTime)
37         r += struct.pack("<Q", self.nServices)
38         r += self.pchReserved
39         r += socket.inet_aton(self.ip)
40         r += struct.pack(">H", self.port)
41         return r
42     def __repr__(self):
43         return "CAddress(nServices=%i ip=%s port=%i)" % (self.nServices, self.ip, self.port)
44
45 class CInv(object):
46     typemap = {
47         0: "Error",
48         1: "TX",
49         2: "Block"}
50     def __init__(self):
51         self.type = 0
52         self.hash = 0L
53     def deserialize(self, f):
54         self.type = struct.unpack("<i", f.read(4))[0]
55         self.hash = deser_uint256(f)
56     def serialize(self):
57         r = ""
58         r += struct.pack("<i", self.type)
59         r += ser_uint256(self.hash)
60         return r
61     def __repr__(self):
62         return "CInv(type=%s hash=%064x)" % (self.typemap[self.type], self.hash)
63
64 class CBlockLocator(object):
65     def __init__(self):
66         self.nVersion = MY_VERSION
67         self.vHave = []
68     def deserialize(self, f):
69         self.nVersion = struct.unpack("<i", f.read(4))[0]
70         self.vHave = deser_uint256_vector(f)
71     def serialize(self):
72         r = ""
73         r += struct.pack("<i", self.nVersion)
74         r += ser_uint256_vector(self.vHave)
75         return r
76     def __repr__(self):
77         return "CBlockLocator(nVersion=%i vHave=%s)" % (self.nVersion, repr(self.vHave))
78
79 class COutPoint(object):
80     def __init__(self):
81         self.hash = 0
82         self.n = 0
83     def deserialize(self, f):
84         self.hash = deser_uint256(f)
85         self.n = struct.unpack("<I", f.read(4))[0]
86     def serialize(self):
87         r = ""
88         r += ser_uint256(self.hash)
89         r += struct.pack("<I", self.n)
90         return r
91     def __repr__(self):
92         return "COutPoint(hash=%064x n=%i)" % (self.hash, self.n)
93
94 class CTxIn(object):
95     def __init__(self):
96         self.prevout = COutPoint()
97         self.scriptSig = ""
98         self.nSequence = 0
99     def deserialize(self, f):
100         self.prevout = COutPoint()
101         self.prevout.deserialize(f)
102         self.scriptSig = deser_string(f)
103         self.nSequence = struct.unpack("<I", f.read(4))[0]
104     def serialize(self):
105         r = ""
106         r += self.prevout.serialize()
107         r += ser_string(self.scriptSig)
108         r += struct.pack("<I", self.nSequence)
109         return r
110     def __repr__(self):
111         return "CTxIn(prevout=%s scriptSig=%s nSequence=%i)" % (repr(self.prevout), binascii.hexlify(self.scriptSig), self.nSequence)
112
113 class CTxOut(object):
114     def __init__(self):
115         self.nValue = 0
116         self.scriptPubKey = ""
117     def deserialize(self, f):
118         self.nValue = struct.unpack("<q", f.read(8))[0]
119         self.scriptPubKey = deser_string(f)
120     def serialize(self):
121         r = ""
122         r += struct.pack("<q", self.nValue)
123         r += ser_string(self.scriptPubKey)
124         return r
125     def __repr__(self):
126         return "CTxOut(nValue=%i.%08i scriptPubKey=%s)" % (self.nValue // 100000000, self.nValue % 100000000, binascii.hexlify(self.scriptPubKey))
127
128 class CTransaction(object):
129     def __init__(self):
130         self.nVersion = 1
131         self.nTime = 0
132         self.vin = []
133         self.vout = []
134         self.nLockTime = 0
135         self.sha256 = None
136     def deserialize(self, f):
137         self.nVersion = struct.unpack("<i", f.read(4))[0]
138         self.nTime = struct.unpack("<i", f.read(4))[0]
139         self.vin = deser_vector(f, CTxIn)
140         self.vout = deser_vector(f, CTxOut)
141         self.nLockTime = struct.unpack("<I", f.read(4))[0]
142         self.sha256 = None
143     def serialize(self):
144         r = ""
145         r += struct.pack("<i", self.nVersion)
146         r += struct.pack("<i", self.nTime)
147         r += ser_vector(self.vin)
148         r += ser_vector(self.vout)
149         r += struct.pack("<I", self.nLockTime)
150         return r
151
152     def calc_sha256(self):
153         if self.sha256 is None:
154             self.sha256 = uint256_from_str(SHA256.new(SHA256.new(self.serialize()).digest()).digest())
155         return self.sha256
156
157     def is_valid(self):
158         self.calc_sha256()
159         for tout in self.vout:
160             if tout.nValue < 0 or tout.nValue > 21000000L * 100000000L:
161                 return False
162         return True
163     def __repr__(self):
164         return "CTransaction(nVersion=%i vin=%s vout=%s nLockTime=%i)" % (self.nVersion, repr(self.vin), repr(self.vout), self.nLockTime)
165
166 class CBlock(object):
167     def __init__(self):
168         self.nVersion = 6
169         self.hashPrevBlock = 0
170         self.hashMerkleRoot = 0
171         self.nTime = 0
172         self.nBits = 0
173         self.nNonce = 0
174         self.vtx = []
175         self.sha256 = None
176         self.signature = b""
177     def deserialize(self, f):
178         self.nVersion = struct.unpack("<i", f.read(4))[0]
179         self.hashPrevBlock = deser_uint256(f)
180         self.hashMerkleRoot = deser_uint256(f)
181         self.nTime = struct.unpack("<I", f.read(4))[0]
182         self.nBits = struct.unpack("<I", f.read(4))[0]
183         self.nNonce = struct.unpack("<I", f.read(4))[0]
184         self.vtx = deser_vector(f, CTransaction)
185         self.signature = deser_string(f)
186     def serialize(self):
187         r = []
188         r.append(struct.pack("<i", self.nVersion))
189         r.append(ser_uint256(self.hashPrevBlock))
190         r.append(ser_uint256(self.hashMerkleRoot))
191         r.append(struct.pack("<I", self.nTime))
192         r.append(struct.pack("<I", self.nBits))
193         r.append(struct.pack("<I", self.nNonce))
194         r.append(ser_vector(self.vtx))
195         r.append(ser_string(self.signature))
196         return ''.join(r)
197     def calc_sha256(self):
198         if self.sha256 is None:
199             r = []
200             r.append(struct.pack("<i", self.nVersion))
201             r.append(ser_uint256(self.hashPrevBlock))
202             r.append(ser_uint256(self.hashMerkleRoot))
203             r.append(struct.pack("<I", self.nTime))
204             r.append(struct.pack("<I", self.nBits))
205             r.append(struct.pack("<I", self.nNonce))
206             self.sha256 = uint256_from_str(scrypt(''.join(r)))
207         return self.sha256
208
209     def is_valid(self):
210         self.calc_sha256()
211         target = uint256_from_compact(self.nBits)
212         if self.sha256 > target:
213             return False
214         hashes = []
215         for tx in self.vtx:
216             tx.sha256 = None
217             if not tx.is_valid():
218                 return False
219             tx.calc_sha256()
220             hashes.append(ser_uint256(tx.sha256))
221
222         while len(hashes) > 1:
223             newhashes = []
224             for i in xrange(0, len(hashes), 2):
225                 i2 = min(i+1, len(hashes)-1)
226                 newhashes.append(SHA256.new(SHA256.new(hashes[i] + hashes[i2]).digest()).digest())
227             hashes = newhashes
228
229         if uint256_from_str(hashes[0]) != self.hashMerkleRoot:
230             return False
231         return True
232     def __repr__(self):
233         return "CBlock(nVersion=%i hashPrevBlock=%064x hashMerkleRoot=%064x nTime=%s nBits=%08x nNonce=%08x vtx=%s)" % (self.nVersion, self.hashPrevBlock, self.hashMerkleRoot, time.ctime(self.nTime), self.nBits, self.nNonce, repr(self.vtx))
234
235 class msg_version(object):
236     command = "version"
237     def __init__(self):
238         self.nVersion = MY_VERSION
239         self.nServices = 0
240         self.nTime = time.time()
241         self.addrTo = CAddress()
242         self.addrFrom = CAddress()
243         self.nNonce = random.getrandbits(64)
244         self.strSubVer = MY_SUBVERSION
245         self.nStartingHeight = 0
246
247     def deserialize(self, f):
248         self.nVersion = struct.unpack("<i", f.read(4))[0]
249         if self.nVersion == 10300:
250             self.nVersion = 300
251         self.nServices = struct.unpack("<Q", f.read(8))[0]
252         self.nTime = struct.unpack("<q", f.read(8))[0]
253         self.addrTo = CAddress()
254         self.addrTo.deserialize(f)
255         self.addrFrom = CAddress()
256         self.addrFrom.deserialize(f)
257         self.nNonce = struct.unpack("<Q", f.read(8))[0]
258         self.strSubVer = deser_string(f)
259         self.nStartingHeight = struct.unpack("<i", f.read(4))[0]
260     def serialize(self):
261         r = []
262         r.append(struct.pack("<i", self.nVersion))
263         r.append(struct.pack("<Q", self.nServices))
264         r.append(struct.pack("<q", self.nTime))
265         r.append(self.addrTo.serialize())
266         r.append(self.addrFrom.serialize())
267         r.append(struct.pack("<Q", self.nNonce))
268         r.append(ser_string(self.strSubVer))
269         r.append(struct.pack("<i", self.nStartingHeight))
270         return ''.join(r)
271     def __repr__(self):
272         return "msg_version(nVersion=%i nServices=%i nTime=%s addrTo=%s addrFrom=%s nNonce=0x%016X strSubVer=%s nStartingHeight=%i)" % (self.nVersion, self.nServices, time.ctime(self.nTime), repr(self.addrTo), repr(self.addrFrom), self.nNonce, self.strSubVer, self.nStartingHeight)
273
274 class msg_verack(object):
275     command = "verack"
276     def __init__(self):
277         pass
278     def deserialize(self, f):
279         pass
280     def serialize(self):
281         return ""
282     def __repr__(self):
283         return "msg_verack()"
284
285 class msg_addr(object):
286     command = "addr"
287     def __init__(self):
288         self.addrs = []
289     def deserialize(self, f):
290         self.addrs = deser_vector(f, CAddress)
291     def serialize(self):
292         return ser_vector(self.addrs)
293     def __repr__(self):
294         return "msg_addr(addrs=%s)" % (repr(self.addrs))
295
296 class msg_inv(object):
297     command = "inv"
298     def __init__(self):
299         self.inv = []
300     def deserialize(self, f):
301         self.inv = deser_vector(f, CInv)
302     def serialize(self):
303         return ser_vector(self.inv)
304     def __repr__(self):
305         return "msg_inv(inv=%s)" % (repr(self.inv))
306
307 class msg_getdata(object):
308     command = "getdata"
309     def __init__(self):
310         self.inv = []
311     def deserialize(self, f):
312         self.inv = deser_vector(f, CInv)
313     def serialize(self):
314         return ser_vector(self.inv)
315     def __repr__(self):
316         return "msg_getdata(inv=%s)" % (repr(self.inv))
317
318 class msg_getblocks(object):
319     command = "getblocks"
320     def __init__(self):
321         self.locator = CBlockLocator()
322         self.hashstop = 0L
323     def deserialize(self, f):
324         self.locator = CBlockLocator()
325         self.locator.deserialize(f)
326         self.hashstop = deser_uint256(f)
327     def serialize(self):
328         r = []
329         r.append(self.locator.serialize())
330         r.append(ser_uint256(self.hashstop))
331         return ''.join(r)
332     def __repr__(self):
333         return "msg_getblocks(locator=%s hashstop=%064x)" % (repr(self.locator), self.hashstop)
334
335 class msg_tx(object):
336     command = "tx"
337     def __init__(self):
338         self.tx = CTransaction()
339     def deserialize(self, f):
340         self.tx.deserialize(f)
341     def serialize(self):
342         return self.tx.serialize()
343     def __repr__(self):
344         return "msg_tx(tx=%s)" % (repr(self.tx))
345
346 class msg_block(object):
347     command = "block"
348     def __init__(self):
349         self.block = CBlock()
350     def deserialize(self, f):
351         self.block.deserialize(f)
352     def serialize(self):
353         return self.block.serialize()
354     def __repr__(self):
355         return "msg_block(block=%s)" % (repr(self.block))
356
357 class msg_getaddr(object):
358     command = "getaddr"
359     def __init__(self):
360         pass
361     def deserialize(self, f):
362         pass
363     def serialize(self):
364         return ""
365     def __repr__(self):
366         return "msg_getaddr()"
367
368 class msg_ping(object):
369     command = "ping"
370     def __init__(self):
371         pass
372     def deserialize(self, f):
373         pass
374     def serialize(self):
375         return ""
376     def __repr__(self):
377         return "msg_ping()"
378
379 class msg_alert(object):
380     command = "alert"
381     def __init__(self):
382         pass
383     def deserialize(self, f):
384         pass
385     def serialize(self):
386         return ""
387     def __repr__(self):
388         return "msg_alert()"
389
390 class BitcoinP2PProtocol(Protocol):
391     messagemap = {
392         "version": msg_version,
393         "verack": msg_verack,
394         "addr": msg_addr,
395         "inv": msg_inv,
396         "getdata": msg_getdata,
397         "getblocks": msg_getblocks,
398         "tx": msg_tx,
399         "block": msg_block,
400         "getaddr": msg_getaddr,
401         "ping": msg_ping,
402         "alert": msg_alert,
403     }
404
405     def connectionMade(self):
406         peer = self.transport.getPeer()
407         self.dstaddr = peer.host
408         self.dstport = peer.port
409         self.recvbuf = ""
410         self.last_sent = 0
411
412         t = msg_version()
413         t.nStartingHeight = getattr(self, 'nStartingHeight', 0)
414         t.addrTo.ip = self.dstaddr
415         t.addrTo.port = self.dstport
416         t.addrTo.nTime = time.time()
417         t.addrFrom.ip = "0.0.0.0"
418         t.addrFrom.port = 0
419         t.addrFrom.nTime = time.time()
420         self.send_message(t)
421
422     def dataReceived(self, data):
423         self.recvbuf += data
424         self.got_data()
425
426     def got_data(self):
427         while True:
428             if len(self.recvbuf) < 4:
429                 return
430             if self.recvbuf[:4] != "\xf9\xbe\xb4\xd9":
431                 raise ValueError("got garbage %s" % repr(self.recvbuf))
432
433             if len(self.recvbuf) < 4 + 12 + 4 + 4:
434                 return
435             command = self.recvbuf[4:4+12].split("\x00", 1)[0]
436             msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0]
437             checksum = self.recvbuf[4+12+4:4+12+4+4]
438             if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen:
439                 return
440             msg = self.recvbuf[4+12+4+4:4+12+4+4+msglen]
441             th = SHA256.new(msg).digest()
442             h = SHA256.new(th).digest()
443             if checksum != h[:4]:
444                 raise ValueError("got bad checksum %s" % repr(self.recvbuf))
445             self.recvbuf = self.recvbuf[4+12+4+4+msglen:]
446
447             if command in self.messagemap:
448                 f = cStringIO.StringIO(msg)
449                 t = self.messagemap[command]()
450                 t.deserialize(f)
451                 self.got_message(t)
452             else:
453                 print "UNKNOWN COMMAND", command, repr(msg)
454
455     def prepare_message(self, message):
456         command = message.command
457         data = message.serialize()
458         tmsg = "\xf9\xbe\xb4\xd9"
459         tmsg += command
460         tmsg += "\x00" * (12 - len(command))
461         tmsg += struct.pack("<I", len(data))
462         th = SHA256.new(data).digest()
463         h = SHA256.new(th).digest()
464         tmsg += h[:4]
465         tmsg += data
466         return tmsg
467
468     def send_serialized_message(self, tmsg):
469         if not self.connected:
470             return
471
472         self.transport.write(tmsg)
473         self.last_sent = time.time()
474
475     def send_message(self, message):
476         if not self.connected:
477             return
478
479         #print message.command
480
481         #print "send %s" % repr(message)
482         command = message.command
483         data = message.serialize()
484         tmsg = "\xf9\xbe\xb4\xd9"
485         tmsg += command
486         tmsg += "\x00" * (12 - len(command))
487         tmsg += struct.pack("<I", len(data))
488         th = SHA256.new(data).digest()
489         h = SHA256.new(th).digest()
490         tmsg += h[:4]
491         tmsg += data
492
493         #print tmsg, len(tmsg)
494         self.transport.write(tmsg)
495         self.last_sent = time.time()
496
497     def got_message(self, message):
498         if self.last_sent + 30 * 60 < time.time():
499             self.send_message(msg_ping())
500
501         mname = 'do_' + message.command
502         #print mname
503         if not hasattr(self, mname):
504             return
505
506         method = getattr(self, mname)
507         method(message)
508
509 #        if message.command == "tx":
510 #            message.tx.calc_sha256()
511 #            sha256 = message.tx.sha256
512 #            pubkey = binascii.hexlify(message.tx.vout[0].scriptPubKey)
513 #            txlock.acquire()
514 #            tx.append([str(sha256), str(time.time()), str(self.dstaddr), pubkey])
515 #            txlock.release()
516
517     def do_version(self, message):
518         #print message
519         self.send_message(msg_verack())
520
521     def do_inv(self, message):
522         want = msg_getdata()
523         for i in message.inv:
524             if i.type == 1:
525                 want.inv.append(i)
526             if i.type == 2:
527                 want.inv.append(i)
528         if len(want.inv):
529             self.send_message(want)