/* FIXME: split and cleanup */ int wusb_dev_4way_handshake(struct wusbhc *wusbhc, struct wusb_dev *wusb_dev, struct wusb_ckhdid *ck) { int result = -ENOMEM; struct usb_device *usb_dev = wusb_dev->usb_dev; struct device *dev = &usb_dev->dev; u32 tkid; __le32 tkid_le; struct usb_handshake *hs; struct aes_ccm_nonce ccm_n; u8 mic[8]; struct wusb_keydvt_in keydvt_in; struct wusb_keydvt_out keydvt_out; hs = kzalloc(3*sizeof(hs[0]), GFP_KERNEL); if (hs == NULL) { dev_err(dev, "can't allocate handshake data\n"); goto error_kzalloc; } /* We need to turn encryption before beginning the 4way * hshake (WUSB1.0[.3.2.2]) */ result = wusb_dev_set_encryption(usb_dev, 1); if (result < 0) goto error_dev_set_encryption; tkid = wusbhc_next_tkid(wusbhc, wusb_dev); tkid_le = cpu_to_le32(tkid); hs[0].bMessageNumber = 1; hs[0].bStatus = 0; memcpy(hs[0].tTKID, &tkid_le, sizeof(hs[0].tTKID)); hs[0].bReserved = 0; memcpy(hs[0].CDID, &wusb_dev->cdid, sizeof(hs[0].CDID)); get_random_bytes(&hs[0].nonce, sizeof(hs[0].nonce)); memset(hs[0].MIC, 0, sizeof(hs[0].MIC)); /* Per WUSB1.0[T7-22] */ result = usb_control_msg( usb_dev, usb_sndctrlpipe(usb_dev, 0), USB_REQ_SET_HANDSHAKE, USB_DIR_OUT | USB_TYPE_STANDARD | USB_RECIP_DEVICE, 1, 0, &hs[0], sizeof(hs[0]), 1000 /* FIXME: arbitrary */); if (result < 0) { dev_err(dev, "Handshake1: request failed: %d\n", result); goto error_hs1; } /* Handshake 2, from the device -- need to verify fields */ result = usb_control_msg( usb_dev, usb_rcvctrlpipe(usb_dev, 0), USB_REQ_GET_HANDSHAKE, USB_DIR_IN | USB_TYPE_STANDARD | USB_RECIP_DEVICE, 2, 0, &hs[1], sizeof(hs[1]), 1000 /* FIXME: arbitrary */); if (result < 0) { dev_err(dev, "Handshake2: request failed: %d\n", result); goto error_hs2; } result = -EINVAL; if (hs[1].bMessageNumber != 2) { dev_err(dev, "Handshake2 failed: bad message number %u\n", hs[1].bMessageNumber); goto error_hs2; } if (hs[1].bStatus != 0) { dev_err(dev, "Handshake2 failed: bad status %u\n", hs[1].bStatus); goto error_hs2; } if (memcmp(hs[0].tTKID, hs[1].tTKID, sizeof(hs[0].tTKID))) { dev_err(dev, "Handshake2 failed: TKID mismatch " "(#1 0x%02x%02x%02x vs #2 0x%02x%02x%02x)\n", hs[0].tTKID[0], hs[0].tTKID[1], hs[0].tTKID[2], hs[1].tTKID[0], hs[1].tTKID[1], hs[1].tTKID[2]); goto error_hs2; } if (memcmp(hs[0].CDID, hs[1].CDID, sizeof(hs[0].CDID))) { dev_err(dev, "Handshake2 failed: CDID mismatch\n"); goto error_hs2; } /* Setup the CCM nonce */ memset(&ccm_n.sfn, 0, sizeof(ccm_n.sfn)); /* Per WUSB1.0[6.5.2] */ memcpy(ccm_n.tkid, &tkid_le, sizeof(ccm_n.tkid)); ccm_n.src_addr = wusbhc->uwb_rc->uwb_dev.dev_addr; ccm_n.dest_addr.data[0] = wusb_dev->addr; ccm_n.dest_addr.data[1] = 0; /* Derive the KCK and PTK from CK, the CCM, H and D nonces */ memcpy(keydvt_in.hnonce, hs[0].nonce, sizeof(keydvt_in.hnonce)); memcpy(keydvt_in.dnonce, hs[1].nonce, sizeof(keydvt_in.dnonce)); result = wusb_key_derive(&keydvt_out, ck->data, &ccm_n, &keydvt_in); if (result < 0) { dev_err(dev, "Handshake2 failed: cannot derive keys: %d\n", result); goto error_hs2; } /* Compute MIC and verify it */ result = wusb_oob_mic(mic, keydvt_out.kck, &ccm_n, &hs[1]); if (result < 0) { dev_err(dev, "Handshake2 failed: cannot compute MIC: %d\n", result); goto error_hs2; } if (memcmp(hs[1].MIC, mic, sizeof(hs[1].MIC))) { dev_err(dev, "Handshake2 failed: MIC mismatch\n"); goto error_hs2; } /* Send Handshake3 */ hs[2].bMessageNumber = 3; hs[2].bStatus = 0; memcpy(hs[2].tTKID, &tkid_le, sizeof(hs[2].tTKID)); hs[2].bReserved = 0; memcpy(hs[2].CDID, &wusb_dev->cdid, sizeof(hs[2].CDID)); memcpy(hs[2].nonce, hs[0].nonce, sizeof(hs[2].nonce)); result = wusb_oob_mic(hs[2].MIC, keydvt_out.kck, &ccm_n, &hs[2]); if (result < 0) { dev_err(dev, "Handshake3 failed: cannot compute MIC: %d\n", result); goto error_hs2; } result = usb_control_msg( usb_dev, usb_sndctrlpipe(usb_dev, 0), USB_REQ_SET_HANDSHAKE, USB_DIR_OUT | USB_TYPE_STANDARD | USB_RECIP_DEVICE, 3, 0, &hs[2], sizeof(hs[2]), 1000 /* FIXME: arbitrary */); if (result < 0) { dev_err(dev, "Handshake3: request failed: %d\n", result); goto error_hs3; } result = wusbhc->set_ptk(wusbhc, wusb_dev->port_idx, tkid, keydvt_out.ptk, sizeof(keydvt_out.ptk)); if (result < 0) goto error_wusbhc_set_ptk; result = wusb_dev_set_gtk(wusbhc, wusb_dev); if (result < 0) { dev_err(dev, "Set GTK for device: request failed: %d\n", result); goto error_wusbhc_set_gtk; } /* Update the device's address from unauth to auth */ if (usb_dev->authenticated == 0) { result = wusb_dev_update_address(wusbhc, wusb_dev); if (result < 0) goto error_dev_update_address; } result = 0; dev_info(dev, "device authenticated\n"); error_dev_update_address: error_wusbhc_set_gtk: error_wusbhc_set_ptk: error_hs3: error_hs2: error_hs1: memset(hs, 0, 3*sizeof(hs[0])); memset(&keydvt_out, 0, sizeof(keydvt_out)); memset(&keydvt_in, 0, sizeof(keydvt_in)); memset(&ccm_n, 0, sizeof(ccm_n)); memset(mic, 0, sizeof(mic)); if (result < 0) wusb_dev_set_encryption(usb_dev, 0); error_dev_set_encryption: kfree(hs); error_kzalloc: return result; }
/* These come from WUSB1.0[A.1] + 2006/12 errata * NOTE: can't make this const or global -- somehow it seems * the scatterlists for crypto get confused and we get * bad data. There is no doc on this... */ struct wusb_keydvt_in stv_keydvt_in_a1 = { .hnonce = { 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f }, .dnonce = { 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f } }; result = wusb_key_derive(&keydvt_out, stv_key_a1, &stv_keydvt_n_a1, &stv_keydvt_in_a1); if (result < 0) printk(KERN_ERR "E: WUSB key derivation test: " "derivation failed: %d\n", result); if (memcmp(&stv_keydvt_out_a1, &keydvt_out, sizeof(keydvt_out))) { printk(KERN_ERR "E: WUSB key derivation test: " "mismatch between key derivation result " "and WUSB1.0[A1] Errata 2006/12\n"); printk(KERN_ERR "E: keydvt in: key\n"); wusb_key_dump(stv_key_a1, sizeof(stv_key_a1)); printk(KERN_ERR "E: keydvt in: nonce\n"); wusb_key_dump( &stv_keydvt_n_a1, sizeof(stv_keydvt_n_a1)); printk(KERN_ERR "E: keydvt in: hnonce & dnonce\n"); wusb_key_dump(&stv_keydvt_in_a1, sizeof(stv_keydvt_in_a1)); printk(KERN_ERR "E: keydvt out: KCK\n"); wusb_key_dump(&keydvt_out.kck, sizeof(keydvt_out.kck));