use labrador_ldpc::{decoder::DecodeFrom, LDPCCode};

use crate::{buffer::Buffer, error::EncodeError, soft_bit::SoftBit};

macro_rules! enc_chunk {
    ($d:ident, $i:ident, $t:ident, $n:literal) => {
        ::paste::paste! {
        let mut codeword = [0; $n * 2];
        codeword[0..$n].copy_from_slice(&$d[$i..($i + $n)]);
        labrador_ldpc::LDPCCode::[<$t>].encode(&mut codeword);
            .map_err(|_| EncodeError::CatsOverflow)?;

            $i += $n;

macro_rules! dec_chunk {
    ($d:ident, $p:ident, $w:ident, $t:ident, $n:literal) => {
        ::paste::paste! {
            let code_data = &mut $d[..$n];
            let code_parity = &$p.get_mut(..$n)?;

            let mut input = [0; $n * 2];
            const CODE: LDPCCode = LDPCCode::[<$t>];
            let mut out = [0; CODE.output_len()];
            CODE.decode_bf(&input, &mut out, &mut $w[..CODE.decode_bf_working_len()], 16);

            $d = &mut $d[$n..];
            $p = &mut $p[$n..];

macro_rules! dec_chunk_soft {
    ($d:ident, $p:ident, $w:ident, $w_u8:ident, $out:ident, $t:ident, $n:literal) => {
        ::paste::paste! {
            let code_data = &mut $d[..$n];
            let code_parity = &$p.get_mut(..$n)?;

            let mut input = [T::zero(); $n * 2];
            const CODE: LDPCCode = LDPCCode::[<$t>];

			let mut out_tmp = [0; CODE.output_len()];
            CODE.decode_ms(&input, &mut out_tmp, &mut $w[..CODE.decode_ms_working_len()], &mut $w_u8[..CODE.decode_ms_working_u8_len()], 16);

            $d = &mut $d[$n..];
            $p = &mut $p[$n..];

// On failure this still modifies the data array!
pub(crate) fn encode<const N: usize>(data: &mut Buffer<N>) -> Result<(), EncodeError> {
    let mut i = 0;
    let n = data.len();

    loop {
        match n - i {
            512.. => {
                enc_chunk!(data, i, TM8192, 512);

            128.. => {
                enc_chunk!(data, i, TM2048, 128);

            32.. => {
                enc_chunk!(data, i, TC512, 32);

            16.. => {
                enc_chunk!(data, i, TC256, 16);

            8.. => {
                enc_chunk!(data, i, TC128, 8);

            0 => break,

            _ => {
                let mut codeword = [0xAA; 16];
                codeword[0..(n - i)].copy_from_slice(&data[i..n]);
                LDPCCode::TC128.encode(&mut codeword);
                    .map_err(|_| EncodeError::CatsOverflow)?;

                i = n;

        .map_err(|_| EncodeError::CatsOverflow)?;


pub(crate) fn decode<const N: usize>(data_av: &mut Buffer<N>) -> Option<()> {
    if data_av.len() < 2 {
        return None;

    let len = [data_av[data_av.len() - 2], data_av[data_av.len() - 1]];
    let len = u16::from_le_bytes(len).into();

    if len >= data_av.len() {
        return None;

    let (mut data, parity) = data_av.split_at_mut(len);
    let mut parity = parity.get_mut(..parity.len().checked_sub(2)?)?;

    let mut working = [0; LDPCCode::TM8192.decode_bf_working_len()];

    loop {
        match data.len() {
            512.. => {
                dec_chunk!(data, parity, working, TM8192, 512);

            128.. => {
                dec_chunk!(data, parity, working, TM2048, 128);

            32.. => {
                dec_chunk!(data, parity, working, TC512, 32);

            16.. => {
                dec_chunk!(data, parity, working, TC256, 16);

            8.. => {
                dec_chunk!(data, parity, working, TC128, 8);

            0 => break,

            _ => {
                let mut code_data = [0xAA; 8];
                let code_parity = &parity.get_mut(..8)?;

                let mut input = [0; 16];
                let mut out = [0; LDPCCode::TC128.output_len()];
                    &mut out,
                    &mut working[..LDPCCode::TC128.decode_bf_working_len()],

                data = &mut data[..0];
                parity = &mut parity[..0];

    // remove the parity data


pub(crate) fn decode_soft<const N: usize, const M: usize, T: DecodeFrom>(
    data_av: &mut Buffer<N, T>,
    out: &mut Buffer<M>,
) -> Option<()> {
    if data_av.len() % 8 != 0 {
        return None;

    if data_av.len() < 16 {
        return None;

    let len: usize = len_from_soft(data_av[(data_av.len() - 16)..].try_into().unwrap()).into();

    if len >= data_av.len() {
        return None;

    let (mut data, parity) = data_av.split_at_mut(len * 8);
    let mut parity = parity.get_mut(..parity.len().checked_sub(16)?)?;

    let mut working = [T::zero(); LDPCCode::TM8192.decode_ms_working_len()];
    let mut working_u8 = [0; LDPCCode::TM8192.decode_ms_working_u8_len()];

    loop {
        match data.len() {
            4096.. => {
                dec_chunk_soft!(data, parity, working, working_u8, out, TM8192, 4096);

            1024.. => {
                dec_chunk_soft!(data, parity, working, working_u8, out, TM2048, 1024);

            256.. => {
                dec_chunk_soft!(data, parity, working, working_u8, out, TC512, 256);

            128.. => {
                dec_chunk_soft!(data, parity, working, working_u8, out, TC256, 128);

            64.. => {
                dec_chunk_soft!(data, parity, working, working_u8, out, TC128, 64);

            0 => break,

            _ => {
                // Extra bits are padded with 0xAA
                // We need to tell the soft decoder that these bits can't have flipped
                // So we set 1 to -50 and 0 to +50.
                // There is probably a better way to do this
                let mut code_data = [
                    1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
                    1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
                    1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0,
                .map(|x| T::from_hard_bit(x > 0)); // TODO !!
                let code_parity = &parity.get_mut(..64)?;

                let mut input = [T::zero(); 128];
                let mut tmp_out = [0; LDPCCode::TC128.output_len()];
                    &mut tmp_out,
                    &mut working[..LDPCCode::TC128.decode_ms_working_len()],
                    &mut working_u8[..LDPCCode::TC128.decode_ms_working_u8_len()],
                out.try_extend_from_slice(&tmp_out[..(data.len() / 8)])

                data = &mut data[..0];
                parity = &mut parity[..0];


fn len_from_soft<T: DecodeFrom>(bits: &[T; 16]) -> u16 {
    let mut upper = 0;
    for b in &bits[0..8] {
        upper <<= 1;
        upper |= u8::from(b.hard_bit());

    let mut lower = 0;
    for b in &bits[8..] {
        lower <<= 1;
        lower |= u8::from(b.hard_bit());

    u16::from_le_bytes([upper, lower])

mod tests {
    use bitvec::{order::Msb0, view::BitView};

    use super::*;

    fn len_test() {
        // from the example in the docs
        let mut buf = [0; 8191];
        let mut data = Buffer::new_empty(&mut buf);

        for _ in 0..41 {
            data.extend(b"Example packet data  wueirpqwerwrywqoeiruy29346129384761");
        data.extend(b"Example packet data  wueirpqwerwrywqoeiru346129384761");

        assert_eq!(2349, data.len());

        encode(&mut data).unwrap();

        assert_eq!(4703, data.len());

    fn basic_encode_decode_short() {
        let mut buf = [0; 32];
        let mut buf2 = [0; 32];
        let mut data = Buffer::new_empty(&mut buf);
        data.try_extend_from_slice(b"Hello world!").unwrap();
        let orig = data.clone_backing(&mut buf2);

        encode(&mut data).unwrap();

        decode(&mut data).unwrap();

        assert_eq!(*orig, *data);

    fn basic_encode_decode() {
        let mut buf = [0; 8191];
        let mut buf2 = [0; 8191];
        let mut data = Buffer::new_empty(&mut buf);
        for _ in 0..50 {
            data.extend(b"This is a test packet. jsalksjd093809324JASLD:LKD*#$)(*#@)");
        let orig = data.clone_backing(&mut buf2);

        encode(&mut data).unwrap();
        assert_ne!(*orig, *data);

        decode(&mut data).unwrap();

        assert_eq!(*orig, *data);

    fn encode_decode_with_bit_flips() {
        let mut buf = [0; 8191];
        let mut buf2 = [0; 8191];
        let mut data = Buffer::new_empty(&mut buf);

        for _ in 0..50 {
            data.extend(b"jsalksjd093809324JASLD:LKD*#$)(*#@) Another test packet");
        let orig = data.clone_backing(&mut buf2);

        encode(&mut data).unwrap();
        assert_ne!(*orig, *data);

        data[234] ^= 0x55;
        data[0] ^= 0xAA;
        data[999] ^= 0x43;

        decode(&mut data).unwrap();

        assert_eq!(*orig, *data);

    fn basic_encode_decode_soft() {
        let mut buf = [0; 8191];
        let mut buf2 = [0; 8191];
        let mut data = Buffer::new_empty(&mut buf);
        for _ in 0..50 {
            data.extend(b"This is a test packet. jsalksjd093809324JASLD:LKD*#$)(*#@)");
        let orig = data.clone_backing(&mut buf2);

        encode(&mut data).unwrap();
        assert_ne!(*orig, *data);

        let mut soft = [0.0; 8191 * 8];
        let mut soft = Buffer::new_empty(&mut soft);
        for b in data.view_bits::<Msb0>() {

        let mut out = [0; 8191];
        let mut out = Buffer::new_empty(&mut out);
        decode_soft(&mut soft, &mut out).unwrap();

        assert_eq!(*orig, *out);