#!/usr/bin/env python

# This file is part of Window-Switch.
# Copyright (c) 2009-2013 Antoine Martin <antoine@nagafix.co.uk>
# Window-Switch is released under the terms of the GNU GPL v3

from winswitch.util.simple_logger import Logger, msig
logger = Logger("crypt_util", log_colour=Logger.WHITE)
debug_import = logger.get_debug_import()

debug_import("time, os, binascii, hashlib")
import time
import os
import binascii
import hashlib

debug_import("Crypto")
from Crypto.PublicKey import RSA
from Crypto.Util.number import long_to_bytes

debug_import("common")
from winswitch.util.common import alphanumfile, visible_command
from winswitch.globals import OSX, WIN32

KEY_SIZE = 4096
if OSX or WIN32:
	#osx is way too slow with large keys
	#and win32 can have problems with lack of entropy..
	KEY_SIZE = 2048
PART_SEPARATOR = ":"
SALT_SEPARATOR = "/"

HEX_SIGNATURES = False		#incompatible with versions<=0.12.4


def make_key_fingerprint(modulus, public_exponent):
	sha1 = hashlib.sha1()
	sha1.update("%s" % modulus)
	sha1.update("%s" % public_exponent)
	key_fingerprint = sha1.hexdigest()
	return	key_fingerprint

def generate_key():
	logger.slog("key size is %s bits, this may take a few seconds..." % KEY_SIZE)
	start = time.time()
	randfunc = os.urandom
	rsa_key = RSA.generate(KEY_SIZE, randfunc)
	end = time.time()
	delta = int((end-start)*10)/10.0
	logger.slog("finished, this took %s seconds" % delta)
	return rsa_key

def recreate_key(crypto_modulus, crypto_public_exponent, crypto_private_exponent=0):
	try:
		if crypto_private_exponent==0:
			return RSA.construct((long(crypto_modulus), long(crypto_public_exponent)))
		else:
			return RSA.construct((long(crypto_modulus), long(crypto_public_exponent), long(crypto_private_exponent)))
	except Exception, e:
		logger.serror("failed: %s" % e, crypto_modulus, crypto_public_exponent, crypto_private_exponent)
		return	None

def key_to_string(key):
	if key.n==0:
		return	""
	return	"%s" % key.n


def encrypt(key, message):
	return key.encrypt(message, '')

def	encrypt_hex(key, message):
	if not message or len(message)==0:
		return	""
	enc_msgs = encrypt(key, message)
	return	hex_str(enc_msgs)

def encrypt_salted_hex(key, salt, message):
	new_message = "%s%s%s" % (salt, SALT_SEPARATOR, message)
	enc_hex = encrypt_hex(key, new_message)
	return enc_hex

def sign(key, message):
	return key.sign(message, '')

def	sign_long(key, message):
	sigs = sign(key, message)
	if HEX_SIGNATURES:
		sig = "0x"+binascii.hexlify(long_to_bytes(sigs[0]))
	else:
		sig = str(sigs[0])
	logger.sdebug("=%s" % visible_command(sig), type(key), message)
	return	sig

def verify(key, message, sig):
	return key.verify(message, sig)

def	verify_long(key, message, sig):
	if sig.startswith("0x"):
		raw_sig = [long(sig, 16)]
	else:
		raw_sig = [long(sig)]
	correct = verify(key, message, raw_sig)
	return	correct

def hex_str(enc_msgs):
	"""
	Takes a list of messages and turns it into a PART_SEPARATOR separated hex string.
	"""
	enc_msg = None
	for enc in enc_msgs:
		if enc_msg:
			enc_msg += PART_SEPARATOR
		else:
			enc_msg = ""
		enc_msg += binascii.hexlify("%s" % enc)
	return	enc_msg

def decrypt(key, message):
	return key.decrypt(message)

def decrypt_hex(key, message):
	if not message or len(message)==0:
		return	""
	enc_msgs = unhex_str(message)
	decrypted = decrypt(key, enc_msgs)
	return	decrypted

def decrypt_salted_hex(key, salt, message):
	if not message or len(message)==0:
		return	""
	sig = msig(key,salt, "sanitized message: %s" % alphanumfile(message))
	decrypted = decrypt_hex(key, message)
	#remove salt:
	head = "%s%s" % (salt, SALT_SEPARATOR)
	pos = decrypted.find(head)
	if pos!=0:
		logger.log(sig+" decrypted message (alphanum only)='%s'" % alphanumfile(decrypted))
		logger.error(sig+" failed: salt '%s' not found at the start of the message! (pos=%s)" % (salt, pos))
		return ""
	return decrypted[len(head):]

def unhex_str(message):
	"""
	Takes a PART_SEPARATOR separated hex string and turns it into a list of strings.
	"""
	if not message:
		return	None
	enc_msgs = None
	for arg in message.split(PART_SEPARATOR):
		msg = binascii.unhexlify(arg)
		if enc_msgs:
			enc_msgs += (msg,)
		else:
			enc_msgs = msg,
	return	enc_msgs
