add some more tls wip code

This commit is contained in:
leitner
2015-05-08 01:21:32 +00:00
parent ed9c3d238e
commit e4a6b9268f
7 changed files with 409 additions and 0 deletions

13
fmt_tls_handshake_cert.c Normal file
View File

@@ -0,0 +1,13 @@
#include "tinytls.h"
#include <string.h>
size_t fmt_tls_handshake_cert(char* dest,const char* cert,size_t len) {
if (len>0x1000) return 0; // completely arbitrary decision on my part
if (dest) {
dest[0]=0;
dest[1]=(len>>8);
dest[2]=(len&0xff);
memcpy(dest+3,cert,len);
}
return len+3;
}

View File

@@ -0,0 +1,28 @@
#include "tinytls.h"
size_t fmt_tls_handshake_certs_header(char* dest,size_t len_of_certs) {
if (len_of_certs>0x4000) return 0; // completely arbitrary decision on my part
if (dest) {
/* We need to write two headers containing three lengths */
/* Someone ought to get their ass kicked for this */
/* 1. TLS packet layer header */
dest[0]=22; // handshake protocol
dest[1]=3; // tls 1.2
dest[2]=3;
dest[3]=((len_of_certs+7)>>8);
dest[4]=((len_of_certs+7)&0xff);
/* 2. handshake protocol header */
dest[5]=11; // handshake type: certificate
dest[6]=0;
dest[7]=((len_of_certs+3)>>8);
dest[8]=((len_of_certs+3)&0xff);
/* and now the same length ... a third time! */
dest[9]=0;
dest[10]=(len_of_certs>>8);
dest[11]=(len_of_certs&0xff);
}
return 12;
}

View File

@@ -0,0 +1,8 @@
#include "tinytls.h"
#include <string.h>
size_t fmt_tls_serverhellodone(char* dest) {
if (dest)
memcpy(dest,"\x16\x03\x03\x00\x04\x0e\x00\x00",9);
return 9;
}

85
tls_accept.c Normal file
View File

@@ -0,0 +1,85 @@
#include "tinytls.h"
#include <stdlib.h>
#include <string.h>
tls_error_code tls_accept(uintptr_t fd,struct ssl_context* sc) {
tls_error_code r,ret=PROTOCOLFAIL;
size_t l;
switch (sc->state) {
case READ_CLIENTHELLO:
r=tls_doread(fd,sc);
if (r!=OK) return r;
l=fmt_tls_serverhello(sc->scratch,sc->scratch,sc->ofsinmessage,sc);
sc->ofsinmessage=0;
if (l==7)
// failure, send error message back
goto alertfail;
// figure out which certificates to send
if (sc->readcert) {
enum alerttype a=sc->readcert(sc);
if (a!=0) {
fmt_tls_alert_pkt(sc->scratch,2,a);
goto alertfail;
}
}
{
size_t i,s;
char* x;
for (i=s=0; i<MAXCERT && sc->mycert[i].l; ++i) {
if (sc->mycert[i].l>0x1000) {
nocert:
fmt_tls_alert_pkt(sc->scratch,2,INTERNAL_ERROR);
ret=YOUSUCK;
goto alertfail;
}
s+=sc->mycert[i].l+3; // fmt_tls_handshake_cert shortcut
}
if (l+s+12+9 > sizeof(sc->scratch)) {
// l is the size of the serverhello which we generated, at most 309 bytes
// s is the sum of the sizes of the certificates, at most 0x1003*MAXCERT
// 12 is for fmt_tls_handshake_certs_header
// 9 is for fmt_tls_serverhellodone
// -> no integer overflow
char* x=realloc((char*)sc->message.s,l+s+12+9);
if (!x) {
fmt_tls_alert_pkt(sc->scratch,2,INTERNAL_ERROR);
ret=OOM;
goto alertfail;
}
memcpy(x,sc->scratch,l);
sc->message.s=x;
}
sc->message.l=l+s+12+9;
if (sc->mycert[0].l==0)
goto nocert;
x=(char*)sc->message.s+l;
x+=fmt_tls_handshake_certs_header(x,s);
for (i=0; i<MAXCERT; ++i)
if (sc->mycert[i].l)
x+=fmt_tls_handshake_cert(x,sc->mycert[i].s,sc->mycert[i].l);
x+=fmt_tls_serverhellodone(x);
}
r=WRITE_SERVERHELLO;
sc->ofsinmessage=0;
// fall through
case WRITE_SERVERHELLO:
r=tls_dowrite(fd,sc);
if (r!=OK) return r;
return r;
case WRITE_ALERTFAIL:
alertfail:
sc->state=WRITE_ALERTFAIL;
r=tls_dowrite(fd,sc);
if (r!=OK) return r;
// fall through
default:
sc->state=FAIL;
}
return ret;
}

125
tls_connect.c Normal file
View File

