#!/usr/bin/python3

# tlsa - A tool to create DANE/TLSA records. (called 'swede' before)
#
# This tool is loosly based on the 'dane' program in the sshfp package by Paul
# Wouters and Christopher Olah
#
# Copyright 2012 Pieter Lexis <pieter.lexis@os3.nl>
# Copyright 2014 Paul Wouters <pwouters@redhat.com>
# Copyright 2015-2019 Dirk Stoecker <hashslinger@dstoecker.de>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

VERSION="3.1"

import sys
import os
import socket
import unbound
import subprocess
import re
from M2Crypto import X509, SSL, BIO, m2
from binascii import b2a_hex
from hashlib import sha256, sha512
from ipaddress import IPv4Address, IPv6Address

ROOTKEY="none"
cauldron = ( "/var/lib/unbound/root.anchor", "/var/lib/unbound/root.key", "/etc/unbound/root.key" )
for root in cauldron:
	if os.path.isfile(root):
		ROOTKEY=root
		break

def genTLSA(hostname, protocol, port, certificate, output='rfc', usage=1, selector=0, mtype=1):
	"""This function generates a TLSARecord object using the data passed in the parameters,
	it then validates the record and returns the RR as a string.
	"""
	# check if valid vars were passed
	if hostname[-1] != '.':
		hostname += '.'

	# Create the record without a certificate
	if port == '*':
		record = TLSARecord(name='%s._%s.%s'%(port,protocol,hostname), usage=usage, selector=selector, mtype=mtype, cert ='')
	else:
		record = TLSARecord(name='_%s._%s.%s'%(port,protocol,hostname), usage=usage, selector=selector, mtype=mtype, cert ='')
	# Check if the record is valid
	if record.isValid:
		if record.selector == 0:
			# Hash the Full certificate
			record.cert = getHash(certificate, record.mtype)
		else:
			# Hash only the SubjectPublicKeyInfo
			record.cert = getHash(certificate.get_pubkey(), record.mtype)

	record.isValid(raiseException=True)

	if output == 'generic':
		return record.getRecord(generic=True)
	return record.getRecord()

def genRecords(hostname, address, protocol, port, chain, output='rfc', usage=1, selector=0, mtype=1):
	cert = None
	chained = checkChainLink(chain)[0]
	if not chained:
		print("WARN: Certificates don't chain")
	certsleft = len(chain)
	for chaincert in chain:
		certsleft -= 1
		if usage == 1 or usage == 3:
			# The first cert is the end-entity cert
			print('Got a certificate for %s with Subject: %s' % (address, chaincert.get_subject()))
			cert = chaincert
			break
		else:
			if (usage == 0 and chaincert.check_ca()) or usage == 2:
				if certsleft: # don't ask for the last one
					sys.stdout.write('Got a certificate with the following Subject:\n\t%s\nUse this as certificate to match? [y/N] ' % chaincert.get_subject())
					input_ok = False
					while not input_ok:
						user_input = input()
						if user_input in ['','n','N']:
							input_ok=True
						elif user_input in ['y', 'Y']:
							input_ok = True
							cert = chaincert
						else:
							sys.stdout.write('Please answer Y or N')
				else:
					print('Using certificate with the following Subject:\n\t%s\n' % chaincert.get_subject())
					cert = chaincert
			if cert:
				break

	if cert: # Print the requested records based on the retrieved certificates
		if output == 'both':
			print(genTLSA(hostname, protocol, port, cert, 'draft', usage, selector, mtype))
			print(genTLSA(hostname, protocol, port, cert, 'rfc', usage, selector, mtype))
		else:
			print(genTLSA(hostname, protocol, port, cert, output, usage, selector, mtype))

	# Clear the cert from memory (to stop M2Crypto from segfaulting)
	cert=None

def getA(hostname, secure=True):
	"""Gets a list of A records for hostname, returns a list of ARecords"""
	records = []
	try:
		records = getRecords(hostname, rrtype='A', secure=secure)
	except InsecureLookupException as e:
		print(str(e))
	except DNSLookupError as e:
		print('Unable to resolve %s: %s' % (hostname, str(e)))
	ret = []
	for record in records:
		ret.append(ARecord(hostname, str(IPv4Address(int(b2a_hex(record),16)))))
	return ret

def getAAAA(hostname, secure=True):
	"""Gets a list of AAAA records for hostname, returns a list of AAAARecords"""
	records = []
	try:
		records = getRecords(hostname, rrtype='AAAA', secure=secure)
	except InsecureLookupException as e:
		print(str(e))
	except DNSLookupError as e:
		print('Unable to resolve %s: %s' % (hostname, str(e)))
	ret = []
	for record in records:
		ret.append(AAAARecord(hostname, str(IPv6Address(int(b2a_hex(record),16)))))
	return ret

