interface: call socket.connect before sll.wrap_socket, for proxy. (fixes bug #207)
authorThomasV <thomasv@gitorious>
Tue, 1 Oct 2013 07:01:46 +0000 (09:01 +0200)
committerThomasV <thomasv@gitorious>
Tue, 1 Oct 2013 07:01:46 +0000 (09:01 +0200)
gui/qt/network_dialog.py
lib/interface.py
lib/network.py

index 16a0927..d66b31c 100644 (file)
@@ -99,7 +99,7 @@ class NetworkDialog(QDialog):
 
         self.server_protocol.connect(self.server_protocol, SIGNAL('currentIndexChanged(int)'), self.change_protocol)
 
-        label = _('Active Servers') if interface.servers else _('Default Servers')
+        label = _('Active Servers') #if interface.servers else _('Default Servers')
         self.servers_list_widget = QTreeWidget(parent)
         self.servers_list_widget.setHeaderLabels( [ label, _('Limit') ] )
         self.servers_list_widget.setMaximumHeight(150)
index 713ea50..7873396 100644 (file)
@@ -32,11 +32,23 @@ proxy_modes = ['socks4', 'socks5', 'http']
 class Interface(threading.Thread):
 
 
-    def init_server(self, host, port, proxy=None, use_ssl=True):
-        self.host = host
-        self.port = port
-        self.proxy = proxy
-        self.use_ssl = use_ssl
+    def __init__(self, config=None):
+
+        if config is None:
+            from simple_config import SimpleConfig
+            config = SimpleConfig()
+
+        threading.Thread.__init__(self)
+        self.daemon = True
+        self.config = config
+        self.connect_event = threading.Event()
+
+        self.subscriptions = {}
+        self.lock = threading.Lock()
+
+        self.rtime = 0
+        self.bytes_received = 0
+        self.is_connected = False
         self.poll_interval = 1
 
         #json
@@ -44,6 +56,21 @@ class Interface(threading.Thread):
         self.unanswered_requests = {}
         self.pending_transactions_for_notifications= []
 
+        # parse server
+        s = config.get('server')
+        host, port, protocol = s.split(':')
+        port = int(port)
+        if protocol not in 'ghst':
+            raise BaseException('Unknown protocol: %s'%protocol)
+
+        self.host = host
+        self.port = port
+        self.protocol = protocol
+        self.use_ssl = ( protocol in 'sg' )
+        self.proxy = self.parse_proxy_options(config.get('proxy'))
+        self.server = host + ':%d:%s'%(port, protocol)
+
+
 
     def queue_json_response(self, c):
 
@@ -104,7 +131,6 @@ class Interface(threading.Thread):
 
 
     def init_http(self, host, port, proxy=None, use_ssl=True):
-        self.init_server(host, port, proxy, use_ssl)
         self.session_id = None
         self.is_connected = True
         self.connection_msg = ('https' if self.use_ssl else 'http') + '://%s:%d'%( self.host, self.port )
@@ -209,27 +235,25 @@ class Interface(threading.Thread):
 
 
 
-    def init_tcp(self, host, port, proxy=None, use_ssl=True):
+    def start_tcp(self):
 
         if self.use_ssl:
-            cert_path = os.path.join( self.config.get('path'), 'certs', host)
+            cert_path = os.path.join( self.config.get('path'), 'certs', self.host)
             if not os.path.exists(cert_path):
                 dir_path = os.path.join( self.config.get('path'), 'certs')
                 if not os.path.exists(dir_path):
                     os.mkdir(dir_path)
                 try:
-                    cert = ssl.get_server_certificate((host, port))
+                    cert = ssl.get_server_certificate((self.host, self.port))
                 except:
-                    print_error("failed to connect", host, port)
+                    print_error("failed to connect", self.host, self.port)
                     return
                     
                 with open(cert_path,"w") as f:
                     f.write(cert)
 
-        self.init_server(host, port, proxy, use_ssl)
+        self.connection_msg = "%s:%d"%(self.host, self.port)
 
-        global proxy_modes
-        self.connection_msg = "%s:%d"%(self.host,self.port)
 
         if self.proxy is None:
             s = socket.socket( socket.AF_INET, socket.SOCK_STREAM )
@@ -238,6 +262,14 @@ class Interface(threading.Thread):
             s = socks.socksocket()
             s.setproxy(proxy_modes.index(self.proxy["mode"]) + 1, self.proxy["host"], int(self.proxy["port"]) )
 
+        s.settimeout(2)
+        s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+
+        try:
+            s.connect(( self.host.encode('ascii'), int(self.port)))
+        except:
+            print_error("failed to connect", self.host, self.port)
+            return
 
         if self.use_ssl:
             try:
@@ -246,37 +278,28 @@ class Interface(threading.Thread):
                                     cert_reqs=ssl.CERT_REQUIRED,
                                     ca_certs=cert_path,
                                     do_handshake_on_connect=True)
