update a few scripts
[electrum-nvc.git] / lib / interface.py
index e6f2e21..b82cfae 100644 (file)
@@ -25,24 +25,71 @@ import ssl
 
 from version import ELECTRUM_VERSION, PROTOCOL_VERSION
 from util import print_error, print_msg
+from simple_config import SimpleConfig
 
 
 DEFAULT_TIMEOUT = 5
 proxy_modes = ['socks4', 'socks5', 'http']
 
 
-class Interface(threading.Thread):
+def check_cert(host, cert):
+    from OpenSSL import crypto as c
+    _cert = c.load_certificate(c.FILETYPE_PEM, cert)
+
+    m = "host: %s\n"%host
+    m += "has_expired: %s\n"% _cert.has_expired()
+    m += "pubkey: %s bits\n" % _cert.get_pubkey().bits()
+    m += "serial number: %s\n"% _cert.get_serial_number() 
+    #m += "issuer: %s\n"% _cert.get_issuer()
+    #m += "algo: %s\n"% _cert.get_signature_algorithm() 
+    m += "version: %s\n"% _cert.get_version()
+    print_msg(m)
+
+
+def cert_has_expired(cert_path):
+    try:
+        import OpenSSL
+    except:
+        print_error("Warning: cannot import OpenSSL")
+        return False
+    from OpenSSL import crypto as c
+    with open(cert_path) as f:
+        cert = f.read()
+    _cert = c.load_certificate(c.FILETYPE_PEM, cert)
+    return _cert.has_expired()
+
+
+def check_certificates():
+    config = SimpleConfig()
+    mydir = os.path.join(config.path, "certs")
+    certs = os.listdir(mydir)
+    for c in certs:
+        print c
+        p = os.path.join(mydir,c)
+        with open(p) as f:
+            cert = f.read()
+        check_cert(c, cert)
+    
 
+def cert_verify_hostname(s):
+    # hostname verification (disabled)
+    from backports.ssl_match_hostname import match_hostname, CertificateError
+    try:
+        match_hostname(s.getpeercert(True), host)
+        print_error("hostname matches", host)
+    except CertificateError, ce:
+        print_error("hostname did not match", host)
 
-    def __init__(self, config=None):
 
-        if config is None:
-            from simple_config import SimpleConfig
-            config = SimpleConfig()
+
+class Interface(threading.Thread):
+
+
+    def __init__(self, server, config = None):
 
         threading.Thread.__init__(self)
         self.daemon = True
-        self.config = config
+        self.config = config if config is not None else SimpleConfig()
         self.connect_event = threading.Event()
 
         self.subscriptions = {}
@@ -53,14 +100,16 @@ class Interface(threading.Thread):
         self.is_connected = False
         self.poll_interval = 1
 
+        self.debug = False # dump network messages. can be changed at runtime using the console
+
         #json
         self.message_id = 0
         self.unanswered_requests = {}
         self.pending_transactions_for_notifications= []
 
         # parse server
-        s = config.get('server')
-        host, port, protocol = s.split(':')
+        self.server = server
+        host, port, protocol = self.server.split(':')
         port = int(port)
             
         if protocol not in 'ghst':
@@ -70,17 +119,19 @@ class Interface(threading.Thread):
         self.port = port
         self.protocol = protocol
         self.use_ssl = ( protocol in 'sg' )
-        self.proxy = self.parse_proxy_options(config.get('proxy'))
+        self.proxy = self.parse_proxy_options(self.config.get('proxy'))
         if self.proxy:
             self.proxy_mode = proxy_modes.index(self.proxy["mode"]) + 1
-        self.server = host + ':%d:%s'%(port, protocol)
+
+
 
 
 
     def queue_json_response(self, c):
 
         # uncomment to debug
-        # print_error( "<--",c )
+        if self.debug:
+            print_error( "<--",c )
 
         msg_id = c.get('id')
         error = c.get('error')
@@ -135,7 +186,7 @@ class Interface(threading.Thread):
         self.server_version = result
 
 
-    def init_http(self, host, port, proxy=None, use_ssl=True):
+    def start_http(self):
         self.session_id = None
         self.is_connected = True
         self.connection_msg = ('https' if self.use_ssl else 'http') + '://%s:%d'%( self.host, self.port )
@@ -171,7 +222,7 @@ class Interface(threading.Thread):
 
                 
     def poll(self):
-        self.send([])
+        self.send([], None)
 
 
     def send_http(self, messages, callback):
@@ -242,6 +293,8 @@ class Interface(threading.Thread):
 
     def start_tcp(self):
 
+        self.connection_msg = self.host + ':%d' % self.port
+
         if self.proxy is not None:
 
             socks.setdefaultproxy(self.proxy_mode, self.proxy["host"], int(self.proxy["port"]))