def getVerificationErrorReason(num):
	"""This function returns the name of the X509 Error based on int(num)
	"""
	# These were taken from the M2Crypto.m2 code
	return {
50: "X509_V_ERR_APPLICATION_VERIFICATION",
22: "X509_V_ERR_CERT_CHAIN_TOO_LONG",
10: "X509_V_ERR_CERT_HAS_EXPIRED",
9:  "X509_V_ERR_CERT_NOT_YET_VALID",
28: "X509_V_ERR_CERT_REJECTED",
23: "X509_V_ERR_CERT_REVOKED",
7:  "X509_V_ERR_CERT_SIGNATURE_FAILURE",
27: "X509_V_ERR_CERT_UNTRUSTED",
12: "X509_V_ERR_CRL_HAS_EXPIRED",
11: "X509_V_ERR_CRL_NOT_YET_VALID",
8:  "X509_V_ERR_CRL_SIGNATURE_FAILURE",
18: "X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT",
14: "X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD",
13: "X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD",
15: "X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD",
16: "X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD",
24: "X509_V_ERR_INVALID_CA",
26: "X509_V_ERR_INVALID_PURPOSE",
17: "X509_V_ERR_OUT_OF_MEM",
25: "X509_V_ERR_PATH_LENGTH_EXCEEDED",
19: "X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN",
6:  "X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY",
4:  "X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE",
5:  "X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE",
3:  "X509_V_ERR_UNABLE_TO_GET_CRL",
2:  "X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT",
20: "X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY",
21: "X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE",
0:  "X509_V_OK"}[int(num)]

def getRecords(hostname, rrtype='A', secure=True, noresolver=False):
	"""Do a lookup of a name and a rrtype, returns a list of binary coded strings. Only queries for rr_class IN."""
	global resolvconf
	ctx = unbound.ub_ctx()
	if os.path.isfile(ROOTKEY):
		# file format guessing - unbound should deal with this :(
		try:
			ctx.add_ta_file(ROOTKEY)
		except:
			unbound.ub_ctx_trustedkeys(ctx,ROOTKEY)
	# Use the local cache and emulate /etc/hosts
	ctx.hosts("/etc/hosts")
	if resolvconf and not noresolver:
		ctx.resolvconf(resolvconf)

	if rrtype == 52:
		rrtypestr = "TLSA"
	else:
		rrtypestr = rrtype
	if type(rrtype) == str:
		if 'RR_TYPE_' + rrtype in dir(unbound):
			rrtype = getattr(unbound, 'RR_TYPE_' + rrtype)
		else:
			raise Exception('Error: unknown RR TYPE: %s.' % rrtype)
	elif type(rrtype) != int:
		raise Exception('Error: rrtype in wrong format, neither int nor str.')

	status, result = ctx.resolve(hostname, rrtype=rrtype)
	if status == 0 and result.havedata:
		if not result.secure:
			if secure:
				if not noresolver and resolvconf:
					retval = getRecords(hostname, rrtype, secure, True)
					print('Warning: initial query using resolver config file was not secure (try option --resolvconf="").', file=sys.stderr)
					resolvconf = None
					return retval
				# The data is insecure and a secure lookup was requested
				raise InsecureLookupException('Error: Answer was not DNSSEC-secure')
			else:
				print('Warning: query data is not secure.', file=sys.stderr)
		# If we are here the data was either secure or insecure data is accepted
		return result.data.raw
	else:
		raise DNSLookupError('Unsuccessful DNS lookup or no data returned for rrtype %s (%s).' % (rrtypestr, rrtype))

# identical to Connection.connect() except for the last parameter
def sslStartTLSConnect(connection, addr, starttls=None):
        connection.socket.connect(addr)
        connection.addr = addr
        if starttls:
            # primitive method, no error checks yet
            if starttls == "smtp":
                data = connection.socket.recv(500)
                connection.socket.send("EHLO M2Crypto\r\n".encode('ascii'))
                data = connection.socket.recv(500)
                connection.socket.send("STARTTLS\r\n".encode('ascii'))
                data = connection.socket.recv(500)
            elif starttls == "imap":
                data = connection.socket.recv(500)
                connection.socket.send(". STARTTLS\r\n".encode('ascii'))
                data = connection.socket.recv(500)
            elif starttls == "ftp":
                data = connection.socket.recv(500)
                connection.socket.send("AUTH TLS\r\n".encode('ascii'))
                data = connection.socket.recv(500)
            elif starttls == "pop3":
                data = connection.socket.recv(500)
                connection.socket.send("STLS\r\n".encode('ascii'))
                data = connection.socket.recv(500)
        connection.setup_ssl()
        connection.set_connect_state()
        ret = connection.connect_ssl()
        check = getattr(connection, 'postConnectionCheck', connection.clientPostConnectionCheck)
        if check is not None:
            if not check(connection.get_peer_cert(), connection.addr[0]):
                raise Checker.SSLVerificationError('post connection check failed')
        return ret

