rewriting portions of nattraverso
[p2pool.git] / nattraverso / pynupnp / upnp.py
index e3af52a..eae568f 100644 (file)
@@ -28,19 +28,6 @@ class UPnPError(Exception):
     """
     pass
 
-def search_upnp_device ():
-    """
-    Check the network for an UPnP device. Returns a deferred
-    with the L{UPnPDevice} instance as result, if found.
-    
-    @return: A deferred called with the L{UPnPDevice} instance
-    @rtype: L{twisted.internet.defer.Deferred}
-    """
-    try:
-        return UPnPProtocol().search_device()
-    except Exception, msg:
-        return defer.fail(UPnPError(msg))
-
 class UPnPMapper(portmapper.NATMapper):
     """
     This is the UPnP port mapper implementing the
@@ -326,14 +313,11 @@ class UPnPDevice:
         """
         logging.debug("_on_no_port_mapping_received: %s", failure)
         err = failure.value
-        try:
-            message = err.args[0]["UPnPError"]["errorDescription"]
-            if "SpecifiedArrayIndexInvalid" == message:
-                return mappings
-            else:
-                raise UPnPError("GetGenericPortMappingEntry got %s"%(message))
-        except:
-            raise UPnPError("GetGenericPortMappingEntry got %s"%(err.args[0]))
+        message = err.args[0]["UPnPError"]["errorDescription"]
+        if "SpecifiedArrayIndexInvalid" == message:
+            return mappings
+        else:
+            return failure
     
     
     def _on_port_mapping_added(self, response):
@@ -349,7 +333,7 @@ class UPnPDevice:
         
         @raise UPnPError: When the port mapping could not be added
         """
-        raise UPnPError(failure.value.args[0])
+        return failure
     
     def _on_port_mapping_removed(self, response):
         """
@@ -364,7 +348,7 @@ class UPnPDevice:
         
         @raise UPnPError: When the port mapping could not be deleted
         """
-        raise UPnPError(failure.value.args[0])
+        return failure
 
 # UPNP multicast address, port and request string
 _UPNP_MCAST = '239.255.255.250'
@@ -388,13 +372,11 @@ class UPnPProtocol(DatagramProtocol, object):
         """
         super(UPnPProtocol, self).__init__(*args, **kwargs)
         
-        # Url to use to talk to upnp device
-        self._control_url = None
-        self._device = None
-        
         #Device discovery deferred
         self._discovery = None
         self._discovery_timeout = None
+        self.mcast = None
+        self._done = False
     
     # Public methods
     def search_device(self):
@@ -407,31 +389,33 @@ class UPnPProtocol(DatagramProtocol, object):
         @return: A deferred called with the detected L{UPnPDevice} instance.
         @rtype: L{twisted.internet.defer.Deferred}
         """
+        if self._discovery is not None:
+            raise ValueError('already used')
         self._discovery = defer.Deferred()
+        self._discovery_timeout = reactor.callLater(6, self._on_discovery_timeout)
         
         attempt = 0
         mcast = None
         while True:
             try:
-                mcast = reactor.listenMulticast(1900+attempt, self)
+                self.mcast = reactor.listenMulticast(1900+attempt, self)
                 break
             except CannotListenError:
                 attempt = random.randint(0, 500)
         
         # joined multicast group, starting upnp search
-        mcast.joinGroup('239.255.255.250', socket.INADDR_ANY)
+        self.mcast.joinGroup('239.255.255.250', socket.INADDR_ANY)
         
         self.transport.write(_UPNP_SEARCH_REQUEST, (_UPNP_MCAST, _UPNP_PORT))
         self.transport.write(_UPNP_SEARCH_REQUEST, (_UPNP_MCAST, _UPNP_PORT))
         self.transport.write(_UPNP_SEARCH_REQUEST, (_UPNP_MCAST, _UPNP_PORT))
         
-        self._discovery_timeout = reactor.callLater(
-            6, self._on_discovery_timeout)
-        
         return self._discovery
     
     #Private methods
     def datagramReceived(self, dgram, address):
