~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

TOMOYO Linux Cross Reference
Linux/tools/usb/p9_fwd.py

Version: ~ [ linux-6.12-rc7 ] ~ [ linux-6.11.7 ] ~ [ linux-6.10.14 ] ~ [ linux-6.9.12 ] ~ [ linux-6.8.12 ] ~ [ linux-6.7.12 ] ~ [ linux-6.6.60 ] ~ [ linux-6.5.13 ] ~ [ linux-6.4.16 ] ~ [ linux-6.3.13 ] ~ [ linux-6.2.16 ] ~ [ linux-6.1.116 ] ~ [ linux-6.0.19 ] ~ [ linux-5.19.17 ] ~ [ linux-5.18.19 ] ~ [ linux-5.17.15 ] ~ [ linux-5.16.20 ] ~ [ linux-5.15.171 ] ~ [ linux-5.14.21 ] ~ [ linux-5.13.19 ] ~ [ linux-5.12.19 ] ~ [ linux-5.11.22 ] ~ [ linux-5.10.229 ] ~ [ linux-5.9.16 ] ~ [ linux-5.8.18 ] ~ [ linux-5.7.19 ] ~ [ linux-5.6.19 ] ~ [ linux-5.5.19 ] ~ [ linux-5.4.285 ] ~ [ linux-5.3.18 ] ~ [ linux-5.2.21 ] ~ [ linux-5.1.21 ] ~ [ linux-5.0.21 ] ~ [ linux-4.20.17 ] ~ [ linux-4.19.323 ] ~ [ linux-4.18.20 ] ~ [ linux-4.17.19 ] ~ [ linux-4.16.18 ] ~ [ linux-4.15.18 ] ~ [ linux-4.14.336 ] ~ [ linux-4.13.16 ] ~ [ linux-4.12.14 ] ~ [ linux-4.11.12 ] ~ [ linux-4.10.17 ] ~ [ linux-4.9.337 ] ~ [ linux-4.4.302 ] ~ [ linux-3.10.108 ] ~ [ linux-2.6.32.71 ] ~ [ linux-2.6.0 ] ~ [ linux-2.4.37.11 ] ~ [ unix-v6-master ] ~ [ ccs-tools-1.8.12 ] ~ [ policy-sample ] ~
Architecture: ~ [ i386 ] ~ [ alpha ] ~ [ m68k ] ~ [ mips ] ~ [ ppc ] ~ [ sparc ] ~ [ sparc64 ] ~

Diff markup

