1 // SPDX-License-Identifier: GPL-2.0 1 // SPDX-License-Identifier: GPL-2.0 >> 2 #include <linux/init.h> 2 #include <linux/static_call.h> 3 #include <linux/static_call.h> >> 4 #include <linux/bug.h> >> 5 #include <linux/smp.h> >> 6 #include <linux/sort.h> >> 7 #include <linux/slab.h> >> 8 #include <linux/module.h> >> 9 #include <linux/cpu.h> >> 10 #include <linux/processor.h> >> 11 #include <asm/sections.h> >> 12 >> 13 extern struct static_call_site __start_static_call_sites[], >> 14 __stop_static_call_sites[]; >> 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], >> 16 __stop_static_call_tramp_key[]; >> 17 >> 18 static bool static_call_initialized; >> 19 >> 20 /* mutex to protect key modules/sites */ >> 21 static DEFINE_MUTEX(static_call_mutex); >> 22 >> 23 static void static_call_lock(void) >> 24 { >> 25 mutex_lock(&static_call_mutex); >> 26 } >> 27 >> 28 static void static_call_unlock(void) >> 29 { >> 30 mutex_unlock(&static_call_mutex); >> 31 } >> 32 >> 33 static inline void *static_call_addr(struct static_call_site *site) >> 34 { >> 35 return (void *)((long)site->addr + (long)&site->addr); >> 36 } >> 37 >> 38 static inline unsigned long __static_call_key(const struct static_call_site *site) >> 39 { >> 40 return (long)site->key + (long)&site->key; >> 41 } >> 42 >> 43 static inline struct static_call_key *static_call_key(const struct static_call_site *site) >> 44 { >> 45 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); >> 46 } >> 47 >> 48 /* These assume the key is word-aligned. */ >> 49 static inline bool static_call_is_init(struct static_call_site *site) >> 50 { >> 51 return __static_call_key(site) & STATIC_CALL_SITE_INIT; >> 52 } >> 53 >> 54 static inline bool static_call_is_tail(struct static_call_site *site) >> 55 { >> 56 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; >> 57 } >> 58 >> 59 static inline void static_call_set_init(struct static_call_site *site) >> 60 { >> 61 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - >> 62 (long)&site->key; >> 63 } >> 64 >> 65 static int static_call_site_cmp(const void *_a, const void *_b) >> 66 { >> 67 const struct static_call_site *a = _a; >> 68 const struct static_call_site *b = _b; >> 69 const struct static_call_key *key_a = static_call_key(a); >> 70 const struct static_call_key *key_b = static_call_key(b); >> 71 >> 72 if (key_a < key_b) >> 73 return -1; >> 74 >> 75 if (key_a > key_b) >> 76 return 1; >> 77 >> 78 return 0; >> 79 } >> 80 >> 81 static void static_call_site_swap(void *_a, void *_b, int size) >> 82 { >> 83 long delta = (unsigned long)_a - (unsigned long)_b; >> 84 struct static_call_site *a = _a; >> 85 struct static_call_site *b = _b; >> 86 struct static_call_site tmp = *a; >> 87 >> 88 a->addr = b->addr - delta; >> 89 a->key = b->key - delta; >> 90 >> 91 b->addr = tmp.addr + delta; >> 92 b->key = tmp.key + delta; >> 93 } >> 94 >> 95 static inline void static_call_sort_entries(struct static_call_site *start, >> 96 struct static_call_site *stop) >> 97 { >> 98 sort(start, stop - start, sizeof(struct static_call_site), >> 99 static_call_site_cmp, static_call_site_swap); >> 100 } >> 101 >> 102 static inline bool static_call_key_has_mods(struct static_call_key *key) >> 103 { >> 104 return !(key->type & 1); >> 105 } >> 106 >> 107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) >> 108 { >> 109 if (!static_call_key_has_mods(key)) >> 110 return NULL; >> 111 >> 112 return key->mods; >> 113 } >> 114 >> 115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) >> 116 { >> 117 if (static_call_key_has_mods(key)) >> 118 return NULL; >> 119 >> 120 return (struct static_call_site *)(key->type & ~1); >> 121 } >> 122 >> 123 void __static_call_update(struct static_call_key *key, void *tramp, void *func) >> 124 { >> 125 struct static_call_site *site, *stop; >> 126 struct static_call_mod *site_mod, first; >> 127 >> 128 cpus_read_lock(); >> 129 static_call_lock(); >> 130 >> 131 if (key->func == func) >> 132 goto done; >> 133 >> 134 key->func = func; >> 135 >> 136 arch_static_call_transform(NULL, tramp, func, false); >> 137 >> 138 /* >> 139 * If uninitialized, we'll not update the callsites, but they still >> 140 * point to the trampoline and we just patched that. >> 141 */ >> 142 if (WARN_ON_ONCE(!static_call_initialized)) >> 143 goto done; >> 144 >> 145 first = (struct static_call_mod){ >> 146 .next = static_call_key_next(key), >> 147 .mod = NULL, >> 148 .sites = static_call_key_sites(key), >> 149 }; >> 150 >> 151 for (site_mod = &first; site_mod; site_mod = site_mod->next) { >> 152 bool init = system_state < SYSTEM_RUNNING; >> 153 struct module *mod = site_mod->mod; >> 154 >> 155 if (!site_mod->sites) { >> 156 /* >> 157 * This can happen if the static call key is defined in >> 158 * a module which doesn't use it. >> 159 * >> 160 * It also happens in the has_mods case, where the >> 161 * 'first' entry has no sites associated with it. >> 162 */ >> 163 continue; >> 164 } >> 165 >> 166 stop = __stop_static_call_sites; >> 167 >> 168 if (mod) { >> 169 #ifdef CONFIG_MODULES >> 170 stop = mod->static_call_sites + >> 171 mod->num_static_call_sites; >> 172 init = mod->state == MODULE_STATE_COMING; >> 173 #endif >> 174 } >> 175 >> 176 for (site = site_mod->sites; >> 177 site < stop && static_call_key(site) == key; site++) { >> 178 void *site_addr = static_call_addr(site); >> 179 >> 180 if (!init && static_call_is_init(site)) >> 181 continue; >> 182 >> 183 if (!kernel_text_address((unsigned long)site_addr)) { >> 184 /* >> 185 * This skips patching built-in __exit, which >> 186 * is part of init_section_contains() but is >> 187 * not part of kernel_text_address(). >> 188 * >> 189 * Skipping built-in __exit is fine since it >> 190 * will never be executed. >> 191 */ >> 192 WARN_ONCE(!static_call_is_init(site), >> 193 "can't patch static call site at %pS", >> 194 site_addr); >> 195 continue; >> 196 } >> 197 >> 198 arch_static_call_transform(site_addr, NULL, func, >> 199 static_call_is_tail(site)); >> 200 } >> 201 } >> 202 >> 203 done: >> 204 static_call_unlock(); >> 205 cpus_read_unlock(); >> 206 } >> 207 EXPORT_SYMBOL_GPL(__static_call_update); >> 208 >> 209 static int __static_call_init(struct module *mod, >> 210 struct static_call_site *start, >> 211 struct static_call_site *stop) >> 212 { >> 213 struct static_call_site *site; >> 214 struct static_call_key *key, *prev_key = NULL; >> 215 struct static_call_mod *site_mod; >> 216 >> 217 if (start == stop) >> 218 return 0; >> 219 >> 220 static_call_sort_entries(start, stop); >> 221 >> 222 for (site = start; site < stop; site++) { >> 223 void *site_addr = static_call_addr(site); >> 224 >> 225 if ((mod && within_module_init((unsigned long)site_addr, mod)) || >> 226 (!mod && init_section_contains(site_addr, 1))) >> 227 static_call_set_init(site); >> 228 >> 229 key = static_call_key(site); >> 230 if (key != prev_key) { >> 231 prev_key = key; >> 232 >> 233 /* >> 234 * For vmlinux (!mod) avoid the allocation by storing >> 235 * the sites pointer in the key itself. Also see >> 236 * __static_call_update()'s @first. >> 237 * >> 238 * This allows architectures (eg. x86) to call >> 239 * static_call_init() before memory allocation works. >> 240 */ >> 241 if (!mod) { >> 242 key->sites = site; >> 243 key->type |= 1; >> 244 goto do_transform; >> 245 } >> 246 >> 247 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); >> 248 if (!site_mod) >> 249 return -ENOMEM; >> 250 >> 251 /* >> 252 * When the key has a direct sites pointer, extract >> 253 * that into an explicit struct static_call_mod, so we >> 254 * can have a list of modules. >> 255 */ >> 256 if (static_call_key_sites(key)) { >> 257 site_mod->mod = NULL; >> 258 site_mod->next = NULL; >> 259 site_mod->sites = static_call_key_sites(key); >> 260 >> 261 key->mods = site_mod; >> 262 >> 263 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); >> 264 if (!site_mod) >> 265 return -ENOMEM; >> 266 } >> 267 >> 268 site_mod->mod = mod; >> 269 site_mod->sites = site; >> 270 site_mod->next = static_call_key_next(key); >> 271 key->mods = site_mod; >> 272 } >> 273 >> 274 do_transform: >> 275 arch_static_call_transform(site_addr, NULL, key->func, >> 276 static_call_is_tail(site)); >> 277 } >> 278 >> 279 return 0; >> 280 } >> 281 >> 282 static int addr_conflict(struct static_call_site *site, void *start, void *end) >> 283 { >> 284 unsigned long addr = (unsigned long)static_call_addr(site); >> 285 >> 286 if (addr <= (unsigned long)end && >> 287 addr + CALL_INSN_SIZE > (unsigned long)start) >> 288 return 1; >> 289 >> 290 return 0; >> 291 } >> 292 >> 293 static int __static_call_text_reserved(struct static_call_site *iter_start, >> 294 struct static_call_site *iter_stop, >> 295 void *start, void *end, bool init) >> 296 { >> 297 struct static_call_site *iter = iter_start; >> 298 >> 299 while (iter < iter_stop) { >> 300 if (init || !static_call_is_init(iter)) { >> 301 if (addr_conflict(iter, start, end)) >> 302 return 1; >> 303 } >> 304 iter++; >> 305 } >> 306 >> 307 return 0; >> 308 } >> 309 >> 310 #ifdef CONFIG_MODULES >> 311 >> 312 static int __static_call_mod_text_reserved(void *start, void *end) >> 313 { >> 314 struct module *mod; >> 315 int ret; >> 316 >> 317 preempt_disable(); >> 318 mod = __module_text_address((unsigned long)start); >> 319 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); >> 320 if (!try_module_get(mod)) >> 321 mod = NULL; >> 322 preempt_enable(); >> 323 >> 324 if (!mod) >> 325 return 0; >> 326 >> 327 ret = __static_call_text_reserved(mod->static_call_sites, >> 328 mod->static_call_sites + mod->num_static_call_sites, >> 329 start, end, mod->state == MODULE_STATE_COMING); >> 330 >> 331 module_put(mod); >> 332 >> 333 return ret; >> 334 } >> 335 >> 336 static unsigned long tramp_key_lookup(unsigned long addr) >> 337 { >> 338 struct static_call_tramp_key *start = __start_static_call_tramp_key; >> 339 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; >> 340 struct static_call_tramp_key *tramp_key; >> 341 >> 342 for (tramp_key = start; tramp_key != stop; tramp_key++) { >> 343 unsigned long tramp; >> 344 >> 345 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; >> 346 if (tramp == addr) >> 347 return (long)tramp_key->key + (long)&tramp_key->key; >> 348 } >> 349 >> 350 return 0; >> 351 } >> 352 >> 353 static int static_call_add_module(struct module *mod) >> 354 { >> 355 struct static_call_site *start = mod->static_call_sites; >> 356 struct static_call_site *stop = start + mod->num_static_call_sites; >> 357 struct static_call_site *site; >> 358 >> 359 for (site = start; site != stop; site++) { >> 360 unsigned long s_key = __static_call_key(site); >> 361 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; >> 362 unsigned long key; >> 363 >> 364 /* >> 365 * Is the key is exported, 'addr' points to the key, which >> 366 * means modules are allowed to call static_call_update() on >> 367 * it. >> 368 * >> 369 * Otherwise, the key isn't exported, and 'addr' points to the >> 370 * trampoline so we need to lookup the key. >> 371 * >> 372 * We go through this dance to prevent crazy modules from >> 373 * abusing sensitive static calls. >> 374 */ >> 375 if (!kernel_text_address(addr)) >> 376 continue; >> 377 >> 378 key = tramp_key_lookup(addr); >> 379 if (!key) { >> 380 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", >> 381 static_call_addr(site)); >> 382 return -EINVAL; >> 383 } >> 384 >> 385 key |= s_key & STATIC_CALL_SITE_FLAGS; >> 386 site->key = key - (long)&site->key; >> 387 } >> 388 >> 389 return __static_call_init(mod, start, stop); >> 390 } >> 391 >> 392 static void static_call_del_module(struct module *mod) >> 393 { >> 394 struct static_call_site *start = mod->static_call_sites; >> 395 struct static_call_site *stop = mod->static_call_sites + >> 396 mod->num_static_call_sites; >> 397 struct static_call_key *key, *prev_key = NULL; >> 398 struct static_call_mod *site_mod, **prev; >> 399 struct static_call_site *site; >> 400 >> 401 for (site = start; site < stop; site++) { >> 402 key = static_call_key(site); >> 403 if (key == prev_key) >> 404 continue; >> 405 >> 406 prev_key = key; >> 407 >> 408 for (prev = &key->mods, site_mod = key->mods; >> 409 site_mod && site_mod->mod != mod; >> 410 prev = &site_mod->next, site_mod = site_mod->next) >> 411 ; >> 412 >> 413 if (!site_mod) >> 414 continue; >> 415 >> 416 *prev = site_mod->next; >> 417 kfree(site_mod); >> 418 } >> 419 } >> 420 >> 421 static int static_call_module_notify(struct notifier_block *nb, >> 422 unsigned long val, void *data) >> 423 { >> 424 struct module *mod = data; >> 425 int ret = 0; >> 426 >> 427 cpus_read_lock(); >> 428 static_call_lock(); >> 429 >> 430 switch (val) { >> 431 case MODULE_STATE_COMING: >> 432 ret = static_call_add_module(mod); >> 433 if (ret) { >> 434 WARN(1, "Failed to allocate memory for static calls"); >> 435 static_call_del_module(mod); >> 436 } >> 437 break; >> 438 case MODULE_STATE_GOING: >> 439 static_call_del_module(mod); >> 440 break; >> 441 } >> 442 >> 443 static_call_unlock(); >> 444 cpus_read_unlock(); >> 445 >> 446 return notifier_from_errno(ret); >> 447 } >> 448 >> 449 static struct notifier_block static_call_module_nb = { >> 450 .notifier_call = static_call_module_notify, >> 451 }; >> 452 >> 453 #else >> 454 >> 455 static inline int __static_call_mod_text_reserved(void *start, void *end) >> 456 { >> 457 return 0; >> 458 } >> 459 >> 460 #endif /* CONFIG_MODULES */ >> 461 >> 462 int static_call_text_reserved(void *start, void *end) >> 463 { >> 464 bool init = system_state < SYSTEM_RUNNING; >> 465 int ret = __static_call_text_reserved(__start_static_call_sites, >> 466 __stop_static_call_sites, start, end, init); >> 467 >> 468 if (ret) >> 469 return ret; >> 470 >> 471 return __static_call_mod_text_reserved(start, end); >> 472 } >> 473 >> 474 int __init static_call_init(void) >> 475 { >> 476 int ret; >> 477 >> 478 if (static_call_initialized) >> 479 return 0; >> 480 >> 481 cpus_read_lock(); >> 482 static_call_lock(); >> 483 ret = __static_call_init(NULL, __start_static_call_sites, >> 484 __stop_static_call_sites); >> 485 static_call_unlock(); >> 486 cpus_read_unlock(); >> 487 >> 488 if (ret) { >> 489 pr_err("Failed to allocate memory for static_call!\n"); >> 490 BUG(); >> 491 } >> 492 >> 493 static_call_initialized = true; >> 494 >> 495 #ifdef CONFIG_MODULES >> 496 register_module_notifier(&static_call_module_nb); >> 497 #endif >> 498 return 0; >> 499 } >> 500 early_initcall(static_call_init); 3 501 4 long __static_call_return0(void) 502 long __static_call_return0(void) 5 { 503 { 6 return 0; 504 return 0; 7 } 505 } 8 EXPORT_SYMBOL_GPL(__static_call_return0); !! 506 >> 507 #ifdef CONFIG_STATIC_CALL_SELFTEST >> 508 >> 509 static int func_a(int x) >> 510 { >> 511 return x+1; >> 512 } >> 513 >> 514 static int func_b(int x) >> 515 { >> 516 return x+2; >> 517 } >> 518 >> 519 DEFINE_STATIC_CALL(sc_selftest, func_a); >> 520 >> 521 static struct static_call_data { >> 522 int (*func)(int); >> 523 int val; >> 524 int expect; >> 525 } static_call_data [] __initdata = { >> 526 { NULL, 2, 3 }, >> 527 { func_b, 2, 4 }, >> 528 { func_a, 2, 3 } >> 529 }; >> 530 >> 531 static int __init test_static_call_init(void) >> 532 { >> 533 int i; >> 534 >> 535 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { >> 536 struct static_call_data *scd = &static_call_data[i]; >> 537 >> 538 if (scd->func) >> 539 static_call_update(sc_selftest, scd->func); >> 540 >> 541 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); >> 542 } >> 543 >> 544 return 0; >> 545 } >> 546 early_initcall(test_static_call_init); >> 547 >> 548 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 9 549
Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.