~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

TOMOYO Linux Cross Reference
Linux/tools/net/ynl/lib/ynl.py

Version: ~ [ linux-6.12-rc7 ] ~ [ linux-6.11.7 ] ~ [ linux-6.10.14 ] ~ [ linux-6.9.12 ] ~ [ linux-6.8.12 ] ~ [ linux-6.7.12 ] ~ [ linux-6.6.60 ] ~ [ linux-6.5.13 ] ~ [ linux-6.4.16 ] ~ [ linux-6.3.13 ] ~ [ linux-6.2.16 ] ~ [ linux-6.1.116 ] ~ [ linux-6.0.19 ] ~ [ linux-5.19.17 ] ~ [ linux-5.18.19 ] ~ [ linux-5.17.15 ] ~ [ linux-5.16.20 ] ~ [ linux-5.15.171 ] ~ [ linux-5.14.21 ] ~ [ linux-5.13.19 ] ~ [ linux-5.12.19 ] ~ [ linux-5.11.22 ] ~ [ linux-5.10.229 ] ~ [ linux-5.9.16 ] ~ [ linux-5.8.18 ] ~ [ linux-5.7.19 ] ~ [ linux-5.6.19 ] ~ [ linux-5.5.19 ] ~ [ linux-5.4.285 ] ~ [ linux-5.3.18 ] ~ [ linux-5.2.21 ] ~ [ linux-5.1.21 ] ~ [ linux-5.0.21 ] ~ [ linux-4.20.17 ] ~ [ linux-4.19.323 ] ~ [ linux-4.18.20 ] ~ [ linux-4.17.19 ] ~ [ linux-4.16.18 ] ~ [ linux-4.15.18 ] ~ [ linux-4.14.336 ] ~ [ linux-4.13.16 ] ~ [ linux-4.12.14 ] ~ [ linux-4.11.12 ] ~ [ linux-4.10.17 ] ~ [ linux-4.9.337 ] ~ [ linux-4.4.302 ] ~ [ linux-3.10.108 ] ~ [ linux-2.6.32.71 ] ~ [ linux-2.6.0 ] ~ [ linux-2.4.37.11 ] ~ [ unix-v6-master ] ~ [ ccs-tools-1.8.12 ] ~ [ policy-sample ] ~
Architecture: ~ [ i386 ] ~ [ alpha ] ~ [ m68k ] ~ [ mips ] ~ [ ppc ] ~ [ sparc ] ~ [ sparc64 ] ~

  1 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
  2 
  3 from collections import namedtuple
  4 from enum import Enum
  5 import functools
  6 import os
  7 import random
  8 import socket
  9 import struct
 10 from struct import Struct
 11 import sys
 12 import yaml
 13 import ipaddress
 14 import uuid
 15 
 16 from .nlspec import SpecFamily
 17 
 18 #
 19 # Generic Netlink code which should really be in some library, but I can't quickly find one.
 20 #
 21 
 22 
 23 class Netlink:
 24     # Netlink socket
 25     SOL_NETLINK = 270
 26 
 27     NETLINK_ADD_MEMBERSHIP = 1
 28     NETLINK_CAP_ACK = 10
 29     NETLINK_EXT_ACK = 11
 30     NETLINK_GET_STRICT_CHK = 12
 31 
 32     # Netlink message
 33     NLMSG_ERROR = 2
 34     NLMSG_DONE = 3
 35 
 36     NLM_F_REQUEST = 1
 37     NLM_F_ACK = 4
 38     NLM_F_ROOT = 0x100
 39     NLM_F_MATCH = 0x200
 40 
 41     NLM_F_REPLACE = 0x100
 42     NLM_F_EXCL = 0x200
 43     NLM_F_CREATE = 0x400
 44     NLM_F_APPEND = 0x800
 45 
 46     NLM_F_CAPPED = 0x100
 47     NLM_F_ACK_TLVS = 0x200
 48 
 49     NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
 50 
 51     NLA_F_NESTED = 0x8000
 52     NLA_F_NET_BYTEORDER = 0x4000
 53 
 54     NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
 55 
 56     # Genetlink defines
 57     NETLINK_GENERIC = 16
 58 
 59     GENL_ID_CTRL = 0x10
 60 
 61     # nlctrl
 62     CTRL_CMD_GETFAMILY = 3
 63 
 64     CTRL_ATTR_FAMILY_ID = 1
 65     CTRL_ATTR_FAMILY_NAME = 2
 66     CTRL_ATTR_MAXATTR = 5
 67     CTRL_ATTR_MCAST_GROUPS = 7
 68 
 69     CTRL_ATTR_MCAST_GRP_NAME = 1
 70     CTRL_ATTR_MCAST_GRP_ID = 2
 71 
 72     # Extack types
 73     NLMSGERR_ATTR_MSG = 1
 74     NLMSGERR_ATTR_OFFS = 2
 75     NLMSGERR_ATTR_COOKIE = 3
 76     NLMSGERR_ATTR_POLICY = 4
 77     NLMSGERR_ATTR_MISS_TYPE = 5
 78     NLMSGERR_ATTR_MISS_NEST = 6
 79 
 80     # Policy types
 81     NL_POLICY_TYPE_ATTR_TYPE = 1
 82     NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
 83     NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
 84     NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
 85     NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
 86     NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
 87     NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
 88     NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
 89     NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
 90     NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
 91     NL_POLICY_TYPE_ATTR_PAD = 11
 92     NL_POLICY_TYPE_ATTR_MASK = 12
 93 
 94     AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
 95                                   's8', 's16', 's32', 's64',
 96                                   'binary', 'string', 'nul-string',
 97                                   'nested', 'nested-array',
 98                                   'bitfield32', 'sint', 'uint'])
 99 
