From e4a6b9268f9eac48c339fbe5f4d7e4c98846024f Mon Sep 17 00:00:00 2001 From: leitner Date: Fri, 8 May 2015 01:21:32 +0000 Subject: [PATCH] add some more tls wip code --- fmt_tls_handshake_cert.c | 13 ++++ fmt_tls_handshake_certs_header.c | 28 +++++++ fmt_tls_serverhellodone.c | 8 ++ tls_accept.c | 85 +++++++++++++++++++++ tls_connect.c | 125 +++++++++++++++++++++++++++++++ tls_doread.c | 120 +++++++++++++++++++++++++++++ tls_dowrite.c | 30 ++++++++ 7 files changed, 409 insertions(+) create mode 100644 fmt_tls_handshake_cert.c create mode 100644 fmt_tls_handshake_certs_header.c create mode 100644 fmt_tls_serverhellodone.c create mode 100644 tls_accept.c create mode 100644 tls_connect.c create mode 100644 tls_doread.c create mode 100644 tls_dowrite.c diff --git a/fmt_tls_handshake_cert.c b/fmt_tls_handshake_cert.c new file mode 100644 index 0000000..7295267 --- /dev/null +++ b/fmt_tls_handshake_cert.c @@ -0,0 +1,13 @@ +#include "tinytls.h" +#include + +size_t fmt_tls_handshake_cert(char* dest,const char* cert,size_t len) { + if (len>0x1000) return 0; // completely arbitrary decision on my part + if (dest) { + dest[0]=0; + dest[1]=(len>>8); + dest[2]=(len&0xff); + memcpy(dest+3,cert,len); + } + return len+3; +} diff --git a/fmt_tls_handshake_certs_header.c b/fmt_tls_handshake_certs_header.c new file mode 100644 index 0000000..bc18f09 --- /dev/null +++ b/fmt_tls_handshake_certs_header.c @@ -0,0 +1,28 @@ +#include "tinytls.h" + +size_t fmt_tls_handshake_certs_header(char* dest,size_t len_of_certs) { + if (len_of_certs>0x4000) return 0; // completely arbitrary decision on my part + if (dest) { + /* We need to write two headers containing three lengths */ + /* Someone ought to get their ass kicked for this */ + /* 1. TLS packet layer header */ + dest[0]=22; // handshake protocol + dest[1]=3; // tls 1.2 + dest[2]=3; + dest[3]=((len_of_certs+7)>>8); + dest[4]=((len_of_certs+7)&0xff); + + /* 2. handshake protocol header */ + dest[5]=11; // handshake type: certificate + dest[6]=0; + dest[7]=((len_of_certs+3)>>8); + dest[8]=((len_of_certs+3)&0xff); + + /* and now the same length ... a third time! */ + dest[9]=0; + dest[10]=(len_of_certs>>8); + dest[11]=(len_of_certs&0xff); + } + return 12; +} + diff --git a/fmt_tls_serverhellodone.c b/fmt_tls_serverhellodone.c new file mode 100644 index 0000000..70e2d6f --- /dev/null +++ b/fmt_tls_serverhellodone.c @@ -0,0 +1,8 @@ +#include "tinytls.h" +#include + +size_t fmt_tls_serverhellodone(char* dest) { + if (dest) + memcpy(dest,"\x16\x03\x03\x00\x04\x0e\x00\x00",9); + return 9; +} diff --git a/tls_accept.c b/tls_accept.c new file mode 100644 index 0000000..0f58932 --- /dev/null +++ b/tls_accept.c @@ -0,0 +1,85 @@ +#include "tinytls.h" +#include +#include + +tls_error_code tls_accept(uintptr_t fd,struct ssl_context* sc) { + tls_error_code r,ret=PROTOCOLFAIL; + size_t l; + switch (sc->state) { + case READ_CLIENTHELLO: + r=tls_doread(fd,sc); + if (r!=OK) return r; + + l=fmt_tls_serverhello(sc->scratch,sc->scratch,sc->ofsinmessage,sc); + sc->ofsinmessage=0; + if (l==7) + // failure, send error message back + goto alertfail; + // figure out which certificates to send + if (sc->readcert) { + enum alerttype a=sc->readcert(sc); + if (a!=0) { + fmt_tls_alert_pkt(sc->scratch,2,a); + goto alertfail; + } + } + { + size_t i,s; + char* x; + for (i=s=0; imycert[i].l; ++i) { + if (sc->mycert[i].l>0x1000) { +nocert: + fmt_tls_alert_pkt(sc->scratch,2,INTERNAL_ERROR); + ret=YOUSUCK; + goto alertfail; + } + s+=sc->mycert[i].l+3; // fmt_tls_handshake_cert shortcut + } + if (l+s+12+9 > sizeof(sc->scratch)) { + // l is the size of the serverhello which we generated, at most 309 bytes + // s is the sum of the sizes of the certificates, at most 0x1003*MAXCERT + // 12 is for fmt_tls_handshake_certs_header + // 9 is for fmt_tls_serverhellodone + // -> no integer overflow + char* x=realloc((char*)sc->message.s,l+s+12+9); + if (!x) { + fmt_tls_alert_pkt(sc->scratch,2,INTERNAL_ERROR); + ret=OOM; + goto alertfail; + } + memcpy(x,sc->scratch,l); + sc->message.s=x; + } + sc->message.l=l+s+12+9; + if (sc->mycert[0].l==0) + goto nocert; + + x=(char*)sc->message.s+l; + x+=fmt_tls_handshake_certs_header(x,s); + for (i=0; imycert[i].l) + x+=fmt_tls_handshake_cert(x,sc->mycert[i].s,sc->mycert[i].l); + x+=fmt_tls_serverhellodone(x); + } + + r=WRITE_SERVERHELLO; + sc->ofsinmessage=0; + // fall through + + case WRITE_SERVERHELLO: + r=tls_dowrite(fd,sc); + if (r!=OK) return r; + return r; + + case WRITE_ALERTFAIL: +alertfail: + sc->state=WRITE_ALERTFAIL; + r=tls_dowrite(fd,sc); + if (r!=OK) return r; + // fall through + + default: + sc->state=FAIL; + } + return ret; +} diff --git a/tls_connect.c b/tls_connect.c new file mode 100644 index 0000000..d9c47de --- /dev/null +++ b/tls_connect.c @@ -0,0 +1,125 @@ +#include "tinytls.h" +#include +#include "uint16.h" +#include "uint32.h" +#include "buffer.h" +#include + +inline int puts(const char* s) { + buffer_putmflush(buffer_1,s,"\n"); + return 0; +} + +tls_error_code tls_connect(uintptr_t fd,struct ssl_context* sc) { + size_t l; + tls_error_code r,ret=PROTOCOLFAIL; + switch (sc->state) { + case NONE: + puts("TLS_CONNECT"); + // initial connect attempt; send client hello + sc->message.l=fmt_tls_clienthello(NULL,sc); + // scratch should be enough to hold the client hello + // depending on session data length and sc->hostname + if (sc->message.l<=sizeof(sc->scratch)) + sc->message.s=sc->scratch; + else { + if (!(sc->message.s=malloc(sc->message.l))) + return OOM; + } + sc->message.l=fmt_tls_clienthello((char*)sc->message.s,sc); + sc->ofsinmessage=0; sc->message.s=sc->scratch; + sc->state=WRITE_CLIENTHELLO; + // fall through + + case WRITE_CLIENTHELLO: + puts("WRITE_CLIENTHELLO"); + r=tls_dowrite(fd,sc); + if (r!=OK) return r; + r=READ_SERVERHELLO; + sc->ofsinmessage=0; + // fall through + + case READ_SERVERHELLO: + puts("READ_SERVERHELLO"); + r=tls_doread(fd,sc); + if (r!=OK) return r; + if (sc->message.s[0]!=22) { // "handshake" +nothandshake: + fmt_tls_alert_pkt(sc->scratch,2,UNEXPECTED_MESSAGE); + goto alertfail; + } + if ((l=uint16_read_big(sc->message.s+3))<54) { // outer length +decodeerror: + fmt_tls_alert_pkt(sc->scratch,2,DECODE_ERROR); + goto alertfail; + } + if (sc->message.s[5]!=2) goto nothandshake; // "server hello" + if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length + if ((size_t)(unsigned char)(sc->message.s[5+38])+54message.s+sc->message.s[5+38]+5+38+1; + // make sure they don't pull a fast one on us + // and "agree" to a cipher/compression method we did not offer + uint16_t cipher=uint16_read_big(x); + if (tls_cipherprio(cipher)<0) goto decodeerror; + if (x[2]!=0) goto decodeerror; + sc->cipher=cipher; + sc->compressionmethod=0; + } + r=READ_CERT; + sc->ofsinmessage=0; + // fall through + + case READ_CERT: + puts("READ_CERT"); + r=tls_doread(fd,sc); + if (r!=OK) return r; + if (sc->message.s[0]!=22) goto nothandshake; // "handshake" + if ((l=uint16_read_big(sc->message.s+3))<50) goto decodeerror; + if (sc->message.s[5]!=11) goto nothandshake; // "certificate" + if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length + if ((uint32_read_big(sc->message.s+8)&0xffffff)+7!=l) goto decodeerror; // innerer length + { + const char* x=sc->message.s+9+3; + const char* max=x+l-7; + size_t i; + sc->theircert[0].s=malloc(l); + for (i=0; i=max) break; + if (x[0]) goto decodeerror; + sc->theircert[i].l=uint16_read_big(x+1); + x+=3; + if ((uintptr_t)(max-x) < sc->theircert[i].l) goto decodeerror; + if (i!=0) sc->theircert[i].s=sc->theircert[i-1].s+sc->theircert[i-1].l; + memcpy((char*)sc->theircert[i].s,x,sc->theircert[i].l); + x+=sc->theircert[i].l; + } + } + + r=READ_SERVERHELLODONE; + sc->ofsinmessage=0; + // fall through + +case READ_SERVERHELLODONE: + puts("READ_SERVERHELLODONE"); + r=tls_doread(fd,sc); + if (r!=OK) return r; + if (sc->message.s[0]!=22) goto nothandshake; // "handshake" + if ((l=uint16_read_big(sc->message.s+3))!=4) goto decodeerror; + if (sc->message.s[5]!=14) goto nothandshake; // "server hello done" + if ((uint32_read_big(sc->message.s+5)&0xffffff)+4!=l) goto decodeerror; // inner length + return OK; + + case WRITE_ALERTFAIL: +alertfail: + sc->state=WRITE_ALERTFAIL; + sc->message.s=sc->scratch; + sc->message.l=7; + r=tls_dowrite(fd,sc); + if (r!=OK) return r; + + default: + sc->state=FAIL; + } + return ret; +} diff --git a/tls_doread.c b/tls_doread.c new file mode 100644 index 0000000..a5d4781 --- /dev/null +++ b/tls_doread.c @@ -0,0 +1,120 @@ +#include "tinytls.h" +#include "uint16.h" +#include +#include +#include + +tls_error_code tls_doread(uintptr_t fd,struct ssl_context* sc) { + size_t l; + ssize_t r; +again: + if (sc->ofsinmessage < 5) { + // we have not read anything yet. + // point message to scratch and read the first bit + sc->message.s=sc->scratch; + sc->message.l=0; + l=5-sc->ofsinmessage; + } else { + // we have read enough to know how much we are supposed to be reading + // in this case s->message is setup right for us already + l=sc->message.l-sc->ofsinmessage; + } + + if (sc->_read) + r=sc->_read(fd,(char*)sc->message.s+sc->ofsinmessage,l); + else + r=read(fd,(char*)sc->message.s+sc->ofsinmessage,l); + + if (r==0) // EOF when we expected something -> protocol error + return PROTOCOLFAIL; + + if (r<0) { + // we accept the traditional -1+errno + // and the libowfat -3+errno for error and -1 for EAGAIN + // as long as errno is still set to EAGAIN + if (r==-3) return IOFAIL; + if (r==-1) + return errno==EAGAIN ? WANTREAD : IOFAIL; + return YOUSUCK; + } + + if ((size_t)r>l) + return YOUSUCK; // callback says it read more than we asked for + + sc->ofsinmessage+=l; + if (sc->ofsinmessage>=5 && sc->ofsinmessage-l<5) { + // we did not know how much we wanted before, but we do now + sc->message.l=5+uint16_read_big(sc->scratch+3); + if (sc->message.l>sizeof(sc->scratch)) { + char* x; + if (!(x=realloc((char*)sc->message.s,sc->message.l))) + return OOM; + sc->message.s=(char*)x; // make sure we don't clobber sc->message.s in the OOM case + memcpy((char*)sc->message.s,sc->scratch,sc->ofsinmessage-l); + } + /* attempt to read the rest */ + goto again; + } + + if (sc->ofsinmessage >= sc->message.l) { + // we read one full packet. See if it is an alert. + if (sc->message.s[0]==ALERT) { + if (sc->message.l!=7) // alerts are 5 bytes header plus 2 bytes alert + return PROTOCOLFAIL; + // it is an alert; skip warnings, signal errors + if (sc->message.s[5]==1) { + // it's a warning, we can ignore it. + if (sc->ofsinmessage>7) { + // since we initially read into scratch, we could have read + // more than 7 bytes. Move latter part forward. + memmove((char*)sc->message.s+7,sc->message.s,sc->ofsinmessage-7); + sc->ofsinmessage-=7; + goto again; + } + } + + switch (sc->message.s[6]) { + case BAD_RECORD_MAC: + case DECRYPTION_FAILED: + case DECRYPT_ERROR: + case EXPORT_RESTRICTION: + case INSUFFICIENT_SECURITY: + return CRYPTOFAIL; + + case HANDSHAKE_FAILURE: + case INTERNAL_ERROR: + case USER_CANCELED: + case NO_RENEGOTIATION: + return NEGOTIATIONFAIL; + + case NO_CERT: + case BAD_CERT: + case UNSUPPORTED_CERT: + case CERT_REVOKED: + case CERT_EXPIRED: + case CERT_UNKNOWN: + case UNKNOWN_CA: + return CERTFAIL; + + default: + return PROTOCOLFAIL; + } + } + return OK; + } + return WANTREAD; +} + +void tls_prepare_next_read(struct ssl_context* sc) { + size_t psize; + if (sc->message.s==sc->scratch && // we are reading into scratch + sc->message.l>5 && // we have a header + sc->message.l>(psize=5+uint16_read_big(sc->scratch+3))) { + // we have a handled packet and extra data in the scratch buffer + // memmove the rest over the handled packet + memmove(sc->scratch,sc->scratch+psize,sc->message.l-psize); + sc->message.l-=psize; + } + sc->ofsinmessage=0; +} + diff --git a/tls_dowrite.c b/tls_dowrite.c new file mode 100644 index 0000000..ce4f044 --- /dev/null +++ b/tls_dowrite.c @@ -0,0 +1,30 @@ +#include "tinytls.h" +#include +#include + +tls_error_code tls_dowrite(uintptr_t fd,struct ssl_context* sc) { + size_t l=sc->message.l-sc->ofsinmessage; + ssize_t r; + if (sc->_write) + r=sc->_write(fd,sc->message.s+sc->ofsinmessage,l); + else + r=write(fd,sc->message.s+sc->ofsinmessage,l); + + if (r==0) // EOF when we expected something -> protocol error + return PROTOCOLFAIL; + if (r<0) { + // we accept the traditional -1+errno + // and the libowfat -3+errno for error and -1 for EAGAIN + // as long as errno is still set to EAGAIN + if (r==-3) return IOFAIL; + if (r==-1) + return errno==EAGAIN ? WANTWRITE : IOFAIL; + return YOUSUCK; + } + if ((size_t)r>l) + return YOUSUCK; // callback says it read more than we asked for + + sc->ofsinmessage+=l; + + return sc->ofsinmessage < sc->message.l ? WANTWRITE : OK; +}