fixed error in assertion text
[p2pool.git] / wstools / TimeoutSocket.py
1 """Based on code from timeout_socket.py, with some tweaks for compatibility.
2    These tweaks should really be rolled back into timeout_socket, but it's
3    not totally clear who is maintaining it at this point. In the meantime,
4    we'll use a different module name for our tweaked version to avoid any
5    confusion.
6
7    The original timeout_socket is by:
8
9         Scott Cotton <scott@chronis.pobox.com>
10         Lloyd Zusman <ljz@asfast.com>
11         Phil Mayes <pmayes@olivebr.com>
12         Piers Lauder <piers@cs.su.oz.au>
13         Radovan Garabik <garabik@melkor.dnp.fmph.uniba.sk>
14 """
15
16 ident = "$Id$"
17
18 import string, socket, select, errno
19
20 WSAEINVAL = getattr(errno, 'WSAEINVAL', 10022)
21
22
23 class TimeoutSocket:
24     """A socket imposter that supports timeout limits."""
25
26     def __init__(self, timeout=20, sock=None):
27         self.timeout = float(timeout)
28         self.inbuf = ''
29         if sock is None:
30             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
31         self.sock = sock
32         self.sock.setblocking(0)
33         self._rbuf = ''
34         self._wbuf = ''
35
36     def __getattr__(self, name):
37         # Delegate to real socket attributes.
38         return getattr(self.sock, name)
39
40     def connect(self, *addr):
41         timeout = self.timeout
42         sock = self.sock
43         try:
44             # Non-blocking mode
45             sock.setblocking(0)
46             apply(sock.connect, addr)
47             sock.setblocking(timeout != 0)
48             return 1
49         except socket.error,why:
50             if not timeout:
51                 raise
52             sock.setblocking(1)
53             if len(why.args) == 1:
54                 code = 0
55             else:
56                 code, why = why
57             if code not in (
58                 errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK
59                 ):
60                 raise
61             r,w,e = select.select([],[sock],[],timeout)
62             if w:
63                 try:
64                     apply(sock.connect, addr)
65                     return 1
66                 except socket.error,why:
67                     if len(why.args) == 1:
68                         code = 0
69                     else:
70                         code, why = why
71                     if code in (errno.EISCONN, WSAEINVAL):
72                         return 1
73                     raise
74         raise TimeoutError('socket connect() timeout.')
75
76     def send(self, data, flags=0):
77         total = len(data)
78         next = 0
79         while 1:
80             r, w, e = select.select([],[self.sock], [], self.timeout)
81             if w:
82                 buff = data[next:next + 8192]
83                 sent = self.sock.send(buff, flags)
84                 next = next + sent
85                 if next == total:
86                     return total
87                 continue
88             raise TimeoutError('socket send() timeout.')
89
90     def recv(self, amt, flags=0):
91         if select.select([self.sock], [], [], self.timeout)[0]:
92             return self.sock.recv(amt, flags)
93         raise TimeoutError('socket recv() timeout.')
94
95     buffsize = 4096
96     handles = 1
97
98     def makefile(self, mode="r", buffsize=-1):
99         self.handles = self.handles + 1
100         self.mode = mode
101         return self
102
103     def close(self):
104         self.handles = self.handles - 1
105         if self.handles == 0 and self.sock.fileno() >= 0:
106             self.sock.close()
107
108     def read(self, n=-1):
109         if not isinstance(n, type(1)):
110             n = -1
111         if n >= 0:
112             k = len(self._rbuf)
113             if n <= k:
114                 data = self._rbuf[:n]
115                 self._rbuf = self._rbuf[n:]
116                 return data
117             n = n - k
118             L = [self._rbuf]
119             self._rbuf = ""
120             while n > 0:
121                 new = self.recv(max(n, self.buffsize))
122                 if not new: break
123                 k = len(new)
124                 if k > n:
125                     L.append(new[:n])
126                     self._rbuf = new[n:]
127                     break
128                 L.append(new)
129                 n = n - k
130             return "".join(L)
131         k = max(4096, self.buffsize)
132         L = [self._rbuf]
133         self._rbuf = ""
134         while 1:
135             new = self.recv(k)
136             if not new: break
137             L.append(new)
138             k = min(k*2, 1024**2)
139         return "".join(L)
140
141     def readline(self, limit=-1):
142         data = ""
143         i = self._rbuf.find('\n')
144         while i < 0 and not (0 < limit <= len(self._rbuf)):
145             new = self.recv(self.buffsize)
146             if not new: break
147             i = new.find('\n')
148             if i >= 0: i = i + len(self._rbuf)
149             self._rbuf = self._rbuf + new
150         if i < 0: i = len(self._rbuf)
151         else: i = i+1
152         if 0 <= limit < len(self._rbuf): i = limit
153         data, self._rbuf = self._rbuf[:i], self._rbuf[i:]
154         return data
155
156     def readlines(self, sizehint = 0):
157         total = 0
158         list = []
159         while 1:
160             line = self.readline()
161             if not line: break
162             list.append(line)
163             total += len(line)
164             if sizehint and total >= sizehint:
165                 break
166         return list
167
168     def writelines(self, list):
169         self.send(''.join(list))
170
171     def write(self, data):
172         self.send(data)
173
174     def flush(self):
175         pass
176
177
178 class TimeoutError(Exception):
179     pass