Differences between /tools/usb/p9_fwd.py (Architecture alpha) and /tools/usb/p9_fwd.py (Architecture sparc64)


  1 #!/usr/bin/env python3                              1 #!/usr/bin/env python3
  2 # SPDX-License-Identifier: GPL-2.0                  2 # SPDX-License-Identifier: GPL-2.0
  3                                                     3 
  4 import argparse                                     4 import argparse
  5 import errno                                        5 import errno
  6 import logging                                      6 import logging
  7 import socket                                       7 import socket
  8 import struct                                       8 import struct
  9 import time                                         9 import time
 10                                                    10 
 11 import usb.core                                    11 import usb.core
 12 import usb.util                                    12 import usb.util
 13                                                    13 
 14                                                    14 
 15 def path_from_usb_dev(dev):                        15 def path_from_usb_dev(dev):
 16     """Takes a pyUSB device as argument and re     16     """Takes a pyUSB device as argument and returns a string.
 17     The string is a Path representation of the     17     The string is a Path representation of the position of the USB device on the USB bus tree.
 18                                                    18 
 19     This path is used to find a USB device on      19     This path is used to find a USB device on the bus or all devices connected to a HUB.
 20     The path is made up of the number of the U     20     The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
 21     if dev.port_numbers:                           21     if dev.port_numbers:
 22         dev_path = ".".join(str(i) for i in de     22         dev_path = ".".join(str(i) for i in dev.port_numbers)
 23         return f"{dev.bus}-{dev_path}"             23         return f"{dev.bus}-{dev_path}"
 24     return ""                                      24     return ""
 25                                                    25 
 26                                                    26 
 27 HEXDUMP_FILTER = "".join(chr(x).isprintable()      27 HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
 28                                                    28 
 29                                                    29 
 30 class Forwarder:                                   30 class Forwarder:
 31     @staticmethod                                  31     @staticmethod
 32     def _log_hexdump(data):                        32     def _log_hexdump(data):
 33         if not logging.root.isEnabledFor(loggi     33         if not logging.root.isEnabledFor(logging.TRACE):
 34             return                                 34             return
 35         L = 16                                     35         L = 16
 36         for c in range(0, len(data), L):           36         for c in range(0, len(data), L):
 37             chars = data[c : c + L]                37             chars = data[c : c + L]
 38             dump = " ".join(f"{x:02x}" for x i     38             dump = " ".join(f"{x:02x}" for x in chars)
 39             printable = "".join(HEXDUMP_FILTER     39             printable = "".join(HEXDUMP_FILTER[x] for x in chars)
 40             line = f"{c:08x}  {dump:{L*3}s} |{     40             line = f"{c:08x}  {dump:{L*3}s} |{printable:{L}s}|"
 41             logging.root.log(logging.TRACE, "%     41             logging.root.log(logging.TRACE, "%s", line)
 42                                                    42 
 43     def __init__(self, server, vid, pid, path)     43     def __init__(self, server, vid, pid, path):
 44         self.stats = {                             44         self.stats = {
 45             "c2s packets": 0,                      45             "c2s packets": 0,
 46             "c2s bytes": 0,                        46             "c2s bytes": 0,
 47             "s2c packets": 0,                      47             "s2c packets": 0,
 48             "s2c bytes": 0,                        48             "s2c bytes": 0,
 49         }                                          49         }
 50         self.stats_logged = time.monotonic()       50         self.stats_logged = time.monotonic()
 51                                                    51 
 52         def find_filter(dev):                      52         def find_filter(dev):
 53             dev_path = path_from_usb_dev(dev)      53             dev_path = path_from_usb_dev(dev)
 54             if path is not None:                   54             if path is not None:
 55                 return dev_path == path            55                 return dev_path == path
 56             return True                            56             return True
 57                                                    57 
 58         dev = usb.core.find(idVendor=vid, idPr     58         dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
 59         if dev is None:                            59         if dev is None:
 60             raise ValueError("Device not found     60             raise ValueError("Device not found")
 61                                                    61 
 62         logging.info(f"found device: {dev.bus}     62         logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
 63                                                    63 
 64         # dev.set_configuration() is not neces     64         # dev.set_configuration() is not necessary since g_multi has only one
 65         usb9pfs = None                             65         usb9pfs = None
 66         # g_multi adds 9pfs as last interface      66         # g_multi adds 9pfs as last interface
 67         cfg = dev.get_active_configuration()       67         cfg = dev.get_active_configuration()
 68         for intf in cfg:                           68         for intf in cfg:
 69             # we have to detach the usb-storag     69             # we have to detach the usb-storage driver from multi gadget since
 70             # stall option could be set, which     70             # stall option could be set, which will lead to spontaneous port
 71             # resets and our transfers will ru     71             # resets and our transfers will run dead
 72             if intf.bInterfaceClass == 0x08:       72             if intf.bInterfaceClass == 0x08:
 73                 if dev.is_kernel_driver_active     73                 if dev.is_kernel_driver_active(intf.bInterfaceNumber):
 74                     dev.detach_kernel_driver(i     74                     dev.detach_kernel_driver(intf.bInterfaceNumber)
 75                                                    75 
 76             if intf.bInterfaceClass == 0xFF an     76             if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
 77                 usb9pfs = intf                     77                 usb9pfs = intf
 78         if usb9pfs is None:                        78         if usb9pfs is None:
 79             raise ValueError("Interface not fo     79             raise ValueError("Interface not found")
 80                                                    80 
 81         logging.info(f"claiming interface:\n{u     81         logging.info(f"claiming interface:\n{usb9pfs}")
 82         usb.util.claim_interface(dev, usb9pfs.     82         usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
 83         ep_out = usb.util.find_descriptor(         83         ep_out = usb.util.find_descriptor(
 84             usb9pfs,                               84             usb9pfs,
 85             custom_match=lambda e: usb.util.en     85             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
 86         )                                          86         )
 87         assert ep_out is not None                  87         assert ep_out is not None
 88         ep_in = usb.util.find_descriptor(          88         ep_in = usb.util.find_descriptor(
 89             usb9pfs,                               89             usb9pfs,
 90             custom_match=lambda e: usb.util.en     90             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
 91         )                                          91         )
 92         assert ep_in is not None                   92         assert ep_in is not None
 93         logging.info("interface claimed")          93         logging.info("interface claimed")
 94                                                    94 
 95         self.ep_out = ep_out                       95         self.ep_out = ep_out
 96         self.ep_in = ep_in                         96         self.ep_in = ep_in
 97         self.dev = dev                             97         self.dev = dev
 98                                                    98 
 99         # create and connect socket                99         # create and connect socket
100         self.s = socket.socket(socket.AF_INET,    100         self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
101         self.s.connect(server)                    101         self.s.connect(server)
102                                                   102 
103         logging.info("connected to server")       103         logging.info("connected to server")
104                                                   104 
105     def c2s(self):                                105     def c2s(self):
106         """forward a request from the USB clie    106         """forward a request from the USB client to the TCP server"""
107         data = None                               107         data = None
108         while data is None:                       108         while data is None:
109             try:                                  109             try:
110                 logging.log(logging.TRACE, "c2    110                 logging.log(logging.TRACE, "c2s: reading")
111                 data = self.ep_in.read(self.ep    111                 data = self.ep_in.read(self.ep_in.wMaxPacketSize)
112             except usb.core.USBTimeoutError:      112             except usb.core.USBTimeoutError:
113                 logging.log(logging.TRACE, "c2    113                 logging.log(logging.TRACE, "c2s: reading timed out")
114                 continue                          114                 continue
115             except usb.core.USBError as e:        115             except usb.core.USBError as e:
116                 if e.errno == errno.EIO:          116                 if e.errno == errno.EIO:
117                     logging.debug("c2s: readin    117                     logging.debug("c2s: reading failed with %s, retrying", repr(e))
118                     time.sleep(0.5)               118                     time.sleep(0.5)
119                     continue                      119                     continue
120                 logging.error("c2s: reading fa    120                 logging.error("c2s: reading failed with %s, aborting", repr(e))
121                 raise                             121                 raise
122         size = struct.unpack("<I", data[:4])[0    122         size = struct.unpack("<I", data[:4])[0]
123         while len(data) < size:                   123         while len(data) < size:
124             data += self.ep_in.read(size - len    124             data += self.ep_in.read(size - len(data))
125         logging.log(logging.TRACE, "c2s: writi    125         logging.log(logging.TRACE, "c2s: writing")
126         self._log_hexdump(data)                   126         self._log_hexdump(data)
127         self.s.send(data)                         127         self.s.send(data)
128         logging.debug("c2s: forwarded %i bytes    128         logging.debug("c2s: forwarded %i bytes", size)
129         self.stats["c2s packets"] += 1            129         self.stats["c2s packets"] += 1
130         self.stats["c2s bytes"] += size           130         self.stats["c2s bytes"] += size
131                                                   131 
132     def s2c(self):                                132     def s2c(self):
133         """forward a response from the TCP ser    133         """forward a response from the TCP server to the USB client"""
134         logging.log(logging.TRACE, "s2c: readi    134         logging.log(logging.TRACE, "s2c: reading")
135         data = self.s.recv(4)                     135         data = self.s.recv(4)
136         size = struct.unpack("<I", data[:4])[0    136         size = struct.unpack("<I", data[:4])[0]
137         while len(data) < size:                   137         while len(data) < size:
138             data += self.s.recv(size - len(dat    138             data += self.s.recv(size - len(data))
139         logging.log(logging.TRACE, "s2c: writi    139         logging.log(logging.TRACE, "s2c: writing")
140         self._log_hexdump(data)                   140         self._log_hexdump(data)
141         while data:                               141         while data:
142             written = self.ep_out.write(data)     142             written = self.ep_out.write(data)
143             assert written > 0                    143             assert written > 0
144             data = data[written:]                 144             data = data[written:]
145         if size % self.ep_out.wMaxPacketSize =    145         if size % self.ep_out.wMaxPacketSize == 0:
146             logging.log(logging.TRACE, "sendin    146             logging.log(logging.TRACE, "sending zero length packet")
147             self.ep_out.write(b"")                147             self.ep_out.write(b"")
148         logging.debug("s2c: forwarded %i bytes    148         logging.debug("s2c: forwarded %i bytes", size)
149         self.stats["s2c packets"] += 1            149         self.stats["s2c packets"] += 1
150         self.stats["s2c bytes"] += size           150         self.stats["s2c bytes"] += size
151                                                   151 
152     def log_stats(self):                          152     def log_stats(self):
153         logging.info("statistics:")               153         logging.info("statistics:")
154         for k, v in self.stats.items():           154         for k, v in self.stats.items():
155             logging.info(f"  {k+':':14s} {v}")    155             logging.info(f"  {k+':':14s} {v}")
156                                                   156 
157     def log_stats_interval(self, interval=5):     157     def log_stats_interval(self, interval=5):
158         if (time.monotonic() - self.stats_logg    158         if (time.monotonic() - self.stats_logged) < interval:
159             return                                159             return
160                                                   160 
161         self.log_stats()                          161         self.log_stats()
162         self.stats_logged = time.monotonic()      162         self.stats_logged = time.monotonic()
163                                                   163 
164                                                   164 
165 def try_get_usb_str(dev, name):                   165 def try_get_usb_str(dev, name):
166     try:                                          166     try:
167         with open(f"/sys/bus/usb/devices/{dev.    167         with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
168             return f.read().strip()               168             return f.read().strip()
169     except FileNotFoundError:                     169     except FileNotFoundError:
170         return None                               170         return None
171                                                   171 
172                                                   172 
173 def list_usb(args):                               173 def list_usb(args):
174     vid, pid = [int(x, 16) for x in args.id.sp    174     vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
175                                                   175 
176     print("Bus | Addr | Manufacturer     | Pro    176     print("Bus | Addr | Manufacturer     | Product          | ID        | Path")
177     print("--- | ---- | ---------------- | ---    177     print("--- | ---- | ---------------- | ---------------- | --------- | ----")
178     for dev in usb.core.find(find_all=True, id    178     for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
179         path = path_from_usb_dev(dev) or ""       179         path = path_from_usb_dev(dev) or ""
180         manufacturer = try_get_usb_str(dev, "m    180         manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
181         product = try_get_usb_str(dev, "produc    181         product = try_get_usb_str(dev, "product") or "unknown"
182         print(                                    182         print(
183             f"{dev.bus:3} | {dev.address:4} |     183             f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
184         )                                         184         )
185                                                   185 
186                                                   186 
187 def connect(args):                                187 def connect(args):
188     vid, pid = [int(x, 16) for x in args.id.sp    188     vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
189                                                   189 
190     f = Forwarder(server=(args.server, args.po    190     f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
191                                                   191 
192     try:                                          192     try:
193         while True:                               193         while True:
194             f.c2s()                               194             f.c2s()
195             f.s2c()                               195             f.s2c()
196             f.log_stats_interval()                196             f.log_stats_interval()
197     finally:                                      197     finally:
198         f.log_stats()                             198         f.log_stats()
199                                                   199 
200                                                   200 
201 def main():                                       201 def main():
202     parser = argparse.ArgumentParser(             202     parser = argparse.ArgumentParser(
203         description="Forward 9PFS requests fro    203         description="Forward 9PFS requests from USB to TCP",
204     )                                             204     )
205                                                   205 
206     parser.add_argument("--id", type=str, defa    206     parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
207     parser.add_argument("--path", type=str, re    207     parser.add_argument("--path", type=str, required=False, help="path of target device")
208     parser.add_argument("-v", "--verbose", act    208     parser.add_argument("-v", "--verbose", action="count", default=0)
209                                                   209 
210     subparsers = parser.add_subparsers()          210     subparsers = parser.add_subparsers()
211     subparsers.required = True                    211     subparsers.required = True
212     subparsers.dest = "command"                   212     subparsers.dest = "command"
213                                                   213 
214     parser_list = subparsers.add_parser("list"    214     parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
215     parser_list.set_defaults(func=list_usb)       215     parser_list.set_defaults(func=list_usb)
216                                                   216 
217     parser_connect = subparsers.add_parser(       217     parser_connect = subparsers.add_parser(
218         "connect", help="Forward messages betw    218         "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
219     )                                             219     )
220     parser_connect.set_defaults(func=connect)     220     parser_connect.set_defaults(func=connect)
221     connect_group = parser_connect.add_argumen    221     connect_group = parser_connect.add_argument_group()
222     connect_group.required = True                 222     connect_group.required = True
223     parser_connect.add_argument("-s", "--serve    223     parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
224     parser_connect.add_argument("-p", "--port"    224     parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
225                                                   225 
226     args = parser.parse_args()                    226     args = parser.parse_args()
227                                                   227 
228     logging.TRACE = logging.DEBUG - 5             228     logging.TRACE = logging.DEBUG - 5
229     logging.addLevelName(logging.TRACE, "TRACE    229     logging.addLevelName(logging.TRACE, "TRACE")
230                                                   230 
231     if args.verbose >= 2:                         231     if args.verbose >= 2:
232         level = logging.TRACE                     232         level = logging.TRACE
233     elif args.verbose:                            233     elif args.verbose:
234         level = logging.DEBUG                     234         level = logging.DEBUG
235     else:                                         235     else:
236         level = logging.INFO                      236         level = logging.INFO
237     logging.basicConfig(level=level, format="%    237     logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
238                                                   238 
239     args.func(args)                               239     args.func(args)
240                                                   240 
241                                                   241 
242 if __name__ == "__main__":                        242 if __name__ == "__main__":
243     main()                                        243     main()
                                                      

~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

kernel.org | git.kernel.org | LWN.net | Project Home | SVN repository | Mail admin

Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.

sflogo.php