#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.IO;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Date;
namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
internal class DtlsRecordLayer
: DatagramTransport
{
private const int RECORD_HEADER_LENGTH = 13;
private const int MAX_FRAGMENT_LENGTH = 1 << 14;
private const long TCP_MSL = 1000L * 60 * 2;
private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
private readonly DatagramTransport mTransport;
private readonly TlsContext mContext;
private readonly TlsPeer mPeer;
private readonly ByteQueue mRecordQueue = new ByteQueue();
private volatile bool mClosed = false;
private volatile bool mFailed = false;
private volatile ProtocolVersion mReadVersion = null, mWriteVersion = null;
private volatile bool mInHandshake;
private volatile int mPlaintextLimit;
private DtlsEpoch mCurrentEpoch, mPendingEpoch;
private DtlsEpoch mReadEpoch, mWriteEpoch;
private DtlsHandshakeRetransmit mRetransmit = null;
private DtlsEpoch mRetransmitEpoch = null;
private long mRetransmitExpiry = 0;
internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType)
{
this.mTransport = transport;
this.mContext = context;
this.mPeer = peer;
this.mInHandshake = true;
this.mCurrentEpoch = new DtlsEpoch(0, new TlsNullCipher(context));
this.mPendingEpoch = null;
this.mReadEpoch = mCurrentEpoch;
this.mWriteEpoch = mCurrentEpoch;
SetPlaintextLimit(MAX_FRAGMENT_LENGTH);
}
internal bool IsClosed
{
get { return mClosed; }
}
internal virtual void SetPlaintextLimit(int plaintextLimit)
{
this.mPlaintextLimit = plaintextLimit;
}
internal virtual int ReadEpoch
{
get { return mReadEpoch.Epoch; }
}
internal virtual ProtocolVersion ReadVersion
{
get { return mReadVersion; }
set { this.mReadVersion = value; }
}
internal virtual void SetWriteVersion(ProtocolVersion writeVersion)
{
this.mWriteVersion = writeVersion;
}
internal virtual void InitPendingEpoch(TlsCipher pendingCipher)
{
if (mPendingEpoch != null)
throw new InvalidOperationException();
/*
* TODO "In order to ensure that any given sequence/epoch pair is unique, implementations
* MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment
* lifetime."
*/
// TODO Check for overflow
this.mPendingEpoch = new DtlsEpoch(mWriteEpoch.Epoch + 1, pendingCipher);
}
internal virtual void HandshakeSuccessful(DtlsHandshakeRetransmit retransmit)
{
if (mReadEpoch == mCurrentEpoch || mWriteEpoch == mCurrentEpoch)
{
// TODO
throw new InvalidOperationException();
}
if (retransmit != null)
{
this.mRetransmit = retransmit;
this.mRetransmitEpoch = mCurrentEpoch;
this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT;
}
this.mInHandshake = false;
this.mCurrentEpoch = mPendingEpoch;
this.mPendingEpoch = null;
}
internal virtual void ResetWriteEpoch()
{
if (mRetransmitEpoch != null)
{
this.mWriteEpoch = mRetransmitEpoch;
}
else
{
this.mWriteEpoch = mCurrentEpoch;
}
}
public virtual int GetReceiveLimit()
{
return System.Math.Min(this.mPlaintextLimit,
mReadEpoch.Cipher.GetPlaintextLimit(mTransport.GetReceiveLimit() - RECORD_HEADER_LENGTH));
}
public virtual int GetSendLimit()
{
return System.Math.Min(this.mPlaintextLimit,
mWriteEpoch.Cipher.GetPlaintextLimit(mTransport.GetSendLimit() - RECORD_HEADER_LENGTH));
}
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
byte[] record = null;
for (;;)
{
int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
if (record == null || record.Length < receiveLimit)
{
record = new byte[receiveLimit];
}
try
{
if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
{
mRetransmit = null;
mRetransmitEpoch = null;
}
int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
if (received < 0)
{
return received;
}
if (received < RECORD_HEADER_LENGTH)
{
continue;
}
int length = TlsUtilities.ReadUint16(record, 11);
if (received != (length + RECORD_HEADER_LENGTH))
{
continue;
}
byte type = TlsUtilities.ReadUint8(record, 0);
// TODO Support user-specified custom protocols?
switch (type)
{
case ContentType.alert:
case ContentType.application_data:
case ContentType.change_cipher_spec:
case ContentType.handshake:
case ContentType.heartbeat:
break;
default:
// TODO Exception?
continue;
}
int epoch = TlsUtilities.ReadUint16(record, 3);
DtlsEpoch recordEpoch = null;
if (epoch == mReadEpoch.Epoch)
{
recordEpoch = mReadEpoch;
}
else if (type == ContentType.handshake && mRetransmitEpoch != null
&& epoch == mRetransmitEpoch.Epoch)
{
recordEpoch = mRetransmitEpoch;
}
if (recordEpoch == null)
{
continue;
}
long seq = TlsUtilities.ReadUint48(record, 5);
if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
{
continue;
}
ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
if (!version.IsDtls)
{
continue;
}
if (mReadVersion != null && !mReadVersion.Equals(version))
{
continue;
}
byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
received - RECORD_HEADER_LENGTH);
recordEpoch.ReplayWindow.ReportAuthenticated(seq);
if (plaintext.Length > this.mPlaintextLimit)
{
continue;
}
if (mReadVersion == null)
{
mReadVersion = version;
}
switch (type)
{
case ContentType.alert:
{
if (plaintext.Length == 2)
{
byte alertLevel = plaintext[0];
byte alertDescription = plaintext[1];
mPeer.NotifyAlertReceived(alertLevel, alertDescription);
if (alertLevel == AlertLevel.fatal)
{
Failed();
throw new TlsFatalAlert(alertDescription);
}
// TODO Can close_notify be a fatal alert?
if (alertDescription == AlertDescription.close_notify)
{
CloseTransport();
}
}
continue;
}
case ContentType.application_data:
{
if (mInHandshake)
{
// TODO Consider buffering application data for new epoch that arrives
// out-of-order with the Finished message
continue;
}
break;
}
case ContentType.change_cipher_spec:
{
// Implicitly receive change_cipher_spec and change to pending cipher state
for (int i = 0; i < plaintext.Length; ++i)
{
byte message = TlsUtilities.ReadUint8(plaintext, i);
if (message != ChangeCipherSpec.change_cipher_spec)
{
continue;
}
if (mPendingEpoch != null)
{
mReadEpoch = mPendingEpoch;
}
}
continue;
}
case ContentType.handshake:
{
if (!mInHandshake)
{
if (mRetransmit != null)
{
mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
}
// TODO Consider support for HelloRequest
continue;
}
break;
}
case ContentType.heartbeat:
{
// TODO[RFC 6520]
continue;
}
}
/*
* NOTE: If we receive any non-handshake data in the new epoch implies the peer has
* received our final flight.
*/
if (!mInHandshake && mRetransmit != null)
{
this.mRetransmit = null;
this.mRetransmitEpoch = null;
}
Array.Copy(plaintext, 0, buf, off, plaintext.Length);
return plaintext.Length;
}
catch (IOException e)
{
// NOTE: Assume this is a timeout for the moment
throw e;
}
}
}
/// <exception cref="IOException"/>
public virtual void Send(byte[] buf, int off, int len)
{
byte contentType = ContentType.application_data;
if (this.mInHandshake || this.mWriteEpoch == this.mRetransmitEpoch)
{
contentType = ContentType.handshake;
byte handshakeType = TlsUtilities.ReadUint8(buf, off);
if (handshakeType == HandshakeType.finished)
{
DtlsEpoch nextEpoch = null;
if (this.mInHandshake)
{
nextEpoch = mPendingEpoch;
}
else if (this.mWriteEpoch == this.mRetransmitEpoch)
{
nextEpoch = mCurrentEpoch;
}
if (nextEpoch == null)
{
// TODO
throw new InvalidOperationException();
}
// Implicitly send change_cipher_spec and change to pending cipher state
// TODO Send change_cipher_spec and finished records in single datagram?
byte[] data = new byte[]{ 1 };
SendRecord(ContentType.change_cipher_spec, data, 0, data.Length);
mWriteEpoch = nextEpoch;
}
}
SendRecord(contentType, buf, off, len);
}
public virtual void Close()
{
if (!mClosed)
{
if (mInHandshake)
{
Warn(AlertDescription.user_canceled, "User canceled handshake");
}
CloseTransport();
}
}
internal virtual void Failed()
{
if (!mClosed)
{
mFailed = true;
CloseTransport();
}
}
internal virtual void Fail(byte alertDescription)
{
if (!mClosed)
{
try
{
RaiseAlert(AlertLevel.fatal, alertDescription, null, null);
}
catch (Exception)
{
// Ignore
}
mFailed = true;
CloseTransport();
}
}
internal virtual void Warn(byte alertDescription, string message)
{
RaiseAlert(AlertLevel.warning, alertDescription, message, null);
}
private void CloseTransport()
{
if (!mClosed)
{
/*
* RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is
* required to send a close_notify alert before closing the write side of the
* connection. The other party MUST respond with a close_notify alert of its own and
* close down the connection immediately, discarding any pending writes.
*/
try
{
if (!mFailed)
{
Warn(AlertDescription.close_notify, null);
}
mTransport.Close();
}
catch (Exception)
{
// Ignore
}
mClosed = true;
}
}
private void RaiseAlert(byte alertLevel, byte alertDescription, string message, Exception cause)
{
mPeer.NotifyAlertRaised(alertLevel, alertDescription, message, cause);
byte[] error = new byte[2];
error[0] = (byte)alertLevel;
error[1] = (byte)alertDescription;
SendRecord(ContentType.alert, error, 0, 2);
}
private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
{
if (mRecordQueue.Available > 0)
{
int length = 0;
if (mRecordQueue.Available >= RECORD_HEADER_LENGTH)
{
byte[] lengthBytes = new byte[2];
mRecordQueue.Read(lengthBytes, 0, 2, 11);
length = TlsUtilities.ReadUint16(lengthBytes, 0);
}
int received = System.Math.Min(mRecordQueue.Available, RECORD_HEADER_LENGTH + length);
mRecordQueue.RemoveData(buf, off, received, 0);
return received;
}
{
int received = mTransport.Receive(buf, off, len, waitMillis);
if (received >= RECORD_HEADER_LENGTH)
{
int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
if (received > recordLength)
{
mRecordQueue.AddData(buf, off + recordLength, received - recordLength);
received = recordLength;
}
}
return received;
}
}
private void SendRecord(byte contentType, byte[] buf, int off, int len)
{
// Never send anything until a valid ClientHello has been received
if (mWriteVersion == null)
return;
if (len > this.mPlaintextLimit)
throw new TlsFatalAlert(AlertDescription.internal_error);
/*
* RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
* or ChangeCipherSpec content types.
*/
if (len < 1 && contentType != ContentType.application_data)
throw new TlsFatalAlert(AlertDescription.internal_error);
int recordEpoch = mWriteEpoch.Epoch;
long recordSequenceNumber = mWriteEpoch.AllocateSequenceNumber();
byte[] ciphertext = mWriteEpoch.Cipher.EncodePlaintext(
GetMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len);
// TODO Check the ciphertext length?
byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH];
TlsUtilities.WriteUint8(contentType, record, 0);
ProtocolVersion version = mWriteVersion;
TlsUtilities.WriteVersion(version, record, 1);
TlsUtilities.WriteUint16(recordEpoch, record, 3);
TlsUtilities.WriteUint48(recordSequenceNumber, record, 5);
TlsUtilities.WriteUint16(ciphertext.Length, record, 11);
Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length);
mTransport.Send(record, 0, record.Length);
}
private static long GetMacSequenceNumber(int epoch, long sequence_number)
{
return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number;
}
}
}
#pragma warning restore
#endif