Newer
Older
BlackoutClient / Assets / Best HTTP / Source / SecureProtocol / crypto / tls / DtlsReliableHandshake.cs
#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.Collections;
using System.IO;

using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;

namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
    internal class DtlsReliableHandshake
    {
        private const int MaxReceiveAhead = 16;
        private const int MessageHeaderLength = 12;

        private readonly DtlsRecordLayer mRecordLayer;

        private TlsHandshakeHash mHandshakeHash;

        private IDictionary mCurrentInboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable();
        private IDictionary mPreviousInboundFlight = null;
        private IList mOutboundFlight = BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateArrayList();
        private bool mSending = true;

        private int mMessageSeq = 0, mNextReceiveSeq = 0;

        internal DtlsReliableHandshake(TlsContext context, DtlsRecordLayer transport)
        {
            this.mRecordLayer = transport;
            this.mHandshakeHash = new DeferredHash();
            this.mHandshakeHash.Init(context);
        }

        internal void NotifyHelloComplete()
        {
            this.mHandshakeHash = mHandshakeHash.NotifyPrfDetermined();
        }

        internal TlsHandshakeHash HandshakeHash
        {
            get { return mHandshakeHash; }
        }

        internal TlsHandshakeHash PrepareToFinish()
        {
            TlsHandshakeHash result = mHandshakeHash;
            this.mHandshakeHash = mHandshakeHash.StopTracking();
            return result;
        }

        internal void SendMessage(byte msg_type, byte[] body)
        {
            TlsUtilities.CheckUint24(body.Length);

            if (!mSending)
            {
                CheckInboundFlight();
                mSending = true;
                mOutboundFlight.Clear();
            }

            Message message = new Message(mMessageSeq++, msg_type, body);

            mOutboundFlight.Add(message);

            WriteMessage(message);
            UpdateHandshakeMessagesDigest(message);
        }

        internal byte[] ReceiveMessageBody(byte msg_type)
        {
            Message message = ReceiveMessage();
            if (message.Type != msg_type)
                throw new TlsFatalAlert(AlertDescription.unexpected_message);

            return message.Body;
        }

        internal Message ReceiveMessage()
        {
            if (mSending)
            {
                mSending = false;
                PrepareInboundFlight(BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.CreateHashtable());
            }

            byte[] buf = null;

            // TODO Check the conditions under which we should reset this
            int readTimeoutMillis = 1000;

            for (;;)
            {
                try
                {
                    for (;;)
                    {
                        if (mRecordLayer.IsClosed)
                            throw new TlsFatalAlert(AlertDescription.user_canceled);

                        Message pending = GetPendingMessage();
                        if (pending != null)
                            return pending;

                        int receiveLimit = mRecordLayer.GetReceiveLimit();
                        if (buf == null || buf.Length < receiveLimit)
                        {
                            buf = new byte[receiveLimit];
                        }

                        int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
                        if (received < 0)
                            break;

                        bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
                        if (resentOutbound)
                        {
                            readTimeoutMillis = BackOff(readTimeoutMillis);
                        }
                    }
                }
                catch (IOException)
                {
                    // NOTE: Assume this is a timeout for the moment
                }

                ResendOutboundFlight();
                readTimeoutMillis = BackOff(readTimeoutMillis);
            }
        }

        internal void Finish()
        {
            DtlsHandshakeRetransmit retransmit = null;
            if (!mSending)
            {
                CheckInboundFlight();
            }
            else
            {
                PrepareInboundFlight(null);

                if (mPreviousInboundFlight != null)
                {
                    /*
                     * RFC 6347 4.2.4. In addition, for at least twice the default MSL defined for [TCP],
                     * when in the FINISHED state, the node that transmits the last flight (the server in an
                     * ordinary handshake or the client in a resumed handshake) MUST respond to a retransmit
                     * of the peer's last flight with a retransmit of the last flight.
                     */
                    retransmit = new Retransmit(this);
                }
            }

            mRecordLayer.HandshakeSuccessful(retransmit);
        }

        internal void ResetHandshakeMessagesDigest()
        {
            mHandshakeHash.Reset();
        }

        private int BackOff(int timeoutMillis)
        {
            /*
             * TODO[DTLS] implementations SHOULD back off handshake packet size during the
             * retransmit backoff.
             */
            return System.Math.Min(timeoutMillis * 2, 60000);
        }

        /**
         * Check that there are no "extra" messages left in the current inbound flight
         */
        private void CheckInboundFlight()
        {
            foreach (int key in mCurrentInboundFlight.Keys)
            {
                if (key >= mNextReceiveSeq)
                {
                    // TODO Should this be considered an error?
                }
            }
        }

        private Message GetPendingMessage()
        {
            DtlsReassembler next = (DtlsReassembler)mCurrentInboundFlight[mNextReceiveSeq];
            if (next != null)
            {
                byte[] body = next.GetBodyIfComplete();
                if (body != null)
                {
                    mPreviousInboundFlight = null;
                    return UpdateHandshakeMessagesDigest(new Message(mNextReceiveSeq++, next.MsgType, body));
                }
            }
            return null;
        }

        private void PrepareInboundFlight(IDictionary nextFlight)
        {
            ResetAll(mCurrentInboundFlight);
            mPreviousInboundFlight = mCurrentInboundFlight;
            mCurrentInboundFlight = nextFlight;
        }

        private bool ProcessRecord(int windowSize, int epoch, byte[] buf, int off, int len)
        {
            bool checkPreviousFlight = false;

            while (len >= MessageHeaderLength)
            {
                int fragment_length = TlsUtilities.ReadUint24(buf, off + 9);
                int message_length = fragment_length + MessageHeaderLength;
                if (len < message_length)
                {
                    // NOTE: Truncated message - ignore it
                    break;
                }

                int length = TlsUtilities.ReadUint24(buf, off + 1);
                int fragment_offset = TlsUtilities.ReadUint24(buf, off + 6);
                if (fragment_offset + fragment_length > length)
                {
                    // NOTE: Malformed fragment - ignore it and the rest of the record
                    break;
                }

                /*
                 * NOTE: This very simple epoch check will only work until we want to support
                 * renegotiation (and we're not likely to do that anyway).
                 */
                byte msg_type = TlsUtilities.ReadUint8(buf, off + 0);
                int expectedEpoch = msg_type == HandshakeType.finished ? 1 : 0;
                if (epoch != expectedEpoch)
                {
                    break;
                }

                int message_seq = TlsUtilities.ReadUint16(buf, off + 4);
                if (message_seq >= (mNextReceiveSeq + windowSize))
                {
                    // NOTE: Too far ahead - ignore
                }
                else if (message_seq >= mNextReceiveSeq)
                {
                    DtlsReassembler reassembler = (DtlsReassembler)mCurrentInboundFlight[message_seq];
                    if (reassembler == null)
                    {
                        reassembler = new DtlsReassembler(msg_type, length);
                        mCurrentInboundFlight[message_seq] = reassembler;
                    }

                    reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
                        fragment_length);
                }
                else if (mPreviousInboundFlight != null)
                {
                    /*
                     * NOTE: If we receive the previous flight of incoming messages in full again,
                     * retransmit our last flight
                     */

                    DtlsReassembler reassembler = (DtlsReassembler)mPreviousInboundFlight[message_seq];
                    if (reassembler != null)
                    {
                        reassembler.ContributeFragment(msg_type, length, buf, off + MessageHeaderLength, fragment_offset,
                            fragment_length);
                        checkPreviousFlight = true;
                    }
                }

                off += message_length;
                len -= message_length;
            }

            bool result = checkPreviousFlight && CheckAll(mPreviousInboundFlight);
            if (result)
            {
                ResendOutboundFlight();
                ResetAll(mPreviousInboundFlight);
            }
            return result;
        }

        private void ResendOutboundFlight()
        {
            mRecordLayer.ResetWriteEpoch();
            for (int i = 0; i < mOutboundFlight.Count; ++i)
            {
                WriteMessage((Message)mOutboundFlight[i]);
            }
        }

        private Message UpdateHandshakeMessagesDigest(Message message)
        {
            if (message.Type != HandshakeType.hello_request)
            {
                byte[] body = message.Body;
                byte[] buf = new byte[MessageHeaderLength];
                TlsUtilities.WriteUint8(message.Type, buf, 0);
                TlsUtilities.WriteUint24(body.Length, buf, 1);
                TlsUtilities.WriteUint16(message.Seq, buf, 4);
                TlsUtilities.WriteUint24(0, buf, 6);
                TlsUtilities.WriteUint24(body.Length, buf, 9);
                mHandshakeHash.BlockUpdate(buf, 0, buf.Length);
                mHandshakeHash.BlockUpdate(body, 0, body.Length);
            }
            return message;
        }

        private void WriteMessage(Message message)
        {
            int sendLimit = mRecordLayer.GetSendLimit();
            int fragmentLimit = sendLimit - MessageHeaderLength;

            // TODO Support a higher minimum fragment size?
            if (fragmentLimit < 1)
            {
                // TODO Should we be throwing an exception here?
                throw new TlsFatalAlert(AlertDescription.internal_error);
            }

            int length = message.Body.Length;

            // NOTE: Must still send a fragment if body is empty
            int fragment_offset = 0;
            do
            {
                int fragment_length = System.Math.Min(length - fragment_offset, fragmentLimit);
                WriteHandshakeFragment(message, fragment_offset, fragment_length);
                fragment_offset += fragment_length;
            }
            while (fragment_offset < length);
        }

        private void WriteHandshakeFragment(Message message, int fragment_offset, int fragment_length)
        {
            RecordLayerBuffer fragment = new RecordLayerBuffer(MessageHeaderLength + fragment_length);
            TlsUtilities.WriteUint8(message.Type, fragment);
            TlsUtilities.WriteUint24(message.Body.Length, fragment);
            TlsUtilities.WriteUint16(message.Seq, fragment);
            TlsUtilities.WriteUint24(fragment_offset, fragment);
            TlsUtilities.WriteUint24(fragment_length, fragment);
            fragment.Write(message.Body, fragment_offset, fragment_length);

            fragment.SendToRecordLayer(mRecordLayer);
        }

        private static bool CheckAll(IDictionary inboundFlight)
        {
            foreach (DtlsReassembler r in inboundFlight.Values)
            {
                if (r.GetBodyIfComplete() == null)
                {
                    return false;
                }
            }
            return true;
        }

        private static void ResetAll(IDictionary inboundFlight)
        {
            foreach (DtlsReassembler r in inboundFlight.Values)
            {
                r.Reset();
            }
        }

        internal class Message
        {
            private readonly int mMessageSeq;
            private readonly byte mMsgType;
            private readonly byte[] mBody;

            internal Message(int message_seq, byte msg_type, byte[] body)
            {
                this.mMessageSeq = message_seq;
                this.mMsgType = msg_type;
                this.mBody = body;
            }

            public int Seq
            {
                get { return mMessageSeq; }
            }

            public byte Type
            {
                get { return mMsgType; }
            }

            public byte[] Body
            {
                get { return mBody; }
            }
        }

        internal class RecordLayerBuffer
            :   MemoryStream
        {
            internal RecordLayerBuffer(int size)
                :   base(size)
            {
            }

            internal void SendToRecordLayer(DtlsRecordLayer recordLayer)
            {
#if PORTABLE || NETFX_CORE
                byte[] buf = ToArray();
                int bufLen = buf.Length;
#else
                byte[] buf = GetBuffer();
                int bufLen = (int)Length;
#endif

                recordLayer.Send(buf, 0, bufLen);
                BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Platform.Dispose(this);
            }
        }

        internal class Retransmit
            :   DtlsHandshakeRetransmit
        {
            private readonly DtlsReliableHandshake mOuter;

            internal Retransmit(DtlsReliableHandshake outer)
            {
                this.mOuter = outer;
            }

            public void ReceivedHandshakeRecord(int epoch, byte[] buf, int off, int len)
            {
                mOuter.ProcessRecord(0, epoch, buf, off, len);
            }
        }
    }
}
#pragma warning restore
#endif