bundled wstools-0.3
[p2pool.git] / wstools / TimeoutSocket.py
diff --git a/wstools/TimeoutSocket.py b/wstools/TimeoutSocket.py
new file mode 100755 (executable)
index 0000000..48b898d
--- /dev/null
@@ -0,0 +1,179 @@
+"""Based on code from timeout_socket.py, with some tweaks for compatibility.
+   These tweaks should really be rolled back into timeout_socket, but it's
+   not totally clear who is maintaining it at this point. In the meantime,
+   we'll use a different module name for our tweaked version to avoid any
+   confusion.
+
+   The original timeout_socket is by:
+
+       Scott Cotton <scott@chronis.pobox.com>
+       Lloyd Zusman <ljz@asfast.com>
+       Phil Mayes <pmayes@olivebr.com>
+       Piers Lauder <piers@cs.su.oz.au>
+       Radovan Garabik <garabik@melkor.dnp.fmph.uniba.sk>
+"""
+
+ident = "$Id$"
+
+import string, socket, select, errno
+
+WSAEINVAL = getattr(errno, 'WSAEINVAL', 10022)
+
+
+class TimeoutSocket:
+    """A socket imposter that supports timeout limits."""
+
+    def __init__(self, timeout=20, sock=None):
+        self.timeout = float(timeout)
+        self.inbuf = ''
+        if sock is None:
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.sock = sock
+        self.sock.setblocking(0)
+        self._rbuf = ''
+        self._wbuf = ''
+
+    def __getattr__(self, name):
+        # Delegate to real socket attributes.
+        return getattr(self.sock, name)
+
+    def connect(self, *addr):
+        timeout = self.timeout
+        sock = self.sock
+        try:
+            # Non-blocking mode
+            sock.setblocking(0)
+            apply(sock.connect, addr)
+            sock.setblocking(timeout != 0)
+            return 1
+        except socket.error,why:
+            if not timeout:
+                raise
+            sock.setblocking(1)
+            if len(why.args) == 1:
+                code = 0
+            else:
+                code, why = why
+            if code not in (
+                errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK
+                ):
+                raise
+            r,w,e = select.select([],[sock],[],timeout)
+            if w:
+                try:
+                    apply(sock.connect, addr)
+                    return 1
+                except socket.error,why:
+                    if len(why.args) == 1:
+                        code = 0
+                    else:
+                        code, why = why
+                    if code in (errno.EISCONN, WSAEINVAL):
+                        return 1
+                    raise
+        raise TimeoutError('socket connect() timeout.')
+
+    def send(self, data, flags=0):
+        total = len(data)
+        next = 0
+        while 1:
+            r, w, e = select.select([],[self.sock], [], self.timeout)
+            if w:
+                buff = data[next:next + 8192]
+                sent = self.sock.send(buff, flags)
+                next = next + sent
+                if next == total:
+                    return total
+                continue
+            raise TimeoutError('socket send() timeout.')
+
+    def recv(self, amt, flags=0):
+        if select.select([self.sock], [], [], self.timeout)[0]:
+            return self.sock.recv(amt, flags)
+        raise TimeoutError('socket recv() timeout.')
+
+    buffsize = 4096
+    handles = 1
+
+    def makefile(self, mode="r", buffsize=-1):
+        self.handles = self.handles + 1
+        self.mode = mode
+        return self
+
+    def close(self):
+        self.handles = self.handles - 1
+        if self.handles == 0 and self.sock.fileno() >= 0:
+            self.sock.close()
+
+    def read(self, n=-1):
+        if not isinstance(n, type(1)):
+            n = -1
+        if n >= 0:
+            k = len(self._rbuf)
+            if n <= k:
+                data = self._rbuf[:n]
+                self._rbuf = self._rbuf[n:]
+                return data
+            n = n - k
+            L = [self._rbuf]
+            self._rbuf = ""
+            while n > 0:
+                new = self.recv(max(n, self.buffsize))
+                if not new: break
+                k = len(new)
+                if k > n:
+                    L.append(new[:n])
+                    self._rbuf = new[n:]
+                    break
+                L.append(new)
+                n = n - k
+            return "".join(L)
+        k = max(4096, self.buffsize)
+        L = [self._rbuf]
+        self._rbuf = ""
+        while 1:
+            new = self.recv(k)
+            if not new: break
+            L.append(new)
+            k = min(k*2, 1024**2)
+        return "".join(L)
+
+    def readline(self, limit=-1):
+        data = ""
+        i = self._rbuf.find('\n')
+        while i < 0 and not (0 < limit <= len(self._rbuf)):
+            new = self.recv(self.buffsize)
+            if not new: break
+            i = new.find('\n')
+            if i >= 0: i = i + len(self._rbuf)
+            self._rbuf = self._rbuf + new
+        if i < 0: i = len(self._rbuf)
+        else: i = i+1
+        if 0 <= limit < len(self._rbuf): i = limit
+        data, self._rbuf = self._rbuf[:i], self._rbuf[i:]
+        return data
+
+    def readlines(self, sizehint = 0):
+        total = 0
+        list = []
+        while 1:
+            line = self.readline()
+            if not line: break
+            list.append(line)
+            total += len(line)
+            if sizehint and total >= sizehint:
+                break
+        return list
+
+    def writelines(self, list):
+        self.send(''.join(list))
+
+    def write(self, data):
+        self.send(data)
+
+    def flush(self):
+        pass
+
+
+class TimeoutError(Exception):
+    pass