~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

TOMOYO Linux Cross Reference
Linux/tools/usb/p9_fwd.py

Version: ~ [ linux-6.12-rc7 ] ~ [ linux-6.11.7 ] ~ [ linux-6.10.14 ] ~ [ linux-6.9.12 ] ~ [ linux-6.8.12 ] ~ [ linux-6.7.12 ] ~ [ linux-6.6.60 ] ~ [ linux-6.5.13 ] ~ [ linux-6.4.16 ] ~ [ linux-6.3.13 ] ~ [ linux-6.2.16 ] ~ [ linux-6.1.116 ] ~ [ linux-6.0.19 ] ~ [ linux-5.19.17 ] ~ [ linux-5.18.19 ] ~ [ linux-5.17.15 ] ~ [ linux-5.16.20 ] ~ [ linux-5.15.171 ] ~ [ linux-5.14.21 ] ~ [ linux-5.13.19 ] ~ [ linux-5.12.19 ] ~ [ linux-5.11.22 ] ~ [ linux-5.10.229 ] ~ [ linux-5.9.16 ] ~ [ linux-5.8.18 ] ~ [ linux-5.7.19 ] ~ [ linux-5.6.19 ] ~ [ linux-5.5.19 ] ~ [ linux-5.4.285 ] ~ [ linux-5.3.18 ] ~ [ linux-5.2.21 ] ~ [ linux-5.1.21 ] ~ [ linux-5.0.21 ] ~ [ linux-4.20.17 ] ~ [ linux-4.19.323 ] ~ [ linux-4.18.20 ] ~ [ linux-4.17.19 ] ~ [ linux-4.16.18 ] ~ [ linux-4.15.18 ] ~ [ linux-4.14.336 ] ~ [ linux-4.13.16 ] ~ [ linux-4.12.14 ] ~ [ linux-4.11.12 ] ~ [ linux-4.10.17 ] ~ [ linux-4.9.337 ] ~ [ linux-4.4.302 ] ~ [ linux-3.10.108 ] ~ [ linux-2.6.32.71 ] ~ [ linux-2.6.0 ] ~ [ linux-2.4.37.11 ] ~ [ unix-v6-master ] ~ [ ccs-tools-1.8.12 ] ~ [ policy-sample ] ~
Architecture: ~ [ i386 ] ~ [ alpha ] ~ [ m68k ] ~ [ mips ] ~ [ ppc ] ~ [ sparc ] ~ [ sparc64 ] ~

  1 #!/usr/bin/env python3
  2 # SPDX-License-Identifier: GPL-2.0
  3 
  4 import argparse
  5 import errno
  6 import logging
  7 import socket
  8 import struct
  9 import time
 10 
 11 import usb.core
 12 import usb.util
 13 
 14 
 15 def path_from_usb_dev(dev):
 16     """Takes a pyUSB device as argument and returns a string.
 17     The string is a Path representation of the position of the USB device on the USB bus tree.
 18 
 19     This path is used to find a USB device on the bus or all devices connected to a HUB.
 20     The path is made up of the number of the USB controller followed be the ports of the HUB tree."""
 21     if dev.port_numbers:
 22         dev_path = ".".join(str(i) for i in dev.port_numbers)
 23         return f"{dev.bus}-{dev_path}"
 24     return ""
 25 
 26 
 27 HEXDUMP_FILTER = "".join(chr(x).isprintable() and chr(x) or "." for x in range(128)) + "." * 128
 28 
 29 
 30 class Forwarder:
 31     @staticmethod
 32     def _log_hexdump(data):
 33         if not logging.root.isEnabledFor(logging.TRACE):
 34             return
 35         L = 16
 36         for c in range(0, len(data), L):
 37             chars = data[c : c + L]
 38             dump = " ".join(f"{x:02x}" for x in chars)
 39             printable = "".join(HEXDUMP_FILTER[x] for x in chars)
 40             line = f"{c:08x}  {dump:{L*3}s} |{printable:{L}s}|"
 41             logging.root.log(logging.TRACE, "%s", line)
 42 
 43     def __init__(self, server, vid, pid, path):
 44         self.stats = {
 45             "c2s packets": 0,
 46             "c2s bytes": 0,
 47             "s2c packets": 0,
 48             "s2c bytes": 0,
 49         }
 50         self.stats_logged = time.monotonic()
 51 
 52         def find_filter(dev):
 53             dev_path = path_from_usb_dev(dev)
 54             if path is not None:
 55                 return dev_path == path
 56             return True
 57 
 58         dev = usb.core.find(idVendor=vid, idProduct=pid, custom_match=find_filter)
 59         if dev is None:
 60             raise ValueError("Device not found")
 61 
 62         logging.info(f"found device: {dev.bus}/{dev.address} located at {path_from_usb_dev(dev)}")
 63 
 64         # dev.set_configuration() is not necessary since g_multi has only one
 65         usb9pfs = None
 66         # g_multi adds 9pfs as last interface
 67         cfg = dev.get_active_configuration()
 68         for intf in cfg:
 69             # we have to detach the usb-storage driver from multi gadget since
 70             # stall option could be set, which will lead to spontaneous port
 71             # resets and our transfers will run dead
 72             if intf.bInterfaceClass == 0x08:
 73                 if dev.is_kernel_driver_active(intf.bInterfaceNumber):
 74                     dev.detach_kernel_driver(intf.bInterfaceNumber)
 75 
 76             if intf.bInterfaceClass == 0xFF and intf.bInterfaceSubClass == 0xFF and intf.bInterfaceProtocol == 0x09:
 77                 usb9pfs = intf
 78         if usb9pfs is None:
 79             raise ValueError("Interface not found")
 80 
 81         logging.info(f"claiming interface:\n{usb9pfs}")
 82         usb.util.claim_interface(dev, usb9pfs.bInterfaceNumber)
 83         ep_out = usb.util.find_descriptor(
 84             usb9pfs,
 85             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_OUT,
 86         )
 87         assert ep_out is not None
 88         ep_in = usb.util.find_descriptor(
 89             usb9pfs,
 90             custom_match=lambda e: usb.util.endpoint_direction(e.bEndpointAddress) == usb.util.ENDPOINT_IN,
 91         )
 92         assert ep_in is not None
 93         logging.info("interface claimed")
 94 
 95         self.ep_out = ep_out
 96         self.ep_in = ep_in
 97         self.dev = dev
 98 
 99         # create and connect socket
