Cryptographic shuffle

What if I needed to shuffle a list but couldn't hold the whole thing in memory? Or what if I didn't want to shuffle a list, but just traverse it in a shuffled manner? (That is, visit each element once and only once, in a randomized way.) What if I wanted to traverse it, but didn't want to precompute or store the traversal for some reason?

This would allow me to publish items from a list in an order that was unpredictable from the outside, but in fact deterministic and based on a secret key, and without precomputing anything (or worrying about collisions). Or I could use it to assign small non-sequential IDs that would eventually saturate the space of n-character strings in a pseudorandom order, obscuring the true size of the set for anyone who could just view some subset of the assigned IDs. They wouldn't even be able to tell if there were gaps in the list of IDs they could observe.

Essentially, what I'd want is a pseudorandom permutation of the indexes of the list. If the list had 1000 elements what I'd need would be a shuffled list of [0, 1, 2, ..., 998, 999]—or rather, a way to produce that list on demand, the same way every time. I first ran into this question about 15 years ago and have idly pondered it at various times since then, but didn't really have the tools to answer it. And then a few months ago I asked in ##crypto on Freenode IRC and Alipha came up with a nice solution: Encrypt the indexes.

Encrypting indexes

Specifically, what they suggested was using a Feistel cipher in CTR mode. I don't really know the math, but the important thing is that the Feistel cipher produces a keyed pseudorandom permutation. That is, given an encryption key, if you used it to encrypt (say) every possible 16-byte input, the outputs would also be a list of every possible 16-byte value—but in a different order, and one you couldn't predict without knowing the encryption key. Different key, different order. (There are other ciphers with this property, but Feistel ciphers allow you to encrypt data any number of bits in size, not just in n-byte blocks. That becomes important later.)

CTR mode usually refers to a way of encrypting a stream of data by producing a keystream from a counter, but here it's used a bit metaphorically. The indexes of a list are essentially a counter incrementing from 0 to n-1. For simplicity's sake, we'll assume n is a power of 2 (say 2^k). Then, every index can be represented in binary as a message k bits long. Encrypting the binary representation of an index value with a Feistel cipher and reading the resulting binary back as a number, we have a new index into the list. If you encrypt each index this way, the result is a shuffled list of indexes. Here's what that might look like for n=16:

  Indexes: [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]
Encrypted: [ 9,  4,  8,  1,  7, 11, 13, 10,  6, 14,  0,  3,  5, 15,  2, 12]

Those can then be used as indexes into the original list: Index 9, index 4, index 8, and so on. And of course you don't have to hold onto this list of encrypted indexes. They can be generated on the fly, as long as you have the key that was used to encrypt them. (They can also be decrypted using the same key.) But the key insight is that you can directly traverse the list in shuffled order. For example, if you were currently at index 8 in the list and wanted to know the next "random" location, you would first decrypt 8 to find 2 (go from the 8 in the bottom list to the matching position in the top list, a 2), increment 2 to 3, then encrypt (drop down again) from 3 to 1. So the next step in your traversal is from 8 to 1. Decrypt, increment, encrypt. (If incrementing would drop you off the end of the list, just wrap around.)

Generalizing to any length

The limitation to lists of length 2^k is a bit of pain, but as long as timing attacks are not of concern, it can be mitigated by iterating the encryption or decryption step until the output is back in range. Below is an example of using the same key but restricting the shuffle to a list of length 10. Let's say you were performing a shuffled traversal and were currently at index 5. Decrypting 5 (look at the earlier pair for reference) gives 12, which is not in the range [0..9]. So decrypt 12 to get 15, 15 to get 13, and 13 to get 6, which is finally in range. Increment to get 7, and then encrypt to get 10. That's out of range too, so encrypt again to get 0. The diagram may help illustrate the skipped out-of-range indexes:

                                         __
  Indexes: [ 0,  1,  2,  3,  4,  5,  6, /  \  7,  8,  9]
                                11* 13* |  | 10*     14*
                                    15* |  |  
                                    12* |  |  
Encrypted: [ 9,  4,  8,  1,  7,  3,  5, ^  v  0,  6,  2]

