add support for RSA_SHA256
[electrum-nvc.git] / lib / x509.py
index 0656820..80f9919 100644 (file)
@@ -17,7 +17,8 @@
 # along with this program. If not, see <http://www.gnu.org/licenses/>.
 
 
-from datetime import datetime, timedelta
+from datetime import datetime
+import sys
 
 try:
     import pyasn1
@@ -25,6 +26,11 @@ except ImportError:
     sys.exit("Error: pyasn1 does not seem to be installed. Try 'sudo pip install pyasn1'")
 
 try:
+    import pyasn1_modules
+except ImportError:
+    sys.exit("Error: pyasn1 does not seem to be installed. Try 'sudo pip install pyasn1-modules'")
+
+try:
     import tlslite
 except ImportError:
     sys.exit("Error: tlslite does not seem to be installed. Try 'sudo pip install tlslite'")
@@ -45,7 +51,7 @@ from pyasn1_modules.rfc2459 import id_ce_basicConstraints, BasicConstraints
 XMPP_ADDR = ObjectIdentifier('1.3.6.1.5.5.7.8.5')
 SRV_NAME = ObjectIdentifier('1.3.6.1.5.5.7.8.7')
 ALGO_RSA_SHA1 = ObjectIdentifier('1.2.840.113549.1.1.5')
-
+ALGO_RSA_SHA256 = ObjectIdentifier('1.2.840.113549.1.1.11')
 
 class CertificateError(Exception):
     pass
@@ -56,7 +62,10 @@ def decode_str(data):
 
 
 class X509(tlslite.X509):
-    """ Child class of tlslite.X509 that uses pyasn1 """
+    """Child class of tlslite.X509 that uses pyasn1 to parse cert
+    information. Note: pyasn1 is a lot slower than tlslite, so we
+    should try to do everything in tlslite.
+    """
 
     def slow_parse(self):
         self.cert = decoder.decode(str(self.bytes), asn1Spec=Certificate())[0]
@@ -170,9 +179,8 @@ class X509(tlslite.X509):
             return None
         return not_after - datetime.utcnow()
 
-    def check_name(self, expected):
+    def check_date(self):
         not_before, not_after = self.extract_dates()
-        cert_names = self.extract_names()
         now = datetime.utcnow()
         if not_before > now:
             raise CertificateError(
@@ -180,6 +188,9 @@ class X509(tlslite.X509):
         if not_after <= now:
             raise CertificateError(
                 'Certificate has expired.')
+
+    def check_name(self, expected):
+        cert_names = self.extract_names()
         if '.' in expected:
             expected_wild = expected[expected.index('.'):]
         else: