# ----------------------------------------------------------------------------- # Copyright (c) 2019-2020, Arm Limited. All rights reserved. # # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- from collections.abc import Iterable from copy import deepcopy import cbor2 import yaml import base64 from ecdsa import SigningKey, VerifyingKey from pycose.sign1message import Sign1Message from pycose.mac0message import Mac0Message from iatverifier import const FIELD_NAMES = {v: k for k, v in const.NAMES.items()} def sign_eat(token, key=None): signed_msg = Sign1Message() signed_msg.payload = token if key: signed_msg.key = key signed_msg.signature = signed_msg.compute_signature() return signed_msg.encode() def hmac_eat(token, key=None): hmac_msg = Mac0Message(payload=token, key=key) hmac_msg.compute_auth_tag('HS256') return hmac_msg.encode() def convert_map_to_token_files(mapfile, keyfile, outfile, method='sign'): token_map = read_token_map(mapfile) if method == 'sign': with open(keyfile) as fh: signing_key = SigningKey.from_pem(fh.read()) else: with open(keyfile, 'rb') as fh: signing_key = fh.read() with open(outfile, 'wb') as wfh: convert_map_to_token(token_map, signing_key, wfh, method) def convert_map_to_token(token_map, signing_key, wfh, method='sign'): token = cbor2.dumps(token_map) if method == 'raw': signed_token = token elif method == 'sign': signed_token = sign_eat(token, signing_key) elif method == 'mac': signed_token = hmac_eat(token, signing_key) else: err_msg = 'Unexpected method "{}"; must be one of: raw, sign, mac' raise ValueError(err_msg.format(method)) wfh.write(signed_token) def convert_token_to_map(raw_data): payload = get_cose_payload(raw_data) token_map = cbor2.loads(payload) return _relabel_keys(token_map) def read_token_map(f): if hasattr(f, 'read'): raw = yaml.safe_load(f) else: with open(f) as fh: raw = yaml.safe_load(fh) return _parse_raw_token(raw) def extract_iat_from_cose(keyfile, tokenfile, keep_going=False, method='sign'): key = read_keyfile(keyfile, method) try: with open(tokenfile, 'rb') as wfh: return get_cose_payload(wfh.read(), key, method) except Exception as e: msg = 'Bad COSE file "{}": {}' raise ValueError(msg.format(tokenfile, e)) def get_cose_payload(cose, key=None, method='sign'): if method == 'sign': return get_cose_sign1_payload(cose, key) elif method == 'mac': return get_cose_mac0_pyload(cose, key) else: err_msg = 'Unexpected method "{}"; must be one of: sign, mac' raise ValueError(err_msg.format(method)) def get_cose_sign1_payload(cose, key=None): msg = Sign1Message.decode(cose) if key: msg.key = key msg.signature = msg.signers try: msg.verify_signature(alg='ES256') except Exception as e: raise ValueError('Bad signature ({})'.format(e)) return msg.payload def get_cose_mac0_pyload(cose, key=None): msg = Mac0Message.decode(cose) if key: msg.key = key try: msg.verify_auth_tag(alg='HS256') except Exception as e: raise ValueError('Bad signature ({})'.format(e)) return msg.payload def recursive_bytes_to_strings(d, in_place=False): if in_place: result = d else: result = deepcopy(d) if hasattr(result, 'items'): for k, v in result.items(): result[k] = recursive_bytes_to_strings(v, in_place=True) elif (isinstance(result, Iterable) and not isinstance(result, (str, bytes))): result = [recursive_bytes_to_strings(r, in_place=True) for r in result] elif isinstance(result, bytes): result = str(base64.b16encode(result)) return result def read_keyfile(keyfile, method='sign'): if keyfile: if method == 'sign': return read_sign1_key(keyfile) elif method == 'mac': return read_hmac_key(keyfile) else: err_msg = 'Unexpected method "{}"; must be one of: sign, mac' raise ValueError(err_msg.format(method)) else: # no keyfile key = None return key def read_sign1_key(keyfile): try: key = SigningKey.from_pem(open(keyfile, 'rb').read()) except Exception as e: signing_key_error = str(e) try: key = VerifyingKey.from_pem(open(keyfile, 'rb').read()) except Exception as e: verifying_key_error = str(e) msg = 'Bad key file "{}":\n\tpubkey error: {}\n\tprikey error: {}' raise ValueError(msg.format(keyfile, verifying_key_error, signing_key_error)) return key def read_hmac_key(keyfile): return open(keyfile, 'rb').read() def _parse_raw_token(raw): result = {} for raw_key, raw_value in raw.items(): if isinstance(raw_key, int): key = raw_key else: field_name = raw_key.upper() try: key = FIELD_NAMES[field_name] except KeyError: msg = 'Unknown field "{}" in token.'.format(field_name) raise ValueError(msg) if key == const.SECURITY_LIFECYCLE: name_idx = const.SL_NAMES.index(raw_value.upper()) value = (name_idx + 1) << const.SL_SHIFT elif hasattr(raw_value, 'items'): value = _parse_raw_token(raw_value) elif (isinstance(raw_value, Iterable) and not isinstance(raw_value, (str, bytes))): # TODO -- asumes dict elements value = [_parse_raw_token(v) for v in raw_value] else: value = raw_value result[key] = value return result def _relabel_keys(token_map): result = {} for key, value in token_map.items(): if hasattr(value, 'items'): value = _relabel_keys(value) elif (isinstance(value, Iterable) and not isinstance(value, (str, bytes))): # TODO -- asumes dict elements value = [_relabel_keys(v) for v in value] if key == const.SECURITY_LIFECYCLE: value = (const.SL_NAMES[(value >> const.SL_SHIFT) - 1]) if key in const.NAMES: new_key = const.NAMES[key].lower() else: new_key = key result[new_key] = value return result