@@ -0,0 +1,125 @@
#include "tinytls.h"
#include <stdlib.h>
#include "uint16.h"
#include "uint32.h"
#include "buffer.h"
#include <string.h>
inline int puts(const char* s) {
buffer_putmflush(buffer_1,s,"\n");
return 0;
}
tls_error_code tls_connect(uintptr_t fd,struct ssl_context* sc) {
size_t l;
tls_error_code r,ret=PROTOCOLFAIL;
switch (sc->state) {
case NONE:
puts("TLS_CONNECT");
// initial connect attempt; send client hello
sc->message.l=fmt_tls_clienthello(NULL,sc);
// scratch should be enough to hold the client hello
// depending on session data length and sc->hostname
if (sc->message.l<=sizeof(sc->scratch))
sc->message.s=sc->scratch;
else {
if (!(sc->message.s=malloc(sc->message.l)))
return OOM;
}
sc->message.l=fmt_tls_clienthello((char*)sc->message.s,sc);
sc->ofsinmessage=0; sc->message.s=sc->scratch;
sc->state=WRITE_CLIENTHELLO;
// fall through
case WRITE_CLIENTHELLO:
puts("WRITE_CLIENTHELLO");
r=tls_dowrite(fd,sc);
if (r!=OK) return r;
r=READ_SERVERHELLO;
sc->ofsinmessage=0;
// fall through
case READ_SERVERHELLO:
puts("READ_SERVERHELLO");
r=tls_doread(fd,sc);
if (r!=OK) return r;
if (sc->message.s[0]!=22) { // "handshake"
nothandshake:
fmt_tls_alert_pkt(sc->scratch,2,UNEXPECTED_MESSAGE);
goto alertfail;
}
if ((l=uint16_read_big(sc->message.s+3))<54) { // outer length
decodeerror:
fmt_tls_alert_pkt(sc->scratch,2,DECODE_ERROR);
goto alertfail;
}
if (sc->message.s[5]!=2) goto nothandshake; // "server hello"
if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length
if ((size_t)(unsigned char)(sc->message.s[5+38])+54<l) goto decodeerror;
{
const char* x=sc->message.s+sc->message.s[5+38]+5+38+1;
// make sure they don't pull a fast one on us
// and "agree" to a cipher/compression method we did not offer
uint16_t cipher=uint16_read_big(x);
if (tls_cipherprio(cipher)<0) goto decodeerror;
if (x[2]!=0) goto decodeerror;
sc->cipher=cipher;
sc->compressionmethod=0;
}
r=READ_CERT;
sc->ofsinmessage=0;
// fall through
case READ_CERT:
puts("READ_CERT");
r=tls_doread(fd,sc);
if (r!=OK) return r;
if (sc->message.s[0]!=22) goto nothandshake; // "handshake"
if ((l=uint16_read_big(sc->message.s+3))<50) goto decodeerror;
if (sc->message.s[5]!=11) goto nothandshake; // "certificate"
if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length
if ((uint32_read_big(sc->message.s+8)&0xffffff)+7!=l) goto decodeerror; // innerer length
{
const char* x=sc->message.s+9+3;
const char* max=x+l-7;
size_t i;
sc->theircert[0].s=malloc(l);
for (i=0; i<MAXCERT; ++i) {
if (x>=max) break;
if (x[0]) goto decodeerror;
sc->theircert[i].l=uint16_read_big(x+1);
x+=3;
if ((uintptr_t)(max-x) < sc->theircert[i].l) goto decodeerror;
if (i!=0) sc->theircert[i].s=sc->theircert[i-1].s+sc->theircert[i-1].l;
memcpy((char*)sc->theircert[i].s,x,sc->theircert[i].l);
x+=sc->theircert[i].l;
}
}
r=READ_SERVERHELLODONE;
sc->ofsinmessage=0;
// fall through
case READ_SERVERHELLODONE:
puts("READ_SERVERHELLODONE");
r=tls_doread(fd,sc);
if (r!=OK) return r;
if (sc->message.s[0]!=22) goto nothandshake; // "handshake"
if ((l=uint16_read_big(sc->message.s+3))!=4) goto decodeerror;
if (sc->message.s[5]!=14) goto nothandshake; // "server hello done"
if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length
return OK;
case WRITE_ALERTFAIL:
alertfail:
sc->state=WRITE_ALERTFAIL;
sc->message.s=sc->scratch;
sc->message.l=7;
r=tls_dowrite(fd,sc);
if (r!=OK) return r;
default:
sc->state=FAIL;
}
return ret;
}

120
tls_doread.c Normal file
View File

