add some basic certificate handling (from cinap again) Reference: /n/atom/patch/applied/tlscert Date: Sat Mar 15 17:41:58 CET 2014 Signed-off-by: quanstro@quanstro.net --- /sys/src/libsec/port/tlshand.c Sat Mar 15 17:41:47 2014 +++ /sys/src/libsec/port/tlshand.c Sat Mar 15 17:41:49 2014 @@ -66,9 +66,13 @@ int state; // must be set using setstate // input buffer for handshake messages - uchar buf[MaxChunk+2048]; + uchar recvbuf[MaxChunk]; uchar *rp, *ep; + // output buffer + uchar sendbuf[MaxChunk]; + uchar *sendp; + uchar crandom[RandomSize]; // client random uchar srandom[RandomSize]; // server random int clientVersion; // version in ClientHello @@ -111,6 +115,9 @@ struct { Bytes *key; } clientKeyExchange; + struct { + Bytes *signature; + } certificateVerify; Finished finished; } u; } Msg; @@ -242,7 +249,7 @@ {"rc4_128", "sha1", 2*(16+SHA1dlen), TLS_RSA_WITH_RC4_128_SHA}, {"3des_ede_cbc", "sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA}, {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_RSA_WITH_AES_128_CBC_SHA}, - {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA} + {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}, }; static uchar compressors[] = { @@ -250,7 +257,7 @@ }; static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...), PEMChain *chain); -static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...)); +static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)); static void msgClear(Msg *m); static char* msgPrint(char *buf, int n, Msg *m); @@ -302,11 +309,15 @@ static int get16(uchar *p); static Bytes* newbytes(int len); static Bytes* makebytes(uchar* buf, int len); +static Bytes* mptobytes(mpint* big); static void freebytes(Bytes* b); static Ints* newints(int len); static Ints* makeints(int* buf, int len); static void freeints(Ints* b); +/* x509.c */ +extern mpint* pkcs1padbuf(uchar *buf, int len, mpint *modulus); + //================= client/server ======================== // push TLS onto fd, returning new (application) file descriptor @@ -330,8 +341,8 @@ return -1; } buf[n] = 0; - sprint(conn->dir, "#a/tls/%s", buf); - sprint(dname, "#a/tls/%s/hand", buf); + snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf); + snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf); hand = open(dname, ORDWR); if(hand < 0){ close(ctl); @@ -339,20 +350,16 @@ } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain); - sprint(dname, "#a/tls/%s/data", buf); + snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); data = open(dname, ORDWR); - close(fd); close(hand); close(ctl); - if(data < 0){ - return -1; - } - if(tls == nil){ - close(data); + if(data < 0 || tls == nil){ + if(tls != nil) + tlsConnectionFree(tls); return -1; } - if(conn->cert) - free(conn->cert); + free(conn->cert); conn->cert = 0; // client certificates are not yet implemented conn->certlen = 0; conn->sessionIDlen = tls->sid->len; @@ -361,6 +368,7 @@ if(conn->sessionKey != nil && conn->sessionType != nil && strcmp(conn->sessionType, "ttls") == 0) tls->sec->prf(conn->sessionKey, conn->sessionKeylen, tls->sec->sec, MasterSecretSize, conn->sessionConst, tls->sec->crandom, RandomSize, tls->sec->srandom, RandomSize); tlsConnectionFree(tls); + close(fd); return data; } @@ -385,19 +393,22 @@ return -1; } buf[n] = 0; - sprint(conn->dir, "#a/tls/%s", buf); - sprint(dname, "#a/tls/%s/hand", buf); + snprint(conn->dir, sizeof(conn->dir), "#a/tls/%s", buf); + snprint(dname, sizeof(dname), "#a/tls/%s/hand", buf); hand = open(dname, ORDWR); if(hand < 0){ close(ctl); return -1; } - sprint(dname, "#a/tls/%s/data", buf); + snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); data = open(dname, ORDWR); - if(data < 0) + if(data < 0){ + close(hand); + close(ctl); return -1; + } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); - tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->trace); + tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, conn->trace); close(fd); close(hand); close(ctl); @@ -601,7 +612,7 @@ } static TlsConnection * -tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...)) +tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, int (*trace)(char*fmt, ...)) { TlsConnection *c; Msg m; @@ -703,7 +714,7 @@ if(tlsSecSecretc(c->sec, c->sid->data, c->sid->len, c->srandom, c->cert->data, c->cert->len, c->version, &epm, &nepm, kd, c->nsecret) < 0){ - tlsError(c, EBadCertificate, "invalid x509/rsa certificate"); + tlsError(c, EBadCertificate, "bad certificate: %r"); goto Err; } secrets = (char*)emalloc(2*c->nsecret); @@ -718,7 +729,11 @@ } if(creq) { - /* send a zero length certificate */ + if(cert != nil && certlen > 0){ + m.u.certificate.ncert = 1; + m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes)); + m.u.certificate.certs[0] = makebytes(cert, certlen); + } m.tag = HCertificate; if(!msgSend(c, &m, AFlush)) goto Err; @@ -738,6 +753,53 @@ goto Err; msgClear(&m); + /* CertificateVerify */ + /* + * XXX I should only send this when it is not DH right? + * Also we need to know which TLS key + * we have to use in case there are more than one + */ + if(cert){ + uchar hshashes[MD5dlen+SHA1dlen]; /* content of signature */ + MD5state hsmd5_save; + SHAstate hssha1_save; + mpint *signedMP, *paddedHashes; + + /* save the state for the Finish message */ + + m.tag = HCertificateVerify; + hsmd5_save = c->hsmd5; + hssha1_save = c->hssha1; + md5(nil, 0, hshashes, &c->hsmd5); + sha1(nil, 0, hshashes+MD5dlen, &c->hssha1); + + c->hsmd5 = hsmd5_save; + c->hssha1 = hssha1_save; + + c->sec->rpc = factotum_rsa_open(cert, certlen); + if(c->sec->rpc == nil){ + tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); + goto Err; + } + c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); + + paddedHashes = pkcs1padbuf(hshashes, 36, c->sec->rsapub->n); + signedMP = factotum_rsa_decrypt(c->sec->rpc, paddedHashes); + m.u.certificateVerify.signature = mptobytes(signedMP); + free(signedMP); + + if(m.u.certificateVerify.signature == nil){ + msgClear(&m); + goto Err; + } + + if(!msgSend(c, &m, AFlush)){ + msgClear(&m); + goto Err; + } + msgClear(&m); + } + /* change cipher spec */ if(fprint(c->ctl, "changecipher") < 0){ tlsError(c, EInternalError, "can't enable cipher: %r"); @@ -800,19 +862,17 @@ //================= message functions ======================== -static uchar sendbuf[9000], *sendp; - static int msgSend(TlsConnection *c, Msg *m, int act) { uchar *p; // sendp = start of new message; p = write pointer int nn, n, i; - if(sendp == nil) - sendp = sendbuf; - p = sendp; + if(c->sendp == nil) + c->sendp = c->sendbuf; + p = c->sendp; if(c->trace) - c->trace("send %s", msgPrint((char*)p, (sizeof sendbuf) - (p-sendbuf), m)); + c->trace("send %s", msgPrint((char*)p, (sizeof(c->sendbuf)) - (p - c->sendbuf), m)); p[0] = m->tag; // header - fill in size later p += 4; @@ -878,7 +938,7 @@ nn = 0; for(i = 0; i < m->u.certificate.ncert; i++) nn += 3 + m->u.certificate.certs[i]->len; - if(p + 3 + nn - sendbuf > sizeof(sendbuf)) { + if(p + 3 + nn - c->sendbuf > sizeof(c->sendbuf)) { tlsError(c, EInternalError, "output buffer too small for certificate"); goto Err; } @@ -891,6 +951,12 @@ p += m->u.certificate.certs[i]->len; } break; + case HCertificateVerify: + put16(p, m->u.certificateVerify.signature->len); + p += 2; + memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len); + p += m->u.certificateVerify.signature->len; + break; case HClientKeyExchange: n = m->u.clientKeyExchange.key->len; if(c->version != SSL3Version){ @@ -907,21 +973,21 @@ } // go back and fill in size - n = p-sendp; - assert(p <= sendbuf+sizeof(sendbuf)); - put24(sendp+1, n-4); + n = p - c->sendp; + assert(p <= c->sendbuf + sizeof(c->sendbuf)); + put24(c->sendp+1, n-4); // remember hash of Handshake messages if(m->tag != HHelloRequest) { - md5(sendp, n, 0, &c->hsmd5); - sha1(sendp, n, 0, &c->hssha1); + md5(c->sendp, n, 0, &c->hsmd5); + sha1(c->sendp, n, 0, &c->hssha1); } - sendp = p; + c->sendp = p; if(act == AFlush){ - sendp = sendbuf; - if(write(c->hand, sendbuf, p-sendbuf) < 0){ - fprint(2, "tls flush: write(%d, %#p, %ld) error: %r\n", c->hand, sendbuf, p-sendbuf); + c->sendp = c->sendbuf; + if(write(c->hand, c->sendbuf, p - c->sendbuf) < 0){ + fprint(2, "tls flush: write(%d, %#p, %ld) error: %r\n", c->hand, c->sendbuf, p-c->sendbuf); goto Err; } } @@ -940,10 +1006,10 @@ nn = c->ep - c->rp; if(nn < n){ - if(c->rp != c->buf){ - memmove(c->buf, c->rp, nn); - c->rp = c->buf; - c->ep = &c->buf[nn]; + if(c->rp != c->recvbuf){ + memmove(c->recvbuf, c->rp, nn); + c->rp = c->recvbuf; + c->ep = &c->recvbuf[nn]; } for(; nn < n; nn += nr) { nr = read(c->hand, &c->rp[nn], n - nn); @@ -978,8 +1044,8 @@ } } - if(n > sizeof(c->buf)) { - tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->buf)); + if(n > sizeof(c->recvbuf)) { + tlsError(c, EDecodeError, "handshake message too long %d %d", n, sizeof(c->recvbuf)); return 0; } @@ -1115,7 +1181,7 @@ nn = get24(p); p += 3; n -= 3; - if(n != nn) + if(nn == 0 && n > 0) goto Short; /* certs */ i = 0; @@ -1249,6 +1315,9 @@ freebytes(m->u.certificateRequest.cas[i]); free(m->u.certificateRequest.cas); break; + case HCertificateVerify: + freebytes(m->u.certificateVerify.signature); + break; case HServerHelloDone: break; case HClientKeyExchange: @@ -1342,6 +1411,10 @@ for(i=0; iu.certificateRequest.nca; i++) bs = bytesPrint(bs, be, "\t\t", m->u.certificateRequest.cas[i], "\n"); break; + case HCertificateVerify: + bs = seprint(bs, be, "HCertificateVerify\n"); + bs = bytesPrint(bs, be, "\tsignature: ", m->u.certificateVerify.signature,"\n"); + break; case HServerHelloDone: bs = seprint(bs, be, "ServerHelloDone\n"); break; @@ -2196,6 +2269,7 @@ exits("out of memory"); } memset(p, 0, n); + setmalloctag(p, getcallerpc(&n)); return p; } @@ -2209,6 +2283,7 @@ else if(!(ReallocP = realloc(ReallocP, ReallocN))){ exits("out of memory"); } + setrealloctag(ReallocP, getcallerpc(&ReallocP)); return(ReallocP); }