+        if self._done:
+            return
         """
         This is private, handle the multicast answer from the upnp device.
         """
@@ -443,16 +427,11 @@ class UPnPProtocol(DatagramProtocol, object):
         # Prepare status line
         version, status, textstatus = response.split(None, 2)
         
-        if not version.startswith('HTTP') or self._control_url != None:
+        if not version.startswith('HTTP'):
             return
         if status != "200":
             return
         
-        # We had a timeout pending, cancel it
-        if self._discovery_timeout != None:
-            self._discovery_timeout.cancel()
-            self._discovery_timeout = None
-        
         # Launch the info fetching
         def parse_discovery_response(message):
             """Separate headers and body from the received http answer."""
@@ -480,10 +459,11 @@ class UPnPProtocol(DatagramProtocol, object):
         
         loc = headers['location'][0]
         result = client.getPage(url=loc)
-        result.addCallback(self._on_gateway_response, loc).addErrback(
-            self._on_discovery_failed)
+        result.addCallback(self._on_gateway_response, loc).addErrback(self._on_discovery_failed)
     
     def _on_gateway_response(self, body, loc):
+        if self._done:
+            return
         """
         Called with the UPnP device XML description fetched via HTTP.
         
@@ -496,8 +476,6 @@ class UPnPProtocol(DatagramProtocol, object):
         @param body: The xml description of the device.
         @param loc: the url used to retreive the xml description
         """
-        if self._control_url != None:
-            return
         
         # Parse answer
         upnpinfo = UPnPXml(body)
@@ -510,35 +488,42 @@ class UPnPProtocol(DatagramProtocol, object):
         # Check the control url, if None, then the device cannot do what we want
         controlurl = upnpinfo.controlurl
         if controlurl == None:
-            self._on_discovery_failed(
-                UPnPError("upnp response showed no WANConnections"))
+            self._on_discovery_failed(UPnPError("upnp response showed no WANConnections"))
             return
         
-        self._control_url = urlparse.urljoin(urlbase, controlurl)
-        
-        soap_proxy = SoapProxy(self._control_url, upnpinfo.wanservice)
-        if self._discovery != None:
-            self._device = UPnPDevice(soap_proxy, upnpinfo.deviceinfos)
-            self._discovery.callback(self._device)
-            self._discovery = None
+        control_url2 = urlparse.urljoin(urlbase, controlurl)
+        soap_proxy = SoapProxy(control_url2, upnpinfo.wanservice)
+        self._on_discovery_succeeded(UPnPDevice(soap_proxy, upnpinfo.deviceinfos))
+    
+    def _on_discovery_succeeded(self, res):
+        if self._done:
+            return
+        self._done = True
+        self.mcast.stopListening()
+        self._discovery_timeout.cancel()
+        self._discovery.callback(res)
     
     def _on_discovery_failed(self, err):
-        """
-        Called when the UPnP Device discovery has failed.
-        
-        The callback returned in L{search_device} is called with
-        an error, corresponding to the cause of the failure.
-        """
-        self._control_url = None
-        if self._discovery != None:
-            self._discovery.errback(err)
-            self._discovery = None
+        if self._done:
+            return
+        self._done = True
+        self.mcast.stopListening()
+        self._discovery_timeout.cancel()
+        self._discovery.errback(err)
     
     def _on_discovery_timeout(self):
-        """
-        Called when the UPnP Device discovery has timed out.
-        
-        Calls L{_on_discovery_failed}.
-        """
-        self._discovery_timeout = None
-        self._on_discovery_failed(UPnPError())
+        if self._done:
+            return
+        self._done = True
+        self.mcast.stopListening()
+        self._discovery.errback(failure.Failure(defer.TimeoutError()))
+
+def search_upnp_device ():
+    """
+    Check the network for an UPnP device. Returns a deferred
+    with the L{UPnPDevice} instance as result, if found.
+    
+    @return: A deferred called with the L{UPnPDevice} instance
+    @rtype: L{twisted.internet.defer.Deferred}
+    """
+    return defer.maybeDeferred(UPnPProtocol().search_device)