@@ -0,0 +1,120 @@
#include "tinytls.h"
#include "uint16.h"
#include <errno.h>
#include <unistd.h>
#include <stdlib.h>
tls_error_code tls_doread(uintptr_t fd,struct ssl_context* sc) {
size_t l;
ssize_t r;
again:
if (sc->ofsinmessage < 5) {
// we have not read anything yet.
// point message to scratch and read the first bit
sc->message.s=sc->scratch;
sc->message.l=0;
l=5-sc->ofsinmessage;
} else {
// we have read enough to know how much we are supposed to be reading
// in this case s->message is setup right for us already
l=sc->message.l-sc->ofsinmessage;
}
if (sc->_read)
r=sc->_read(fd,(char*)sc->message.s+sc->ofsinmessage,l);
else
r=read(fd,(char*)sc->message.s+sc->ofsinmessage,l);
if (r==0) // EOF when we expected something -> protocol error
return PROTOCOLFAIL;
if (r<0) {
// we accept the traditional -1+errno
// and the libowfat -3+errno for error and -1 for EAGAIN
// as long as errno is still set to EAGAIN
if (r==-3) return IOFAIL;
if (r==-1)
return errno==EAGAIN ? WANTREAD : IOFAIL;
return YOUSUCK;
}
if ((size_t)r>l)
return YOUSUCK; // callback says it read more than we asked for
sc->ofsinmessage+=l;
if (sc->ofsinmessage>=5 && sc->ofsinmessage-l<5) {
// we did not know how much we wanted before, but we do now
sc->message.l=5+uint16_read_big(sc->scratch+3);
if (sc->message.l>sizeof(sc->scratch)) {
char* x;
if (!(x=realloc((char*)sc->message.s,sc->message.l)))
return OOM;
sc->message.s=(char*)x; // make sure we don't clobber sc->message.s in the OOM case
memcpy((char*)sc->message.s,sc->scratch,sc->ofsinmessage-l);
}
/* attempt to read the rest */
goto again;
}
if (sc->ofsinmessage >= sc->message.l) {
// we read one full packet. See if it is an alert.
if (sc->message.s[0]==ALERT) {
if (sc->message.l!=7) // alerts are 5 bytes header plus 2 bytes alert
return PROTOCOLFAIL;
// it is an alert; skip warnings, signal errors
if (sc->message.s[5]==1) {
// it's a warning, we can ignore it.
if (sc->ofsinmessage>7) {
// since we initially read into scratch, we could have read
// more than 7 bytes. Move latter part forward.
memmove((char*)sc->message.s+7,sc->message.s,sc->ofsinmessage-7);
sc->ofsinmessage-=7;
goto again;
}
}
switch (sc->message.s[6]) {
case BAD_RECORD_MAC:
case DECRYPTION_FAILED:
case DECRYPT_ERROR:
case EXPORT_RESTRICTION:
case INSUFFICIENT_SECURITY:
return CRYPTOFAIL;
case HANDSHAKE_FAILURE:
case INTERNAL_ERROR:
case USER_CANCELED:
case NO_RENEGOTIATION:
return NEGOTIATIONFAIL;
case NO_CERT:
case BAD_CERT:
case UNSUPPORTED_CERT:
case CERT_REVOKED:
case CERT_EXPIRED:
case CERT_UNKNOWN:
case UNKNOWN_CA:
return CERTFAIL;
default:
return PROTOCOLFAIL;
}
}
return OK;
}
return WANTREAD;
}
void tls_prepare_next_read(struct ssl_context* sc) {
size_t psize;
if (sc->message.s==sc->scratch && // we are reading into scratch
sc->message.l>5 && // we have a header
sc->message.l>(psize=5+uint16_read_big(sc->scratch+3))) {
// we have a handled packet and extra data in the scratch buffer
// memmove the rest over the handled packet
memmove(sc->scratch,sc->scratch+psize,sc->message.l-psize);
sc->message.l-=psize;
}
sc->ofsinmessage=0;
}

30
tls_dowrite.c Normal file
View File

@@ -0,0 +1,30 @@
#include "tinytls.h"
#include <unistd.h>
#include <errno.h>
tls_error_code tls_dowrite(uintptr_t fd,struct ssl_context* sc) {
size_t l=sc->message.l-sc->ofsinmessage;
ssize_t r;
if (sc->_write)
r=sc->_write(fd,sc->message.s+sc->ofsinmessage,l);
else
r=write(fd,sc->message.s+sc->ofsinmessage,l);
if (r==0) // EOF when we expected something -> protocol error
return PROTOCOLFAIL;
if (r<0) {
// we accept the traditional -1+errno
// and the libowfat -3+errno for error and -1 for EAGAIN
// as long as errno is still set to EAGAIN
if (r==-3) return IOFAIL;
if (r==-1)
return errno==EAGAIN ? WANTWRITE : IOFAIL;
return YOUSUCK;
}
if ((size_t)r>l)
return YOUSUCK; // callback says it read more than we asked for
sc->ofsinmessage+=l;
return sc->ofsinmessage < sc->message.l ? WANTWRITE : OK;
}