The upshot is that if you have a list of length n that isn't a power of 2, you can find the next power of 2 that's greater than n and use that instead. The downside is that you'll need to skip up to (almost) n elements in some situations, so it's not a constant-time traversal. (My vague understanding is that these "skip chains" will tend to be quite large, but... I haven't actually studied combinatorics. The Golomb–Dickman constant seems to be relevant here if you want to look into this.)

Demo

If you want to play around with this, here's a hacked-together Python implementation. The code illustrates mapping the list to a permutation as well as directly traversing the shuffled list. It has terrible performance since I'm not actually working with bits, but with bytes; rather than using the next 2^k, it has to use a much larger ceiling of 2^(k*16). This is horribly slow for the task, but it still runs in just a second or so. I also make no promises as to the quality of the results. (Seriously, this was all written from looking at a diagram on Wikipedia. I don't know this stuff.)

"""
Shuffle a list in a way that does not require holding the list in memory.

Function ``shuffle_indices`` generates indices from 0 to size (exclusive)
in a deterministic random order based on the key bytes. These indices can
then be used for lookups into a list of the given length.

This is probably not a cryptographically secure process. Specifically
note that a number of iterations have to be skipped when the list
is not 2^(4 + x) elements long, which may reveal information. But there
are probably lots of other problems too!
"""

import hashlib
import math
import random

def xor_bytes(xs, ys):
    return bytes(x ^ y for x, y in zip(xs, ys))


def round_fn(block_size, piece, key):
    # 16: 8 bits per byte, and 2 because half the block size
    return hashlib.shake_128(piece + key).digest(block_size // 16)


def encrypt(plain_int, block_size, round_keys):
    plain_bytes = plain_int.to_bytes(block_size, byteorder='big')[-block_size // 8:]

    def run_round(left_i, right_i, round_key):
        left_next = right_i
        round_results = round_fn(block_size, right_i, round_key)
        right_next = xor_bytes(left_i, round_results)
        return (left_next, right_next)

    left = plain_bytes[:block_size // 16]
    right = plain_bytes[block_size // 16:]
    for round_key in round_keys:
        left, right = run_round(left, right, round_key)

    return int.from_bytes(right + left, byteorder='big')


def decrypt(plain_int, block_size, round_keys):
    """Decryption is just encryption with reversed round keys."""
    return encrypt(plain_int, block_size, list(reversed(round_keys)))


def encrypt_until_within_range(plain_int, block_size, round_keys, list_size):
    """
    Iteratively encrypt until the output is in [0, list_size).
    """
    current = plain_int
    while True:
        current = encrypt(current, block_size, round_keys)
        if current < list_size:
            return current


def decrypt_until_within_range(plain_int, block_size, round_keys, list_size):
    """Decryption is just encryption with reversed round keys."""
    return encrypt_until_within_range(plain_int, block_size, list(reversed(round_keys)), list_size)


def block_size_for(size):
    """
    Compute the needed block size for a list of the given length.
    """
    # Need a number of bits n where 2^n >= size and n divisible by
    # sixteen (two bytes, one for each side of the block.) This
    # results in needing to skip a large number of iterations to get
    # to the next index. This could be reduced heavily with sub-byte
    # bit-twiddling, although n still needs to be divisible by 2
    # because of the two halves of the cipher.
    block_size = math.ceil(math.log2(size))
    block_size += -block_size % 16
    return block_size


def make_round_keys(block_size, key):
    """
    Given a block size and a key, yield a list of four round keys.
    """
    rounds = 4
    # Make one key piece for each round
    key_bits_per_round = block_size // 2
    key_bits = key_bits_per_round * rounds
    round_keys_raw = hashlib.shake_128(key).digest(key_bits)
    return [
        round_keys_raw[start:start+key_bits_per_round]
        for start in range(0, key_bits, key_bits_per_round)
    ]


def demo_reversible(size, start, key):
    block_size = block_size_for(size)
    round_keys = make_round_keys(block_size, key)

    print("Demo of encryption and decryption:")
    enc = encrypt(start, block_size, round_keys)
    dec = decrypt(enc, block_size, round_keys)
    print(f"Encrypted {start} to {enc}, then decrypted back to {dec}")


def demo_shuffle_indices(size, key):
    """
    Yield a indices in a permutation based on key.
    """
    block_size = block_size_for(size)
    round_keys = make_round_keys(block_size, key)

    for i in range(0, size):
        enc = encrypt_until_within_range(i, block_size, round_keys, size)
        print(f"{i} -> {enc}")


def shuffled_hop(size, start, key):
    """
    Given a starting index in the permuted list, what's the next index?
    """
    block_size = block_size_for(size)
    round_keys = make_round_keys(block_size, key)

    # Decrypt, increment, encrypt
    plain = decrypt_until_within_range(start, block_size, round_keys, size)
    return encrypt_until_within_range(plain + 1, block_size, round_keys, size)


def main():
    size = 42  # length of list we'll be traversing
    key = "some random seed".encode()

    demo_reversible(size, 17, key)
    print()

    print(f"To permute a list of length {size} this implementation needs a "
          f"block size of {block_size_for(size)} bits.")
    print()

    print(f"A permutation of a list of {size} elements:")
    demo_shuffle_indices(size, key)
    print()

    start = 10
    print(f"Starting at {start}, what's the next index in this traversal?")
    print(shuffled_hop(size, start, key))


if __name__== '__main__':
    main()

While I do have a possible application for this (hiding the size and number of gaps in a list), there are other ways to solve it, so I don't have much inclination to improve the code beyond this; my appetite for bit-twiddling is only so high. But I'd be super curious to hear from anyone who improves upon it or knows of anything like this being used for real!

Updates

2021-05-06: On lobste.rs, bitshift recommends looking into the swap-or-not shuffle algorithm, which apparently avoids some of the issues of the Feistel-based approach. I haven't read the paper yet, but I've already learned that "small-domain cipher" is a good keyword, and that there's a relationship here to format-preserving encryption.

2021-05-06 #2: I found a paper "BISON: Instantiating the Whitened Swap-Or-Not Construction" which seems to be the first concrete instance of swap-or-not, and apparently it is unavoidably slow!


No comments yet. Commenting is not yet reimplemented after the Wordpress migration, sorry! For now, you can email me and I can manually add comments. Feed icon