def getHash(certificate, mtype):
	"""Hashes the certificate based on the mtype.
	The certificate should be an M2Crypto.X509.X509 object (or the result of the get_pubkey() function on said object)
	"""
	certificate = certificate.as_der()
	if mtype == 0:
		return certificate.hex()
	elif mtype == 1:
		return sha256(certificate).hexdigest()
	elif mtype == 2:
		return sha512(certificate).hexdigest()
	else:
		raise Exception('mtype should be 0,1,2')

def getTLSA(hostname, port=443, protocol='tcp', secure=True):
	"""
	This function tries to do a secure lookup of the TLSA record.
	At the moment it requests the TYPE52 record and parses it into a 'valid' TLSA record
	It returns a list of TLSARecord objects
	"""
	if hostname[-1] != '.':
		hostname += '.'

	if not protocol.lower() in ['tcp', 'udp', 'sctp']:
		raise Exception('Error: unknown protocol: %s. Should be one of tcp, udp or sctp' % protocol)
	try:
		if port == '*':
			records = getRecords('*._%s.%s' % (protocol.lower(), hostname), rrtype=52, secure=secure)
		else:
			records = getRecords('_%s._%s.%s' % (port, protocol.lower(), hostname), rrtype=52, secure=secure)
	except InsecureLookupException as e:
		print(str(e))
		sys.exit(1)
	except DNSLookupError as e:
		print('Unable to resolve %s: %s' % (hostname, str(e)))
		sys.exit(1)
	ret = []
	for record in records:
		hexdata = b2a_hex(record)
		if port == '*':
			ret.append(TLSARecord('*._%s.%s' % (protocol.lower(), hostname), int(hexdata[0:2],16), int(hexdata[2:4],16), int(hexdata[4:6],16), hexdata[6:].decode('ascii')))
		else:
			ret.append(TLSARecord('_%s._%s.%s' % (port, protocol.lower(), hostname), int(hexdata[0:2],16), int(hexdata[2:4],16), int(hexdata[4:6],16), hexdata[6:].decode('ascii')))
	return ret

def verifyCertMatch(record, cert):
	"""
	Verify the certificate with the record.
	record should be a TLSARecord and cert should be a M2Crypto.X509.X509
	"""
	if not isinstance(cert, X509.X509):
		return
	if not isinstance(record, TLSARecord):
		return

	if record.selector == 1:
		certhash = getHash(cert.get_pubkey(), record.mtype)
	else:
		certhash = getHash(cert, record.mtype)

	if not certhash:
		return

	if certhash == record.cert:
		return True
	else:
		return False

def verifyCertNameWithHostName(cert, hostname, with_msg=False):
	"""Verify the name on the certificate with a hostname, we need this because we get the cert based on IP address and thusly cannot rely on M2Crypto to verify this"""
	if not isinstance(cert, X509.X509):
		return
	if not isinstance(hostname, str):
		return

	if hostname[-1] == '.':
		hostname = hostname[0:-1]
	hostname = hostname.lower()

	certnames = []
	for entry in str(cert.get_subject()).lower().split("/"):
		if entry.startswith("cn="):
			certnames.append(entry[3:])
	try:
		for value in re.split(", ?", cert.get_ext('subjectAltName').get_value().lower()):
			if value.startswith("dns:"):
				certnames.append(value[4:])
	except:
		pass
	if hostname in certnames:
		return True
	for certname in certnames:
		if certname[0] == '*' and hostname.endswith(certname[1:]):
			return True

	if with_msg:
		print('WARNING: Name on the certificate (Subject: %s, SubjectAltName: %s) doesn\'t match requested hostname (%s).' % (str(cert.get_subject()), cert.get_ext('subjectAltName').get_value(), hostname))
	return False

def checkChainLink(chain, record=None):
	previous_issuer = None
	chained = True
	matched = None
	for cert in chain:
		if previous_issuer:
			if not str(previous_issuer) == str(cert.get_subject()): # The chain cannot be valid
				chained = False
				break
		previous_issuer = cert.get_issuer()
		if record and verifyCertMatch(record, cert):
			matched = cert
	return chained, matched