100         self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
101         self.s.connect(server)
102 
103         logging.info("connected to server")
104 
105     def c2s(self):
106         """forward a request from the USB client to the TCP server"""
107         data = None
108         while data is None:
109             try:
110                 logging.log(logging.TRACE, "c2s: reading")
111                 data = self.ep_in.read(self.ep_in.wMaxPacketSize)
112             except usb.core.USBTimeoutError:
113                 logging.log(logging.TRACE, "c2s: reading timed out")
114                 continue
115             except usb.core.USBError as e:
116                 if e.errno == errno.EIO:
117                     logging.debug("c2s: reading failed with %s, retrying", repr(e))
118                     time.sleep(0.5)
119                     continue
120                 logging.error("c2s: reading failed with %s, aborting", repr(e))
121                 raise
122         size = struct.unpack("<I", data[:4])[0]
123         while len(data) < size:
124             data += self.ep_in.read(size - len(data))
125         logging.log(logging.TRACE, "c2s: writing")
126         self._log_hexdump(data)
127         self.s.send(data)
128         logging.debug("c2s: forwarded %i bytes", size)
129         self.stats["c2s packets"] += 1
130         self.stats["c2s bytes"] += size
131 
132     def s2c(self):
133         """forward a response from the TCP server to the USB client"""
134         logging.log(logging.TRACE, "s2c: reading")
135         data = self.s.recv(4)
136         size = struct.unpack("<I", data[:4])[0]
137         while len(data) < size:
138             data += self.s.recv(size - len(data))
139         logging.log(logging.TRACE, "s2c: writing")
140         self._log_hexdump(data)
141         while data:
142             written = self.ep_out.write(data)
143             assert written > 0
144             data = data[written:]
145         if size % self.ep_out.wMaxPacketSize == 0:
146             logging.log(logging.TRACE, "sending zero length packet")
147             self.ep_out.write(b"")
148         logging.debug("s2c: forwarded %i bytes", size)
149         self.stats["s2c packets"] += 1
150         self.stats["s2c bytes"] += size
151 
152     def log_stats(self):
153         logging.info("statistics:")
154         for k, v in self.stats.items():
155             logging.info(f"  {k+':':14s} {v}")
156 
157     def log_stats_interval(self, interval=5):
158         if (time.monotonic() - self.stats_logged) < interval:
159             return
160 
161         self.log_stats()
162         self.stats_logged = time.monotonic()
163 
164 
165 def try_get_usb_str(dev, name):
166     try:
167         with open(f"/sys/bus/usb/devices/{dev.bus}-{dev.address}/{name}") as f:
168             return f.read().strip()
169     except FileNotFoundError:
170         return None
171 
172 
173 def list_usb(args):
174     vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
175 
176     print("Bus | Addr | Manufacturer     | Product          | ID        | Path")
177     print("--- | ---- | ---------------- | ---------------- | --------- | ----")
178     for dev in usb.core.find(find_all=True, idVendor=vid, idProduct=pid):
179         path = path_from_usb_dev(dev) or ""
180         manufacturer = try_get_usb_str(dev, "manufacturer") or "unknown"
181         product = try_get_usb_str(dev, "product") or "unknown"
182         print(
183             f"{dev.bus:3} | {dev.address:4} | {manufacturer:16} | {product:16} | {dev.idVendor:04x}:{dev.idProduct:04x} | {path:18}"
184         )
185 
186 
187 def connect(args):
188     vid, pid = [int(x, 16) for x in args.id.split(":", 1)]
189 
190     f = Forwarder(server=(args.server, args.port), vid=vid, pid=pid, path=args.path)
191 
192     try:
193         while True:
194             f.c2s()
195             f.s2c()
196             f.log_stats_interval()
197     finally:
198         f.log_stats()
199 
200 
201 def main():
202     parser = argparse.ArgumentParser(
203         description="Forward 9PFS requests from USB to TCP",
204     )
205 
206     parser.add_argument("--id", type=str, default="1d6b:0109", help="vid:pid of target device")
207     parser.add_argument("--path", type=str, required=False, help="path of target device")
208     parser.add_argument("-v", "--verbose", action="count", default=0)
209 
210     subparsers = parser.add_subparsers()
211     subparsers.required = True
212     subparsers.dest = "command"
213 
214     parser_list = subparsers.add_parser("list", help="List all connected 9p gadgets")
215     parser_list.set_defaults(func=list_usb)
216 
217     parser_connect = subparsers.add_parser(
218         "connect", help="Forward messages between the usb9pfs gadget and the 9p server"
219     )
220     parser_connect.set_defaults(func=connect)
221     connect_group = parser_connect.add_argument_group()
222     connect_group.required = True
223     parser_connect.add_argument("-s", "--server", type=str, default="127.0.0.1", help="server hostname")
224     parser_connect.add_argument("-p", "--port", type=int, default=564, help="server port")
225 
226     args = parser.parse_args()
227 
228     logging.TRACE = logging.DEBUG - 5
229     logging.addLevelName(logging.TRACE, "TRACE")
230 
231     if args.verbose >= 2:
232         level = logging.TRACE
233     elif args.verbose:
234         level = logging.DEBUG
235     else:
236         level = logging.INFO
237     logging.basicConfig(level=level, format="%(asctime)-15s %(levelname)-8s %(message)s")
238 
239     args.func(args)
240 
241 
242 if __name__ == "__main__":
243     main()

~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

kernel.org | git.kernel.org | LWN.net | Project Home | SVN repository | Mail admin

Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.

sflogo.php