~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

TOMOYO Linux Cross Reference
Linux/tools/testing/selftests/net/mptcp/mptcp_inq.c

Version: ~ [ linux-6.11.5 ] ~ [ linux-6.10.14 ] ~ [ linux-6.9.12 ] ~ [ linux-6.8.12 ] ~ [ linux-6.7.12 ] ~ [ linux-6.6.58 ] ~ [ linux-6.5.13 ] ~ [ linux-6.4.16 ] ~ [ linux-6.3.13 ] ~ [ linux-6.2.16 ] ~ [ linux-6.1.114 ] ~ [ linux-6.0.19 ] ~ [ linux-5.19.17 ] ~ [ linux-5.18.19 ] ~ [ linux-5.17.15 ] ~ [ linux-5.16.20 ] ~ [ linux-5.15.169 ] ~ [ linux-5.14.21 ] ~ [ linux-5.13.19 ] ~ [ linux-5.12.19 ] ~ [ linux-5.11.22 ] ~ [ linux-5.10.228 ] ~ [ linux-5.9.16 ] ~ [ linux-5.8.18 ] ~ [ linux-5.7.19 ] ~ [ linux-5.6.19 ] ~ [ linux-5.5.19 ] ~ [ linux-5.4.284 ] ~ [ linux-5.3.18 ] ~ [ linux-5.2.21 ] ~ [ linux-5.1.21 ] ~ [ linux-5.0.21 ] ~ [ linux-4.20.17 ] ~ [ linux-4.19.322 ] ~ [ linux-4.18.20 ] ~ [ linux-4.17.19 ] ~ [ linux-4.16.18 ] ~ [ linux-4.15.18 ] ~ [ linux-4.14.336 ] ~ [ linux-4.13.16 ] ~ [ linux-4.12.14 ] ~ [ linux-4.11.12 ] ~ [ linux-4.10.17 ] ~ [ linux-4.9.337 ] ~ [ linux-4.4.302 ] ~ [ linux-3.10.108 ] ~ [ linux-2.6.32.71 ] ~ [ linux-2.6.0 ] ~ [ linux-2.4.37.11 ] ~ [ unix-v6-master ] ~ [ ccs-tools-1.8.9 ] ~ [ policy-sample ] ~
Architecture: ~ [ i386 ] ~ [ alpha ] ~ [ m68k ] ~ [ mips ] ~ [ ppc ] ~ [ sparc ] ~ [ sparc64 ] ~

  1 // SPDX-License-Identifier: GPL-2.0
  2 
  3 #define _GNU_SOURCE
  4 
  5 #include <assert.h>
  6 #include <errno.h>
  7 #include <fcntl.h>
  8 #include <limits.h>
  9 #include <string.h>
 10 #include <stdarg.h>
 11 #include <stdbool.h>
 12 #include <stdint.h>
 13 #include <inttypes.h>
 14 #include <stdio.h>
 15 #include <stdlib.h>
 16 #include <strings.h>
 17 #include <unistd.h>
 18 #include <time.h>
 19 
 20 #include <sys/ioctl.h>
 21 #include <sys/random.h>
 22 #include <sys/socket.h>
 23 #include <sys/types.h>
 24 #include <sys/wait.h>
 25 
 26 #include <netdb.h>
 27 #include <netinet/in.h>
 28 
 29 #include <linux/tcp.h>
 30 #include <linux/sockios.h>
 31 
 32 #ifndef IPPROTO_MPTCP
 33 #define IPPROTO_MPTCP 262
 34 #endif
 35 #ifndef SOL_MPTCP
 36 #define SOL_MPTCP 284
 37 #endif
 38 
 39 static int pf = AF_INET;
 40 static int proto_tx = IPPROTO_MPTCP;
 41 static int proto_rx = IPPROTO_MPTCP;
 42 
 43 static void die_perror(const char *msg)
 44 {
 45         perror(msg);
 46         exit(1);
 47 }
 48 
 49 static void die_usage(int r)
 50 {
 51         fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
 52         exit(r);
 53 }
 54 
 55 static void xerror(const char *fmt, ...)
 56 {
 57         va_list ap;
 58 
 59         va_start(ap, fmt);
 60         vfprintf(stderr, fmt, ap);
 61         va_end(ap);
 62         fputc('\n', stderr);
 63         exit(1);
 64 }
 65 
 66 static const char *getxinfo_strerr(int err)
 67 {
 68         if (err == EAI_SYSTEM)
 69                 return strerror(errno);
 70 
 71         return gai_strerror(err);
 72 }
 73 
 74 static void xgetaddrinfo(const char *node, const char *service,
 75                          const struct addrinfo *hints,
 76                          struct addrinfo **res)
 77 {
 78         int err = getaddrinfo(node, service, hints, res);
 79 
 80         if (err) {
 81                 const char *errstr = getxinfo_strerr(err);
 82 
 83                 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
 84                         node ? node : "", service ? service : "", errstr);
 85                 exit(1);
 86         }
 87 }
 88 
 89 static int sock_listen_mptcp(const char * const listenaddr,
 90                              const char * const port)
 91 {
 92         int sock = -1;
 93         struct addrinfo hints = {
 94                 .ai_protocol = IPPROTO_TCP,
 95                 .ai_socktype = SOCK_STREAM,
 96                 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
 97         };
 98 
 99         hints.ai_family = pf;
100 
101         struct addrinfo *a, *addr;
102         int one = 1;
103 
104         xgetaddrinfo(listenaddr, port, &hints, &addr);
105         hints.ai_family = pf;
106 
107         for (a = addr; a; a = a->ai_next) {
108                 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
109                 if (sock < 0)
110                         continue;
111 
112                 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
113                                      sizeof(one)))
114                         perror("setsockopt");
115 
116                 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
117                         break; /* success */
118 
119                 perror("bind");
120                 close(sock);
121                 sock = -1;
122         }
123 
124         freeaddrinfo(addr);
125 
126         if (sock < 0)
127                 xerror("could not create listen socket");
128 
129         if (listen(sock, 20))
130                 die_perror("listen");
131 
132         return sock;
133 }
134 
135 static int sock_connect_mptcp(const char * const remoteaddr,
136                               const char * const port, int proto)
137 {
138         struct addrinfo hints = {
139                 .ai_protocol = IPPROTO_TCP,
140                 .ai_socktype = SOCK_STREAM,
141         };
142         struct addrinfo *a, *addr;
143         int sock = -1;
144 
145         hints.ai_family = pf;
146 
147         xgetaddrinfo(remoteaddr, port, &hints, &addr);
148         for (a = addr; a; a = a->ai_next) {
149                 sock = socket(a->ai_family, a->ai_socktype, proto);
150                 if (sock < 0)
151                         continue;
152 
153                 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
154                         break; /* success */
155 
156                 die_perror("connect");
157         }
158 
159         if (sock < 0)
160                 xerror("could not create connect socket");
161 
162         freeaddrinfo(addr);
163         return sock;
164 }
165 
166 static int protostr_to_num(const char *s)
167 {
168         if (strcasecmp(s, "tcp") == 0)
169                 return IPPROTO_TCP;
170         if (strcasecmp(s, "mptcp") == 0)
171                 return IPPROTO_MPTCP;
172 
173         die_usage(1);
174         return 0;
175 }
176 
177 static void parse_opts(int argc, char **argv)
178 {
179         int c;
180 
181         while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
182                 switch (c) {
183                 case 'h':
184                         die_usage(0);
185                         break;
186                 case '6':
187                         pf = AF_INET6;
188                         break;
189                 case 't':
190                         proto_tx = protostr_to_num(optarg);
191                         break;
192                 case 'r':
193                         proto_rx = protostr_to_num(optarg);
194                         break;
195                 default:
196                         die_usage(1);
197                         break;
198                 }
199         }
200 }
201 
202 /* wait up to timeout milliseconds */
203 static void wait_for_ack(int fd, int timeout, size_t total)
204 {
205         int i;
206 
207         for (i = 0; i < timeout; i++) {
208                 int nsd, ret, queued = -1;
209                 struct timespec req;
210 
211                 ret = ioctl(fd, TIOCOUTQ, &queued);
212                 if (ret < 0)
213                         die_perror("TIOCOUTQ");
214 
215                 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
216                 if (ret < 0)
217                         die_perror("SIOCOUTQNSD");
218 
219                 if ((size_t)queued > total)
220                         xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
221                 assert(nsd <= queued);
222 
223                 if (queued == 0)
224                         return;
225 
226                 /* wait for peer to ack rx of all data */
227                 req.tv_sec = 0;
228                 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
229                 nanosleep(&req, NULL);
230         }
231 
232         xerror("still tx data queued after %u ms\n", timeout);
233 }
234 
235 static void connect_one_server(int fd, int unixfd)
236 {
237         size_t len, i, total, sent;
238         char buf[4096], buf2[4096];
239         ssize_t ret;
240 
241         len = rand() % (sizeof(buf) - 1);
242 
243         if (len < 128)
244                 len = 128;
245 
246         for (i = 0; i < len ; i++) {
247                 buf[i] = rand() % 26;
248                 buf[i] += 'A';
249         }
250 
251         buf[i] = '\n';
252 
253         /* un-block server */
254         ret = read(unixfd, buf2, 4);
255         assert(ret == 4);
256 
257         assert(strncmp(buf2, "xmit", 4) == 0);
258 
259         ret = write(unixfd, &len, sizeof(len));
260         assert(ret == (ssize_t)sizeof(len));
261 
262         ret = write(fd, buf, len);
263         if (ret < 0)
264                 die_perror("write");
265 
266         if (ret != (ssize_t)len)
267                 xerror("short write");
268 
269         ret = read(unixfd, buf2, 4);
270         assert(strncmp(buf2, "huge", 4) == 0);
271 
272         total = rand() % (16 * 1024 * 1024);
273         total += (1 * 1024 * 1024);
274         sent = total;
275 
276         ret = write(unixfd, &total, sizeof(total));
277         assert(ret == (ssize_t)sizeof(total));
278 
279         wait_for_ack(fd, 5000, len);
280 
281         while (total > 0) {
282                 if (total > sizeof(buf))
283                         len = sizeof(buf);
284                 else
285                         len = total;
286 
287                 ret = write(fd, buf, len);
288                 if (ret < 0)
289                         die_perror("write");
290                 total -= ret;
291 
292                 /* we don't have to care about buf content, only
293                  * number of total bytes sent
294                  */
295         }
296 
297         ret = read(unixfd, buf2, 4);
298         assert(ret == 4);
299         assert(strncmp(buf2, "shut", 4) == 0);
300 
301         wait_for_ack(fd, 5000, sent);
302 
303         ret = write(fd, buf, 1);
304         assert(ret == 1);
305         close(fd);
306         ret = write(unixfd, "closed", 6);
307         assert(ret == 6);
308 
309         close(unixfd);
310 }
311 
312 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
313 {
314         struct cmsghdr *cmsg;
315 
316         for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
317                 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
318                         memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
319                         return;
320                 }
321         }
322 
323         xerror("could not find TCP_CM_INQ cmsg type");
324 }
325 
326 static void process_one_client(int fd, int unixfd)
327 {
328         unsigned int tcp_inq;
329         size_t expect_len;
330         char msg_buf[4096];
331         char buf[4096];
332         char tmp[16];
333         struct iovec iov = {
334                 .iov_base = buf,
335                 .iov_len = 1,
336         };
337         struct msghdr msg = {
338                 .msg_iov = &iov,
339                 .msg_iovlen = 1,
340                 .msg_control = msg_buf,
341                 .msg_controllen = sizeof(msg_buf),
342         };
343         ssize_t ret, tot;
344 
345         ret = write(unixfd, "xmit", 4);
346         assert(ret == 4);
347 
348         ret = read(unixfd, &expect_len, sizeof(expect_len));
349         assert(ret == (ssize_t)sizeof(expect_len));
350 
351         if (expect_len > sizeof(buf))
352                 xerror("expect len %zu exceeds buffer size", expect_len);
353 
354         for (;;) {
355                 struct timespec req;
356                 unsigned int queued;
357 
358                 ret = ioctl(fd, FIONREAD, &queued);
359                 if (ret < 0)
360                         die_perror("FIONREAD");
361                 if (queued > expect_len)
362                         xerror("FIONREAD returned %u, but only %zu expected\n",
363                                queued, expect_len);
364                 if (queued == expect_len)
365                         break;
366 
367                 req.tv_sec = 0;
368                 req.tv_nsec = 1000 * 1000ul;
369                 nanosleep(&req, NULL);
370         }
371 
372         /* read one byte, expect cmsg to return expected - 1 */
373         ret = recvmsg(fd, &msg, 0);
374         if (ret < 0)
375                 die_perror("recvmsg");
376 
377         if (msg.msg_controllen == 0)
378                 xerror("msg_controllen is 0");
379 
380         get_tcp_inq(&msg, &tcp_inq);
381 
382         assert((size_t)tcp_inq == (expect_len - 1));
383 
384         iov.iov_len = sizeof(buf);
385         ret = recvmsg(fd, &msg, 0);
386         if (ret < 0)
387                 die_perror("recvmsg");
388 
389         /* should have gotten exact remainder of all pending data */
390         assert(ret == (ssize_t)tcp_inq);
391 
392         /* should be 0, all drained */
393         get_tcp_inq(&msg, &tcp_inq);
394         assert(tcp_inq == 0);
395 
396         /* request a large swath of data. */
397         ret = write(unixfd, "huge", 4);
398         assert(ret == 4);
399 
400         ret = read(unixfd, &expect_len, sizeof(expect_len));
401         assert(ret == (ssize_t)sizeof(expect_len));
402 
403         /* peer should send us a few mb of data */
404         if (expect_len <= sizeof(buf))
405                 xerror("expect len %zu too small\n", expect_len);
406 
407         tot = 0;
408         do {
409                 iov.iov_len = sizeof(buf);
410                 ret = recvmsg(fd, &msg, 0);
411                 if (ret < 0)
412                         die_perror("recvmsg");
413 
414                 tot += ret;
415 
416                 get_tcp_inq(&msg, &tcp_inq);
417 
418                 if (tcp_inq > expect_len - tot)
419                         xerror("inq %d, remaining %d total_len %d\n",
420                                tcp_inq, expect_len - tot, (int)expect_len);
421 
422                 assert(tcp_inq <= expect_len - tot);
423         } while ((size_t)tot < expect_len);
424 
425         ret = write(unixfd, "shut", 4);
426         assert(ret == 4);
427 
428         /* wait for hangup. Should have received one more byte of data. */
429         ret = read(unixfd, tmp, sizeof(tmp));
430         assert(ret == 6);
431         assert(strncmp(tmp, "closed", 6) == 0);
432 
433         sleep(1);
434 
435         iov.iov_len = 1;
436         ret = recvmsg(fd, &msg, 0);
437         if (ret < 0)
438                 die_perror("recvmsg");
439         assert(ret == 1);
440 
441         get_tcp_inq(&msg, &tcp_inq);
442 
443         /* tcp_inq should be 1 due to received fin. */
444         assert(tcp_inq == 1);
445 
446         iov.iov_len = 1;
447         ret = recvmsg(fd, &msg, 0);
448         if (ret < 0)
449                 die_perror("recvmsg");
450 
451         /* expect EOF */
452         assert(ret == 0);
453         get_tcp_inq(&msg, &tcp_inq);
454         assert(tcp_inq == 1);
455 
456         close(fd);
457 }
458 
459 static int xaccept(int s)
460 {
461         int fd = accept(s, NULL, 0);
462 
463         if (fd < 0)
464                 die_perror("accept");
465 
466         return fd;
467 }
468 
469 static int server(int unixfd)
470 {
471         int fd = -1, r, on = 1;
472 
473         switch (pf) {
474         case AF_INET:
475                 fd = sock_listen_mptcp("127.0.0.1", "15432");
476                 break;
477         case AF_INET6:
478                 fd = sock_listen_mptcp("::1", "15432");
479                 break;
480         default:
481                 xerror("Unknown pf %d\n", pf);
482                 break;
483         }
484 
485         r = write(unixfd, "conn", 4);
486         assert(r == 4);
487 
488         alarm(15);
489         r = xaccept(fd);
490 
491         if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
492                 die_perror("setsockopt");
493 
494         process_one_client(r, unixfd);
495 
496         return 0;
497 }
498 
499 static int client(int unixfd)
500 {
501         int fd = -1;
502 
503         alarm(15);
504 
505         switch (pf) {
506         case AF_INET:
507                 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
508                 break;
509         case AF_INET6:
510                 fd = sock_connect_mptcp("::1", "15432", proto_tx);
511                 break;
512         default:
513                 xerror("Unknown pf %d\n", pf);
514         }
515 
516         connect_one_server(fd, unixfd);
517 
518         return 0;
519 }
520 
521 static void init_rng(void)
522 {
523         unsigned int foo;
524 
525         if (getrandom(&foo, sizeof(foo), 0) == -1) {
526                 perror("getrandom");
527                 exit(1);
528         }
529 
530         srand(foo);
531 }
532 
533 static pid_t xfork(void)
534 {
535         pid_t p = fork();
536 
537         if (p < 0)
538                 die_perror("fork");
539         else if (p == 0)
540                 init_rng();
541 
542         return p;
543 }
544 
545 static int rcheck(int wstatus, const char *what)
546 {
547         if (WIFEXITED(wstatus)) {
548                 if (WEXITSTATUS(wstatus) == 0)
549                         return 0;
550                 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
551                 return WEXITSTATUS(wstatus);
552         } else if (WIFSIGNALED(wstatus)) {
553                 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
554         } else if (WIFSTOPPED(wstatus)) {
555                 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
556         }
557 
558         return 111;
559 }
560 
561 int main(int argc, char *argv[])
562 {
563         int e1, e2, wstatus;
564         pid_t s, c, ret;
565         int unixfds[2];
566 
567         parse_opts(argc, argv);
568 
569         e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
570         if (e1 < 0)
571                 die_perror("pipe");
572 
573         s = xfork();
574         if (s == 0)
575                 return server(unixfds[1]);
576 
577         close(unixfds[1]);
578 
579         /* wait until server bound a socket */
580         e1 = read(unixfds[0], &e1, 4);
581         assert(e1 == 4);
582 
583         c = xfork();
584         if (c == 0)
585                 return client(unixfds[0]);
586 
587         close(unixfds[0]);
588 
589         ret = waitpid(s, &wstatus, 0);
590         if (ret == -1)
591                 die_perror("waitpid");
592         e1 = rcheck(wstatus, "server");
593         ret = waitpid(c, &wstatus, 0);
594         if (ret == -1)
595                 die_perror("waitpid");
596         e2 = rcheck(wstatus, "client");
597 
598         return e1 ? e1 : e2;
599 }
600 

~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

kernel.org | git.kernel.org | LWN.net | Project Home | SVN repository | Mail admin

Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.

sflogo.php