def getLocalChain(filename, debug):
	"""Returns list of M2Crypto.X509.X509 objects and verification result"""
	chain = []
	bio = BIO.openfile(filename)
	Err = None
	while True:
		try:
			cptr = m2.x509_read_pem(bio._ptr())
		except Exception as e:
			Err = e
			break;
		if not cptr:
			break
		chain.append(X509.X509(cptr, _pyfree=1))
	if not chain:
		if Err:
			raise
		else:
			raise Exception("Could not load %s" % filename)
	# FIXME - is this possible using the library (without a call to openssl tool)?
	cmd = ["openssl", "verify", filename]
	process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
	(stdout, stderr) = process.communicate()
	if debug:
		print(stderr, file=sys.stderr)
	if stdout.endswith(b"OK\n"):
		verify_result = 0
	else:
		verify_result = "%s%s" % (stdout, stderr)
	return chain, verify_result

def usageText(usage):
	names = ["PKIX-CA", "PKIX-EE", "DANE-TA", "DANE-EE"]
	return "Usage %d [%s]" % (usage, names[usage])

def verifyChain(chain, verify_result, pre_exit, local, source):
	if local:
		text_cert = "Local certificate"
		text_chain = "certificate in the local certificate chain"
	else:
		text_cert = "Certificate offered by the server"
		text_chain = "certificate in the certificate chain offered by the server"

	if not verifyCertNameWithHostName(cert=chain[0], hostname=str(args.host), with_msg=True):
		# The name on the cert doesn't match the hostname... we don't verify the TLSA record
		print('Caution: %s does not match hostname (%s)' % (text_cert, source))
	ut = usageText(record.usage)
	if record.usage == 1: # End-host cert
		chained, matched = checkChainLink((chain[0],), record)
		if not chained:
			print("WARN: Certificates don't chain")
		if matched:
			if verify_result == 0: # The cert chains to a valid CA cert according to the system-certificates
				print('SUCCESS (%s): %s matches the one mentioned in the TLSA record and chains to a valid CA certificate (%s)' % (ut, text_cert, source))
			else:
				if local:
					reason = verify_result
				else:
					reason = getVerificationErrorReason(verify_result)
				print('FAIL (%s): %s matches the one mentioned in the TLSA record but the following error was raised during PKIX validation (%s): %s' % (ut, text_cert, source, reason))
				if pre_exit == 0: pre_exit = 2
			if args.debug: print('The matched certificate has Subject: %s' % cert.get_subject())
		else:
			print('FAIL: %s does not match the TLSA record (%s)' % (text_cert, source))
			if pre_exit == 0: pre_exit = 2

	elif record.usage == 0: # CA constraint
		# FIXME: Does neither handle chain data in full cert TLSA records (type 0) correctly
		# nor handles short certification chains
		# In these case the chain should be completed with missing records either
		# from TLSA or local certificate store

		# Remove the first (= End-Entity cert) from the chain
		chained, matched = checkChainLink(chain[1:], record)
		if not chained:
			print("WARN: Certificates don't chain")
		if matched:
			if cert.check_ca():
				if verify_result == 0:
					print('SUCCESS (%s): A %s matches the one mentioned in the TLSA record and is a CA certificate (%s)' % (ut, text_chain, source))
				else:
					if local:
						reason = verify_result
					else:
						reason = getVerificationErrorReason(verify_result)
					print('FAIL (%s): A %s matches the one mentioned in the TLSA record and is a CA certificate, but the following error was raised during PKIX validation (%s): %s' % (ut, text_chain, source, reason))
					if pre_exit == 0: pre_exit = 2
			else:
				print('FAIL (%s): A %s matches the one mentioned in the TLSA record but is not a CA certificate (%s)' % (ut, text_chain, source))
				if pre_exit == 0: pre_exit = 2
			if args.debug: print('The matched certificate has Subject: %s' % cert.get_subject())
		else:
			print('FAIL (%s): No %s matches the TLSA record (%s)' % (ut, text_chain, source))
			if pre_exit == 0: pre_exit = 2

	elif record.usage == 2: # Usage 2, use the cert in the record as trust anchor
		# FIXME: Does not handle chain data in full cert TLSA records (type 0) correctly
		# In this case the chain should be completed with missing records if chain is
		# not complete
		chained, matched = checkChainLink(chain, record)
		if matched:
			if not chained:
				print("WARN: Certificates don't chain")
			print('SUCCESS (%s): A %s (including the end-entity certificate) matches the TLSA record (%s)' % (ut, text_chain, source))
			if args.debug: print('The matched certificate has Subject: %s' % matched.get_subject())
		else:
			if not chained:
				print("FAIL: Certificates don't chain")
			print('FAIL (%s): No %s (including the end-entity certificate) matches the TLSA record (%s)' % (ut, text_chain, source))
			if pre_exit == 0: pre_exit = 2

	elif record.usage == 3: # EE cert MUST match
		chained, matched = checkChainLink((chain[0],), record)
		if not chained:
			print("WARN: Certificates don't chain")
		if matched:
			print('SUCCESS (%s): %s matches the TLSA record (%s)' % (ut, text_cert, source))
			if args.debug: print('The matched certificate has Subject: %s' % chain[0].get_subject())
		else:
			print('FAIL (%s): %s does not match the TLSA record (%s)' % (ut, text_cert, source))
			if pre_exit == 0: pre_exit = 2
	return pre_exit