+            except ssl.SSLError, e:
+                print_error("SSL error:", self.host, e)
+                return
             except:
-                print_error("wrap_socket failed", host)
+                traceback.print_exc(file=sys.stdout)
+                print_error("wrap_socket failed", self.host)
                 return
 
-        s.settimeout(2)
-        s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
-
-        try:
-            s.connect(( self.host.encode('ascii'), int(self.port)))
-        except ssl.SSLError, e:
-            print_error("SSL error:", host, e)
-            return
-        except:
-            #traceback.print_exc(file=sys.stdout)
-            print_error("failed to connect", host, port)
-            return
-
         # hostname verification (disabled)
         if self.use_ssl and False:
             from backports.ssl_match_hostname import match_hostname, CertificateError
             try:
-                match_hostname(s.getpeercert(), host)
-                print_error("hostname matches", host)
+                match_hostname(s.getpeercert(), self.host)
+                print_error("hostname matches", self.host)
             except CertificateError, ce:
-                print_error("hostname does not match", host, s.getpeercert())
+                print_error("hostname does not match", self.host, s.getpeercert())
                 return
 
         s.settimeout(60)
         self.s = s
         self.is_connected = True
-        print_error("connected to", host, port)
+        print_error("connected to", self.host, self.port)
 
 
     def run_tcp(self):
@@ -355,65 +378,18 @@ class Interface(threading.Thread):
 
 
 
-    def __init__(self, config=None):
-        #self.server = random.choice(filter_protocol(DEFAULT_SERVERS, 's'))
-        self.proxy = None
 
-        if config is None:
-            from simple_config import SimpleConfig
-            config = SimpleConfig()
 
-        threading.Thread.__init__(self)
-        self.daemon = True
-        self.config = config
-        self.connect_event = threading.Event()
-
-        self.subscriptions = {}
-        self.lock = threading.Lock()
-
-        self.servers = {} # actual list from IRC
-        self.rtime = 0
-        self.bytes_received = 0
-        self.is_connected = False
+    def start_interface(self):
 
-        # init with None server, in case we are offline 
-        self.init_server(None, None)
-
-
-
-
-    def init_interface(self):
-        if self.config.get('server'):
-            self.init_with_server(self.config)
-        else:
-            if self.config.get('auto_cycle') is None:
-                self.config.set_key('auto_cycle', True, False)
-
-        if not self.is_connected: 
-            self.connect_event.set()
-            return
+        if self.protocol in 'st':
+            self.start_tcp()
+        elif self.protocol in 'gh':
+            self.start_http()
 
         self.connect_event.set()
 
 
-    def init_with_server(self, config):
-            
-        s = config.get('server')
-        host, port, protocol = s.split(':')
-        port = int(port)
-
-        self.protocol = protocol
-        proxy = self.parse_proxy_options(config.get('proxy'))
-        self.server = host + ':%d:%s'%(port, protocol)
-
-        #print protocol, host, port
-        if protocol in 'st':
-            self.init_tcp(host, port, proxy, use_ssl=(protocol=='s'))
-        elif protocol in 'gh':
-            self.init_http(host, port, proxy, use_ssl=(protocol=='g'))
-        else:
-            raise BaseException('Unknown protocol: %s'%protocol)
-
 
     def stop_subscriptions(self):
         for callback in self.subscriptions.keys():
@@ -505,7 +481,7 @@ class Interface(threading.Thread):
 
 
     def run(self):
-        self.init_interface()
+        self.start_interface()
         if self.is_connected:
             self.send([('server.version', [ELECTRUM_VERSION, PROTOCOL_VERSION])], self.on_version)
             self.change_status()
index d14a8ef..578c109 100644 (file)
@@ -53,6 +53,7 @@ class Network(threading.Thread):
         self.servers = []
         self.banner = ''
         self.interface = None
+        self.proxy = self.config.get('proxy')
         self.heights = {}
 
 
@@ -95,7 +96,7 @@ class Network(threading.Thread):
     def start_interface(self, server):
         if server in self.interfaces.keys():
             return
-        i = interface.Interface({'server':server, 'path':self.config.path})
+        i = interface.Interface({'server':server, 'path':self.config.path, 'proxy':self.proxy})
         self.interfaces[server] = i
         i.start(self.queue)
 
@@ -130,6 +131,7 @@ class Network(threading.Thread):
 
         i = self.interface
         self.default_server = server
+        self.proxy = proxy
         self.start_interface(server)
         self.interface = self.interfaces[server]
         i.stop_subscriptions() # fixme: it should not stop all subscriptions, and send 'unsubscribe'