1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Cl 1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause 2 2 3 from collections import namedtuple 3 from collections import namedtuple 4 from enum import Enum 4 from enum import Enum 5 import functools 5 import functools 6 import os 6 import os 7 import random 7 import random 8 import socket 8 import socket 9 import struct 9 import struct 10 from struct import Struct 10 from struct import Struct 11 import sys 11 import sys 12 import yaml 12 import yaml 13 import ipaddress 13 import ipaddress 14 import uuid 14 import uuid 15 15 16 from .nlspec import SpecFamily 16 from .nlspec import SpecFamily 17 17 18 # 18 # 19 # Generic Netlink code which should really be 19 # Generic Netlink code which should really be in some library, but I can't quickly find one. 20 # 20 # 21 21 22 22 23 class Netlink: 23 class Netlink: 24 # Netlink socket 24 # Netlink socket 25 SOL_NETLINK = 270 25 SOL_NETLINK = 270 26 26 27 NETLINK_ADD_MEMBERSHIP = 1 27 NETLINK_ADD_MEMBERSHIP = 1 28 NETLINK_CAP_ACK = 10 28 NETLINK_CAP_ACK = 10 29 NETLINK_EXT_ACK = 11 29 NETLINK_EXT_ACK = 11 30 NETLINK_GET_STRICT_CHK = 12 30 NETLINK_GET_STRICT_CHK = 12 31 31 32 # Netlink message 32 # Netlink message 33 NLMSG_ERROR = 2 33 NLMSG_ERROR = 2 34 NLMSG_DONE = 3 34 NLMSG_DONE = 3 35 35 36 NLM_F_REQUEST = 1 36 NLM_F_REQUEST = 1 37 NLM_F_ACK = 4 37 NLM_F_ACK = 4 38 NLM_F_ROOT = 0x100 38 NLM_F_ROOT = 0x100 39 NLM_F_MATCH = 0x200 39 NLM_F_MATCH = 0x200 40 40 41 NLM_F_REPLACE = 0x100 41 NLM_F_REPLACE = 0x100 42 NLM_F_EXCL = 0x200 42 NLM_F_EXCL = 0x200 43 NLM_F_CREATE = 0x400 43 NLM_F_CREATE = 0x400 44 NLM_F_APPEND = 0x800 44 NLM_F_APPEND = 0x800 45 45 46 NLM_F_CAPPED = 0x100 46 NLM_F_CAPPED = 0x100 47 NLM_F_ACK_TLVS = 0x200 47 NLM_F_ACK_TLVS = 0x200 48 48 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 49 NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH 50 50 51 NLA_F_NESTED = 0x8000 51 NLA_F_NESTED = 0x8000 52 NLA_F_NET_BYTEORDER = 0x4000 52 NLA_F_NET_BYTEORDER = 0x4000 53 53 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_B 54 NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER 55 55 56 # Genetlink defines 56 # Genetlink defines 57 NETLINK_GENERIC = 16 57 NETLINK_GENERIC = 16 58 58 59 GENL_ID_CTRL = 0x10 59 GENL_ID_CTRL = 0x10 60 60 61 # nlctrl 61 # nlctrl 62 CTRL_CMD_GETFAMILY = 3 62 CTRL_CMD_GETFAMILY = 3 63 63 64 CTRL_ATTR_FAMILY_ID = 1 64 CTRL_ATTR_FAMILY_ID = 1 65 CTRL_ATTR_FAMILY_NAME = 2 65 CTRL_ATTR_FAMILY_NAME = 2 66 CTRL_ATTR_MAXATTR = 5 66 CTRL_ATTR_MAXATTR = 5 67 CTRL_ATTR_MCAST_GROUPS = 7 67 CTRL_ATTR_MCAST_GROUPS = 7 68 68 69 CTRL_ATTR_MCAST_GRP_NAME = 1 69 CTRL_ATTR_MCAST_GRP_NAME = 1 70 CTRL_ATTR_MCAST_GRP_ID = 2 70 CTRL_ATTR_MCAST_GRP_ID = 2 71 71 72 # Extack types 72 # Extack types 73 NLMSGERR_ATTR_MSG = 1 73 NLMSGERR_ATTR_MSG = 1 74 NLMSGERR_ATTR_OFFS = 2 74 NLMSGERR_ATTR_OFFS = 2 75 NLMSGERR_ATTR_COOKIE = 3 75 NLMSGERR_ATTR_COOKIE = 3 76 NLMSGERR_ATTR_POLICY = 4 76 NLMSGERR_ATTR_POLICY = 4 77 NLMSGERR_ATTR_MISS_TYPE = 5 77 NLMSGERR_ATTR_MISS_TYPE = 5 78 NLMSGERR_ATTR_MISS_NEST = 6 78 NLMSGERR_ATTR_MISS_NEST = 6 79 79 80 # Policy types 80 # Policy types 81 NL_POLICY_TYPE_ATTR_TYPE = 1 81 NL_POLICY_TYPE_ATTR_TYPE = 1 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 82 NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 83 NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 84 NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 85 NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 86 NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 87 NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 88 NL_POLICY_TYPE_ATTR_POLICY_IDX = 8 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 89 NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 90 NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10 91 NL_POLICY_TYPE_ATTR_PAD = 11 91 NL_POLICY_TYPE_ATTR_PAD = 11 92 NL_POLICY_TYPE_ATTR_MASK = 12 92 NL_POLICY_TYPE_ATTR_MASK = 12 93 93 94 AttrType = Enum('AttrType', ['flag', 'u8', 94 AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64', 95 's8', 's16', 95 's8', 's16', 's32', 's64', 96 'binary', 's 96 'binary', 'string', 'nul-string', 97 'nested', 'n 97 'nested', 'nested-array', 98 'bitfield32' 98 'bitfield32', 'sint', 'uint']) 99 99 100 class NlError(Exception): 100 class NlError(Exception): 101 def __init__(self, nl_msg): 101 def __init__(self, nl_msg): 102 self.nl_msg = nl_msg 102 self.nl_msg = nl_msg 103 self.error = -nl_msg.error 103 self.error = -nl_msg.error 104 104 105 def __str__(self): 105 def __str__(self): 106 return f"Netlink error: {os.strerror(self. 106 return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}" 107 107 108 108 109 class ConfigError(Exception): 109 class ConfigError(Exception): 110 pass 110 pass 111 111 112 112 113 class NlAttr: 113 class NlAttr: 114 ScalarFormat = namedtuple('ScalarFormat', 114 ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little']) 115 type_formats = { 115 type_formats = { 116 'u8' : ScalarFormat(Struct('B'), Struc 116 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")), 117 's8' : ScalarFormat(Struct('b'), Struc 117 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")), 118 'u16': ScalarFormat(Struct('H'), Struc 118 'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")), 119 's16': ScalarFormat(Struct('h'), Struc 119 's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")), 120 'u32': ScalarFormat(Struct('I'), Struc 120 'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")), 121 's32': ScalarFormat(Struct('i'), Struc 121 's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")), 122 'u64': ScalarFormat(Struct('Q'), Struc 122 'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")), 123 's64': ScalarFormat(Struct('q'), Struc 123 's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q")) 124 } 124 } 125 125 126 def __init__(self, raw, offset): 126 def __init__(self, raw, offset): 127 self._len, self._type = struct.unpack( 127 self._len, self._type = struct.unpack("HH", raw[offset : offset + 4]) 128 self.type = self._type & ~Netlink.NLA_ 128 self.type = self._type & ~Netlink.NLA_TYPE_MASK 129 self.is_nest = self._type & Netlink.NL 129 self.is_nest = self._type & Netlink.NLA_F_NESTED 130 self.payload_len = self._len 130 self.payload_len = self._len 131 self.full_len = (self.payload_len + 3) 131 self.full_len = (self.payload_len + 3) & ~3 132 self.raw = raw[offset + 4 : offset + s 132 self.raw = raw[offset + 4 : offset + self.payload_len] 133 133 134 @classmethod 134 @classmethod 135 def get_format(cls, attr_type, byte_order= 135 def get_format(cls, attr_type, byte_order=None): 136 format = cls.type_formats[attr_type] 136 format = cls.type_formats[attr_type] 137 if byte_order: 137 if byte_order: 138 return format.big if byte_order == 138 return format.big if byte_order == "big-endian" \ 139 else format.little 139 else format.little 140 return format.native 140 return format.native 141 141 142 def as_scalar(self, attr_type, byte_order= 142 def as_scalar(self, attr_type, byte_order=None): 143 format = self.get_format(attr_type, by 143 format = self.get_format(attr_type, byte_order) 144 return format.unpack(self.raw)[0] 144 return format.unpack(self.raw)[0] 145 145 146 def as_auto_scalar(self, attr_type, byte_o 146 def as_auto_scalar(self, attr_type, byte_order=None): 147 if len(self.raw) != 4 and len(self.raw 147 if len(self.raw) != 4 and len(self.raw) != 8: 148 raise Exception(f"Auto-scalar len 148 raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}") 149 real_type = attr_type[0] + str(len(sel 149 real_type = attr_type[0] + str(len(self.raw) * 8) 150 format = self.get_format(real_type, by 150 format = self.get_format(real_type, byte_order) 151 return format.unpack(self.raw)[0] 151 return format.unpack(self.raw)[0] 152 152 153 def as_strz(self): 153 def as_strz(self): 154 return self.raw.decode('ascii')[:-1] 154 return self.raw.decode('ascii')[:-1] 155 155 156 def as_bin(self): 156 def as_bin(self): 157 return self.raw 157 return self.raw 158 158 159 def as_c_array(self, type): 159 def as_c_array(self, type): 160 format = self.get_format(type) 160 format = self.get_format(type) 161 return [ x[0] for x in format.iter_unp 161 return [ x[0] for x in format.iter_unpack(self.raw) ] 162 162 163 def __repr__(self): 163 def __repr__(self): 164 return f"[type:{self.type} len:{self._ 164 return f"[type:{self.type} len:{self._len}] {self.raw}" 165 165 166 166 167 class NlAttrs: 167 class NlAttrs: 168 def __init__(self, msg, offset=0): 168 def __init__(self, msg, offset=0): 169 self.attrs = [] 169 self.attrs = [] 170 170 171 while offset < len(msg): 171 while offset < len(msg): 172 attr = NlAttr(msg, offset) 172 attr = NlAttr(msg, offset) 173 offset += attr.full_len 173 offset += attr.full_len 174 self.attrs.append(attr) 174 self.attrs.append(attr) 175 175 176 def __iter__(self): 176 def __iter__(self): 177 yield from self.attrs 177 yield from self.attrs 178 178 179 def __repr__(self): 179 def __repr__(self): 180 msg = '' 180 msg = '' 181 for a in self.attrs: 181 for a in self.attrs: 182 if msg: 182 if msg: 183 msg += '\n' 183 msg += '\n' 184 msg += repr(a) 184 msg += repr(a) 185 return msg 185 return msg 186 186 187 187 188 class NlMsg: 188 class NlMsg: 189 def __init__(self, msg, offset, attr_space 189 def __init__(self, msg, offset, attr_space=None): 190 self.hdr = msg[offset : offset + 16] 190 self.hdr = msg[offset : offset + 16] 191 191 192 self.nl_len, self.nl_type, self.nl_fla 192 self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \ 193 struct.unpack("IHHII", self.hdr) 193 struct.unpack("IHHII", self.hdr) 194 194 195 self.raw = msg[offset + 16 : offset + 195 self.raw = msg[offset + 16 : offset + self.nl_len] 196 196 197 self.error = 0 197 self.error = 0 198 self.done = 0 198 self.done = 0 199 199 200 extack_off = None 200 extack_off = None 201 if self.nl_type == Netlink.NLMSG_ERROR 201 if self.nl_type == Netlink.NLMSG_ERROR: 202 self.error = struct.unpack("i", se 202 self.error = struct.unpack("i", self.raw[0:4])[0] 203 self.done = 1 203 self.done = 1 204 extack_off = 20 204 extack_off = 20 205 elif self.nl_type == Netlink.NLMSG_DON 205 elif self.nl_type == Netlink.NLMSG_DONE: 206 self.error = struct.unpack("i", se 206 self.error = struct.unpack("i", self.raw[0:4])[0] 207 self.done = 1 207 self.done = 1 208 extack_off = 4 208 extack_off = 4 209 209 210 self.extack = None 210 self.extack = None 211 if self.nl_flags & Netlink.NLM_F_ACK_T 211 if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off: 212 self.extack = dict() 212 self.extack = dict() 213 extack_attrs = NlAttrs(self.raw[ex 213 extack_attrs = NlAttrs(self.raw[extack_off:]) 214 for extack in extack_attrs: 214 for extack in extack_attrs: 215 if extack.type == Netlink.NLMS 215 if extack.type == Netlink.NLMSGERR_ATTR_MSG: 216 self.extack['msg'] = extac 216 self.extack['msg'] = extack.as_strz() 217 elif extack.type == Netlink.NL 217 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: 218 self.extack['miss-type'] = 218 self.extack['miss-type'] = extack.as_scalar('u32') 219 elif extack.type == Netlink.NL 219 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: 220 self.extack['miss-nest'] = 220 self.extack['miss-nest'] = extack.as_scalar('u32') 221 elif extack.type == Netlink.NL 221 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: 222 self.extack['bad-attr-offs 222 self.extack['bad-attr-offs'] = extack.as_scalar('u32') 223 elif extack.type == Netlink.NL 223 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY: 224 self.extack['policy'] = se 224 self.extack['policy'] = self._decode_policy(extack.raw) 225 else: 225 else: 226 if 'unknown' not in self.e 226 if 'unknown' not in self.extack: 227 self.extack['unknown'] 227 self.extack['unknown'] = [] 228 self.extack['unknown'].app 228 self.extack['unknown'].append(extack) 229 229 230 if attr_space: 230 if attr_space: 231 # We don't have the ability to 231 # We don't have the ability to parse nests yet, so only do global 232 if 'miss-type' in self.extack 232 if 'miss-type' in self.extack and 'miss-nest' not in self.extack: 233 miss_type = self.extack['m 233 miss_type = self.extack['miss-type'] 234 if miss_type in attr_space 234 if miss_type in attr_space.attrs_by_val: 235 spec = attr_space.attr 235 spec = attr_space.attrs_by_val[miss_type] 236 self.extack['miss-type 236 self.extack['miss-type'] = spec['name'] 237 if 'doc' in spec: 237 if 'doc' in spec: 238 self.extack['miss- 238 self.extack['miss-type-doc'] = spec['doc'] 239 239 240 def _decode_policy(self, raw): 240 def _decode_policy(self, raw): 241 policy = {} 241 policy = {} 242 for attr in NlAttrs(raw): 242 for attr in NlAttrs(raw): 243 if attr.type == Netlink.NL_POLICY_ 243 if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE: 244 type = attr.as_scalar('u32') 244 type = attr.as_scalar('u32') 245 policy['type'] = Netlink.AttrT 245 policy['type'] = Netlink.AttrType(type).name 246 elif attr.type == Netlink.NL_POLIC 246 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S: 247 policy['min-value'] = attr.as_ 247 policy['min-value'] = attr.as_scalar('s64') 248 elif attr.type == Netlink.NL_POLIC 248 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S: 249 policy['max-value'] = attr.as_ 249 policy['max-value'] = attr.as_scalar('s64') 250 elif attr.type == Netlink.NL_POLIC 250 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U: 251 policy['min-value'] = attr.as_ 251 policy['min-value'] = attr.as_scalar('u64') 252 elif attr.type == Netlink.NL_POLIC 252 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U: 253 policy['max-value'] = attr.as_ 253 policy['max-value'] = attr.as_scalar('u64') 254 elif attr.type == Netlink.NL_POLIC 254 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH: 255 policy['min-length'] = attr.as 255 policy['min-length'] = attr.as_scalar('u32') 256 elif attr.type == Netlink.NL_POLIC 256 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH: 257 policy['max-length'] = attr.as 257 policy['max-length'] = attr.as_scalar('u32') 258 elif attr.type == Netlink.NL_POLIC 258 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK: 259 policy['bitfield32-mask'] = at 259 policy['bitfield32-mask'] = attr.as_scalar('u32') 260 elif attr.type == Netlink.NL_POLIC 260 elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK: 261 policy['mask'] = attr.as_scala 261 policy['mask'] = attr.as_scalar('u64') 262 return policy 262 return policy 263 263 264 def cmd(self): 264 def cmd(self): 265 return self.nl_type 265 return self.nl_type 266 266 267 def __repr__(self): 267 def __repr__(self): 268 msg = f"nl_len = {self.nl_len} ({len(s 268 msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}" 269 if self.error: 269 if self.error: 270 msg += '\n\terror: ' + str(self.er 270 msg += '\n\terror: ' + str(self.error) 271 if self.extack: 271 if self.extack: 272 msg += '\n\textack: ' + repr(self. 272 msg += '\n\textack: ' + repr(self.extack) 273 return msg 273 return msg 274 274 275 275 276 class NlMsgs: 276 class NlMsgs: 277 def __init__(self, data, attr_space=None): 277 def __init__(self, data, attr_space=None): 278 self.msgs = [] 278 self.msgs = [] 279 279 280 offset = 0 280 offset = 0 281 while offset < len(data): 281 while offset < len(data): 282 msg = NlMsg(data, offset, attr_spa 282 msg = NlMsg(data, offset, attr_space=attr_space) 283 offset += msg.nl_len 283 offset += msg.nl_len 284 self.msgs.append(msg) 284 self.msgs.append(msg) 285 285 286 def __iter__(self): 286 def __iter__(self): 287 yield from self.msgs 287 yield from self.msgs 288 288 289 289 290 genl_family_name_to_id = None 290 genl_family_name_to_id = None 291 291 292 292 293 def _genl_msg(nl_type, nl_flags, genl_cmd, gen 293 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None): 294 # we prepend length in _genl_msg_finalize( 294 # we prepend length in _genl_msg_finalize() 295 if seq is None: 295 if seq is None: 296 seq = random.randint(1, 1024) 296 seq = random.randint(1, 1024) 297 nlmsg = struct.pack("HHII", nl_type, nl_fl 297 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 298 genlmsg = struct.pack("BBH", genl_cmd, gen 298 genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0) 299 return nlmsg + genlmsg 299 return nlmsg + genlmsg 300 300 301 301 302 def _genl_msg_finalize(msg): 302 def _genl_msg_finalize(msg): 303 return struct.pack("I", len(msg) + 4) + ms 303 return struct.pack("I", len(msg) + 4) + msg 304 304 305 305 306 def _genl_load_families(): 306 def _genl_load_families(): 307 with socket.socket(socket.AF_NETLINK, sock 307 with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock: 308 sock.setsockopt(Netlink.SOL_NETLINK, N 308 sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 309 309 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 310 msg = _genl_msg(Netlink.GENL_ID_CTRL, 311 Netlink.NLM_F_REQUEST 311 Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP, 312 Netlink.CTRL_CMD_GETFA 312 Netlink.CTRL_CMD_GETFAMILY, 1) 313 msg = _genl_msg_finalize(msg) 313 msg = _genl_msg_finalize(msg) 314 314 315 sock.send(msg, 0) 315 sock.send(msg, 0) 316 316 317 global genl_family_name_to_id 317 global genl_family_name_to_id 318 genl_family_name_to_id = dict() 318 genl_family_name_to_id = dict() 319 319 320 while True: 320 while True: 321 reply = sock.recv(128 * 1024) 321 reply = sock.recv(128 * 1024) 322 nms = NlMsgs(reply) 322 nms = NlMsgs(reply) 323 for nl_msg in nms: 323 for nl_msg in nms: 324 if nl_msg.error: 324 if nl_msg.error: 325 print("Netlink error:", nl 325 print("Netlink error:", nl_msg.error) 326 return 326 return 327 if nl_msg.done: 327 if nl_msg.done: 328 return 328 return 329 329 330 gm = GenlMsg(nl_msg) 330 gm = GenlMsg(nl_msg) 331 fam = dict() 331 fam = dict() 332 for attr in NlAttrs(gm.raw): 332 for attr in NlAttrs(gm.raw): 333 if attr.type == Netlink.CT 333 if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: 334 fam['id'] = attr.as_sc 334 fam['id'] = attr.as_scalar('u16') 335 elif attr.type == Netlink. 335 elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: 336 fam['name'] = attr.as_ 336 fam['name'] = attr.as_strz() 337 elif attr.type == Netlink. 337 elif attr.type == Netlink.CTRL_ATTR_MAXATTR: 338 fam['maxattr'] = attr. 338 fam['maxattr'] = attr.as_scalar('u32') 339 elif attr.type == Netlink. 339 elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: 340 fam['mcast'] = dict() 340 fam['mcast'] = dict() 341 for entry in NlAttrs(a 341 for entry in NlAttrs(attr.raw): 342 mcast_name = None 342 mcast_name = None 343 mcast_id = None 343 mcast_id = None 344 for entry_attr in 344 for entry_attr in NlAttrs(entry.raw): 345 if entry_attr. 345 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: 346 mcast_name 346 mcast_name = entry_attr.as_strz() 347 elif entry_att 347 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: 348 mcast_id = 348 mcast_id = entry_attr.as_scalar('u32') 349 if mcast_name and 349 if mcast_name and mcast_id is not None: 350 fam['mcast'][m 350 fam['mcast'][mcast_name] = mcast_id 351 if 'name' in fam and 'id' in f 351 if 'name' in fam and 'id' in fam: 352 genl_family_name_to_id[fam 352 genl_family_name_to_id[fam['name']] = fam 353 353 354 354 355 class GenlMsg: 355 class GenlMsg: 356 def __init__(self, nl_msg): 356 def __init__(self, nl_msg): 357 self.nl = nl_msg 357 self.nl = nl_msg 358 self.genl_cmd, self.genl_version, _ = 358 self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0) 359 self.raw = nl_msg.raw[4:] 359 self.raw = nl_msg.raw[4:] 360 360 361 def cmd(self): 361 def cmd(self): 362 return self.genl_cmd 362 return self.genl_cmd 363 363 364 def __repr__(self): 364 def __repr__(self): 365 msg = repr(self.nl) 365 msg = repr(self.nl) 366 msg += f"\tgenl_cmd = {self.genl_cmd} 366 msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n" 367 for a in self.raw_attrs: 367 for a in self.raw_attrs: 368 msg += '\t\t' + repr(a) + '\n' 368 msg += '\t\t' + repr(a) + '\n' 369 return msg 369 return msg 370 370 371 371 372 class NetlinkProtocol: 372 class NetlinkProtocol: 373 def __init__(self, family_name, proto_num) 373 def __init__(self, family_name, proto_num): 374 self.family_name = family_name 374 self.family_name = family_name 375 self.proto_num = proto_num 375 self.proto_num = proto_num 376 376 377 def _message(self, nl_type, nl_flags, seq= 377 def _message(self, nl_type, nl_flags, seq=None): 378 if seq is None: 378 if seq is None: 379 seq = random.randint(1, 1024) 379 seq = random.randint(1, 1024) 380 nlmsg = struct.pack("HHII", nl_type, n 380 nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0) 381 return nlmsg 381 return nlmsg 382 382 383 def message(self, flags, command, version, 383 def message(self, flags, command, version, seq=None): 384 return self._message(command, flags, s 384 return self._message(command, flags, seq) 385 385 386 def _decode(self, nl_msg): 386 def _decode(self, nl_msg): 387 return nl_msg 387 return nl_msg 388 388 389 def decode(self, ynl, nl_msg, op): 389 def decode(self, ynl, nl_msg, op): 390 msg = self._decode(nl_msg) 390 msg = self._decode(nl_msg) 391 if op is None: 391 if op is None: 392 op = ynl.rsp_by_value[msg.cmd()] 392 op = ynl.rsp_by_value[msg.cmd()] 393 fixed_header_size = ynl._struct_size(o 393 fixed_header_size = ynl._struct_size(op.fixed_header) 394 msg.raw_attrs = NlAttrs(msg.raw, fixed 394 msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size) 395 return msg 395 return msg 396 396 397 def get_mcast_id(self, mcast_name, mcast_g 397 def get_mcast_id(self, mcast_name, mcast_groups): 398 if mcast_name not in mcast_groups: 398 if mcast_name not in mcast_groups: 399 raise Exception(f'Multicast group 399 raise Exception(f'Multicast group "{mcast_name}" not present in the spec') 400 return mcast_groups[mcast_name].value 400 return mcast_groups[mcast_name].value 401 401 402 def msghdr_size(self): 402 def msghdr_size(self): 403 return 16 403 return 16 404 404 405 405 406 class GenlProtocol(NetlinkProtocol): 406 class GenlProtocol(NetlinkProtocol): 407 def __init__(self, family_name): 407 def __init__(self, family_name): 408 super().__init__(family_name, Netlink. 408 super().__init__(family_name, Netlink.NETLINK_GENERIC) 409 409 410 global genl_family_name_to_id 410 global genl_family_name_to_id 411 if genl_family_name_to_id is None: 411 if genl_family_name_to_id is None: 412 _genl_load_families() 412 _genl_load_families() 413 413 414 self.genl_family = genl_family_name_to 414 self.genl_family = genl_family_name_to_id[family_name] 415 self.family_id = genl_family_name_to_i 415 self.family_id = genl_family_name_to_id[family_name]['id'] 416 416 417 def message(self, flags, command, version, 417 def message(self, flags, command, version, seq=None): 418 nlmsg = self._message(self.family_id, 418 nlmsg = self._message(self.family_id, flags, seq) 419 genlmsg = struct.pack("BBH", command, 419 genlmsg = struct.pack("BBH", command, version, 0) 420 return nlmsg + genlmsg 420 return nlmsg + genlmsg 421 421 422 def _decode(self, nl_msg): 422 def _decode(self, nl_msg): 423 return GenlMsg(nl_msg) 423 return GenlMsg(nl_msg) 424 424 425 def get_mcast_id(self, mcast_name, mcast_g 425 def get_mcast_id(self, mcast_name, mcast_groups): 426 if mcast_name not in self.genl_family[ 426 if mcast_name not in self.genl_family['mcast']: 427 raise Exception(f'Multicast group 427 raise Exception(f'Multicast group "{mcast_name}" not present in the family') 428 return self.genl_family['mcast'][mcast 428 return self.genl_family['mcast'][mcast_name] 429 429 430 def msghdr_size(self): 430 def msghdr_size(self): 431 return super().msghdr_size() + 4 431 return super().msghdr_size() + 4 432 432 433 433 434 class SpaceAttrs: 434 class SpaceAttrs: 435 SpecValuesPair = namedtuple('SpecValuesPai 435 SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values']) 436 436 437 def __init__(self, attr_space, attrs, oute 437 def __init__(self, attr_space, attrs, outer = None): 438 outer_scopes = outer.scopes if outer e 438 outer_scopes = outer.scopes if outer else [] 439 inner_scope = self.SpecValuesPair(attr 439 inner_scope = self.SpecValuesPair(attr_space, attrs) 440 self.scopes = [inner_scope] + outer_sc 440 self.scopes = [inner_scope] + outer_scopes 441 441 442 def lookup(self, name): 442 def lookup(self, name): 443 for scope in self.scopes: 443 for scope in self.scopes: 444 if name in scope.spec: 444 if name in scope.spec: 445 if name in scope.values: 445 if name in scope.values: 446 return scope.values[name] 446 return scope.values[name] 447 spec_name = scope.spec.yaml['n 447 spec_name = scope.spec.yaml['name'] 448 raise Exception( 448 raise Exception( 449 f"No value for '{name}' in 449 f"No value for '{name}' in attribute space '{spec_name}'") 450 raise Exception(f"Attribute '{name}' n 450 raise Exception(f"Attribute '{name}' not defined in any attribute-set") 451 451 452 452 453 # 453 # 454 # YNL implementation details. 454 # YNL implementation details. 455 # 455 # 456 456 457 457 458 class YnlFamily(SpecFamily): 458 class YnlFamily(SpecFamily): 459 def __init__(self, def_path, schema=None, 459 def __init__(self, def_path, schema=None, process_unknown=False, 460 recv_size=0): 460 recv_size=0): 461 super().__init__(def_path, schema) 461 super().__init__(def_path, schema) 462 462 463 self.include_raw = False 463 self.include_raw = False 464 self.process_unknown = process_unknown 464 self.process_unknown = process_unknown 465 465 466 try: 466 try: 467 if self.proto == "netlink-raw": 467 if self.proto == "netlink-raw": 468 self.nlproto = NetlinkProtocol 468 self.nlproto = NetlinkProtocol(self.yaml['name'], 469 469 self.yaml['protonum']) 470 else: 470 else: 471 self.nlproto = GenlProtocol(se 471 self.nlproto = GenlProtocol(self.yaml['name']) 472 except KeyError: 472 except KeyError: 473 raise Exception(f"Family '{self.ya 473 raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel") 474 474 475 self._recv_dbg = False 475 self._recv_dbg = False 476 # Note that netlink will use conservat 476 # Note that netlink will use conservative (min) message size for 477 # the first dump recv() on the socket, 477 # the first dump recv() on the socket, our setting will only matter 478 # from the second recv() on. 478 # from the second recv() on. 479 self._recv_size = recv_size if recv_si 479 self._recv_size = recv_size if recv_size else 131072 480 # Netlink will always allocate at leas 480 # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo) 481 # for a message, so smaller receive si 481 # for a message, so smaller receive sizes will lead to truncation. 482 # Note that the min size for other fam 482 # Note that the min size for other families may be larger than 4k! 483 if self._recv_size < 4000: 483 if self._recv_size < 4000: 484 raise ConfigError() 484 raise ConfigError() 485 485 486 self.sock = socket.socket(socket.AF_NE 486 self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num) 487 self.sock.setsockopt(Netlink.SOL_NETLI 487 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) 488 self.sock.setsockopt(Netlink.SOL_NETLI 488 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) 489 self.sock.setsockopt(Netlink.SOL_NETLI 489 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1) 490 490 491 self.async_msg_ids = set() 491 self.async_msg_ids = set() 492 self.async_msg_queue = [] 492 self.async_msg_queue = [] 493 493 494 for msg in self.msgs.values(): 494 for msg in self.msgs.values(): 495 if msg.is_async: 495 if msg.is_async: 496 self.async_msg_ids.add(msg.rsp 496 self.async_msg_ids.add(msg.rsp_value) 497 497 498 for op_name, op in self.ops.items(): 498 for op_name, op in self.ops.items(): 499 bound_f = functools.partial(self._ 499 bound_f = functools.partial(self._op, op_name) 500 setattr(self, op.ident_name, bound 500 setattr(self, op.ident_name, bound_f) 501 501 502 502 503 def ntf_subscribe(self, mcast_name): 503 def ntf_subscribe(self, mcast_name): 504 mcast_id = self.nlproto.get_mcast_id(m 504 mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups) 505 self.sock.bind((0, 0)) 505 self.sock.bind((0, 0)) 506 self.sock.setsockopt(Netlink.SOL_NETLI 506 self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP, 507 mcast_id) 507 mcast_id) 508 508 509 def set_recv_dbg(self, enabled): 509 def set_recv_dbg(self, enabled): 510 self._recv_dbg = enabled 510 self._recv_dbg = enabled 511 511 512 def _recv_dbg_print(self, reply, nl_msgs): 512 def _recv_dbg_print(self, reply, nl_msgs): 513 if not self._recv_dbg: 513 if not self._recv_dbg: 514 return 514 return 515 print("Recv: read", len(reply), "bytes 515 print("Recv: read", len(reply), "bytes,", 516 len(nl_msgs.msgs), "messages", f 516 len(nl_msgs.msgs), "messages", file=sys.stderr) 517 for nl_msg in nl_msgs: 517 for nl_msg in nl_msgs: 518 print(" ", nl_msg, file=sys.stder 518 print(" ", nl_msg, file=sys.stderr) 519 519 520 def _encode_enum(self, attr_spec, value): 520 def _encode_enum(self, attr_spec, value): 521 enum = self.consts[attr_spec['enum']] 521 enum = self.consts[attr_spec['enum']] 522 if enum.type == 'flags' or attr_spec.g 522 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 523 scalar = 0 523 scalar = 0 524 if isinstance(value, str): 524 if isinstance(value, str): 525 value = [value] 525 value = [value] 526 for single_value in value: 526 for single_value in value: 527 scalar += enum.entries[single_ 527 scalar += enum.entries[single_value].user_value(as_flags = True) 528 return scalar 528 return scalar 529 else: 529 else: 530 return enum.entries[value].user_va 530 return enum.entries[value].user_value() 531 531 532 def _get_scalar(self, attr_spec, value): 532 def _get_scalar(self, attr_spec, value): 533 try: 533 try: 534 return int(value) 534 return int(value) 535 except (ValueError, TypeError) as e: 535 except (ValueError, TypeError) as e: 536 if 'enum' not in attr_spec: 536 if 'enum' not in attr_spec: 537 raise e 537 raise e 538 return self._encode_enum(attr_spec, va 538 return self._encode_enum(attr_spec, value) 539 539 540 def _add_attr(self, space, name, value, se 540 def _add_attr(self, space, name, value, search_attrs): 541 try: 541 try: 542 attr = self.attr_sets[space][name] 542 attr = self.attr_sets[space][name] 543 except KeyError: 543 except KeyError: 544 raise Exception(f"Space '{space}' 544 raise Exception(f"Space '{space}' has no attribute '{name}'") 545 nl_type = attr.value 545 nl_type = attr.value 546 546 547 if attr.is_multi and isinstance(value, 547 if attr.is_multi and isinstance(value, list): 548 attr_payload = b'' 548 attr_payload = b'' 549 for subvalue in value: 549 for subvalue in value: 550 attr_payload += self._add_attr 550 attr_payload += self._add_attr(space, name, subvalue, search_attrs) 551 return attr_payload 551 return attr_payload 552 552 553 if attr["type"] == 'nest': 553 if attr["type"] == 'nest': 554 nl_type |= Netlink.NLA_F_NESTED 554 nl_type |= Netlink.NLA_F_NESTED 555 attr_payload = b'' 555 attr_payload = b'' 556 sub_attrs = SpaceAttrs(self.attr_s 556 sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs) 557 for subname, subvalue in value.ite 557 for subname, subvalue in value.items(): 558 attr_payload += self._add_attr 558 attr_payload += self._add_attr(attr['nested-attributes'], 559 559 subname, subvalue, sub_attrs) 560 elif attr["type"] == 'flag': 560 elif attr["type"] == 'flag': 561 if not value: 561 if not value: 562 # If value is absent or false 562 # If value is absent or false then skip attribute creation. 563 return b'' 563 return b'' 564 attr_payload = b'' 564 attr_payload = b'' 565 elif attr["type"] == 'string': 565 elif attr["type"] == 'string': 566 attr_payload = str(value).encode(' 566 attr_payload = str(value).encode('ascii') + b'\x00' 567 elif attr["type"] == 'binary': 567 elif attr["type"] == 'binary': 568 if isinstance(value, bytes): 568 if isinstance(value, bytes): 569 attr_payload = value 569 attr_payload = value 570 elif isinstance(value, str): 570 elif isinstance(value, str): 571 attr_payload = bytes.fromhex(v 571 attr_payload = bytes.fromhex(value) 572 elif isinstance(value, dict) and a 572 elif isinstance(value, dict) and attr.struct_name: 573 attr_payload = self._encode_st 573 attr_payload = self._encode_struct(attr.struct_name, value) 574 else: 574 else: 575 raise Exception(f'Unknown type 575 raise Exception(f'Unknown type for binary attribute, value: {value}') 576 elif attr['type'] in NlAttr.type_forma 576 elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar: 577 scalar = self._get_scalar(attr, va 577 scalar = self._get_scalar(attr, value) 578 if attr.is_auto_scalar: 578 if attr.is_auto_scalar: 579 attr_type = attr["type"][0] + 579 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64') 580 else: 580 else: 581 attr_type = attr["type"] 581 attr_type = attr["type"] 582 format = NlAttr.get_format(attr_ty 582 format = NlAttr.get_format(attr_type, attr.byte_order) 583 attr_payload = format.pack(scalar) 583 attr_payload = format.pack(scalar) 584 elif attr['type'] in "bitfield32": 584 elif attr['type'] in "bitfield32": 585 scalar_value = self._get_scalar(at 585 scalar_value = self._get_scalar(attr, value["value"]) 586 scalar_selector = self._get_scalar 586 scalar_selector = self._get_scalar(attr, value["selector"]) 587 attr_payload = struct.pack("II", s 587 attr_payload = struct.pack("II", scalar_value, scalar_selector) 588 elif attr['type'] == 'sub-message': 588 elif attr['type'] == 'sub-message': 589 msg_format = self._resolve_selecto 589 msg_format = self._resolve_selector(attr, search_attrs) 590 attr_payload = b'' 590 attr_payload = b'' 591 if msg_format.fixed_header: 591 if msg_format.fixed_header: 592 attr_payload += self._encode_s 592 attr_payload += self._encode_struct(msg_format.fixed_header, value) 593 if msg_format.attr_set: 593 if msg_format.attr_set: 594 if msg_format.attr_set in self 594 if msg_format.attr_set in self.attr_sets: 595 nl_type |= Netlink.NLA_F_N 595 nl_type |= Netlink.NLA_F_NESTED 596 sub_attrs = SpaceAttrs(msg 596 sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs) 597 for subname, subvalue in v 597 for subname, subvalue in value.items(): 598 attr_payload += self._ 598 attr_payload += self._add_attr(msg_format.attr_set, 599 599 subname, subvalue, sub_attrs) 600 else: 600 else: 601 raise Exception(f"Unknown 601 raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'") 602 else: 602 else: 603 raise Exception(f'Unknown type at 603 raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') 604 604 605 pad = b'\x00' * ((4 - len(attr_payload 605 pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4) 606 return struct.pack('HH', len(attr_payl 606 return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad 607 607 608 def _decode_enum(self, raw, attr_spec): 608 def _decode_enum(self, raw, attr_spec): 609 enum = self.consts[attr_spec['enum']] 609 enum = self.consts[attr_spec['enum']] 610 if enum.type == 'flags' or attr_spec.g 610 if enum.type == 'flags' or attr_spec.get('enum-as-flags', False): 611 i = 0 611 i = 0 612 value = set() 612 value = set() 613 while raw: 613 while raw: 614 if raw & 1: 614 if raw & 1: 615 value.add(enum.entries_by_ 615 value.add(enum.entries_by_val[i].name) 616 raw >>= 1 616 raw >>= 1 617 i += 1 617 i += 1 618 else: 618 else: 619 value = enum.entries_by_val[raw].n 619 value = enum.entries_by_val[raw].name 620 return value 620 return value 621 621 622 def _decode_binary(self, attr, attr_spec): 622 def _decode_binary(self, attr, attr_spec): 623 if attr_spec.struct_name: 623 if attr_spec.struct_name: 624 decoded = self._decode_struct(attr 624 decoded = self._decode_struct(attr.raw, attr_spec.struct_name) 625 elif attr_spec.sub_type: 625 elif attr_spec.sub_type: 626 decoded = attr.as_c_array(attr_spe 626 decoded = attr.as_c_array(attr_spec.sub_type) 627 else: 627 else: 628 decoded = attr.as_bin() 628 decoded = attr.as_bin() 629 if attr_spec.display_hint: 629 if attr_spec.display_hint: 630 decoded = self._formatted_stri 630 decoded = self._formatted_string(decoded, attr_spec.display_hint) 631 return decoded 631 return decoded 632 632 633 def _decode_array_attr(self, attr, attr_sp 633 def _decode_array_attr(self, attr, attr_spec): 634 decoded = [] 634 decoded = [] 635 offset = 0 635 offset = 0 636 while offset < len(attr.raw): 636 while offset < len(attr.raw): 637 item = NlAttr(attr.raw, offset) 637 item = NlAttr(attr.raw, offset) 638 offset += item.full_len 638 offset += item.full_len 639 639 640 if attr_spec["sub-type"] == 'nest' 640 if attr_spec["sub-type"] == 'nest': 641 subattrs = self._decode(NlAttr 641 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes']) 642 decoded.append({ item.type: su 642 decoded.append({ item.type: subattrs }) 643 elif attr_spec["sub-type"] == 'bin 643 elif attr_spec["sub-type"] == 'binary': 644 subattrs = item.as_bin() 644 subattrs = item.as_bin() 645 if attr_spec.display_hint: 645 if attr_spec.display_hint: 646 subattrs = self._formatted 646 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 647 decoded.append(subattrs) 647 decoded.append(subattrs) 648 elif attr_spec["sub-type"] in NlAt 648 elif attr_spec["sub-type"] in NlAttr.type_formats: 649 subattrs = item.as_scalar(attr 649 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order) 650 if attr_spec.display_hint: 650 if attr_spec.display_hint: 651 subattrs = self._formatted 651 subattrs = self._formatted_string(subattrs, attr_spec.display_hint) 652 decoded.append(subattrs) 652 decoded.append(subattrs) 653 else: 653 else: 654 raise Exception(f'Unknown {att 654 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}') 655 return decoded 655 return decoded 656 656 657 def _decode_nest_type_value(self, attr, at 657 def _decode_nest_type_value(self, attr, attr_spec): 658 decoded = {} 658 decoded = {} 659 value = attr 659 value = attr 660 for name in attr_spec['type-value']: 660 for name in attr_spec['type-value']: 661 value = NlAttr(value.raw, 0) 661 value = NlAttr(value.raw, 0) 662 decoded[name] = value.type 662 decoded[name] = value.type 663 subattrs = self._decode(NlAttrs(value. 663 subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes']) 664 decoded.update(subattrs) 664 decoded.update(subattrs) 665 return decoded 665 return decoded 666 666 667 def _decode_unknown(self, attr): 667 def _decode_unknown(self, attr): 668 if attr.is_nest: 668 if attr.is_nest: 669 return self._decode(NlAttrs(attr.r 669 return self._decode(NlAttrs(attr.raw), None) 670 else: 670 else: 671 return attr.as_bin() 671 return attr.as_bin() 672 672 673 def _rsp_add(self, rsp, name, is_multi, de 673 def _rsp_add(self, rsp, name, is_multi, decoded): 674 if is_multi == None: 674 if is_multi == None: 675 if name in rsp and type(rsp[name]) 675 if name in rsp and type(rsp[name]) is not list: 676 rsp[name] = [rsp[name]] 676 rsp[name] = [rsp[name]] 677 is_multi = True 677 is_multi = True 678 else: 678 else: 679 is_multi = False 679 is_multi = False 680 680 681 if not is_multi: 681 if not is_multi: 682 rsp[name] = decoded 682 rsp[name] = decoded 683 elif name in rsp: 683 elif name in rsp: 684 rsp[name].append(decoded) 684 rsp[name].append(decoded) 685 else: 685 else: 686 rsp[name] = [decoded] 686 rsp[name] = [decoded] 687 687 688 def _resolve_selector(self, attr_spec, sea 688 def _resolve_selector(self, attr_spec, search_attrs): 689 sub_msg = attr_spec.sub_message 689 sub_msg = attr_spec.sub_message 690 if sub_msg not in self.sub_msgs: 690 if sub_msg not in self.sub_msgs: 691 raise Exception(f"No sub-message s 691 raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}") 692 sub_msg_spec = self.sub_msgs[sub_msg] 692 sub_msg_spec = self.sub_msgs[sub_msg] 693 693 694 selector = attr_spec.selector 694 selector = attr_spec.selector 695 value = search_attrs.lookup(selector) 695 value = search_attrs.lookup(selector) 696 if value not in sub_msg_spec.formats: 696 if value not in sub_msg_spec.formats: 697 raise Exception(f"No message forma 697 raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'") 698 698 699 spec = sub_msg_spec.formats[value] 699 spec = sub_msg_spec.formats[value] 700 return spec 700 return spec 701 701 702 def _decode_sub_msg(self, attr, attr_spec, 702 def _decode_sub_msg(self, attr, attr_spec, search_attrs): 703 msg_format = self._resolve_selector(at 703 msg_format = self._resolve_selector(attr_spec, search_attrs) 704 decoded = {} 704 decoded = {} 705 offset = 0 705 offset = 0 706 if msg_format.fixed_header: 706 if msg_format.fixed_header: 707 decoded.update(self._decode_struct 707 decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header)); 708 offset = self._struct_size(msg_for 708 offset = self._struct_size(msg_format.fixed_header) 709 if msg_format.attr_set: 709 if msg_format.attr_set: 710 if msg_format.attr_set in self.att 710 if msg_format.attr_set in self.attr_sets: 711 subdict = self._decode(NlAttrs 711 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set) 712 decoded.update(subdict) 712 decoded.update(subdict) 713 else: 713 else: 714 raise Exception(f"Unknown attr 714 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'") 715 return decoded 715 return decoded 716 716 717 def _decode(self, attrs, space, outer_attr 717 def _decode(self, attrs, space, outer_attrs = None): 718 rsp = dict() 718 rsp = dict() 719 if space: 719 if space: 720 attr_space = self.attr_sets[space] 720 attr_space = self.attr_sets[space] 721 search_attrs = SpaceAttrs(attr_spa 721 search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs) 722 722 723 for attr in attrs: 723 for attr in attrs: 724 try: 724 try: 725 attr_spec = attr_space.attrs_b 725 attr_spec = attr_space.attrs_by_val[attr.type] 726 except (KeyError, UnboundLocalErro 726 except (KeyError, UnboundLocalError): 727 if not self.process_unknown: 727 if not self.process_unknown: 728 raise Exception(f"Space '{ 728 raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'") 729 attr_name = f"UnknownAttr({att 729 attr_name = f"UnknownAttr({attr.type})" 730 self._rsp_add(rsp, attr_name, 730 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr)) 731 continue 731 continue 732 732 733 if attr_spec["type"] == 'nest': 733 if attr_spec["type"] == 'nest': 734 subdict = self._decode(NlAttrs 734 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs) 735 decoded = subdict 735 decoded = subdict 736 elif attr_spec["type"] == 'string' 736 elif attr_spec["type"] == 'string': 737 decoded = attr.as_strz() 737 decoded = attr.as_strz() 738 elif attr_spec["type"] == 'binary' 738 elif attr_spec["type"] == 'binary': 739 decoded = self._decode_binary( 739 decoded = self._decode_binary(attr, attr_spec) 740 elif attr_spec["type"] == 'flag': 740 elif attr_spec["type"] == 'flag': 741 decoded = True 741 decoded = True 742 elif attr_spec.is_auto_scalar: 742 elif attr_spec.is_auto_scalar: 743 decoded = attr.as_auto_scalar( 743 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order) 744 elif attr_spec["type"] in NlAttr.t 744 elif attr_spec["type"] in NlAttr.type_formats: 745 decoded = attr.as_scalar(attr_ 745 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order) 746 if 'enum' in attr_spec: 746 if 'enum' in attr_spec: 747 decoded = self._decode_enu 747 decoded = self._decode_enum(decoded, attr_spec) 748 elif attr_spec.display_hint: 748 elif attr_spec.display_hint: 749 decoded = self._formatted_ 749 decoded = self._formatted_string(decoded, attr_spec.display_hint) 750 elif attr_spec["type"] == 'indexed 750 elif attr_spec["type"] == 'indexed-array': 751 decoded = self._decode_array_a 751 decoded = self._decode_array_attr(attr, attr_spec) 752 elif attr_spec["type"] == 'bitfiel 752 elif attr_spec["type"] == 'bitfield32': 753 value, selector = struct.unpac 753 value, selector = struct.unpack("II", attr.raw) 754 if 'enum' in attr_spec: 754 if 'enum' in attr_spec: 755 value = self._decode_enum( 755 value = self._decode_enum(value, attr_spec) 756 selector = self._decode_en 756 selector = self._decode_enum(selector, attr_spec) 757 decoded = {"value": value, "se 757 decoded = {"value": value, "selector": selector} 758 elif attr_spec["type"] == 'sub-mes 758 elif attr_spec["type"] == 'sub-message': 759 decoded = self._decode_sub_msg 759 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs) 760 elif attr_spec["type"] == 'nest-ty 760 elif attr_spec["type"] == 'nest-type-value': 761 decoded = self._decode_nest_ty 761 decoded = self._decode_nest_type_value(attr, attr_spec) 762 else: 762 else: 763 if not self.process_unknown: 763 if not self.process_unknown: 764 raise Exception(f'Unknown 764 raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}') 765 decoded = self._decode_unknown 765 decoded = self._decode_unknown(attr) 766 766 767 self._rsp_add(rsp, attr_spec["name 767 self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded) 768 768 769 return rsp 769 return rsp 770 770 771 def _decode_extack_path(self, attrs, attr_ 771 def _decode_extack_path(self, attrs, attr_set, offset, target): 772 for attr in attrs: 772 for attr in attrs: 773 try: 773 try: 774 attr_spec = attr_set.attrs_by_ 774 attr_spec = attr_set.attrs_by_val[attr.type] 775 except KeyError: 775 except KeyError: 776 raise Exception(f"Space '{attr 776 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'") 777 if offset > target: 777 if offset > target: 778 break 778 break 779 if offset == target: 779 if offset == target: 780 return '.' + attr_spec.name 780 return '.' + attr_spec.name 781 781 782 if offset + attr.full_len <= targe 782 if offset + attr.full_len <= target: 783 offset += attr.full_len 783 offset += attr.full_len 784 continue 784 continue 785 if attr_spec['type'] != 'nest': 785 if attr_spec['type'] != 'nest': 786 raise Exception(f"Can't dive i 786 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack") 787 offset += 4 787 offset += 4 788 subpath = self._decode_extack_path 788 subpath = self._decode_extack_path(NlAttrs(attr.raw), 789 789 self.attr_sets[attr_spec['nested-attributes']], 790 790 offset, target) 791 if subpath is None: 791 if subpath is None: 792 return None 792 return None 793 return '.' + attr_spec.name + subp 793 return '.' + attr_spec.name + subpath 794 794 795 return None 795 return None 796 796 797 def _decode_extack(self, request, op, exta 797 def _decode_extack(self, request, op, extack): 798 if 'bad-attr-offs' not in extack: 798 if 'bad-attr-offs' not in extack: 799 return 799 return 800 800 801 msg = self.nlproto.decode(self, NlMsg( 801 msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op) 802 offset = self.nlproto.msghdr_size() + 802 offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header) 803 path = self._decode_extack_path(msg.ra 803 path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset, 804 extack 804 extack['bad-attr-offs']) 805 if path: 805 if path: 806 del extack['bad-attr-offs'] 806 del extack['bad-attr-offs'] 807 extack['bad-attr'] = path 807 extack['bad-attr'] = path 808 808 809 def _struct_size(self, name): 809 def _struct_size(self, name): 810 if name: 810 if name: 811 members = self.consts[name].member 811 members = self.consts[name].members 812 size = 0 812 size = 0 813 for m in members: 813 for m in members: 814 if m.type in ['pad', 'binary'] 814 if m.type in ['pad', 'binary']: 815 if m.struct: 815 if m.struct: 816 size += self._struct_s 816 size += self._struct_size(m.struct) 817 else: 817 else: 818 size += m.len 818 size += m.len 819 else: 819 else: 820 format = NlAttr.get_format 820 format = NlAttr.get_format(m.type, m.byte_order) 821 size += format.size 821 size += format.size 822 return size 822 return size 823 else: 823 else: 824 return 0 824 return 0 825 825 826 def _decode_struct(self, data, name): 826 def _decode_struct(self, data, name): 827 members = self.consts[name].members 827 members = self.consts[name].members 828 attrs = dict() 828 attrs = dict() 829 offset = 0 829 offset = 0 830 for m in members: 830 for m in members: 831 value = None 831 value = None 832 if m.type == 'pad': 832 if m.type == 'pad': 833 offset += m.len 833 offset += m.len 834 elif m.type == 'binary': 834 elif m.type == 'binary': 835 if m.struct: 835 if m.struct: 836 len = self._struct_size(m. 836 len = self._struct_size(m.struct) 837 value = self._decode_struc 837 value = self._decode_struct(data[offset : offset + len], 838 838 m.struct) 839 offset += len 839 offset += len 840 else: 840 else: 841 value = data[offset : offs 841 value = data[offset : offset + m.len] 842 offset += m.len 842 offset += m.len 843 else: 843 else: 844 format = NlAttr.get_format(m.t 844 format = NlAttr.get_format(m.type, m.byte_order) 845 [ value ] = format.unpack_from 845 [ value ] = format.unpack_from(data, offset) 846 offset += format.size 846 offset += format.size 847 if value is not None: 847 if value is not None: 848 if m.enum: 848 if m.enum: 849 value = self._decode_enum( 849 value = self._decode_enum(value, m) 850 elif m.display_hint: 850 elif m.display_hint: 851 value = self._formatted_st 851 value = self._formatted_string(value, m.display_hint) 852 attrs[m.name] = value 852 attrs[m.name] = value 853 return attrs 853 return attrs 854 854 855 def _encode_struct(self, name, vals): 855 def _encode_struct(self, name, vals): 856 members = self.consts[name].members 856 members = self.consts[name].members 857 attr_payload = b'' 857 attr_payload = b'' 858 for m in members: 858 for m in members: 859 value = vals.pop(m.name) if m.name 859 value = vals.pop(m.name) if m.name in vals else None 860 if m.type == 'pad': 860 if m.type == 'pad': 861 attr_payload += bytearray(m.le 861 attr_payload += bytearray(m.len) 862 elif m.type == 'binary': 862 elif m.type == 'binary': 863 if m.struct: 863 if m.struct: 864 if value is None: 864 if value is None: 865 value = dict() 865 value = dict() 866 attr_payload += self._enco 866 attr_payload += self._encode_struct(m.struct, value) 867 else: 867 else: 868 if value is None: 868 if value is None: 869 attr_payload += bytear 869 attr_payload += bytearray(m.len) 870 else: 870 else: 871 attr_payload += bytes. 871 attr_payload += bytes.fromhex(value) 872 else: 872 else: 873 if value is None: 873 if value is None: 874 value = 0 874 value = 0 875 format = NlAttr.get_format(m.t 875 format = NlAttr.get_format(m.type, m.byte_order) 876 attr_payload += format.pack(va 876 attr_payload += format.pack(value) 877 return attr_payload 877 return attr_payload 878 878 879 def _formatted_string(self, raw, display_h 879 def _formatted_string(self, raw, display_hint): 880 if display_hint == 'mac': 880 if display_hint == 'mac': 881 formatted = ':'.join('%02x' % b fo 881 formatted = ':'.join('%02x' % b for b in raw) 882 elif display_hint == 'hex': 882 elif display_hint == 'hex': 883 if isinstance(raw, int): 883 if isinstance(raw, int): 884 formatted = hex(raw) 884 formatted = hex(raw) 885 else: 885 else: 886 formatted = bytes.hex(raw, ' ' 886 formatted = bytes.hex(raw, ' ') 887 elif display_hint in [ 'ipv4', 'ipv6' 887 elif display_hint in [ 'ipv4', 'ipv6' ]: 888 formatted = format(ipaddress.ip_ad 888 formatted = format(ipaddress.ip_address(raw)) 889 elif display_hint == 'uuid': 889 elif display_hint == 'uuid': 890 formatted = str(uuid.UUID(bytes=ra 890 formatted = str(uuid.UUID(bytes=raw)) 891 else: 891 else: 892 formatted = raw 892 formatted = raw 893 return formatted 893 return formatted 894 894 895 def handle_ntf(self, decoded): 895 def handle_ntf(self, decoded): 896 msg = dict() 896 msg = dict() 897 if self.include_raw: 897 if self.include_raw: 898 msg['raw'] = decoded 898 msg['raw'] = decoded 899 op = self.rsp_by_value[decoded.cmd()] 899 op = self.rsp_by_value[decoded.cmd()] 900 attrs = self._decode(decoded.raw_attrs 900 attrs = self._decode(decoded.raw_attrs, op.attr_set.name) 901 if op.fixed_header: 901 if op.fixed_header: 902 attrs.update(self._decode_struct(d 902 attrs.update(self._decode_struct(decoded.raw, op.fixed_header)) 903 903 904 msg['name'] = op['name'] 904 msg['name'] = op['name'] 905 msg['msg'] = attrs 905 msg['msg'] = attrs 906 self.async_msg_queue.append(msg) 906 self.async_msg_queue.append(msg) 907 907 908 def check_ntf(self): 908 def check_ntf(self): 909 while True: 909 while True: 910 try: 910 try: 911 reply = self.sock.recv(self._r 911 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT) 912 except BlockingIOError: 912 except BlockingIOError: 913 return 913 return 914 914 915 nms = NlMsgs(reply) 915 nms = NlMsgs(reply) 916 self._recv_dbg_print(reply, nms) 916 self._recv_dbg_print(reply, nms) 917 for nl_msg in nms: 917 for nl_msg in nms: 918 if nl_msg.error: 918 if nl_msg.error: 919 print("Netlink error in nt 919 print("Netlink error in ntf!?", os.strerror(-nl_msg.error)) 920 print(nl_msg) 920 print(nl_msg) 921 continue 921 continue 922 if nl_msg.done: 922 if nl_msg.done: 923 print("Netlink done while 923 print("Netlink done while checking for ntf!?") 924 continue 924 continue 925 925 926 decoded = self.nlproto.decode( 926 decoded = self.nlproto.decode(self, nl_msg, None) 927 if decoded.cmd() not in self.a 927 if decoded.cmd() not in self.async_msg_ids: 928 print("Unexpected msg id d 928 print("Unexpected msg id done while checking for ntf", decoded) 929 continue 929 continue 930 930 931 self.handle_ntf(decoded) 931 self.handle_ntf(decoded) 932 932 933 def operation_do_attributes(self, name): 933 def operation_do_attributes(self, name): 934 """ 934 """ 935 For a given operation name, find and ret 935 For a given operation name, find and return a supported 936 set of attributes (as a dict). 936 set of attributes (as a dict). 937 """ 937 """ 938 op = self.find_operation(name) 938 op = self.find_operation(name) 939 if not op: 939 if not op: 940 return None 940 return None 941 941 942 return op['do']['request']['attributes'] 942 return op['do']['request']['attributes'].copy() 943 943 944 def _encode_message(self, op, vals, flags, 944 def _encode_message(self, op, vals, flags, req_seq): 945 nl_flags = Netlink.NLM_F_REQUEST | Net 945 nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK 946 for flag in flags or []: 946 for flag in flags or []: 947 nl_flags |= flag 947 nl_flags |= flag 948 948 949 msg = self.nlproto.message(nl_flags, o 949 msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq) 950 if op.fixed_header: 950 if op.fixed_header: 951 msg += self._encode_struct(op.fixe 951 msg += self._encode_struct(op.fixed_header, vals) 952 search_attrs = SpaceAttrs(op.attr_set, 952 search_attrs = SpaceAttrs(op.attr_set, vals) 953 for name, value in vals.items(): 953 for name, value in vals.items(): 954 msg += self._add_attr(op.attr_set. 954 msg += self._add_attr(op.attr_set.name, name, value, search_attrs) 955 msg = _genl_msg_finalize(msg) 955 msg = _genl_msg_finalize(msg) 956 return msg 956 return msg 957 957 958 def _ops(self, ops): 958 def _ops(self, ops): 959 reqs_by_seq = {} 959 reqs_by_seq = {} 960 req_seq = random.randint(1024, 65535) 960 req_seq = random.randint(1024, 65535) 961 payload = b'' 961 payload = b'' 962 for (method, vals, flags) in ops: 962 for (method, vals, flags) in ops: 963 op = self.ops[method] 963 op = self.ops[method] 964 msg = self._encode_message(op, val 964 msg = self._encode_message(op, vals, flags, req_seq) 965 reqs_by_seq[req_seq] = (op, msg, f 965 reqs_by_seq[req_seq] = (op, msg, flags) 966 payload += msg 966 payload += msg 967 req_seq += 1 967 req_seq += 1 968 968 969 self.sock.send(payload, 0) 969 self.sock.send(payload, 0) 970 970 971 done = False 971 done = False 972 rsp = [] 972 rsp = [] 973 op_rsp = [] 973 op_rsp = [] 974 while not done: 974 while not done: 975 reply = self.sock.recv(self._recv_ 975 reply = self.sock.recv(self._recv_size) 976 nms = NlMsgs(reply, attr_space=op. 976 nms = NlMsgs(reply, attr_space=op.attr_set) 977 self._recv_dbg_print(reply, nms) 977 self._recv_dbg_print(reply, nms) 978 for nl_msg in nms: 978 for nl_msg in nms: 979 if nl_msg.nl_seq in reqs_by_se 979 if nl_msg.nl_seq in reqs_by_seq: 980 (op, req_msg, req_flags) = 980 (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq] 981 if nl_msg.extack: 981 if nl_msg.extack: 982 self._decode_extack(re 982 self._decode_extack(req_msg, op, nl_msg.extack) 983 else: 983 else: 984 op = None 984 op = None 985 req_flags = [] 985 req_flags = [] 986 986 987 if nl_msg.error: 987 if nl_msg.error: 988 raise NlError(nl_msg) 988 raise NlError(nl_msg) 989 if nl_msg.done: 989 if nl_msg.done: 990 if nl_msg.extack: 990 if nl_msg.extack: 991 print("Netlink warning 991 print("Netlink warning:") 992 print(nl_msg) 992 print(nl_msg) 993 993 994 if Netlink.NLM_F_DUMP in r 994 if Netlink.NLM_F_DUMP in req_flags: 995 rsp.append(op_rsp) 995 rsp.append(op_rsp) 996 elif not op_rsp: 996 elif not op_rsp: 997 rsp.append(None) 997 rsp.append(None) 998 elif len(op_rsp) == 1: 998 elif len(op_rsp) == 1: 999 rsp.append(op_rsp[0]) 999 rsp.append(op_rsp[0]) 1000 else: 1000 else: 1001 rsp.append(op_rsp) 1001 rsp.append(op_rsp) 1002 op_rsp = [] 1002 op_rsp = [] 1003 1003 1004 del reqs_by_seq[nl_msg.nl 1004 del reqs_by_seq[nl_msg.nl_seq] 1005 done = len(reqs_by_seq) = 1005 done = len(reqs_by_seq) == 0 1006 break 1006 break 1007 1007 1008 decoded = self.nlproto.decode 1008 decoded = self.nlproto.decode(self, nl_msg, op) 1009 1009 1010 # Check if this is a reply to 1010 # Check if this is a reply to our request 1011 if nl_msg.nl_seq not in reqs_ 1011 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value: 1012 if decoded.cmd() in self. 1012 if decoded.cmd() in self.async_msg_ids: 1013 self.handle_ntf(decod 1013 self.handle_ntf(decoded) 1014 continue 1014 continue 1015 else: 1015 else: 1016 print('Unexpected mes 1016 print('Unexpected message: ' + repr(decoded)) 1017 continue 1017 continue 1018 1018 1019 rsp_msg = self._decode(decode 1019 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name) 1020 if op.fixed_header: 1020 if op.fixed_header: 1021 rsp_msg.update(self._deco 1021 rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header)) 1022 op_rsp.append(rsp_msg) 1022 op_rsp.append(rsp_msg) 1023 1023 1024 return rsp 1024 return rsp 1025 1025 1026 def _op(self, method, vals, flags=None, d 1026 def _op(self, method, vals, flags=None, dump=False): 1027 req_flags = flags or [] 1027 req_flags = flags or [] 1028 if dump: 1028 if dump: 1029 req_flags.append(Netlink.NLM_F_DU 1029 req_flags.append(Netlink.NLM_F_DUMP) 1030 1030 1031 ops = [(method, vals, req_flags)] 1031 ops = [(method, vals, req_flags)] 1032 return self._ops(ops)[0] 1032 return self._ops(ops)[0] 1033 1033 1034 def do(self, method, vals, flags=None): 1034 def do(self, method, vals, flags=None): 1035 return self._op(method, vals, flags) 1035 return self._op(method, vals, flags) 1036 1036 1037 def dump(self, method, vals): 1037 def dump(self, method, vals): 1038 return self._op(method, vals, dump=Tr 1038 return self._op(method, vals, dump=True) 1039 1039 1040 def do_multi(self, ops): 1040 def do_multi(self, ops): 1041 return self._ops(ops) 1041 return self._ops(ops)
Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.