class TLSARecord:
	"""When instanciated, this class contains all the fields of a TLSA record.
	"""
	def __init__(self, name, usage, selector, mtype, cert):
		"""name is the name of the RR in the format: /^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([a-z0-9]*\.){2,}$/
		usage, selector and mtype should be an integer
		cert should be a hexidecimal string representing the certificate to be matched field
		"""
		try:
			self.rrtype = 52    # TLSA per https://www.iana.org/assignments/dns-parameters
			self.rrclass = 1    # IN
			self.name = str(name)
			self.usage = int(usage)
			self.selector = int(selector)
			self.mtype = int(mtype)
			self.cert = cert
		except:
			raise Exception('Invalid value passed, unable to create a TLSARecord')

	def getRecord(self, generic=False):
		"""Returns the RR string of this TLSARecord, either in rfc (default) or generic format"""
		if generic:
			return '%s IN TYPE52 \# %s %s%s%s%s' % (self.name, (len(self.cert)/2)+3 , self._toHex(self.usage), self._toHex(self.selector), self._toHex(self.mtype), self.cert)
		return '%s IN TLSA %s %s %s %s' % (self.name, self.usage, self.selector, self.mtype, self.cert)

	def _toHex(self, val):
		"""Helper function to create hex strings from integers"""
		return "%0.2x" % val

	def isValid(self, raiseException=False):
		"""Check whether all fields in the TLSA record are conforming to the spec and check if the port, protocol and name are good"""
		err =[]
		try:
			if not 1 <= int(self.getPort()) <= 65535:
				err.append('Port %s not within correct range (1 <= port <= 65535)' % self.getPort())
		except:
			if self.getPort() != '*':
				err.append('Port %s not a number' % self.getPort())
		if not self.usage in [0,1,2,3]:
			err.append('Usage: invalid (%s is not one of 0 [PKIX-CA], 1 [PKIX-EE], 2 [DANE-TA] or 3 [DANE-EE])' % self.usage)
		if not self.selector in [0,1]:
			err.append('Selector: invalid (%s is not one of 0 [Cert] or 1 [SPKI])' % self.selector)
		if not self.mtype in [0,1,2]:
			err.append('Matching Type: invalid (%s is not one of 0 [FULL], 1 [SHA-256] or 2 [SHA-512])' % self.mtype)
		if not self.isNameValid():
			err.append('Name (%s) is not in the correct format: _portnumber._transportprotocol.hostname.dom.' % self.name)
		# A certificate length of 0 is accepted
		if self.mtype in [1,2] and len(self.cert) != 0:
			if not len(self.cert) == {1:64,2:128}[self.mtype]:
				err.append('Certificate for Association: invalid (Hash length does not match hash-type in Matching Type(%s))' % {1:'SHA-256',2:'SHA-512'}[self.mtype])
		if len(err) != 0:
			if not raiseException:
				return False
			else:
				msg = 'The TLSA record is invalid.'
				for error in err:
					msg += '\n\t%s' % error
				raise RecordValidityException(msg)
		else:
			return True

	def isNameValid(self):
		"""Check if the name if in the correct format"""
		if not re.match('^(_\d{1,5}|\*)\._(tcp|udp|sctp)\.([-a-z0-9]*\.){2,}$', self.name):
			return False
		return True

	def getProtocol(self):
		"""Returns the protocol based on the name"""
		return re.split('\.', self.name)[1][1:]

	def getPort(self):
		"""Returns the port based on the name"""
		if re.split('\.', self.name)[0][0] == '*':
			return '*'
		else:
			return re.split('\.', self.name)[0][1:]