100 class NlError(Exception):
101   def __init__(self, nl_msg):
102     self.nl_msg = nl_msg
103     self.error = -nl_msg.error
104 
105   def __str__(self):
106     return f"Netlink error: {os.strerror(self.error)}\n{self.nl_msg}"
107 
108 
109 class ConfigError(Exception):
110     pass
111 
112 
113 class NlAttr:
114     ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
115     type_formats = {
116         'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
117         's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
118         'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
119         's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
120         'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
121         's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
122         'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
123         's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
124     }
125 
126     def __init__(self, raw, offset):
127         self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
128         self.type = self._type & ~Netlink.NLA_TYPE_MASK
129         self.is_nest = self._type & Netlink.NLA_F_NESTED
130         self.payload_len = self._len
131         self.full_len = (self.payload_len + 3) & ~3
132         self.raw = raw[offset + 4 : offset + self.payload_len]
133 
134     @classmethod
135     def get_format(cls, attr_type, byte_order=None):
136         format = cls.type_formats[attr_type]
137         if byte_order:
138             return format.big if byte_order == "big-endian" \
139                 else format.little
140         return format.native
141 
142     def as_scalar(self, attr_type, byte_order=None):
143         format = self.get_format(attr_type, byte_order)
144         return format.unpack(self.raw)[0]
145 
146     def as_auto_scalar(self, attr_type, byte_order=None):
147         if len(self.raw) != 4 and len(self.raw) != 8:
148             raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
149         real_type = attr_type[0] + str(len(self.raw) * 8)
150         format = self.get_format(real_type, byte_order)
151         return format.unpack(self.raw)[0]
152 
153     def as_strz(self):
154         return self.raw.decode('ascii')[:-1]
155 
156     def as_bin(self):
157         return self.raw
158 
159     def as_c_array(self, type):
160         format = self.get_format(type)
161         return [ x[0] for x in format.iter_unpack(self.raw) ]
162 
163     def __repr__(self):
164         return f"[type:{self.type} len:{self._len}] {self.raw}"
165 
166 
167 class NlAttrs:
168     def __init__(self, msg, offset=0):
169         self.attrs = []
170 
171         while offset < len(msg):
172             attr = NlAttr(msg, offset)
173             offset += attr.full_len
174             self.attrs.append(attr)
175 
176     def __iter__(self):
177         yield from self.attrs
178 
179     def __repr__(self):
180         msg = ''
181         for a in self.attrs:
182             if msg:
183                 msg += '\n'
184             msg += repr(a)
185         return msg
186 
187 
188 class NlMsg:
189     def __init__(self, msg, offset, attr_space=None):
190         self.hdr = msg[offset : offset + 16]
191 
192         self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
193             struct.unpack("IHHII", self.hdr)
194 
195         self.raw = msg[offset + 16 : offset + self.nl_len]
196 
197         self.error = 0
198         self.done = 0
199 
200         extack_off = None
201         if self.nl_type == Netlink.NLMSG_ERROR:
202             self.error = struct.unpack("i", self.raw[0:4])[0]
203             self.done = 1
204             extack_off = 20
205         elif self.nl_type == Netlink.NLMSG_DONE:
206             self.error = struct.unpack("i", self.raw[0:4])[0]
207             self.done = 1
208             extack_off = 4
209 
210         self.extack = None
211         if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
212             self.extack = dict()
213             extack_attrs = NlAttrs(self.raw[extack_off:])
214             for extack in extack_attrs:
215                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
216                     self.extack['msg'] = extack.as_strz()
217                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
218                     self.extack['miss-type'] = extack.as_scalar('u32')
219                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
220                     self.extack['miss-nest'] = extack.as_scalar('u32')
221                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
222                     self.extack['bad-attr-offs'] = extack.as_scalar('u32')
223                 elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
224                     self.extack['policy'] = self._decode_policy(extack.raw)
225                 else:
226                     if 'unknown' not in self.extack:
227                         self.extack['unknown'] = []
228                     self.extack['unknown'].append(extack)
229 
230             if attr_space:
231                 # We don't have the ability to parse nests yet, so only do global
232                 if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
233                     miss_type = self.extack['miss-type']
234                     if miss_type in attr_space.attrs_by_val:
235                         spec = attr_space.attrs_by_val[miss_type]
236                         self.extack['miss-type'] = spec['name']
237                         if 'doc' in spec:
238                             self.extack['miss-type-doc'] = spec['doc']
239 
240     def _decode_policy(self, raw):
241         policy = {}
242         for attr in NlAttrs(raw):
243             if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
244                 type = attr.as_scalar('u32')
245                 policy['type'] = Netlink.AttrType(type).name
246             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
247                 policy['min-value'] = attr.as_scalar('s64')
248             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
249                 policy['max-value'] = attr.as_scalar('s64')
250             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
251                 policy['min-value'] = attr.as_scalar('u64')
252             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
253                 policy['max-value'] = attr.as_scalar('u64')
254             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
255                 policy['min-length'] = attr.as_scalar('u32')
256             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
257                 policy['max-length'] = attr.as_scalar('u32')
258             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
259                 policy['bitfield32-mask'] = attr.as_scalar('u32')
260             elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
261                 policy['mask'] = attr.as_scalar('u64')
262         return policy
263 
264     def cmd(self):
265         return self.nl_type
266 
267     def __repr__(self):
268         msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
269         if self.error:
270             msg += '\n\terror: ' + str(self.error)
271         if self.extack:
272             msg += '\n\textack: ' + repr(self.extack)
273         return msg
274 
275 
276 class NlMsgs:
277     def __init__(self, data, attr_space=None):
278         self.msgs = []
279 
280         offset = 0
281         while offset < len(data):
282             msg = NlMsg(data, offset, attr_space=attr_space)
283             offset += msg.nl_len
284             self.msgs.append(msg)
285 
286     def __iter__(self):
287         yield from self.msgs
288 
289 
290 genl_family_name_to_id = None
291 
292 
293 def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
294     # we prepend length in _genl_msg_finalize()
295     if seq is None:
296         seq = random.randint(1, 1024)
297     nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
298     genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
299     return nlmsg + genlmsg
300 
301 
302 def _genl_msg_finalize(msg):
303     return struct.pack("I", len(msg) + 4) + msg
304 
305 
306 def _genl_load_families():
307     with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
308         sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
309 
310         msg = _genl_msg(Netlink.GENL_ID_CTRL,
311                         Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
312                         Netlink.CTRL_CMD_GETFAMILY, 1)
313         msg = _genl_msg_finalize(msg)
314 
315         sock.send(msg, 0)
316 
317         global genl_family_name_to_id
318         genl_family_name_to_id = dict()
319 
320         while True:
321             reply = sock.recv(128 * 1024)
322             nms = NlMsgs(reply)
323             for nl_msg in nms:
324                 if nl_msg.error:
325                     print("Netlink error:", nl_msg.error)
326                     return
327                 if nl_msg.done:
328                     return
329 
330                 gm = GenlMsg(nl_msg)
331                 fam = dict()
332                 for attr in NlAttrs(gm.raw):
333                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
334                         fam['id'] = attr.as_scalar('u16')
335                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
336                         fam['name'] = attr.as_strz()
337                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
338                         fam['maxattr'] = attr.as_scalar('u32')
339                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
340                         fam['mcast'] = dict()
341                         for entry in NlAttrs(attr.raw):
342                             mcast_name = None
343                             mcast_id = None
344                             for entry_attr in NlAttrs(entry.raw):
345                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
346                                     mcast_name = entry_attr.as_strz()
347                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
348                                     mcast_id = entry_attr.as_scalar('u32')
349                             if mcast_name and mcast_id is not None:
350                                 fam['mcast'][mcast_name] = mcast_id
351                 if 'name' in fam and 'id' in fam:
352                     genl_family_name_to_id[fam['name']] = fam
353 
354 
355 class GenlMsg:
356     def __init__(self, nl_msg):
357         self.nl = nl_msg
358         self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
359         self.raw = nl_msg.raw[4:]
360 
361     def cmd(self):
362         return self.genl_cmd
363 
364     def __repr__(self):
365         msg = repr(self.nl)
366         msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
367         for a in self.raw_attrs:
368             msg += '\t\t' + repr(a) + '\n'
369         return msg
370 
371 
372 class NetlinkProtocol:
373     def __init__(self, family_name, proto_num):
374         self.family_name = family_name
375         self.proto_num = proto_num
376 
377     def _message(self, nl_type, nl_flags, seq=None):
378         if seq is None:
379             seq = random.randint(1, 1024)
380         nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
381         return nlmsg
382 
383     def message(self, flags, command, version, seq=None):
384         return self._message(command, flags, seq)
385 
386     def _decode(self, nl_msg):
387         return nl_msg
388 
389     def decode(self, ynl, nl_msg, op):
390         msg = self._decode(nl_msg)
391         if op is None:
392             op = ynl.rsp_by_value[msg.cmd()]
393         fixed_header_size = ynl._struct_size(op.fixed_header)
394         msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
395         return msg
396 
397     def get_mcast_id(self, mcast_name, mcast_groups):
398         if mcast_name not in mcast_groups:
399             raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
400         return mcast_groups[mcast_name].value
401 
402     def msghdr_size(self):
403         return 16
404 
405 
406 class GenlProtocol(NetlinkProtocol):
407     def __init__(self, family_name):
408         super().__init__(family_name, Netlink.NETLINK_GENERIC)
409 
410         global genl_family_name_to_id
411         if genl_family_name_to_id is None:
412             _genl_load_families()
413 
414         self.genl_family = genl_family_name_to_id[family_name]
415         self.family_id = genl_family_name_to_id[family_name]['id']
416 
417     def message(self, flags, command, version, seq=None):
418         nlmsg = self._message(self.family_id, flags, seq)
419         genlmsg = struct.pack("BBH", command, version, 0)
420         return nlmsg + genlmsg
421 
422     def _decode(self, nl_msg):
423         return GenlMsg(nl_msg)
424 
425     def get_mcast_id(self, mcast_name, mcast_groups):
426         if mcast_name not in self.genl_family['mcast']:
427             raise Exception(f'Multicast group "{mcast_name}" not present in the family')
428         return self.genl_family['mcast'][mcast_name]
429 
430     def msghdr_size(self):
431         return super().msghdr_size() + 4
432 
433 
434 class SpaceAttrs:
435     SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
436 
437     def __init__(self, attr_space, attrs, outer = None):
438         outer_scopes = outer.scopes if outer else []
439         inner_scope = self.SpecValuesPair(attr_space, attrs)
440         self.scopes = [inner_scope] + outer_scopes
441 
442     def lookup(self, name):
443         for scope in self.scopes:
444             if name in scope.spec:
445                 if name in scope.values:
446                     return scope.values[name]
447                 spec_name = scope.spec.yaml['name']
448                 raise Exception(
449                     f"No value for '{name}' in attribute space '{spec_name}'")
450         raise Exception(f"Attribute '{name}' not defined in any attribute-set")
451 
452 
453 #
454 # YNL implementation details.
455 #
456 
457 
458 class YnlFamily(SpecFamily):
459     def __init__(self, def_path, schema=None, process_unknown=False,
460                  recv_size=0):
461         super().__init__(def_path, schema)
462 
463         self.include_raw = False
464         self.process_unknown = process_unknown
465 
466         try:
467             if self.proto == "netlink-raw":
468                 self.nlproto = NetlinkProtocol(self.yaml['name'],
469                                                self.yaml['protonum'])
470             else:
471                 self.nlproto = GenlProtocol(self.yaml['name'])
472         except KeyError:
473             raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
474 
475         self._recv_dbg = False
476         # Note that netlink will use conservative (min) message size for
477         # the first dump recv() on the socket, our setting will only matter
478         # from the second recv() on.
479         self._recv_size = recv_size if recv_size else 131072
480         # Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
481         # for a message, so smaller receive sizes will lead to truncation.
482         # Note that the min size for other families may be larger than 4k!
483         if self._recv_size < 4000:
484             raise ConfigError()
485 
486         self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
487         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
488         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
489         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
490 
491         self.async_msg_ids = set()
492         self.async_msg_queue = []
493 
494         for msg in self.msgs.values():
495             if msg.is_async:
496                 self.async_msg_ids.add(msg.rsp_value)
497 
498         for op_name, op in self.ops.items():
499             bound_f = functools.partial(self._op, op_name)
500             setattr(self, op.ident_name, bound_f)
501 
502 
503     def ntf_subscribe(self, mcast_name):
504         mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
505         self.sock.bind((0, 0))
506         self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
507                              mcast_id)
508 
509     def set_recv_dbg(self, enabled):
510         self._recv_dbg = enabled
511 
512     def _recv_dbg_print(self, reply, nl_msgs):
513         if not self._recv_dbg:
514             return
515         print("Recv: read", len(reply), "bytes,",
516               len(nl_msgs.msgs), "messages", file=sys.stderr)
517         for nl_msg in nl_msgs:
518             print("  ", nl_msg, file=sys.stderr)
519 
520     def _encode_enum(self, attr_spec, value):
521         enum = self.consts[attr_spec['enum']]
522         if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
523             scalar = 0
524             if isinstance(value, str):
525                 value = [value]
526             for single_value in value:
527                 scalar += enum.entries[single_value].user_value(as_flags = True)
528             return scalar
529         else:
530             return enum.entries[value].user_value()
531 
532     def _get_scalar(self, attr_spec, value):
533         try:
534             return int(value)
535         except (ValueError, TypeError) as e:
536             if 'enum' not in attr_spec:
537                 raise e
538         return self._encode_enum(attr_spec, value)
539 
540     def _add_attr(self, space, name, value, search_attrs):
541         try:
542             attr = self.attr_sets[space][name]
543         except KeyError:
544             raise Exception(f"Space '{space}' has no attribute '{name}'")
545         nl_type = attr.value
546 
547         if attr.is_multi and isinstance(value, list):
548             attr_payload = b''
549             for subvalue in value:
550                 attr_payload += self._add_attr(space, name, subvalue, search_attrs)
551             return attr_payload
552 
553         if attr["type"] == 'nest':
554             nl_type |= Netlink.NLA_F_NESTED
555             attr_payload = b''
556             sub_attrs = SpaceAttrs(self.attr_sets[space], value, search_attrs)
557             for subname, subvalue in value.items():
558                 attr_payload += self._add_attr(attr['nested-attributes'],
559                                                subname, subvalue, sub_attrs)
560         elif attr["type"] == 'flag':
561             if not value:
562                 # If value is absent or false then skip attribute creation.
563                 return b''
564             attr_payload = b''
565         elif attr["type"] == 'string':
566             attr_payload = str(value).encode('ascii') + b'\x00'
567         elif attr["type"] == 'binary':
568             if isinstance(value, bytes):
569                 attr_payload = value
570             elif isinstance(value, str):
571                 attr_payload = bytes.fromhex(value)
572             elif isinstance(value, dict) and attr.struct_name:
573                 attr_payload = self._encode_struct(attr.struct_name, value)
574             else:
575                 raise Exception(f'Unknown type for binary attribute, value: {value}')
576         elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
577             scalar = self._get_scalar(attr, value)
578             if attr.is_auto_scalar:
579                 attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
580             else:
581                 attr_type = attr["type"]
582             format = NlAttr.get_format(attr_type, attr.byte_order)
583             attr_payload = format.pack(scalar)
584         elif attr['type'] in "bitfield32":
585             scalar_value = self._get_scalar(attr, value["value"])
586             scalar_selector = self._get_scalar(attr, value["selector"])
587             attr_payload = struct.pack("II", scalar_value, scalar_selector)
588         elif attr['type'] == 'sub-message':
589             msg_format = self._resolve_selector(attr, search_attrs)
590             attr_payload = b''
591             if msg_format.fixed_header:
592                 attr_payload += self._encode_struct(msg_format.fixed_header, value)
593             if msg_format.attr_set:
594                 if msg_format.attr_set in self.attr_sets:
595                     nl_type |= Netlink.NLA_F_NESTED
596                     sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
597                     for subname, subvalue in value.items():
598                         attr_payload += self._add_attr(msg_format.attr_set,
599                                                        subname, subvalue, sub_attrs)
600                 else:
601                     raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
602         else:
603             raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
604 
605         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
606         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
607 
608     def _decode_enum(self, raw, attr_spec):
609         enum = self.consts[attr_spec['enum']]
610         if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
611             i = 0
612             value = set()
613             while raw:
614                 if raw & 1:
615                     value.add(enum.entries_by_val[i].name)
616                 raw >>= 1
617                 i += 1
618         else:
619             value = enum.entries_by_val[raw].name
620         return value
621 
622     def _decode_binary(self, attr, attr_spec):
623         if attr_spec.struct_name:
624             decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
625         elif attr_spec.sub_type:
626             decoded = attr.as_c_array(attr_spec.sub_type)
627         else:
628             decoded = attr.as_bin()
629             if attr_spec.display_hint:
630                 decoded = self._formatted_string(decoded, attr_spec.display_hint)
631         return decoded
632 
633     def _decode_array_attr(self, attr, attr_spec):
634         decoded = []
635         offset = 0
636         while offset < len(attr.raw):
637             item = NlAttr(attr.raw, offset)
638             offset += item.full_len
639 
640             if attr_spec["sub-type"] == 'nest':
641                 subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
642                 decoded.append({ item.type: subattrs })
643             elif attr_spec["sub-type"] == 'binary':
644                 subattrs = item.as_bin()
645                 if attr_spec.display_hint:
646                     subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
647                 decoded.append(subattrs)
648             elif attr_spec["sub-type"] in NlAttr.type_formats:
649                 subattrs = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
650                 if attr_spec.display_hint:
651                     subattrs = self._formatted_string(subattrs, attr_spec.display_hint)
652                 decoded.append(subattrs)
653             else:
654                 raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
655         return decoded
656 
657     def _decode_nest_type_value(self, attr, attr_spec):
658         decoded = {}
659         value = attr
660         for name in attr_spec['type-value']:
661             value = NlAttr(value.raw, 0)
662             decoded[name] = value.type
663         subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
664         decoded.update(subattrs)
665         return decoded
666 
667     def _decode_unknown(self, attr):
668         if attr.is_nest:
669             return self._decode(NlAttrs(attr.raw), None)
670         else:
671             return attr.as_bin()
672 
673     def _rsp_add(self, rsp, name, is_multi, decoded):
674         if is_multi == None:
675             if name in rsp and type(rsp[name]) is not list:
676                 rsp[name] = [rsp[name]]
677                 is_multi = True
678             else:
679                 is_multi = False
680 
681         if not is_multi:
682             rsp[name] = decoded
683         elif name in rsp:
684             rsp[name].append(decoded)
685         else:
686             rsp[name] = [decoded]
687 
688     def _resolve_selector(self, attr_spec, search_attrs):
689         sub_msg = attr_spec.sub_message
690         if sub_msg not in self.sub_msgs:
691             raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
692         sub_msg_spec = self.sub_msgs[sub_msg]
693 
694         selector = attr_spec.selector
695         value = search_attrs.lookup(selector)
696         if value not in sub_msg_spec.formats:
697             raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
698 
699         spec = sub_msg_spec.formats[value]
700         return spec
701 
702     def _decode_sub_msg(self, attr, attr_spec, search_attrs):
703         msg_format = self._resolve_selector(attr_spec, search_attrs)
704         decoded = {}
705         offset = 0
706         if msg_format.fixed_header:
707             decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header));
708             offset = self._struct_size(msg_format.fixed_header)
709         if msg_format.attr_set:
710             if msg_format.attr_set in self.attr_sets:
711                 subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
712                 decoded.update(subdict)
713             else:
714                 raise Exception(f"Unknown attribute-set '{attr_space}' when decoding '{attr_spec.name}'")
715         return decoded
716 
717     def _decode(self, attrs, space, outer_attrs = None):
718         rsp = dict()
719         if space:
720             attr_space = self.attr_sets[space]
721             search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
722 
723         for attr in attrs:
724             try:
725                 attr_spec = attr_space.attrs_by_val[attr.type]
726             except (KeyError, UnboundLocalError):
727                 if not self.process_unknown:
728                     raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
729                 attr_name = f"UnknownAttr({attr.type})"
730                 self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
731                 continue
732 
733             if attr_spec["type"] == 'nest':
734                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
735                 decoded = subdict
736             elif attr_spec["type"] == 'string':
737                 decoded = attr.as_strz()
738             elif attr_spec["type"] == 'binary':
739                 decoded = self._decode_binary(attr, attr_spec)
740             elif attr_spec["type"] == 'flag':
741                 decoded = True
742             elif attr_spec.is_auto_scalar:
743                 decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
744             elif attr_spec["type"] in NlAttr.type_formats:
745                 decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
746                 if 'enum' in attr_spec:
747                     decoded = self._decode_enum(decoded, attr_spec)
748                 elif attr_spec.display_hint:
749                     decoded = self._formatted_string(decoded, attr_spec.display_hint)
750             elif attr_spec["type"] == 'indexed-array':
751                 decoded = self._decode_array_attr(attr, attr_spec)
752             elif attr_spec["type"] == 'bitfield32':
753                 value, selector = struct.unpack("II", attr.raw)
754                 if 'enum' in attr_spec:
755                     value = self._decode_enum(value, attr_spec)
756                     selector = self._decode_enum(selector, attr_spec)
757                 decoded = {"value": value, "selector": selector}
758             elif attr_spec["type"] == 'sub-message':
759                 decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
760             elif attr_spec["type"] == 'nest-type-value':
761                 decoded = self._decode_nest_type_value(attr, attr_spec)
762             else:
763                 if not self.process_unknown:
764                     raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
765                 decoded = self._decode_unknown(attr)
766 
767             self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
768 
769         return rsp
770 
771     def _decode_extack_path(self, attrs, attr_set, offset, target):
772         for attr in attrs:
773             try:
774                 attr_spec = attr_set.attrs_by_val[attr.type]
775             except KeyError:
776                 raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
777             if offset > target:
778                 break
779             if offset == target:
780                 return '.' + attr_spec.name
781 
782             if offset + attr.full_len <= target:
783                 offset += attr.full_len
784                 continue
785             if attr_spec['type'] != 'nest':
786                 raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
787             offset += 4
788             subpath = self._decode_extack_path(NlAttrs(attr.raw),
789                                                self.attr_sets[attr_spec['nested-attributes']],
790                                                offset, target)
791             if subpath is None:
792                 return None
793             return '.' + attr_spec.name + subpath
794 
795         return None
796 
797     def _decode_extack(self, request, op, extack):
798         if 'bad-attr-offs' not in extack:
799             return
800 
801         msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
802         offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
803         path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
804                                         extack['bad-attr-offs'])
805         if path:
806             del extack['bad-attr-offs']
807             extack['bad-attr'] = path
808 
809     def _struct_size(self, name):
810         if name:
811             members = self.consts[name].members
812             size = 0
813             for m in members:
814                 if m.type in ['pad', 'binary']:
815                     if m.struct:
816                         size += self._struct_size(m.struct)
817                     else:
818                         size += m.len
819                 else:
820                     format = NlAttr.get_format(m.type, m.byte_order)
821                     size += format.size
822             return size
823         else:
824             return 0
825 
826     def _decode_struct(self, data, name):
827         members = self.consts[name].members
828         attrs = dict()
829         offset = 0
830         for m in members:
831             value = None
832             if m.type == 'pad':
833                 offset += m.len
834             elif m.type == 'binary':
835                 if m.struct:
836                     len = self._struct_size(m.struct)
837                     value = self._decode_struct(data[offset : offset + len],
838                                                 m.struct)
839                     offset += len
840                 else:
841                     value = data[offset : offset + m.len]
842                     offset += m.len
843             else:
844                 format = NlAttr.get_format(m.type, m.byte_order)
845                 [ value ] = format.unpack_from(data, offset)
846                 offset += format.size
847             if value is not None:
848                 if m.enum:
849                     value = self._decode_enum(value, m)
850                 elif m.display_hint:
851                     value = self._formatted_string(value, m.display_hint)
852                 attrs[m.name] = value
853         return attrs
854 
855     def _encode_struct(self, name, vals):
856         members = self.consts[name].members
857         attr_payload = b''
858         for m in members:
859             value = vals.pop(m.name) if m.name in vals else None
860             if m.type == 'pad':
861                 attr_payload += bytearray(m.len)
862             elif m.type == 'binary':
863                 if m.struct:
864                     if value is None:
865                         value = dict()
866                     attr_payload += self._encode_struct(m.struct, value)
867                 else:
868                     if value is None:
869                         attr_payload += bytearray(m.len)
870                     else:
871                         attr_payload += bytes.fromhex(value)
872             else:
873                 if value is None:
874                     value = 0
875                 format = NlAttr.get_format(m.type, m.byte_order)
876                 attr_payload += format.pack(value)
877         return attr_payload
878 
879     def _formatted_string(self, raw, display_hint):
880         if display_hint == 'mac':
881             formatted = ':'.join('%02x' % b for b in raw)
882         elif display_hint == 'hex':
883             if isinstance(raw, int):
884                 formatted = hex(raw)
885             else:
886                 formatted = bytes.hex(raw, ' ')
887         elif display_hint in [ 'ipv4', 'ipv6' ]:
888             formatted = format(ipaddress.ip_address(raw))
889         elif display_hint == 'uuid':
890             formatted = str(uuid.UUID(bytes=raw))
891         else:
892             formatted = raw
893         return formatted
894 
895     def handle_ntf(self, decoded):
896         msg = dict()
897         if self.include_raw:
898             msg['raw'] = decoded
899         op = self.rsp_by_value[decoded.cmd()]
900         attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
901         if op.fixed_header:
902             attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
903 
904         msg['name'] = op['name']
905         msg['msg'] = attrs
906         self.async_msg_queue.append(msg)
907 
908     def check_ntf(self):
909         while True:
910             try:
911                 reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
912             except BlockingIOError:
913                 return
914 
915             nms = NlMsgs(reply)
916             self._recv_dbg_print(reply, nms)
917             for nl_msg in nms:
918                 if nl_msg.error:
919                     print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
920                     print(nl_msg)
921                     continue
922                 if nl_msg.done:
923                     print("Netlink done while checking for ntf!?")
924                     continue
925 
926                 decoded = self.nlproto.decode(self, nl_msg, None)
927                 if decoded.cmd() not in self.async_msg_ids:
928                     print("Unexpected msg id done while checking for ntf", decoded)
929                     continue
930 
931                 self.handle_ntf(decoded)
932 
933     def operation_do_attributes(self, name):
934       """
935       For a given operation name, find and return a supported
936       set of attributes (as a dict).
937       """
938       op = self.find_operation(name)
939       if not op:
940         return None
941 
942       return op['do']['request']['attributes'].copy()
943 
944     def _encode_message(self, op, vals, flags, req_seq):
945         nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
946         for flag in flags or []:
947             nl_flags |= flag
948 
949         msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
950         if op.fixed_header:
951             msg += self._encode_struct(op.fixed_header, vals)
952         search_attrs = SpaceAttrs(op.attr_set, vals)
953         for name, value in vals.items():
954             msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
955         msg = _genl_msg_finalize(msg)
956         return msg
957 
958     def _ops(self, ops):
959         reqs_by_seq = {}
960         req_seq = random.randint(1024, 65535)
961         payload = b''
962         for (method, vals, flags) in ops:
963             op = self.ops[method]
964             msg = self._encode_message(op, vals, flags, req_seq)
965             reqs_by_seq[req_seq] = (op, msg, flags)
966             payload += msg
967             req_seq += 1
968 
969         self.sock.send(payload, 0)
970 
971         done = False
972         rsp = []
973         op_rsp = []
974         while not done:
975             reply = self.sock.recv(self._recv_size)
976             nms = NlMsgs(reply, attr_space=op.attr_set)
977             self._recv_dbg_print(reply, nms)
978             for nl_msg in nms:
979                 if nl_msg.nl_seq in reqs_by_seq:
980                     (op, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
981                     if nl_msg.extack:
982                         self._decode_extack(req_msg, op, nl_msg.extack)
983                 else:
984                     op = None
985                     req_flags = []
986 
987                 if nl_msg.error:
988                     raise NlError(nl_msg)
989                 if nl_msg.done:
990                     if nl_msg.extack:
991                         print("Netlink warning:")
992                         print(nl_msg)
993 
994                     if Netlink.NLM_F_DUMP in req_flags:
995                         rsp.append(op_rsp)
996                     elif not op_rsp:
997                         rsp.append(None)
998                     elif len(op_rsp) == 1:
999                         rsp.append(op_rsp[0])
1000                     else:
1001                         rsp.append(op_rsp)
1002                     op_rsp = []
1003 
1004                     del reqs_by_seq[nl_msg.nl_seq]
1005                     done = len(reqs_by_seq) == 0
1006                     break
1007 
1008                 decoded = self.nlproto.decode(self, nl_msg, op)
1009 
1010                 # Check if this is a reply to our request
1011                 if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1012                     if decoded.cmd() in self.async_msg_ids:
1013                         self.handle_ntf(decoded)
1014                         continue
1015                     else:
1016                         print('Unexpected message: ' + repr(decoded))
1017                         continue
1018 
1019                 rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1020                 if op.fixed_header:
1021                     rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1022                 op_rsp.append(rsp_msg)
1023 
1024         return rsp
1025 
1026     def _op(self, method, vals, flags=None, dump=False):
1027         req_flags = flags or []
1028         if dump:
1029             req_flags.append(Netlink.NLM_F_DUMP)
1030 
1031         ops = [(method, vals, req_flags)]
1032         return self._ops(ops)[0]
1033 
1034     def do(self, method, vals, flags=None):
1035         return self._op(method, vals, flags)
1036 
1037     def dump(self, method, vals):
1038         return self._op(method, vals, dump=True)
1039 
1040     def do_multi(self, ops):
1041         return self._ops(ops)

~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

kernel.org | git.kernel.org | LWN.net | Project Home | SVN repository | Mail admin

Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.

sflogo.php