1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 3 from collections import namedtuple 4 from enum import Enum 5 import functools 6 import os 7 import random 8 import socket 9 import struct 10 from struct import Struct 11 import sys 12 import yaml 13 import ipaddress 14 import uuid 15 16 from .nlspec import SpecFamily 17 18 # 19 # Generic Netlink code which should really be in some library, but I can't quickly find one. 20 # 21 22 23 class Netlink: 24 # Netlink socket 25 SOL_NETLINK = 270 26 27 NETLINK_ADD_MEMBERSHIP = 1 28 NETLINK_CAP_ACK = 10 29 NETLINK_EXT_ACK = 11 30 NETLINK_GET_STRICT_CHK = 12 31 32 # Netlink message 33 NLMSG_ERROR = 2 34 NLMSG_DONE = 3 35 36 NLM_F_REQUEST = 1 37 NLM_F_ACK = 4 38 NLM_F_ROOT = 0x100 39 NLM_F_MATCH = 0x200 40 41 NLM_F_REPLACE = 0x100 42 NLM_F_EXCL = 0x200 43 NLM_F_CREATE = 0x400 44 NLM_F_APPEND = 0x800 45 46 NLM_F_CAPPED = 0x100 47 NLM_F_ACK_TLVS = 0x200 48 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 50 51 NLA_F_NESTED = 0x8000 52 NLA_F_NET_BYTEORDER = 0x4000 53 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 55 56 # Genetlink defines 57 NETLINK_GENERIC = 16 58 59 GENL_ID_CTRL = 0x10 60 61 # nlctrl 62 CTRL_CMD_GETFAMILY = 3 63 64 CTRL_ATTR_FAMILY_ID = 1 65 CTRL_ATTR_FAMILY_NAME = 2 66 CTRL_ATTR_MAXATTR = 5 67 CTRL_ATTR_MCAST_GROUPS = 7 68 69 CTRL_ATTR_MCAST_GRP_NAME = 1 70 CTRL_ATTR_MCAST_GRP_ID = 2 71 72 # Extack types 73 NLMSGERR_ATTR_MSG = 1 74 NLMSGERR_ATTR_OFFS = 2 75 NLMSGERR_ATTR_COOKIE = 3 76 NLMSGERR_ATTR_POLICY = 4 77 NLMSGERR_ATTR_MISS_TYPE = 5 78 NLMSGERR_ATTR_MISS_NEST = 6 79 80 # Policy types 81 NL_POLICY_TYPE_ATTR_TYPE = 1 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 91 NL_POLICY_TYPE_ATTR_PAD = 11 92 NL_POLICY_TYPE_ATTR_MASK = 12 93 94 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 95 's8', 's16', 's32', 's64', 96 'binary', 'string', 'nul-string', 97 'nested', 'nested-array', 98 'bitfield32', 'sint', 'uint']) 99 100 class NlError(Exception): 101 def __init__(self, nl_msg): 102 self.nl_msg = nl_msg 103 self.error = -nl_msg.error 104 105 def __str__(self): 106 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 107 108 109 class ConfigError(Exception): 110 pass 111 112 113 class NlAttr: 114 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 115 type_formats = { 116 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 117 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 118 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 119 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 120 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 121 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 122 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 123 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 124 } 125 126 def __init__(self, raw, offset): 127 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 128 self.type = self._type & ~Netlink.NLA_TYPE_MASK 129 self.is_nest = self._type & Netlink.NLA_F_NESTED 130 self.payload_len = self._len 131 self.full_len = (self.payload_len + 3) & ~3 132 self.raw = raw[offset + 4 : offset + self.payload_len] 133 134 @classmethod 135 def get_format(cls, attr_type, byte_order=None): 136 format = cls.type_formats[attr_type] 137 if byte_order: 138 return format.big if byte_order == "big-endian" \ 139 else format.little 140 return format.native 141 142 def as_scalar(self, attr_type, byte_order=None): 143 format = self.get_format(attr_type, byte_order) 144 return format.unpack(self.raw)[0] 145 146 def as_auto_scalar(self, attr_type, byte_order=None): 147 if len(self.raw) != 4 and len(self.raw) != 8: 148 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 149 real_type = attr_type[0] + str(len(self.raw) * 8) 150 format = self.get_format(real_type, byte_order) 151 return format.unpack(self.raw)[0] 152 153 def as_strz(self): 154 return self.raw.decode('ascii')[:-1] 155 156 def as_bin(self): 157 return self.raw 158 159 def as_c_array(self, type): 160 format = self.get_format(type) 161 return [ x[0] for x in format.iter_unpack(self.raw) ] 162 163 def __repr__(self): 164 return f"[type:{self.type} len:{self._len}] {self.raw}" 165 166 167 class NlAttrs: 168 def __init__(self, msg, offset=0): 169 self.attrs = [] 170 171 while offset < len(msg): 172 attr = NlAttr(msg, offset) 173 offset += attr.full_len 174 self.attrs.append(attr) 175 176 def __iter__(self): 177 yield from self.attrs 178 179 def __repr__(self): 180 msg = '' 181 for a in self.attrs: 182 if msg: 183 msg += '\n' 184 msg += repr(a) 185 return msg 186 187 188 class NlMsg: 189 def __init__(self, msg, offset, attr_space=None): 190 self.hdr = msg[offset : offset + 16] 191 192 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 193 struct.unpack("IHHII", self.hdr) 194 195 self.raw = msg[offset + 16 : offset + self.nl_len] 196 197 self.error = 0 198 self.done = 0 199 200 extack_off = None 201 if self.nl_type == Netlink.NLMSG_ERROR: 202 self.error = struct.unpack("i", self.raw[0:4])[0] 203 self.done = 1 204 extack_off = 20 205 elif self.nl_type == Netlink.NLMSG_DONE: 206 self.error = struct.unpack("i", self.raw[0:4])[0] 207 self.done = 1 208 extack_off = 4 209 210 self.extack = None 211 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 212 self.extack = dict() 213 extack_attrs = NlAttrs(self.raw[extack_off:]) 214 for extack in extack_attrs: 215 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 216 self.extack['msg'] = extack.as_strz() 217 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 218 self.extack['miss-type'] = extack.as_scalar('u32') 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 220 self.extack['miss-nest'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 222 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 224 self.extack['policy'] = self._decode_policy(extack.raw) 225 else: 226 if 'unknown' not in self.extack: 227 self.extack['unknown'] = [] 228 self.extack['unknown'].append(extack) 229 230 if attr_space: 231 # We don't have the ability to parse nests yet, so only do global 232 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 233 miss_type = self.extack['miss-type'] 234 if miss_type in attr_space.attrs_by_val: 235 spec = attr_space.attrs_by_val[miss_type] 236 self.extack['miss-type'] = spec['name'] 237 if 'doc' in spec: 238 self.extack['miss-type-doc'] = spec['doc'] 239 240 def _decode_policy(self, raw): 241 policy = {} 242 for attr in NlAttrs(raw): 243 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 244 type = attr.as_scalar('u32') 245 policy['type'] = Netlink.AttrType(type).name 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 247 policy['min-value'] = attr.as_scalar('s64') 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 249 policy['max-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 251 policy['min-value'] = attr.as_scalar('u64') 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 253 policy['max-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 255 policy['min-length'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 257 policy['max-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 259 policy['bitfield32-mask'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 261 policy['mask'] = attr.as_scalar('u64') 262 return policy 263 264 def cmd(self): 265 return self.nl_type 266 267 def __repr__(self): 268 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 269 if self.error: 270 msg += '\n\terror: ' + str(self.error) 271 if self.extack: 272 msg += '\n\textack: ' + repr(self.extack) 273 return msg 274 275 276 class NlMsgs: 277 def __init__(self, data, attr_space=None): 278 self.msgs = [] 279 280 offset = 0 281 while offset < len(data): 282 msg = NlMsg(data, offset, attr_space=attr_space) 283 offset += msg.nl_len 284 self.msgs.append(msg) 285 286 def __iter__(self): 287 yield from self.msgs 288 289 290 genl_family_name_to_id = None 291 292 293 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 294 # we prepend length in _genl_msg_finalize() 295 if seq is None: 296 seq = random.randint(1, 1024) 297 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 298 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 299 return nlmsg + genlmsg 300 301 302 def _genl_msg_finalize(msg): 303 return struct.pack("I", len(msg) + 4) + msg 304 305 306 def _genl_load_families(): 307 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 308 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 309 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 311 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 312 Netlink.CTRL_CMD_GETFAMILY, 1) 313 msg = _genl_msg_finalize(msg) 314 315 sock.send(msg, 0) 316 317 global genl_family_name_to_id 318 genl_family_name_to_id = dict() 319 320 while True: 321 reply = sock.recv(128 * 1024) 322 nms = NlMsgs(reply) 323 for nl_msg in nms: 324 if nl_msg.error: 325 print("Netlink error:", nl_msg.error) 326 return 327 if nl_msg.done: 328 return 329 330 gm = GenlMsg(nl_msg) 331 fam = dict() 332 for attr in NlAttrs(gm.raw): 333 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 334 fam['id'] = attr.as_scalar('u16') 335 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 336 fam['name'] = attr.as_strz() 337 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 338 fam['maxattr'] = attr.as_scalar('u32') 339 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 340 fam['mcast'] = dict() 341 for entry in NlAttrs(attr.raw): 342 mcast_name = None 343 mcast_id = None 344 for entry_attr in NlAttrs(entry.raw): 345 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 346 mcast_name = entry_attr.as_strz() 347 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 348 mcast_id = entry_attr.as_scalar('u32') 349 if mcast_name and mcast_id is not None: 350 fam['mcast'][mcast_name] = mcast_id 351 if 'name' in fam and 'id' in fam: 352 genl_family_name_to_id[fam['name']] = fam 353 354 355 class GenlMsg: 356 def __init__(self, nl_msg): 357 self.nl = nl_msg 358 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 359 self.raw = nl_msg.raw[4:] 360 361 def cmd(self): 362 return self.genl_cmd 363 364 def __repr__(self): 365 msg = repr(self.nl) 366 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 367 for a in self.raw_attrs: 368 msg += '\t\t' + repr(a) + '\n' 369 return msg 370 371 372 class NetlinkProtocol: 373 def __init__(self, family_name, proto_num): 374 self.family_name = family_name 375 self.proto_num = proto_num 376 377 def _message(self, nl_type, nl_flags, seq=None): 378 if seq is None: 379 seq = random.randint(1, 1024) 380 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 381 return nlmsg 382 383 def message(self, flags, command, version, seq=None): 384 return self._message(command, flags, seq) 385 386 def _decode(self, nl_msg): 387 return nl_msg 388 389 def decode(self, ynl, nl_msg, op): 390 msg = self._decode(nl_msg) 391 if op is None: 392 op = ynl.rsp_by_value[msg.cmd()] 393 fixed_header_size = ynl._struct_size(op.fixed_header) 394 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 395 return msg 396 397 def get_mcast_id(self, mcast_name, mcast_groups): 398 if mcast_name not in mcast_groups: 399 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 400 return mcast_groups[mcast_name].value 401 402 def msghdr_size(self): 403 return 16 404 405 406 class GenlProtocol(NetlinkProtocol): 407 def __init__(self, family_name): 408 super().__init__(family_name, Netlink.NETLINK_GENERIC) 409 410 global genl_family_name_to_id 411 if genl_family_name_to_id is None: 412 _genl_load_families() 413 414 self.genl_family = genl_family_name_to_id[family_name] 415 self.family_id = genl_family_name_to_id[family_name]['id'] 416 417 def message(self, flags, command, version, seq=None): 418 nlmsg = self._message(self.family_id, flags, seq) 419 genlmsg = struct.pack("BBH", command, version, 0) 420 return nlmsg + genlmsg 421 422 def _decode(self, nl_msg): 423 return GenlMsg(nl_msg) 424 425 def get_mcast_id(self, mcast_name, mcast_groups): 426 if mcast_name not in self.genl_family['mcast']: 427 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 428 return self.genl_family['mcast'][mcast_name] 429 430 def msghdr_size(self): 431 return super().msghdr_size() + 4 432 433 434 class SpaceAttrs: 435 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 436 437 def __init__(self, attr_space, attrs, outer = None): 438 outer_scopes = outer.scopes if outer else [] 439 inner_scope = self.SpecValuesPair(attr_space, attrs) 440 self.scopes = [inner_scope] + outer_scopes 441 442 def lookup(self, name): 443 for scope in self.scopes: 444 if name in scope.spec: 445 if name in scope.values: 446 return scope.values[name] 447 spec_name = scope.spec.yaml['name'] 448 raise Exception( 449 f"No value for '{name}' in attribute space '{spec_name}'") 450 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 451 452 453 # 454 # YNL implementation details. 455 # 456 457 458 class YnlFamily(SpecFamily): 459 def __init__(self, def_path, schema=None, process_unknown=False, 460 recv_size=0): 461 super().__init__(def_path, schema) 462 463 self.include_raw = False 464 self.process_unknown = process_unknown 465 466 try: 467 if self.proto == "netlink-raw": 468 self.nlproto = NetlinkProtocol(self.yaml['name'], 469 self.yaml['protonum']) 470 else: 471 self.nlproto = GenlProtocol(self.yaml['name']) 472 except KeyError: 473 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 474 475 self._recv_dbg = False 476 # Note that netlink will use conservative (min) message size for 477 # the first dump recv() on the socket, our setting will only matter 478 # from the second recv() on. 479 self._recv_size = recv_size if recv_size else 131072 480 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 481 # for a message, so smaller receive sizes will lead to truncation. 482 # Note that the min size for other families may be larger than 4k! 483 if self._recv_size < 4000: 484 raise ConfigError() 485 486 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 487 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 488 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 489 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 490 491 self.async_msg_ids = set() 492 self.async_msg_queue = [] 493 494 for msg in self.msgs.values(): 495 if msg.is_async: 496 self.async_msg_ids.add(msg.rsp_value) 497 498 for op_name, op in self.ops.items(): 499 bound_f = functools.partial(self._op, op_name) 500 setattr(self, op.ident_name, bound_f) 501 502 503 def ntf_subscribe(self, mcast_name): 504 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 505 self.sock.bind((0, 0)) 506 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 507 mcast_id) 508 509 def set_recv_dbg(self, enabled): 510 self._recv_dbg = enabled 511 512 def _recv_dbg_print(self, reply, nl_msgs): 513 if not self._recv_dbg: 514 return 515 print("Recv: read", len(reply), "bytes,", 516 len(nl_msgs.msgs), "messages", file=sys.stderr) 517 for nl_msg in nl_msgs: 518 print(" ", nl_msg, file=sys.stderr) 519 520 def _encode_enum(self, attr_spec, value): 521 enum = self.consts[attr_spec['enum']] 522 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 523 scalar = 0 524 if isinstance(value, str): 525 value = [value] 526 for single_value in value: 527 scalar += enum.entries[single_value].user_value(as_flags = True) 528 return scalar 529 else: 530 return enum.entries[value].user_value() 531 532 def _get_scalar(self, attr_spec, value): 533 try: 534 return int(value) 535 except (ValueError, TypeError) as e: 536 if 'enum' not in attr_spec: 537 raise e 538 return self._encode_enum(attr_spec, value) 539 540 def _add_attr(self, space, name, value, search_attrs): 541 try: 542 attr = self.attr_sets[space][name] 543 except KeyError: 544 raise Exception(f"Space '{space}' has no attribute '{name}'") 545 nl_type = attr.value 546 547 if attr.is_multi and isinstance(value, list): 548 attr_payload = b'' 549 for subvalue in value: 550 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 551 return attr_payload 552 553 if attr["type"] == 'nest': 554 nl_type |= Netlink.NLA_F_NESTED 555 attr_payload = b'' 556 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 557 for subname, subvalue in value.items(): 558 attr_payload += self._add_attr(attr['nested-attributes'], 559 subname, subvalue, sub_attrs) 560 elif attr["type"] == 'flag': 561 if not value: 562 # If value is absent or false then skip attribute creation. 563 return b'' 564 attr_payload = b'' 565 elif attr["type"] == 'string': 566 attr_payload = str(value).encode('ascii') + b'\x00' 567 elif attr["type"] == 'binary': 568 if isinstance(value, bytes): 569 attr_payload = value 570 elif isinstance(value, str): 571 attr_payload = bytes.fromhex(value) 572 elif isinstance(value, dict) and attr.struct_name: 573 attr_payload = self._encode_struct(attr.struct_name, value) 574 else: 575 raise Exception(f'Unknown type for binary attribute, value: {value}') 576 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 577 scalar = self._get_scalar(attr, value) 578 if attr.is_auto_scalar: 579 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 580 else: 581 attr_type = attr["type"] 582 format = NlAttr.get_format(attr_type, attr.byte_order) 583 attr_payload = format.pack(scalar) 584 elif attr['type'] in "bitfield32": 585 scalar_value = self._get_scalar(attr, value["value"]) 586 scalar_selector = self._get_scalar(attr, value["selector"]) 587 attr_payload = struct.pack("II", scalar_value, scalar_selector) 588 elif attr['type'] == 'sub-message': 589 msg_format = self._resolve_selector(attr, search_attrs) 590 attr_payload = b'' 591 if msg_format.fixed_header: 592 attr_payload += self._encode_struct(msg_format.fixed_header, value) 593 if msg_format.attr_set: 594 if msg_format.attr_set in self.attr_sets: 595 nl_type |= Netlink.NLA_F_NESTED 596 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 597 for subname, subvalue in value.items(): 598 attr_payload += self._add_attr(msg_format.attr_set, 599 subname, subvalue, sub_attrs) 600 else: 601 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 602 else: 603 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 604 605 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 606 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 607 608 def _decode_enum(self, raw, attr_spec): 609 enum = self.consts[attr_spec['enum']] 610 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 611 i = 0 612 value = set() 613 while raw: 614 if raw & 1: 615 value.add(enum.entries_by_val[i].name) 616 raw >>= 1 617 i += 1 618 else: 619 value = enum.entries_by_val[raw].name 620 return value 621 622 def _decode_binary(self, attr, attr_spec): 623 if attr_spec.struct_name: 624 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 625 elif attr_spec.sub_type: 626 decoded = attr.as_c_array(attr_spec.sub_type) 627 else: 628 decoded = attr.as_bin() 629 if attr_spec.display_hint: 630 decoded = self._formatted_string(decoded, attr_spec.display_hint) 631 return decoded 632 633 def _decode_array_attr(self, attr, attr_spec): 634 decoded = [] 635 offset = 0 636 while offset < len(attr.raw): 637 item = NlAttr(attr.raw, offset) 638 offset += item.full_len 639 640 if attr_spec["sub-type"] == 'nest': 641 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 642 decoded.append({ item.type: subattrs }) 643 elif attr_spec["sub-type"] == 'binary': 644 subattrs = item.as_bin() 645 if attr_spec.display_hint: 646 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 647 decoded.append(subattrs) 648 elif attr_spec["sub-type"] in NlAttr.type_formats: 649 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 650 if attr_spec.display_hint: 651 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 652 decoded.append(subattrs) 653 else: 654 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 655 return decoded 656 657 def _decode_nest_type_value(self, attr, attr_spec): 658 decoded = {} 659 value = attr 660 for name in attr_spec['type-value']: 661 value = NlAttr(value.raw, 0) 662 decoded[name] = value.type 663 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 664 decoded.update(subattrs) 665 return decoded 666 667 def _decode_unknown(self, attr): 668 if attr.is_nest: 669 return self._decode(NlAttrs(attr.raw), None) 670 else: 671 return attr.as_bin() 672 673 def _rsp_add(self, rsp, name, is_multi, decoded): 674 if is_multi == None: 675 if name in rsp and type(rsp[name]) is not list: 676 rsp[name] = [rsp[name]] 677 is_multi = True 678 else: 679 is_multi = False 680 681 if not is_multi: 682 rsp[name] = decoded 683 elif name in rsp: 684 rsp[name].append(decoded) 685 else: 686 rsp[name] = [decoded] 687 688 def _resolve_selector(self, attr_spec, search_attrs): 689 sub_msg = attr_spec.sub_message 690 if sub_msg not in self.sub_msgs: 691 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 692 sub_msg_spec = self.sub_msgs[sub_msg] 693 694 selector = attr_spec.selector 695 value = search_attrs.lookup(selector) 696 if value not in sub_msg_spec.formats: 697 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 698 699 spec = sub_msg_spec.formats[value] 700 return spec 701 702 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 703 msg_format = self._resolve_selector(attr_spec, search_attrs) 704 decoded = {} 705 offset = 0 706 if msg_format.fixed_header: 707 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 708 offset = self._struct_size(msg_format.fixed_header) 709 if msg_format.attr_set: 710 if msg_format.attr_set in self.attr_sets: 711 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 712 decoded.update(subdict) 713 else: 714 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 715 return decoded 716 717 def _decode(self, attrs, space, outer_attrs = None): 718 rsp = dict() 719 if space: 720 attr_space = self.attr_sets[space] 721 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 722 723 for attr in attrs: 724 try: 725 attr_spec = attr_space.attrs_by_val[attr.type] 726 except (KeyError, UnboundLocalError): 727 if not self.process_unknown: 728 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 729 attr_name = f"UnknownAttr({attr.type})" 730 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 731 continue 732 733 if attr_spec["type"] == 'nest': 734 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 735 decoded = subdict 736 elif attr_spec["type"] == 'string': 737 decoded = attr.as_strz() 738 elif attr_spec["type"] == 'binary': 739 decoded = self._decode_binary(attr, attr_spec) 740 elif attr_spec["type"] == 'flag': 741 decoded = True 742 elif attr_spec.is_auto_scalar: 743 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 744 elif attr_spec["type"] in NlAttr.type_formats: 745 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 746 if 'enum' in attr_spec: 747 decoded = self._decode_enum(decoded, attr_spec) 748 elif attr_spec.display_hint: 749 decoded = self._formatted_string(decoded, attr_spec.display_hint) 750 elif attr_spec["type"] == 'indexed-array': 751 decoded = self._decode_array_attr(attr, attr_spec) 752 elif attr_spec["type"] == 'bitfield32': 753 value, selector = struct.unpack("II", attr.raw) 754 if 'enum' in attr_spec: 755 value = self._decode_enum(value, attr_spec) 756 selector = self._decode_enum(selector, attr_spec) 757 decoded = {"value": value, "selector": selector} 758 elif attr_spec["type"] == 'sub-message': 759 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 760 elif attr_spec["type"] == 'nest-type-value': 761 decoded = self._decode_nest_type_value(attr, attr_spec) 762 else: 763 if not self.process_unknown: 764 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 765 decoded = self._decode_unknown(attr) 766 767 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 768 769 return rsp 770 771 def _decode_extack_path(self, attrs, attr_set, offset, target): 772 for attr in attrs: 773 try: 774 attr_spec = attr_set.attrs_by_val[attr.type] 775 except KeyError: 776 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 777 if offset > target: 778 break 779 if offset == target: 780 return '.' + attr_spec.name 781 782 if offset + attr.full_len <= target: 783 offset += attr.full_len 784 continue 785 if attr_spec['type'] != 'nest': 786 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 787 offset += 4 788 subpath = self._decode_extack_path(NlAttrs(attr.raw), 789 self.attr_sets[attr_spec['nested-attributes']], 790 offset, target) 791 if subpath is None: 792 return None 793 return '.' + attr_spec.name + subpath 794 795 return None 796 797 def _decode_extack(self, request, op, extack): 798 if 'bad-attr-offs' not in extack: 799 return 800 801 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 802 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 803 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 804 extack['bad-attr-offs']) 805 if path: 806 del extack['bad-attr-offs'] 807 extack['bad-attr'] = path 808 809 def _struct_size(self, name): 810 if name: 811 members = self.consts[name].members 812 size = 0 813 for m in members: 814 if m.type in ['pad', 'binary']: 815 if m.struct: 816 size += self._struct_size(m.struct) 817 else: 818 size += m.len 819 else: 820 format = NlAttr.get_format(m.type, m.byte_order) 821 size += format.size 822 return size 823 else: 824 return 0 825 826 def _decode_struct(self, data, name): 827 members = self.consts[name].members 828 attrs = dict() 829 offset = 0 830 for m in members: 831 value = None 832 if m.type == 'pad': 833 offset += m.len 834 elif m.type == 'binary': 835 if m.struct: 836 len = self._struct_size(m.struct) 837 value = self._decode_struct(data[offset : offset + len], 838 m.struct) 839 offset += len 840 else: 841 value = data[offset : offset + m.len] 842 offset += m.len 843 else: 844 format = NlAttr.get_format(m.type, m.byte_order) 845 [ value ] = format.unpack_from(data, offset) 846 offset += format.size 847 if value is not None: 848 if m.enum: 849 value = self._decode_enum(value, m) 850 elif m.display_hint: 851 value = self._formatted_string(value, m.display_hint) 852 attrs[m.name] = value 853 return attrs 854 855 def _encode_struct(self, name, vals): 856 members = self.consts[name].members 857 attr_payload = b'' 858 for m in members: 859 value = vals.pop(m.name) if m.name in vals else None 860 if m.type == 'pad': 861 attr_payload += bytearray(m.len) 862 elif m.type == 'binary': 863 if m.struct: 864 if value is None: 865 value = dict() 866 attr_payload += self._encode_struct(m.struct, value) 867 else: 868 if value is None: 869 attr_payload += bytearray(m.len) 870 else: 871 attr_payload += bytes.fromhex(value) 872 else: 873 if value is None: 874 value = 0 875 format = NlAttr.get_format(m.type, m.byte_order) 876 attr_payload += format.pack(value) 877 return attr_payload 878 879 def _formatted_string(self, raw, display_hint): 880 if display_hint == 'mac': 881 formatted = ':'.join('%02x' % b for b in raw) 882 elif display_hint == 'hex': 883 if isinstance(raw, int): 884 formatted = hex(raw) 885 else: 886 formatted = bytes.hex(raw, ' ') 887 elif display_hint in [ 'ipv4', 'ipv6' ]: 888 formatted = format(ipaddress.ip_address(raw)) 889 elif display_hint == 'uuid': 890 formatted = str(uuid.UUID(bytes=raw)) 891 else: 892 formatted = raw 893 return formatted 894 895 def handle_ntf(self, decoded): 896 msg = dict() 897 if self.include_raw: 898 msg['raw'] = decoded 899 op = self.rsp_by_value[decoded.cmd()] 900 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 901 if op.fixed_header: 902 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 903 904 msg['name'] = op['name'] 905 msg['msg'] = attrs 906 self.async_msg_queue.append(msg) 907 908 def check_ntf(self): 909 while True: 910 try: 911 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 912 except BlockingIOError: 913 return 914 915 nms = NlMsgs(reply) 916 self._recv_dbg_print(reply, nms) 917 for nl_msg in nms: 918 if nl_msg.error: 919 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 920 print(nl_msg) 921 continue 922 if nl_msg.done: 923 print("Netlink done while checking for ntf!?") 924 continue 925 926 decoded = self.nlproto.decode(self, nl_msg, None) 927 if decoded.cmd() not in self.async_msg_ids: 928 print("Unexpected msg id done while checking for ntf", decoded) 929 continue 930 931 self.handle_ntf(decoded) 932 933 def operation_do_attributes(self, name): 934 """ 935 For a given operation name, find and return a supported 936 set of attributes (as a dict). 937 """ 938 op = self.find_operation(name) 939 if not op: 940 return None 941 942 return op['do']['request']['attributes'].copy() 943 944 def _encode_message(self, op, vals, flags, req_seq): 945 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 946 for flag in flags or []: 947 nl_flags |= flag 948 949 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 950 if op.fixed_header: 951 msg += self._encode_struct(op.fixed_header, vals) 952 search_attrs = SpaceAttrs(op.attr_set, vals) 953 for name, value in vals.items(): 954 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 955 msg = _genl_msg_finalize(msg) 956 return msg 957 958 def _ops(self, ops): 959 reqs_by_seq = {} 960 req_seq = random.randint(1024, 65535) 961 payload = b'' 962 for (method, vals, flags) in ops: 963 op = self.ops[method] 964 msg = self._encode_message(op, vals, flags, req_seq) 965 reqs_by_seq[req_seq] = (op, msg, flags) 966 payload += msg 967 req_seq += 1 968 969 self.sock.send(payload, 0) 970 971 done = False 972 rsp = [] 973 op_rsp = [] 974 while not done: 975 reply = self.sock.recv(self._recv_size) 976 nms = NlMsgs(reply, attr_space=op.attr_set) 977 self._recv_dbg_print(reply, nms) 978 for nl_msg in nms: 979 if nl_msg.nl_seq in reqs_by_seq: 980 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 981 if nl_msg.extack: 982 self._decode_extack(req_msg, op, nl_msg.extack) 983 else: 984 op = None 985 req_flags = [] 986 987 if nl_msg.error: 988 raise NlError(nl_msg) 989 if nl_msg.done: 990 if nl_msg.extack: 991 print("Netlink warning:") 992 print(nl_msg) 993 994 if Netlink.NLM_F_DUMP in req_flags: 995 rsp.append(op_rsp) 996 elif not op_rsp: 997 rsp.append(None) 998 elif len(op_rsp) == 1: 999 rsp.append(op_rsp[0]) 1000 else: 1001 rsp.append(op_rsp) 1002 op_rsp = [] 1003 1004 del reqs_by_seq[nl_msg.nl_seq] 1005 done = len(reqs_by_seq) == 0 1006 break 1007 1008 decoded = self.nlproto.decode(self, nl_msg, op) 1009 1010 # Check if this is a reply to our request 1011 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1012 if decoded.cmd() in self.async_msg_ids: 1013 self.handle_ntf(decoded) 1014 continue 1015 else: 1016 print('Unexpected message: ' + repr(decoded)) 1017 continue 1018 1019 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1020 if op.fixed_header: 1021 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1022 op_rsp.append(rsp_msg) 1023 1024 return rsp 1025 1026 def _op(self, method, vals, flags=None, dump=False): 1027 req_flags = flags or [] 1028 if dump: 1029 req_flags.append(Netlink.NLM_F_DUMP) 1030 1031 ops = [(method, vals, req_flags)] 1032 return self._ops(ops)[0] 1033 1034 def do(self, method, vals, flags=None): 1035 return self._op(method, vals, flags) 1036 1037 def dump(self, method, vals): 1038 return self._op(method, vals, dump=True) 1039 1040 def do_multi(self, ops): 1041 return self._ops(ops)
Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.