class ARecord:
	"""An object representing an A Record (IPv4 address)"""
	def __init__(self, hostname, address):
		self.rrtype = 1
		self.hostname = hostname
		self.address = address

	def __str__(self):
		return self.address

	def isValid(self):
		try:
			IPv4Address(self.address)
			return True
		except:
			return False

class AAAARecord:
	"""An object representing an AAAA Record (IPv6 address)"""
	def __init__(self, hostname, address):
		self.rrtype = 28
		self.hostname = hostname
		self.address = address

	def __str__(self):
		return self.address

	def isValid(self):
		try:
			IPv6Address(self.address)
			return True
		except:
			return False

# Exceptions
class RecordValidityException(Exception):
	pass

class InsecureLookupException(Exception):
	pass

class DNSLookupError(Exception):
	pass

if __name__ == '__main__':
	import argparse
	# create the parser
	parser = argparse.ArgumentParser(description='Create and verify TLSA records.', epilog='For bugs. see tlsa@nohats.ca')

	# Caveat: For TLSA validation, this program chases through the certificate chain offered by the server, not its local certificates.'
	parser.add_argument('--verify','-v', action='store_true', help='Verify a TLSA record, exit 0 when all TLSA records are matched, exit 2 when a record does not match the received certificate, exit 1 on error.')

	parser.add_argument('--create','-c', action='store_true', help='Create a TLSA record')
	parser.add_argument('--version', action='version', version='tlsa version: %s'%VERSION, help='show version and exit')

	parser.add_argument('-4', '--ipv4', dest='ipv4', action='store_true',help='use ipv4 networking only')
	parser.add_argument('-6', '--ipv6', dest='ipv6', action='store_true',help='use ipv6 networking only')
	parser.add_argument('--insecure', action='store_true', default=False, help='Allow use of non-dnssec secured answers')
	parser.add_argument('--resolvconf', metavar='/PATH/TO/RESOLV.CONF', action='store', default='/etc/resolv.conf', help='Use a recursive resolver listed in a resolv.conf file (default: /etc/resolv.conf)')
	parser.add_argument('host', metavar="hostname")

	parser.add_argument('--port', '-p', action='store', default='443', help='The port, or \'*\' where running TLS is located (default: %(default)s).')
	parser.add_argument('--starttls', '-t', action='store', choices=['no', 'smtp', 'imap', 'pop3', 'ftp'], help='Protocol needs special start procedure, requires protocol name.')
	parser.add_argument('--protocol', action='store', choices=['tcp','udp','sctp'], default='tcp', help='The protocol the TLS service is using (default: %(default)s).')
	parser.add_argument('--only-rr', action='store_true', help='Only verify that the TLSA resource record is correct (do not check certificate)')
	parser.add_argument('--rootkey', metavar='/PATH/TO/ROOT.KEY', action='store', help='Specify file location containing the DNSSEC root key')

	if os.path.isdir("/etc/pki/tls/certs/"):
		cadir = "/etc/pki/tls/certs/"
	elif os.path.isdir("/etc/ssl/certs/"):
		cadir = "/etc/ssl/certs/"
	else:
		cadir = "."
		print ("warning: no system wide CAdir found, using current directory")

	parser.add_argument('--ca-cert', metavar='/PATH/TO/CERTSTORE', action='store', default = cadir, help='Path to a CA certificate or a directory containing the certificates (default: %(default)s)')
	parser.add_argument('--debug', '-d', action='store_true', help='Print details plus the result of the validation')
	parser.add_argument('--quiet', '-q', action='store_true', help='Ignored for backwards compatibility')

	parser.add_argument('--certificate', help='The certificate used for the host. If certificate is empty, the certificate will be downloaded from the server')
	parser.add_argument('--output', '-o', action='store', default='rfc', choices=['generic','rfc','both'], help='The type of output. Generic (RFC 3597, TYPE52), RFC (TLSA) or both (default: %(default)s).')

	# Usage of the certificate
	parser.add_argument('--usage', '-u', action='store', type=int, default=3, choices=[0,1,2,3], help='The Usage of the Certificate for Association. \'0\' for CA [PKIX-CA], \'1\' for End Entity [PKIX-EE], \'2\' for trust-anchor [DANE-TA], \'3\' for ONLY End-Entity match [DANE-EE] (default: %(default)s).')
	parser.add_argument('--selector', '-s', action='store', type=int, default=0, choices=[0,1], help='The Selector for the Certificate for Association. \'0\' for Full Certificate [Cert], \'1\' for SubjectPublicKeyInfo [SPKI] (default: %(default)s).')
	parser.add_argument('--mtype', '-m', action='store', type=int, default=1, choices=[0,1,2], help='The Matching Type of the Certificate for Association. \'0\' for Exact match, \'1\' for SHA-256 hash, \'2\' for SHA-512 (default: %(default)s).')

	args = parser.parse_args()

	if args.rootkey:
		if os.path.isfile(args.rootkey):
			ROOTKEY = args.rootkey
		else:
			print("ignored specified non-existing rootkey file %s"%args.rootkey)

	# check whether it's ASCII or needs to be converted into Punnycode
	try:
		args.host.encode('ascii')
	except UnicodeEncodeError:
		args.host = str(args.host.encode('idna'), 'utf-8')

	if args.host[-1] != '.':
		args.host += '.'
	snihost = args.host[0:-1]

	if not args.starttls:
		if args.port == "25" or args.port == "587":
			args.starttls = "smtp"
		elif args.port == "143":
			args.starttls = "imap"
		elif args.port == "110":
			args.starttls = "pop3"
		elif args.port == "21":
			args.starttls = "ftp"
	elif args.starttls == "no":
		args.starttls = None

	global resolvconf
	if args.resolvconf:
		if os.path.isfile(args.resolvconf):
			resolvconf = args.resolvconf
		else:
			print('%s is not a file. Unable to use it as resolv.conf' % args.resolvconf, file=sys.stdout)
			sys.exit(1)
	else:
		resolvconf = None

	# not operations are fun!
	secure = not args.insecure

	if args.verify:
		records = getTLSA(args.host, args.port, args.protocol, secure)
		if len(records) == 0:
			sys.exit(1)

		for record in records:
			pre_exit = 0
			# First, check if the first three fields have correct values.
			if args.debug:
				print('Received the following record for name %s:' % record.name)
				print('\tUsage:\t\t\t\t%d (%s)' % (record.usage, {0:'CA Constraint [PKIX-CA]', 1:'End-Entity Constraint + chain to CA [PKIX-EE]', 2:'Trust Anchor [DANE-TA]', 3:'End-Entity [DANE-EE]'}.get(record.usage, 'INVALID')))
				print('\tSelector:\t\t\t%d (%s)' % (record.selector, {0:'Certificate [Cert]', 1:'SubjectPublicKeyInfo [SPKI]'}.get(record.selector, 'INVALID')))
				print('\tMatching Type:\t\t\t%d (%s)' % (record.mtype, {0:'Full Certificate', 1:'SHA-256', 2:'SHA-512'}.get(record.mtype, 'INVALID')))
				print('\tCertificate for Association:\t%s' % record.cert)

			try:
				record.isValid(raiseException=True)
			except RecordValidityException as e:
				print('Error: %s' % str(e), file=sys.stderr)
				continue
			else:
				if args.debug:
					print('This record is valid (well-formed).')

			if args.only_rr:
				# Go to the next record
				continue

			# When we are here, The user also wants to verify the certificates with the record
			if args.protocol != 'tcp':
				print('Only SSL over TCP is supported (sorry)', file=sys.stderr)
				sys.exit(0)

			if args.debug:
				print('Attempting to verify the record with the TLS service...')

			if not args.ipv4 and not args.ipv6:
				addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
			elif args.ipv4:
				addresses = getA(args.host, secure=secure)
			else:
				addresses = getAAAA(args.host, secure=secure)

			if args.certificate:
				try:
					chain, verify_result = getLocalChain(args.certificate, args.debug)
					pre_exit = verifyChain(chain, verify_result, pre_exit, True, args.certificate)
				except Exception as e:
					print('Could not verify local certificate: %s.' % e)
					sys.exit(0)
			for address in addresses:
				if args.debug:
					print('Got the following IP: %s' % str(address))
				# We do the certificate handling here, as M2Crypto keeps segfaulting when we do it in a method
				ctx = SSL.Context()
				if os.path.isfile(args.ca_cert):
					if ctx.load_verify_locations(cafile=args.ca_cert) != 1: raise Exception('No CA cert')
				elif os.path.exists(args.ca_cert):
					if ctx.load_verify_locations(capath=args.ca_cert) != 1: raise Exception('No CA certs')
				else:
					print('%s is neither a file nor a directory, unable to continue' % args.ca_cert, file=sys.stderr)
					sys.exit(1)
				# Don't error when the verification fails in the SSL handshake
				ctx.set_verify(SSL.verify_none, depth=9)
				if isinstance(address, AAAARecord):
					sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
					sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
				else:
					sock = None
				connection = SSL.Connection(ctx, sock=sock)
				try:
					connection.set_tlsext_host_name(snihost)
					if(args.debug):
						print('Did set servername %s' % snihost)
				except AttributeError:
					print('Could not set SNI (old M2Crypto version?)')
				except:
					print('Could not set SNI')
				try:
					if args.starttls:
						sslStartTLSConnect(connection, (str(address), int(args.port)), args.starttls)
						#connection.connect((str(address), int(args.port)), args.starttls)
					else:
						connection.connect((str(address), int(args.port)))
				#except TypeError:
				#	print 'Cannot connect to %s (old M2Crypto version not supporting start script?)' % address
				#	continue
				except SSL.Checker.WrongHost as e:
					# The name on the remote cert doesn't match the hostname because we connect on IP, not hostname (as we want secure lookup)
					pass
				except socket.error as e:
					print('Cannot connect to %s: %s' % (address, str(e)))
					continue
				chain = connection.get_peer_cert_chain()
				verify_result = connection.get_verify_result()

				# Good, now let's verify
				pre_exit = verifyChain(chain, verify_result, pre_exit, False, address)
				# Cleanup, just in case
				connection.clear()
				connection.close()
				ctx.close()

			# END for address in addresses
		# END for record in records
		sys.exit(pre_exit)
	# END if args.verify

	else: # we want to create
		cert = None
		if not args.certificate:
			if args.protocol != 'tcp':
				print('Only SSL over TCP is supported (sorry)', file=sys.stderr)
				sys.exit(1)
			if args.debug:
				print('No certificate specified on the commandline, attempting to retrieve it from the server %s' % (args.host))
			connection_port = args.port
			if args.port == '*':
				sys.stdout.write('The port specified on the commandline is *, please specify the port of the TLS service on %s (443): ' % args.host)
				input_ok = False
				while not input_ok:
					user_input = input()
					if user_input == '':
						connection_port = 443
						break
					try:
						if 1 <= int(user_input) <= 65535:
							connection_port = user_input
							input_ok = True
					except:
						sys.stdout.write('Port %s not numerical or within correct range (1 <= port <= 65535), try again (hit enter for default 443): ' % user_input)
			# Get the address records for the host
			try:
				#addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)
				if not args.ipv4 and not args.ipv6:
					addresses = getA(args.host, secure=secure) + getAAAA(args.host, secure=secure)

				elif args.ipv4:
					addresses = getA(args.host, secure=secure)
				else:
					addresses = getAAAA(args.host, secure=secure)

			except InsecureLookupException as e:
				print(str(e), file=sys.stderr)
				sys.exit(1)

			for address in addresses:
				if args.debug:
					print('Attempting to get certificate from %s' % str(address))
				# We do the certificate handling here, as M2Crypto keeps segfaulting when try to do stuff with the cert if we don't
				ctx = SSL.Context()
				ctx.set_verify(SSL.verify_none, depth=9)
				if isinstance(address, AAAARecord):
					sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
					sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
				else:
					sock = None
				connection = SSL.Connection(ctx, sock=sock)
				try:
					connection.set_tlsext_host_name(snihost)
					if(args.debug):
						print('Did set servername %s' % snihost)
				except AttributeError:
					print('Could not set SNI (old M2Crypto version?)')
				except:
					print('Could not set SNI')
				try:
					if args.starttls:
						sslStartTLSConnect(connection, (str(address), int(connection_port)), args.starttls)
						#connection.connect((str(address), int(connection_port)), args.starttls)
					else:
						connection.connect((str(address), int(connection_port)))
				#except TypeError:
				#	print 'Cannot connect to %s (old M2Crypto version not supporting start script?)' % address
				#	continue
				except SSL.Checker.WrongHost:
					pass
				except socket.error as e:
					print('Cannot connect to %s: %s' % (address, str(e)))
					continue

				chain = connection.get_peer_cert_chain()
				genRecords(args.host, address, args.protocol, args.port, chain, args.output, args.usage, args.selector, args.mtype)
				# Cleanup the connection and context
				connection.clear()
				connection.close()
				ctx.close()

		else: # Pass the path to the certificate to the genTLSA function
			try:
				chain, verify_result = getLocalChain(args.certificate, args.debug)
				if args.usage == 0 or args.usage == 1 and not verify_result:
					print('Following error was raised during PKIX validation: %s' % verify_result)
				genRecords(args.host, args.host, args.protocol, args.port, chain, args.output, args.usage, args.selector, args.mtype)
			except Exception as e:
				print('Could not verify local certificate: %s.' % e)
