1 // SPDX-License-Identifier: GPL-2.0-only 1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 2 /* 3 * Copyright (C) 2024, SUSE LLC 3 * Copyright (C) 2024, SUSE LLC 4 * 4 * 5 * Authors: Enzo Matsumiya <ematsumiya@suse.de 5 * Authors: Enzo Matsumiya <ematsumiya@suse.de> 6 * 6 * 7 * This file implements I/O compression suppor 7 * This file implements I/O compression support for SMB2 messages (SMB 3.1.1 only). 8 * See compress/ for implementation details of 8 * See compress/ for implementation details of each algorithm. 9 * 9 * 10 * References: 10 * References: 11 * MS-SMB2 "3.1.4.4 Compressing the Message" 11 * MS-SMB2 "3.1.4.4 Compressing the Message" 12 * MS-SMB2 "3.1.5.3 Decompressing the Chained 12 * MS-SMB2 "3.1.5.3 Decompressing the Chained Message" 13 * MS-XCA - for details of the supported algor 13 * MS-XCA - for details of the supported algorithms 14 */ 14 */ 15 #include <linux/slab.h> 15 #include <linux/slab.h> 16 #include <linux/kernel.h> 16 #include <linux/kernel.h> 17 #include <linux/uio.h> 17 #include <linux/uio.h> 18 #include <linux/sort.h> 18 #include <linux/sort.h> 19 19 20 #include "cifsglob.h" 20 #include "cifsglob.h" 21 #include "../common/smb2pdu.h" 21 #include "../common/smb2pdu.h" 22 #include "cifsproto.h" 22 #include "cifsproto.h" 23 #include "smb2proto.h" 23 #include "smb2proto.h" 24 24 25 #include "compress/lz77.h" 25 #include "compress/lz77.h" 26 #include "compress.h" 26 #include "compress.h" 27 27 28 /* 28 /* 29 * The heuristic_*() functions below try to de 29 * The heuristic_*() functions below try to determine data compressibility. 30 * 30 * 31 * Derived from fs/btrfs/compression.c, changi 31 * Derived from fs/btrfs/compression.c, changing coding style, some parameters, and removing 32 * unused parts. 32 * unused parts. 33 * 33 * 34 * Read that file for better and more detailed 34 * Read that file for better and more detailed explanation of the calculations. 35 * 35 * 36 * The algorithms are ran in a collected sampl 36 * The algorithms are ran in a collected sample of the input (uncompressed) data. 37 * The sample is formed of 2K reads in PAGE_SI 37 * The sample is formed of 2K reads in PAGE_SIZE intervals, with a maximum size of 4M. 38 * 38 * 39 * Parsing the sample goes from "low-hanging f 39 * Parsing the sample goes from "low-hanging fruits" (fastest algorithms, likely compressible) 40 * to "need more analysis" (likely uncompressi 40 * to "need more analysis" (likely uncompressible). 41 */ 41 */ 42 42 43 struct bucket { 43 struct bucket { 44 unsigned int count; 44 unsigned int count; 45 }; 45 }; 46 46 47 /** 47 /** 48 * has_low_entropy() - Compute Shannon entropy 48 * has_low_entropy() - Compute Shannon entropy of the sampled data. 49 * @bkt: Bytes counts of the sample. 49 * @bkt: Bytes counts of the sample. 50 * @slen: Size of the sample. 50 * @slen: Size of the sample. 51 * 51 * 52 * Return: true if the level (percentage of nu 52 * Return: true if the level (percentage of number of bits that would be required to 53 * compress the data) is below the min 53 * compress the data) is below the minimum threshold. 54 * 54 * 55 * Note: 55 * Note: 56 * There _is_ an entropy level here that's > 6 56 * There _is_ an entropy level here that's > 65 (minimum threshold) that would indicate a 57 * possibility of compression, but compressing 57 * possibility of compression, but compressing, or even further analysing, it would waste so much 58 * resources that it's simply not worth it. 58 * resources that it's simply not worth it. 59 * 59 * 60 * Also Shannon entropy is the last computed h 60 * Also Shannon entropy is the last computed heuristic; if we got this far and ended up 61 * with uncertainty, just stay on the safe sid 61 * with uncertainty, just stay on the safe side and call it uncompressible. 62 */ 62 */ 63 static bool has_low_entropy(struct bucket *bkt 63 static bool has_low_entropy(struct bucket *bkt, size_t slen) 64 { 64 { 65 const size_t threshold = 65, max_entro 65 const size_t threshold = 65, max_entropy = 8 * ilog2(16); 66 size_t i, p, p2, len, sum = 0; 66 size_t i, p, p2, len, sum = 0; 67 67 68 #define pow4(n) (n * n * n * n) 68 #define pow4(n) (n * n * n * n) 69 len = ilog2(pow4(slen)); 69 len = ilog2(pow4(slen)); 70 70 71 for (i = 0; i < 256 && bkt[i].count > 71 for (i = 0; i < 256 && bkt[i].count > 0; i++) { 72 p = bkt[i].count; 72 p = bkt[i].count; 73 p2 = ilog2(pow4(p)); 73 p2 = ilog2(pow4(p)); 74 sum += p * (len - p2); 74 sum += p * (len - p2); 75 } 75 } 76 76 77 sum /= slen; 77 sum /= slen; 78 78 79 return ((sum * 100 / max_entropy) <= t 79 return ((sum * 100 / max_entropy) <= threshold); 80 } 80 } 81 81 82 #define BYTE_DIST_BAD 0 82 #define BYTE_DIST_BAD 0 83 #define BYTE_DIST_GOOD 1 83 #define BYTE_DIST_GOOD 1 84 #define BYTE_DIST_MAYBE 2 84 #define BYTE_DIST_MAYBE 2 85 /** 85 /** 86 * calc_byte_distribution() - Compute byte dis 86 * calc_byte_distribution() - Compute byte distribution on the sampled data. 87 * @bkt: Byte counts of the sample. 87 * @bkt: Byte counts of the sample. 88 * @slen: Size of the sample. 88 * @slen: Size of the sample. 89 * 89 * 90 * Return: 90 * Return: 91 * BYTE_DIST_BAD: A "hard no" for compre 91 * BYTE_DIST_BAD: A "hard no" for compression -- a computed uniform distribution of 92 * the bytes (e.g. random 92 * the bytes (e.g. random or encrypted data). 93 * BYTE_DIST_GOOD: High probability (norm 93 * BYTE_DIST_GOOD: High probability (normal (Gaussian) distribution) of the data being 94 * compressible. 94 * compressible. 95 * BYTE_DIST_MAYBE: When computed byte dis 95 * BYTE_DIST_MAYBE: When computed byte distribution resulted in "low > n < high" 96 * grounds. has_low_entr 96 * grounds. has_low_entropy() should be used for a final decision. 97 */ 97 */ 98 static int calc_byte_distribution(struct bucke 98 static int calc_byte_distribution(struct bucket *bkt, size_t slen) 99 { 99 { 100 const size_t low = 64, high = 200, thr 100 const size_t low = 64, high = 200, threshold = slen * 90 / 100; 101 size_t sum = 0; 101 size_t sum = 0; 102 int i; 102 int i; 103 103 104 for (i = 0; i < low; i++) 104 for (i = 0; i < low; i++) 105 sum += bkt[i].count; 105 sum += bkt[i].count; 106 106 107 if (sum > threshold) 107 if (sum > threshold) 108 return BYTE_DIST_BAD; 108 return BYTE_DIST_BAD; 109 109 110 for (; i < high && bkt[i].count > 0; i 110 for (; i < high && bkt[i].count > 0; i++) { 111 sum += bkt[i].count; 111 sum += bkt[i].count; 112 if (sum > threshold) 112 if (sum > threshold) 113 break; 113 break; 114 } 114 } 115 115 116 if (i <= low) 116 if (i <= low) 117 return BYTE_DIST_GOOD; 117 return BYTE_DIST_GOOD; 118 118 119 if (i >= high) 119 if (i >= high) 120 return BYTE_DIST_BAD; 120 return BYTE_DIST_BAD; 121 121 122 return BYTE_DIST_MAYBE; 122 return BYTE_DIST_MAYBE; 123 } 123 } 124 124 125 static bool is_mostly_ascii(const struct bucke 125 static bool is_mostly_ascii(const struct bucket *bkt) 126 { 126 { 127 size_t count = 0; 127 size_t count = 0; 128 int i; 128 int i; 129 129 130 for (i = 0; i < 256; i++) 130 for (i = 0; i < 256; i++) 131 if (bkt[i].count > 0) 131 if (bkt[i].count > 0) 132 /* Too many non-ASCII 132 /* Too many non-ASCII (0-63) bytes. */ 133 if (++count > 64) 133 if (++count > 64) 134 return false; 134 return false; 135 135 136 return true; 136 return true; 137 } 137 } 138 138 139 static bool has_repeated_data(const u8 *sample 139 static bool has_repeated_data(const u8 *sample, size_t len) 140 { 140 { 141 size_t s = len / 2; 141 size_t s = len / 2; 142 142 143 return (!memcmp(&sample[0], &sample[s] 143 return (!memcmp(&sample[0], &sample[s], s)); 144 } 144 } 145 145 146 static int cmp_bkt(const void *_a, const void 146 static int cmp_bkt(const void *_a, const void *_b) 147 { 147 { 148 const struct bucket *a = _a, *b = _b; 148 const struct bucket *a = _a, *b = _b; 149 149 150 /* Reverse sort. */ 150 /* Reverse sort. */ 151 if (a->count > b->count) 151 if (a->count > b->count) 152 return -1; 152 return -1; 153 153 154 return 1; 154 return 1; 155 } 155 } 156 156 157 /* 157 /* 158 * TODO: 158 * TODO: 159 * Support other iter types, if required. 159 * Support other iter types, if required. 160 * Only ITER_XARRAY is supported for now. 160 * Only ITER_XARRAY is supported for now. 161 */ 161 */ 162 static int collect_sample(const struct iov_ite 162 static int collect_sample(const struct iov_iter *iter, ssize_t max, u8 *sample) 163 { 163 { 164 struct folio *folios[16], *folio; 164 struct folio *folios[16], *folio; 165 unsigned int nr, i, j, npages; 165 unsigned int nr, i, j, npages; 166 loff_t start = iter->xarray_start + it 166 loff_t start = iter->xarray_start + iter->iov_offset; 167 pgoff_t last, index = start / PAGE_SIZ 167 pgoff_t last, index = start / PAGE_SIZE; 168 size_t len, off, foff; 168 size_t len, off, foff; 169 void *p; 169 void *p; 170 int s = 0; 170 int s = 0; 171 171 172 last = (start + max - 1) / PAGE_SIZE; 172 last = (start + max - 1) / PAGE_SIZE; 173 do { 173 do { 174 nr = xa_extract(iter->xarray, 174 nr = xa_extract(iter->xarray, (void **)folios, index, last, ARRAY_SIZE(folios), 175 XA_PRESENT); 175 XA_PRESENT); 176 if (nr == 0) 176 if (nr == 0) 177 return -EIO; 177 return -EIO; 178 178 179 for (i = 0; i < nr; i++) { 179 for (i = 0; i < nr; i++) { 180 folio = folios[i]; 180 folio = folios[i]; 181 npages = folio_nr_page 181 npages = folio_nr_pages(folio); 182 foff = start - folio_p 182 foff = start - folio_pos(folio); 183 off = foff % PAGE_SIZE 183 off = foff % PAGE_SIZE; 184 184 185 for (j = foff / PAGE_S 185 for (j = foff / PAGE_SIZE; j < npages; j++) { 186 size_t len2; 186 size_t len2; 187 187 188 len = min_t(si 188 len = min_t(size_t, max, PAGE_SIZE - off); 189 len2 = min_t(s 189 len2 = min_t(size_t, len, SZ_2K); 190 190 191 p = kmap_local 191 p = kmap_local_page(folio_page(folio, j)); 192 memcpy(&sample 192 memcpy(&sample[s], p, len2); 193 kunmap_local(p 193 kunmap_local(p); 194 194 195 s += len2; 195 s += len2; 196 196 197 if (len2 < SZ_ 197 if (len2 < SZ_2K || s >= max - SZ_2K) 198 return 198 return s; 199 199 200 max -= len; 200 max -= len; 201 if (max <= 0) 201 if (max <= 0) 202 return 202 return s; 203 203 204 start += len; 204 start += len; 205 off = 0; 205 off = 0; 206 index++; 206 index++; 207 } 207 } 208 } 208 } 209 } while (nr == ARRAY_SIZE(folios)); 209 } while (nr == ARRAY_SIZE(folios)); 210 210 211 return s; 211 return s; 212 } 212 } 213 213 214 /** 214 /** 215 * is_compressible() - Determines if a chunk o 215 * is_compressible() - Determines if a chunk of data is compressible. 216 * @data: Iterator containing uncompressed dat 216 * @data: Iterator containing uncompressed data. 217 * 217 * 218 * Return: true if @data is compressible, fals 218 * Return: true if @data is compressible, false otherwise. 219 * 219 * 220 * Tests shows that this function is quite rel 220 * Tests shows that this function is quite reliable in predicting data compressibility, 221 * matching close to 1:1 with the behaviour of 221 * matching close to 1:1 with the behaviour of LZ77 compression success and failures. 222 */ 222 */ 223 static bool is_compressible(const struct iov_i 223 static bool is_compressible(const struct iov_iter *data) 224 { 224 { 225 const size_t read_size = SZ_2K, bkt_si 225 const size_t read_size = SZ_2K, bkt_size = 256, max = SZ_4M; 226 struct bucket *bkt = NULL; 226 struct bucket *bkt = NULL; 227 size_t len; 227 size_t len; 228 u8 *sample; 228 u8 *sample; 229 bool ret = false; 229 bool ret = false; 230 int i; 230 int i; 231 231 232 /* Preventive double check -- already 232 /* Preventive double check -- already checked in should_compress(). */ 233 len = iov_iter_count(data); 233 len = iov_iter_count(data); 234 if (unlikely(len < read_size)) 234 if (unlikely(len < read_size)) 235 return ret; 235 return ret; 236 236 237 if (len - read_size > max) 237 if (len - read_size > max) 238 len = max; 238 len = max; 239 239 240 sample = kvzalloc(len, GFP_KERNEL); 240 sample = kvzalloc(len, GFP_KERNEL); 241 if (!sample) { 241 if (!sample) { 242 WARN_ON_ONCE(1); 242 WARN_ON_ONCE(1); 243 243 244 return ret; 244 return ret; 245 } 245 } 246 246 247 /* Sample 2K bytes per page of the unc 247 /* Sample 2K bytes per page of the uncompressed data. */ 248 i = collect_sample(data, len, sample); 248 i = collect_sample(data, len, sample); 249 if (i <= 0) { 249 if (i <= 0) { 250 WARN_ON_ONCE(1); 250 WARN_ON_ONCE(1); 251 251 252 goto out; 252 goto out; 253 } 253 } 254 254 255 len = i; 255 len = i; 256 ret = true; 256 ret = true; 257 257 258 if (has_repeated_data(sample, len)) 258 if (has_repeated_data(sample, len)) 259 goto out; 259 goto out; 260 260 261 bkt = kcalloc(bkt_size, sizeof(*bkt), 261 bkt = kcalloc(bkt_size, sizeof(*bkt), GFP_KERNEL); 262 if (!bkt) { 262 if (!bkt) { 263 WARN_ON_ONCE(1); 263 WARN_ON_ONCE(1); 264 ret = false; 264 ret = false; 265 265 266 goto out; 266 goto out; 267 } 267 } 268 268 269 for (i = 0; i < len; i++) 269 for (i = 0; i < len; i++) 270 bkt[sample[i]].count++; 270 bkt[sample[i]].count++; 271 271 272 if (is_mostly_ascii(bkt)) 272 if (is_mostly_ascii(bkt)) 273 goto out; 273 goto out; 274 274 275 /* Sort in descending order */ 275 /* Sort in descending order */ 276 sort(bkt, bkt_size, sizeof(*bkt), cmp_ 276 sort(bkt, bkt_size, sizeof(*bkt), cmp_bkt, NULL); 277 277 278 i = calc_byte_distribution(bkt, len); 278 i = calc_byte_distribution(bkt, len); 279 if (i != BYTE_DIST_MAYBE) { 279 if (i != BYTE_DIST_MAYBE) { 280 ret = !!i; 280 ret = !!i; 281 281 282 goto out; 282 goto out; 283 } 283 } 284 284 285 ret = has_low_entropy(bkt, len); 285 ret = has_low_entropy(bkt, len); 286 out: 286 out: 287 kvfree(sample); 287 kvfree(sample); 288 kfree(bkt); 288 kfree(bkt); 289 289 290 return ret; 290 return ret; 291 } 291 } 292 292 293 bool should_compress(const struct cifs_tcon *t 293 bool should_compress(const struct cifs_tcon *tcon, const struct smb_rqst *rq) 294 { 294 { 295 const struct smb2_hdr *shdr = rq->rq_i 295 const struct smb2_hdr *shdr = rq->rq_iov->iov_base; 296 296 297 if (unlikely(!tcon || !tcon->ses || !t 297 if (unlikely(!tcon || !tcon->ses || !tcon->ses->server)) 298 return false; 298 return false; 299 299 300 if (!tcon->ses->server->compression.en 300 if (!tcon->ses->server->compression.enabled) 301 return false; 301 return false; 302 302 303 if (!(tcon->share_flags & SMB2_SHAREFL 303 if (!(tcon->share_flags & SMB2_SHAREFLAG_COMPRESS_DATA)) 304 return false; 304 return false; 305 305 306 if (shdr->Command == SMB2_WRITE) { 306 if (shdr->Command == SMB2_WRITE) { 307 const struct smb2_write_req *w 307 const struct smb2_write_req *wreq = rq->rq_iov->iov_base; 308 308 309 if (le32_to_cpu(wreq->Length) 309 if (le32_to_cpu(wreq->Length) < SMB_COMPRESS_MIN_LEN) 310 return false; 310 return false; 311 311 312 return is_compressible(&rq->rq 312 return is_compressible(&rq->rq_iter); 313 } 313 } 314 314 315 return (shdr->Command == SMB2_READ); 315 return (shdr->Command == SMB2_READ); 316 } 316 } 317 317 318 int smb_compress(struct TCP_Server_Info *serve 318 int smb_compress(struct TCP_Server_Info *server, struct smb_rqst *rq, compress_send_fn send_fn) 319 { 319 { 320 struct iov_iter iter; 320 struct iov_iter iter; 321 u32 slen, dlen; 321 u32 slen, dlen; 322 void *src, *dst = NULL; 322 void *src, *dst = NULL; 323 int ret; 323 int ret; 324 324 325 if (!server || !rq || !rq->rq_iov || ! 325 if (!server || !rq || !rq->rq_iov || !rq->rq_iov->iov_base) 326 return -EINVAL; 326 return -EINVAL; 327 327 328 if (rq->rq_iov->iov_len != sizeof(stru 328 if (rq->rq_iov->iov_len != sizeof(struct smb2_write_req)) 329 return -EINVAL; 329 return -EINVAL; 330 330 331 slen = iov_iter_count(&rq->rq_iter); 331 slen = iov_iter_count(&rq->rq_iter); 332 src = kvzalloc(slen, GFP_KERNEL); 332 src = kvzalloc(slen, GFP_KERNEL); 333 if (!src) { 333 if (!src) { 334 ret = -ENOMEM; 334 ret = -ENOMEM; 335 goto err_free; 335 goto err_free; 336 } 336 } 337 337 338 /* Keep the original iter intact. */ 338 /* Keep the original iter intact. */ 339 iter = rq->rq_iter; 339 iter = rq->rq_iter; 340 340 341 if (!copy_from_iter_full(src, slen, &i 341 if (!copy_from_iter_full(src, slen, &iter)) { 342 ret = -EIO; 342 ret = -EIO; 343 goto err_free; 343 goto err_free; 344 } 344 } 345 345 346 /* 346 /* 347 * This is just overprovisioning, as t 347 * This is just overprovisioning, as the algorithm will error out if @dst reaches 7/8 348 * of @slen. 348 * of @slen. 349 */ 349 */ 350 dlen = slen; 350 dlen = slen; 351 dst = kvzalloc(dlen, GFP_KERNEL); 351 dst = kvzalloc(dlen, GFP_KERNEL); 352 if (!dst) { 352 if (!dst) { 353 ret = -ENOMEM; 353 ret = -ENOMEM; 354 goto err_free; 354 goto err_free; 355 } 355 } 356 356 357 ret = lz77_compress(src, slen, dst, &d 357 ret = lz77_compress(src, slen, dst, &dlen); 358 if (!ret) { 358 if (!ret) { 359 struct smb2_compression_hdr hd 359 struct smb2_compression_hdr hdr = { 0 }; 360 struct smb_rqst comp_rq = { .r 360 struct smb_rqst comp_rq = { .rq_nvec = 3, }; 361 struct kvec iov[3]; 361 struct kvec iov[3]; 362 362 363 hdr.ProtocolId = SMB2_COMPRESS 363 hdr.ProtocolId = SMB2_COMPRESSION_TRANSFORM_ID; 364 hdr.OriginalCompressedSegmentS 364 hdr.OriginalCompressedSegmentSize = cpu_to_le32(slen); 365 hdr.CompressionAlgorithm = SMB 365 hdr.CompressionAlgorithm = SMB3_COMPRESS_LZ77; 366 hdr.Flags = SMB2_COMPRESSION_F 366 hdr.Flags = SMB2_COMPRESSION_FLAG_NONE; 367 hdr.Offset = cpu_to_le32(rq->r 367 hdr.Offset = cpu_to_le32(rq->rq_iov[0].iov_len); 368 368 369 iov[0].iov_base = &hdr; 369 iov[0].iov_base = &hdr; 370 iov[0].iov_len = sizeof(hdr); 370 iov[0].iov_len = sizeof(hdr); 371 iov[1] = rq->rq_iov[0]; 371 iov[1] = rq->rq_iov[0]; 372 iov[2].iov_base = dst; 372 iov[2].iov_base = dst; 373 iov[2].iov_len = dlen; 373 iov[2].iov_len = dlen; 374 374 375 comp_rq.rq_iov = iov; 375 comp_rq.rq_iov = iov; 376 376 377 ret = send_fn(server, 1, &comp 377 ret = send_fn(server, 1, &comp_rq); 378 } else if (ret == -EMSGSIZE || dlen >= 378 } else if (ret == -EMSGSIZE || dlen >= slen) { 379 ret = send_fn(server, 1, rq); 379 ret = send_fn(server, 1, rq); 380 } 380 } 381 err_free: 381 err_free: 382 kvfree(dst); 382 kvfree(dst); 383 kvfree(src); 383 kvfree(src); 384 384 385 return ret; 385 return ret; 386 } 386 } 387 387
Linux® is a registered trademark of Linus Torvalds in the United States and other countries.
TOMOYO® is a registered trademark of NTT DATA CORPORATION.