use labrador_ldpc::decoder::DecodeFrom;

const START_STATE: u16 = 0xE9CF;

pub(crate) fn whiten(data: &mut [u8]) {
    let mut state = START_STATE;

    for d in data.iter_mut() {
        let b;
        (b, state) = lfsr_byte(state);
        *d ^= b;
    }
}

// One bit per element
pub(crate) fn whiten_soft<T: DecodeFrom>(data: &mut [T]) {
    let mut state = START_STATE;

    for d in data.iter_mut() {
        let b = state & 1;
        if b > 0 {
            // flip the soft bit. In LLR speak, this just means flip the sign
            *d = -*d;
        }
        state = lfsr(state);
    }
}

// (byte, state)
fn lfsr_byte(mut state: u16) -> (u8, u16) {
    let mut out = 0;
    for i in (0..8).rev() {
        out |= u8::try_from(state & 1).unwrap() << i;
        state = lfsr(state);
    }

    (out, state)
}

// https://en.wikipedia.org/wiki/Linear-feedback_shift_register#Galois_LFSRs
fn lfsr(mut state: u16) -> u16 {
    let lsb = state & 1;
    state >>= 1;
    if lsb > 0 {
        state ^= 0xB400; // apply toggle mask
    }

    state
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::soft_bit::SoftBit;

    #[test]
    fn basic() {
        let mut data = [0; 64];
        data[0..57]
            .clone_from_slice(&b"Hello world! The quick brown fox jumped over the lazy dog"[..]);
        let orig = data;

        whiten(&mut data);
        assert_ne!(orig, data);

        whiten(&mut data);
        assert_eq!(orig, data);
    }

    #[test]
    fn basic_soft() {
        let mut data = [-2.3, 4.7, 0.0, 2.7, 7.8, 45.2, -0.1, -0.82];
        let orig = data;

        whiten_soft(&mut data);
        assert_eq!([2.3, -4.7, -0.0, -2.7, 7.8, 45.2, 0.1, 0.82], data);

        whiten_soft(&mut data);
        assert_eq!(orig, data);
    }

    #[test]
    fn compare_soft_and_hard() {
        let mut data = [0; 64];
        data[0..57]
            .clone_from_slice(&b"Hello world! The quick brown fox jumped over the lazy dog"[..]);
        let orig = data;

        whiten(&mut data);

        let mut soft = [0.0; 64 * 8];
        for (i, b) in data.iter().enumerate() {
            for j in 0..8 {
                soft[8 * i + j] = f32::from_hard_bit(b & (1 << (7 - j)) > 0);
            }
        }

        whiten_soft(&mut soft);

        assert_eq!(orig.len() * 8, soft.len());
        for (i, b) in orig.iter().enumerate() {
            for j in 0..8 {
                assert_eq!(soft[8 * i + j].hard_bit(), *b & (1 << (7 - j)) > 0);
            }
        }
    }

    #[test]
    fn test_lfsr() {
        let start = 0xACE1;
        let end_expected = 0xE270;

        let state = lfsr(start);

        assert_eq!(end_expected, state);
    }

    #[test]
    fn test_lfsr_byte() {
        let start = 0xE9CF;
        let (out, state) = lfsr_byte(start);
        assert_eq!(0xF3, out);
        assert_eq!(0xE3B1, state);
    }

    #[test]
    fn test_doc_example() {
        let start = 0xE9CF;
        let expected_out = [
            0xF3, 0x8D, 0xD0, 0x6E, 0x1F, 0x65, 0x75, 0x75, 0xA5, 0xBA, 0xA9, 0xD0, 0x7A, 0x1D,
            0x1, 0x21,
        ];

        let mut actual_out = [0; 16];
        let mut state = start;
        for a in &mut actual_out {
            let (out, ns) = lfsr_byte(state);
            state = ns;
            *a = out;
        }

        assert_eq!(expected_out, actual_out);
    }

    #[test]
    fn test_doc_example_through_whitener() {
        let expected_out = [
            0xF3, 0x8D, 0xD0, 0x6E, 0x1F, 0x65, 0x75, 0x75, 0xA5, 0xBA, 0xA9, 0xD0, 0x7A, 0x1D,
            0x1, 0x21,
        ];

        let mut actual_out = [0; 16];
        whiten(&mut actual_out);

        assert_eq!(expected_out, actual_out);
    }
}