-- Internal.hs: private utility functions
-- Copyright © 2012  Clint Adams
-- This software is released under the terms of the ISC license.
-- (See the LICENSE file).

module Codec.Encryption.OpenPGP.Internal (
   countBits
 , beBSToInteger
 , integerToBEBS
 , PktStreamContext(..)
 , asn1Prefix
 , hash
 , issuer
 , emptyPSC
 , pubkeyToMPIs
) where

import qualified Crypto.Cipher.DSA as DSA
import qualified Crypto.Cipher.RSA as RSA

import qualified Crypto.Hash.MD5 as MD5
import qualified Crypto.Hash.RIPEMD160 as RIPEMD160
import qualified Crypto.Hash.SHA1 as SHA1
import qualified Crypto.Hash.SHA224 as SHA224
import qualified Crypto.Hash.SHA256 as SHA256
import qualified Crypto.Hash.SHA384 as SHA384
import qualified Crypto.Hash.SHA512 as SHA512

import qualified Data.ASN1.DER as DER
import Data.Bits (testBit, shiftL, shiftR, (.&.))
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.List (find, mapAccumR, unfoldr)
import Data.Word (Word8, Word16)

import Codec.Encryption.OpenPGP.Types

countBits :: ByteString -> Word16
countBits bs = fromIntegral (B.length bs * 8) - fromIntegral (go (B.head bs) 7)
    where
        go :: Word8 -> Int -> Word8
        go _ 0 = 7
        go n b = if testBit n b then 7 - fromIntegral b else go n (b-1)

beBSToInteger :: ByteString -> Integer
beBSToInteger = sum . snd . mapAccumR (\acc x -> (acc + 8, fromIntegral x `shiftL` acc)) 0 . B.unpack

integerToBEBS :: Integer -> ByteString
integerToBEBS = B.pack . reverse . unfoldr (\x -> if x == 0 then Nothing else Just ((fromIntegral x :: Word8) .&. 0xff, x `shiftR` 8))

data PktStreamContext = PktStreamContext { lastLD :: Pkt
                      , lastUIDorUAt :: Pkt
                      , lastSig :: Pkt
                      , lastPrimaryKey :: Pkt
                      , lastSubkey :: Pkt
                      }

emptyPSC :: PktStreamContext
emptyPSC = PktStreamContext (MarkerPkt B.empty) (MarkerPkt B.empty) (MarkerPkt B.empty) (MarkerPkt B.empty) (MarkerPkt B.empty)

issuer :: Pkt -> Maybe EightOctetKeyId
issuer (SignaturePkt (SigV4 _ _ _ _ usubs _ _)) = fmap (\(SigSubPacket _ (Issuer i)) -> i) (find isIssuer usubs)
    where
        isIssuer (SigSubPacket _ (Issuer _)) = True
        isIssuer _ = False
issuer _ = Nothing

hash :: HashAlgorithm -> ByteString -> ByteString
hash SHA1 = SHA1.hash
hash RIPEMD160 = RIPEMD160.hash
hash SHA256 = SHA256.hash
hash SHA384 = SHA384.hash
hash SHA512 = SHA512.hash
hash SHA224 = SHA224.hash
hash DeprecatedMD5 = MD5.hash
hash _ = id -- FIXME

asn1Prefix :: HashAlgorithm -> ByteString
asn1Prefix ha = do
    let start = DER.Start DER.Sequence
    let (blen, oid) = (bitLength ha, hashOid ha)
    let numpty = DER.Null
    let end = DER.End DER.Sequence
    let fakeint = DER.OctetString (BL.pack (replicate ((blen `div` 8) - 1) 0 ++ [1]))
    case DER.encodeASN1Stream [start,start,oid,numpty,end,fakeint,end] of
        Left _ -> error "encodeASN1 failure"
        Right l -> B.concat . BL.toChunks $ getPrefix l
    where
        getPrefix = BL.reverse . BL.dropWhile (==0) . BL.drop 1 . BL.reverse
        bitLength DeprecatedMD5 = 128
        bitLength SHA1 = 160
        bitLength RIPEMD160 = 160
        bitLength SHA256 = 256
        bitLength SHA384 = 384
        bitLength SHA512 = 512
        bitLength SHA224 = 224
        bitLength _ = 0
        hashOid DeprecatedMD5 = DER.OID [1,2,840,113549,2,5]
        hashOid RIPEMD160 = DER.OID [1,3,36,3,2,1]
        hashOid SHA1 = DER.OID [1,3,14,3,2,26]
        hashOid SHA224 = DER.OID [2,16,840,1,101,3,4,2,4]
        hashOid SHA256 = DER.OID [2,16,840,1,101,3,4,2,1]
        hashOid SHA384 = DER.OID [2,16,840,1,101,3,4,2,2]
        hashOid SHA512 = DER.OID [2,16,840,1,101,3,4,2,3]
        hashOid _ = DER.OID []

pubkeyToMPIs :: PKey -> [MPI]
pubkeyToMPIs (RSAPubKey k) = [MPI (RSA.public_n k), MPI (RSA.public_e k)]
pubkeyToMPIs (DSAPubKey k) = (\(p,g,q) y -> [MPI p,MPI q,MPI g,MPI y]) (DSA.public_params k) (DSA.public_y k)
pubkeyToMPIs (ElGamalPubKey k) = fmap MPI k