@@ -252,25 +305,34 @@ class Interface(threading.Thread):
             socket.getaddrinfo = getaddrinfo
 
         if self.use_ssl:
-            cert_path = os.path.join( self.config.get('path'), 'certs', self.host)
+            cert_path = os.path.join( self.config.path, 'certs', self.host)
+
             if not os.path.exists(cert_path):
+                is_new = True
                 # get server certificate.
                 # Do not use ssl.get_server_certificate because it does not work with proxy
                 s = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
                 try:
                     s.connect((self.host, self.port))
                 except:
-                    print_error("failed to connect", self.host, self.port)
+                    # print_error("failed to connect", self.host, self.port)
                     return
 
-                s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_NONE, ca_certs=None)
+                try:
+                    s = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_SSLv3, cert_reqs=ssl.CERT_NONE, ca_certs=None)
+                except ssl.SSLError, e:
+                    print_error("SSL error:", self.host, e)
+                    return
                 dercert = s.getpeercert(True)
                 s.close()
                 cert = ssl.DER_cert_to_PEM_cert(dercert)
-                    
-                with open(cert_path,"w") as f:
+                temporary_path = cert_path + '.temp'
+                with open(temporary_path,"w") as f:
                     f.write(cert)
 
+            else:
+                is_new = False
+
 
         s = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
         s.settimeout(2)
@@ -287,25 +349,29 @@ class Interface(threading.Thread):
                 s = ssl.wrap_socket(s,
                                     ssl_version=ssl.PROTOCOL_SSLv3,
                                     cert_reqs=ssl.CERT_REQUIRED,
-                                    ca_certs=cert_path,
+                                    ca_certs= (temporary_path if is_new else cert_path),
                                     do_handshake_on_connect=True)
             except ssl.SSLError, e:
                 print_error("SSL error:", self.host, e)
+                if e.errno != 1:
+                    return
+                if is_new:
+                    os.rename(temporary_path, cert_path + '.rej')
+                else:
+                    if cert_has_expired(cert_path):
+                        print_error("certificate has expired:", cert_path)
+                        os.unlink(cert_path)
+                    else:
+                        print_msg("wrong certificate", self.host)
                 return
             except:
-                traceback.print_exc(file=sys.stdout)
                 print_error("wrap_socket failed", self.host)
+                traceback.print_exc(file=sys.stdout)
                 return
 
-        # hostname verification (disabled)
-        if self.use_ssl and False:
-            from backports.ssl_match_hostname import match_hostname, CertificateError
-            try:
-                match_hostname(s.getpeercert(), self.host)
-                print_error("hostname matches", self.host)
-            except CertificateError, ce:
-                print_error("hostname does not match", self.host, s.getpeercert())
-                return
+            if is_new:
+                print_error("saving certificate for", self.host)
+                os.rename(temporary_path, cert_path)
 
         s.settimeout(60)
         self.s = s
@@ -367,8 +433,8 @@ class Interface(threading.Thread):
             request = json.dumps( { 'id':self.message_id, 'method':method, 'params':params } )
             self.unanswered_requests[self.message_id] = method, params, callback
             ids.append(self.message_id)
-            # uncomment to debug
-            # print "-->", request
+            if self.debug:
+                print "-->", request
             self.message_id += 1
             out += request + '\n'
         while out:
@@ -425,6 +491,7 @@ class Interface(threading.Thread):
                         self.subscriptions[callback].append(message)
 
         if not self.is_connected: 
+            print_error("interface: trying to send while not connected")
             return
 
         if self.protocol in 'st':
@@ -468,27 +535,12 @@ class Interface(threading.Thread):
         return self.unanswered_requests == {}
 
 
-    def synchronous_get(self, requests, timeout=100000000):
-        # todo: use generators, unanswered_requests should be a list of arrays...
-        queue = Queue.Queue()
-        ids = self.send(requests, lambda i,r: queue.put(r))
-        id2 = ids[:]
-        res = {}
-        while ids:
-            r = queue.get(True, timeout)
-            _id = r.get('id')
-            if _id in ids:
-                ids.remove(_id)
-                res[_id] = r.get('result')
-        out = []
-        for _id in id2:
-            out.append(res[_id])
-        return out
-
 
-    def start(self, queue):
-        self.queue = queue
+    def start(self, queue = None, wait = False):
+        self.queue = queue if queue else Queue.Queue()
         threading.Thread.start(self)
+        if wait:
+            self.connect_event.wait()
 
 
     def run(self):
@@ -507,11 +559,5 @@ class Interface(threading.Thread):
 
 
 if __name__ == "__main__":
-    
-    q = Queue.Queue()
-    i = Interface({'server':'btc.it-zone.org:50002:s', 'path':'/extra/key/wallet', 'verbose':True})
-    i.start(q)
-    time.sleep(1)
-    exit()
 
-    
+    check_certificates()