~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

TOMOYO Linux Cross Reference
Linux/tools/testing/vsock/vsock_diag_test.c

Version: ~ [ linux-6.11.5 ] ~ [ linux-6.10.14 ] ~ [ linux-6.9.12 ] ~ [ linux-6.8.12 ] ~ [ linux-6.7.12 ] ~ [ linux-6.6.58 ] ~ [ linux-6.5.13 ] ~ [ linux-6.4.16 ] ~ [ linux-6.3.13 ] ~ [ linux-6.2.16 ] ~ [ linux-6.1.114 ] ~ [ linux-6.0.19 ] ~ [ linux-5.19.17 ] ~ [ linux-5.18.19 ] ~ [ linux-5.17.15 ] ~ [ linux-5.16.20 ] ~ [ linux-5.15.169 ] ~ [ linux-5.14.21 ] ~ [ linux-5.13.19 ] ~ [ linux-5.12.19 ] ~ [ linux-5.11.22 ] ~ [ linux-5.10.228 ] ~ [ linux-5.9.16 ] ~ [ linux-5.8.18 ] ~ [ linux-5.7.19 ] ~ [ linux-5.6.19 ] ~ [ linux-5.5.19 ] ~ [ linux-5.4.284 ] ~ [ linux-5.3.18 ] ~ [ linux-5.2.21 ] ~ [ linux-5.1.21 ] ~ [ linux-5.0.21 ] ~ [ linux-4.20.17 ] ~ [ linux-4.19.322 ] ~ [ linux-4.18.20 ] ~ [ linux-4.17.19 ] ~ [ linux-4.16.18 ] ~ [ linux-4.15.18 ] ~ [ linux-4.14.336 ] ~ [ linux-4.13.16 ] ~ [ linux-4.12.14 ] ~ [ linux-4.11.12 ] ~ [ linux-4.10.17 ] ~ [ linux-4.9.337 ] ~ [ linux-4.4.302 ] ~ [ linux-3.10.108 ] ~ [ linux-2.6.32.71 ] ~ [ linux-2.6.0 ] ~ [ linux-2.4.37.11 ] ~ [ unix-v6-master ] ~ [ ccs-tools-1.8.9 ] ~ [ policy-sample ] ~
Architecture: ~ [ i386 ] ~ [ alpha ] ~ [ m68k ] ~ [ mips ] ~ [ ppc ] ~ [ sparc ] ~ [ sparc64 ] ~

  1 // SPDX-License-Identifier: GPL-2.0-only
  2 /*
  3  * vsock_diag_test - vsock_diag.ko test suite
  4  *
  5  * Copyright (C) 2017 Red Hat, Inc.
  6  *
  7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
  8  */
  9 
 10 #include <getopt.h>
 11 #include <stdio.h>
 12 #include <stdlib.h>
 13 #include <string.h>
 14 #include <errno.h>
 15 #include <unistd.h>
 16 #include <sys/stat.h>
 17 #include <sys/types.h>
 18 #include <linux/list.h>
 19 #include <linux/net.h>
 20 #include <linux/netlink.h>
 21 #include <linux/sock_diag.h>
 22 #include <linux/vm_sockets_diag.h>
 23 #include <netinet/tcp.h>
 24 
 25 #include "timeout.h"
 26 #include "control.h"
 27 #include "util.h"
 28 
 29 /* Per-socket status */
 30 struct vsock_stat {
 31         struct list_head list;
 32         struct vsock_diag_msg msg;
 33 };
 34 
 35 static const char *sock_type_str(int type)
 36 {
 37         switch (type) {
 38         case SOCK_DGRAM:
 39                 return "DGRAM";
 40         case SOCK_STREAM:
 41                 return "STREAM";
 42         case SOCK_SEQPACKET:
 43                 return "SEQPACKET";
 44         default:
 45                 return "INVALID TYPE";
 46         }
 47 }
 48 
 49 static const char *sock_state_str(int state)
 50 {
 51         switch (state) {
 52         case TCP_CLOSE:
 53                 return "UNCONNECTED";
 54         case TCP_SYN_SENT:
 55                 return "CONNECTING";
 56         case TCP_ESTABLISHED:
 57                 return "CONNECTED";
 58         case TCP_CLOSING:
 59                 return "DISCONNECTING";
 60         case TCP_LISTEN:
 61                 return "LISTEN";
 62         default:
 63                 return "INVALID STATE";
 64         }
 65 }
 66 
 67 static const char *sock_shutdown_str(int shutdown)
 68 {
 69         switch (shutdown) {
 70         case 1:
 71                 return "RCV_SHUTDOWN";
 72         case 2:
 73                 return "SEND_SHUTDOWN";
 74         case 3:
 75                 return "RCV_SHUTDOWN | SEND_SHUTDOWN";
 76         default:
 77                 return "";
 78         }
 79 }
 80 
 81 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
 82 {
 83         if (cid == VMADDR_CID_ANY)
 84                 fprintf(fp, "*:");
 85         else
 86                 fprintf(fp, "%u:", cid);
 87 
 88         if (port == VMADDR_PORT_ANY)
 89                 fprintf(fp, "*");
 90         else
 91                 fprintf(fp, "%u", port);
 92 }
 93 
 94 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
 95 {
 96         print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
 97         fprintf(fp, " ");
 98         print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
 99         fprintf(fp, " %s %s %s %u\n",
100                 sock_type_str(st->msg.vdiag_type),
101                 sock_state_str(st->msg.vdiag_state),
102                 sock_shutdown_str(st->msg.vdiag_shutdown),
103                 st->msg.vdiag_ino);
104 }
105 
106 static void print_vsock_stats(FILE *fp, struct list_head *head)
107 {
108         struct vsock_stat *st;
109 
110         list_for_each_entry(st, head, list)
111                 print_vsock_stat(fp, st);
112 }
113 
114 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
115 {
116         struct vsock_stat *st;
117         struct stat stat;
118 
119         if (fstat(fd, &stat) < 0) {
120                 perror("fstat");
121                 exit(EXIT_FAILURE);
122         }
123 
124         list_for_each_entry(st, head, list)
125                 if (st->msg.vdiag_ino == stat.st_ino)
126                         return st;
127 
128         fprintf(stderr, "cannot find fd %d\n", fd);
129         exit(EXIT_FAILURE);
130 }
131 
132 static void check_no_sockets(struct list_head *head)
133 {
134         if (!list_empty(head)) {
135                 fprintf(stderr, "expected no sockets\n");
136                 print_vsock_stats(stderr, head);
137                 exit(1);
138         }
139 }
140 
141 static void check_num_sockets(struct list_head *head, int expected)
142 {
143         struct list_head *node;
144         int n = 0;
145 
146         list_for_each(node, head)
147                 n++;
148 
149         if (n != expected) {
150                 fprintf(stderr, "expected %d sockets, found %d\n",
151                         expected, n);
152                 print_vsock_stats(stderr, head);
153                 exit(EXIT_FAILURE);
154         }
155 }
156 
157 static void check_socket_state(struct vsock_stat *st, __u8 state)
158 {
159         if (st->msg.vdiag_state != state) {
160                 fprintf(stderr, "expected socket state %#x, got %#x\n",
161                         state, st->msg.vdiag_state);
162                 exit(EXIT_FAILURE);
163         }
164 }
165 
166 static void send_req(int fd)
167 {
168         struct sockaddr_nl nladdr = {
169                 .nl_family = AF_NETLINK,
170         };
171         struct {
172                 struct nlmsghdr nlh;
173                 struct vsock_diag_req vreq;
174         } req = {
175                 .nlh = {
176                         .nlmsg_len = sizeof(req),
177                         .nlmsg_type = SOCK_DIAG_BY_FAMILY,
178                         .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
179                 },
180                 .vreq = {
181                         .sdiag_family = AF_VSOCK,
182                         .vdiag_states = ~(__u32)0,
183                 },
184         };
185         struct iovec iov = {
186                 .iov_base = &req,
187                 .iov_len = sizeof(req),
188         };
189         struct msghdr msg = {
190                 .msg_name = &nladdr,
191                 .msg_namelen = sizeof(nladdr),
192                 .msg_iov = &iov,
193                 .msg_iovlen = 1,
194         };
195 
196         for (;;) {
197                 if (sendmsg(fd, &msg, 0) < 0) {
198                         if (errno == EINTR)
199                                 continue;
200 
201                         perror("sendmsg");
202                         exit(EXIT_FAILURE);
203                 }
204 
205                 return;
206         }
207 }
208 
209 static ssize_t recv_resp(int fd, void *buf, size_t len)
210 {
211         struct sockaddr_nl nladdr = {
212                 .nl_family = AF_NETLINK,
213         };
214         struct iovec iov = {
215                 .iov_base = buf,
216                 .iov_len = len,
217         };
218         struct msghdr msg = {
219                 .msg_name = &nladdr,
220                 .msg_namelen = sizeof(nladdr),
221                 .msg_iov = &iov,
222                 .msg_iovlen = 1,
223         };
224         ssize_t ret;
225 
226         do {
227                 ret = recvmsg(fd, &msg, 0);
228         } while (ret < 0 && errno == EINTR);
229 
230         if (ret < 0) {
231                 perror("recvmsg");
232                 exit(EXIT_FAILURE);
233         }
234 
235         return ret;
236 }
237 
238 static void add_vsock_stat(struct list_head *sockets,
239                            const struct vsock_diag_msg *resp)
240 {
241         struct vsock_stat *st;
242 
243         st = malloc(sizeof(*st));
244         if (!st) {
245                 perror("malloc");
246                 exit(EXIT_FAILURE);
247         }
248 
249         st->msg = *resp;
250         list_add_tail(&st->list, sockets);
251 }
252 
253 /*
254  * Read vsock stats into a list.
255  */
256 static void read_vsock_stat(struct list_head *sockets)
257 {
258         long buf[8192 / sizeof(long)];
259         int fd;
260 
261         fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
262         if (fd < 0) {
263                 perror("socket");
264                 exit(EXIT_FAILURE);
265         }
266 
267         send_req(fd);
268 
269         for (;;) {
270                 const struct nlmsghdr *h;
271                 ssize_t ret;
272 
273                 ret = recv_resp(fd, buf, sizeof(buf));
274                 if (ret == 0)
275                         goto done;
276                 if (ret < sizeof(*h)) {
277                         fprintf(stderr, "short read of %zd bytes\n", ret);
278                         exit(EXIT_FAILURE);
279                 }
280 
281                 h = (struct nlmsghdr *)buf;
282 
283                 while (NLMSG_OK(h, ret)) {
284                         if (h->nlmsg_type == NLMSG_DONE)
285                                 goto done;
286 
287                         if (h->nlmsg_type == NLMSG_ERROR) {
288                                 const struct nlmsgerr *err = NLMSG_DATA(h);
289 
290                                 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
291                                         fprintf(stderr, "NLMSG_ERROR\n");
292                                 else {
293                                         errno = -err->error;
294                                         perror("NLMSG_ERROR");
295                                 }
296 
297                                 exit(EXIT_FAILURE);
298                         }
299 
300                         if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
301                                 fprintf(stderr, "unexpected nlmsg_type %#x\n",
302                                         h->nlmsg_type);
303                                 exit(EXIT_FAILURE);
304                         }
305                         if (h->nlmsg_len <
306                             NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
307                                 fprintf(stderr, "short vsock_diag_msg\n");
308                                 exit(EXIT_FAILURE);
309                         }
310 
311                         add_vsock_stat(sockets, NLMSG_DATA(h));
312 
313                         h = NLMSG_NEXT(h, ret);
314                 }
315         }
316 
317 done:
318         close(fd);
319 }
320 
321 static void free_sock_stat(struct list_head *sockets)
322 {
323         struct vsock_stat *st;
324         struct vsock_stat *next;
325 
326         list_for_each_entry_safe(st, next, sockets, list)
327                 free(st);
328 }
329 
330 static void test_no_sockets(const struct test_opts *opts)
331 {
332         LIST_HEAD(sockets);
333 
334         read_vsock_stat(&sockets);
335 
336         check_no_sockets(&sockets);
337 }
338 
339 static void test_listen_socket_server(const struct test_opts *opts)
340 {
341         union {
342                 struct sockaddr sa;
343                 struct sockaddr_vm svm;
344         } addr = {
345                 .svm = {
346                         .svm_family = AF_VSOCK,
347                         .svm_port = opts->peer_port,
348                         .svm_cid = VMADDR_CID_ANY,
349                 },
350         };
351         LIST_HEAD(sockets);
352         struct vsock_stat *st;
353         int fd;
354 
355         fd = socket(AF_VSOCK, SOCK_STREAM, 0);
356 
357         if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
358                 perror("bind");
359                 exit(EXIT_FAILURE);
360         }
361 
362         if (listen(fd, 1) < 0) {
363                 perror("listen");
364                 exit(EXIT_FAILURE);
365         }
366 
367         read_vsock_stat(&sockets);
368 
369         check_num_sockets(&sockets, 1);
370         st = find_vsock_stat(&sockets, fd);
371         check_socket_state(st, TCP_LISTEN);
372 
373         close(fd);
374         free_sock_stat(&sockets);
375 }
376 
377 static void test_connect_client(const struct test_opts *opts)
378 {
379         int fd;
380         LIST_HEAD(sockets);
381         struct vsock_stat *st;
382 
383         fd = vsock_stream_connect(opts->peer_cid, opts->peer_port);
384         if (fd < 0) {
385                 perror("connect");
386                 exit(EXIT_FAILURE);
387         }
388 
389         read_vsock_stat(&sockets);
390 
391         check_num_sockets(&sockets, 1);
392         st = find_vsock_stat(&sockets, fd);
393         check_socket_state(st, TCP_ESTABLISHED);
394 
395         control_expectln("DONE");
396         control_writeln("DONE");
397 
398         close(fd);
399         free_sock_stat(&sockets);
400 }
401 
402 static void test_connect_server(const struct test_opts *opts)
403 {
404         struct vsock_stat *st;
405         LIST_HEAD(sockets);
406         int client_fd;
407 
408         client_fd = vsock_stream_accept(VMADDR_CID_ANY, opts->peer_port, NULL);
409         if (client_fd < 0) {
410                 perror("accept");
411                 exit(EXIT_FAILURE);
412         }
413 
414         read_vsock_stat(&sockets);
415 
416         check_num_sockets(&sockets, 1);
417         st = find_vsock_stat(&sockets, client_fd);
418         check_socket_state(st, TCP_ESTABLISHED);
419 
420         control_writeln("DONE");
421         control_expectln("DONE");
422 
423         close(client_fd);
424         free_sock_stat(&sockets);
425 }
426 
427 static struct test_case test_cases[] = {
428         {
429                 .name = "No sockets",
430                 .run_server = test_no_sockets,
431         },
432         {
433                 .name = "Listen socket",
434                 .run_server = test_listen_socket_server,
435         },
436         {
437                 .name = "Connect",
438                 .run_client = test_connect_client,
439                 .run_server = test_connect_server,
440         },
441         {},
442 };
443 
444 static const char optstring[] = "";
445 static const struct option longopts[] = {
446         {
447                 .name = "control-host",
448                 .has_arg = required_argument,
449                 .val = 'H',
450         },
451         {
452                 .name = "control-port",
453                 .has_arg = required_argument,
454                 .val = 'P',
455         },
456         {
457                 .name = "mode",
458                 .has_arg = required_argument,
459                 .val = 'm',
460         },
461         {
462                 .name = "peer-cid",
463                 .has_arg = required_argument,
464                 .val = 'p',
465         },
466         {
467                 .name = "peer-port",
468                 .has_arg = required_argument,
469                 .val = 'q',
470         },
471         {
472                 .name = "list",
473                 .has_arg = no_argument,
474                 .val = 'l',
475         },
476         {
477                 .name = "skip",
478                 .has_arg = required_argument,
479                 .val = 's',
480         },
481         {
482                 .name = "help",
483                 .has_arg = no_argument,
484                 .val = '?',
485         },
486         {},
487 };
488 
489 static void usage(void)
490 {
491         fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--peer-port=<port>] [--list] [--skip=<test_id>]\n"
492                 "\n"
493                 "  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
494                 "  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
495                 "\n"
496                 "Run vsock_diag.ko tests.  Must be launched in both\n"
497                 "guest and host.  One side must use --mode=client and\n"
498                 "the other side must use --mode=server.\n"
499                 "\n"
500                 "A TCP control socket connection is used to coordinate tests\n"
501                 "between the client and the server.  The server requires a\n"
502                 "listen address and the client requires an address to\n"
503                 "connect to.\n"
504                 "\n"
505                 "The CID of the other side must be given with --peer-cid=<cid>.\n"
506                 "\n"
507                 "Options:\n"
508                 "  --help                 This help message\n"
509                 "  --control-host <host>  Server IP address to connect to\n"
510                 "  --control-port <port>  Server port to listen on/connect to\n"
511                 "  --mode client|server   Server or client mode\n"
512                 "  --peer-cid <cid>       CID of the other side\n"
513                 "  --peer-port <port>     AF_VSOCK port used for the test [default: %d]\n"
514                 "  --list                 List of tests that will be executed\n"
515                 "  --skip <test_id>       Test ID to skip;\n"
516                 "                         use multiple --skip options to skip more tests\n",
517                 DEFAULT_PEER_PORT
518                 );
519         exit(EXIT_FAILURE);
520 }
521 
522 int main(int argc, char **argv)
523 {
524         const char *control_host = NULL;
525         const char *control_port = NULL;
526         struct test_opts opts = {
527                 .mode = TEST_MODE_UNSET,
528                 .peer_cid = VMADDR_CID_ANY,
529                 .peer_port = DEFAULT_PEER_PORT,
530         };
531 
532         init_signals();
533 
534         for (;;) {
535                 int opt = getopt_long(argc, argv, optstring, longopts, NULL);
536 
537                 if (opt == -1)
538                         break;
539 
540                 switch (opt) {
541                 case 'H':
542                         control_host = optarg;
543                         break;
544                 case 'm':
545                         if (strcmp(optarg, "client") == 0)
546                                 opts.mode = TEST_MODE_CLIENT;
547                         else if (strcmp(optarg, "server") == 0)
548                                 opts.mode = TEST_MODE_SERVER;
549                         else {
550                                 fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
551                                 return EXIT_FAILURE;
552                         }
553                         break;
554                 case 'p':
555                         opts.peer_cid = parse_cid(optarg);
556                         break;
557                 case 'q':
558                         opts.peer_port = parse_port(optarg);
559                         break;
560                 case 'P':
561                         control_port = optarg;
562                         break;
563                 case 'l':
564                         list_tests(test_cases);
565                         break;
566                 case 's':
567                         skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
568                                   optarg);
569                         break;
570                 case '?':
571                 default:
572                         usage();
573                 }
574         }
575 
576         if (!control_port)
577                 usage();
578         if (opts.mode == TEST_MODE_UNSET)
579                 usage();
580         if (opts.peer_cid == VMADDR_CID_ANY)
581                 usage();
582 
583         if (!control_host) {
584                 if (opts.mode != TEST_MODE_SERVER)
585                         usage();
586                 control_host = "0.0.0.0";
587         }
588 
589         control_init(control_host, control_port,
590                      opts.mode == TEST_MODE_SERVER);
591 
592         run_tests(test_cases, &opts);
593 
594         control_cleanup();
595         return EXIT_SUCCESS;
596 }
597 

~ [ source navigation ] ~ [ diff markup ] ~ [ identifier search ] ~

kernel.org | git.kernel.org | LWN.net | Project Home | SVN repository | Mail admin

Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.

sflogo.php