import sys
import os
import time
import serial
from datetime import datetime

# Constants
LOADADDR     = '0x70000000'
PART_BOOT0   = '0x00082800' # BOOT0 partition 5
PART_BOOT1   = '0x00282800' # BOOT1 partition 6 
PART_CONFIG  = '0x00682800' # config partition 8
ERASE_BLOCKS = '0x40000'
LOG_FILE = "flashtool.log"
PARTITION_MBT_PATH = "gpt-7part.img"
FIRMWARE_PATH = "smb_netonix.ext4.gz"
WS4_6_CONFG_TXT = "ws4_6_config.txt"
WS4_8_CONFG_TXT = "ws4_8_config.txt"
WS4_14_CONFG_TXT = "ws4_14_config.txt"


# YMODEM Protocol constants
SOH = b'\x01'
STX = b'\x02'
EOT = b'\x04'
ACK = b'\x06'
NAK = b'\x15'
CRC = b'C'
CAN = b'\x18'

class YModem:
    def __init__(self, getc, putc, debug=True):
        self.getc = getc
        self.putc = putc
        self.debug = debug

    def send(self, stream, filename, filesize, callback=None):
        if not self._wait_for_receiver_ready():
            raise Exception("Receiver never sent 'C'")

        filename = os.path.basename(filename)
        filename = filename.encode('ascii', errors='replace')
        size_bytes = str(filesize).encode('ascii')

        header = (filename + b'\x00' + size_bytes + b'\x00').ljust(128, b'\x00')
        if self.debug:
            print(f"[debug] Header block: {header}")
        self._send_packet(0, header)
        if not self._wait_ack():
            raise Exception("Header not ACKed")

        block_number = 1
        total_sent = 0
        while True:
            data = stream.read(1024)
            if not data:
                break
            if len(data) < 1024:
                data += b'\x1A' * (1024 - len(data))
            if self.debug:
                print(f"[*] Sending block {block_number} ({len(data)} bytes)")
            self._send_packet(block_number, data)
            if not self._wait_ack():
                raise Exception(f"Block {block_number} not ACKed")
            if callback:
                callback(filesize, total_sent, 0)
            block_number = (block_number + 1) % 256
            total_sent += len(data)

        if self.debug:
            print("[*] Sending EOT")
        self.putc(EOT)
        time.sleep(0.1)
        if not self._wait_ack():
            raise Exception("EOT not ACKed")

        if self.debug:
            print("[*] Sending final empty packet")
        self._send_packet(0, b'\x00' * 128)
        self._wait_ack()

    def _send_packet(self, blockno, data):
        assert len(data) in (128, 1024)
        packet_type = STX if len(data) == 1024 else SOH
        header = packet_type + bytes([blockno]) + bytes([255 - blockno])
        crc = self._crc16(data)
        trailer = bytes([crc >> 8]) + bytes([crc & 0xFF])
        self.putc(header + data + trailer)
        time.sleep(0.05)
        if self.debug:
            print(f"[debug] Sent block {blockno}, CRC={crc:04X}")

    def _wait_ack(self, timeout=2.0):
        start = time.time()
        while time.time() - start < timeout:
            c = self.getc(1)
            if c == ACK:
                return True
            elif c == NAK:
                return False
        return False

    def _wait_for_receiver_ready(self, timeout=10):
        print("[*] Waiting for receiver to send 'C'...")
        time.sleep(0.5)  # Allow banner to finish
        start = time.time()
        while time.time() - start < timeout:
            c = self.getc(1)
            if c == CRC:
                print("[✔] Received 'C' — receiver is ready.")
                return True
            elif c == CAN:
                raise Exception("Transfer cancelled by receiver.")
            elif c and self.debug:
                print(f"[debug] Unexpected byte from receiver: {c}")
            time.sleep(0.01)
        return False

    def _crc16(self, data):
        crc = 0
        for b in data:
            crc ^= b << 8
            for _ in range(8):
                crc = (crc << 1) ^ 0x1021 if crc & 0x8000 else (crc << 1)
            crc &= 0xFFFF
        return crc



# console configuration
def update_baud(ser, newbaud):
    try:
        port = ser.port
        ser.close()
        print(f"[*] Reopening {port} at {newbaud} baud...")
        new_ser = serial.Serial(port, baudrate=newbaud, timeout=1)
        new_ser.reset_input_buffer()
        new_ser.reset_output_buffer()

        time.sleep(1)  # Give the port and device time to settle
        print(f"[✔] Baud updated to {newbaud}")
        return new_ser
    except serial.SerialException as e:
        print(f"[!] Failed to reopen port at {newbaud} baud: {e}")
        raise


# Uboot interactions

def log_response(cmd, response):
    timestamp = datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
    with open(LOG_FILE, "a", encoding="utf-8") as f:
        f.write(f"{timestamp} SENT: {cmd.strip()}\n")
        f.write(f"{timestamp} RECEIVED:\n{response}\n")
        f.write("-" * 60 + "\n")

def send_and_read(ser, cmd, delay=0.3, read_timeout=1.5):
    ser.write((cmd + '\n').encode('utf-8'))
    ser.flush()
    time.sleep(delay)
    deadline = time.time() + read_timeout
    response = b''
    while time.time() < deadline:
        data = ser.read(ser.in_waiting or 1)
        if data:
            response += data
            deadline = time.time() + read_timeout
        else:
            time.sleep(0.05)

    decoded = response.decode(errors="ignore")
    log_response(cmd, decoded)
    return decoded

def wait_for_prompt(ser, prompt=b'=>', timeout=10):
    start = time.time()
    buffer = b''
    while time.time() - start < timeout:
        data = ser.read(ser.in_waiting or 1)
        if data:
            buffer += data
            if prompt in buffer:
                return True
        time.sleep(0.05)
    return False

def send_file_ymodem(ser, filepath):
    def getc(size, timeout=1):
        ser.timeout = timeout
        return ser.read(size) or None

    def putc(data, timeout=1):
        return ser.write(data)

    def progress(total, sent, err):
        percent = (sent / total) * 100
        print(f"\r[*] Progress: {percent:.2f}%", end='')

    print(f"[*] Sending {filepath} via YMODEM...")
    with open(filepath, 'rb') as stream:
        y = YModem(getc, putc, debug=False)  # Set to True to debug
        y.send(stream, filepath, os.path.getsize(filepath), callback=progress)
    print("\n[✔] YMODEM transfer complete.")

def write_partitions_corrected(ser):
    print("[*] Writing partitions...")    
    print("[*] Starting loady...")
    send_and_read(ser, f'loady $loadaddr 921600')
    print("[*] Switching baud to 921600...")
    ser = update_baud(ser, 921600)
    time.sleep(5)
    ser.write(b'\r\n')   # send <enter> to trigger transfer start on uboot side
    ser.flush()
    time.sleep(0.5)
    
    send_file_ymodem(ser, PARTITION_MBT_PATH)

    print("[*] Switching back to 115200...")
    ser = update_baud(ser, 115200)
    ser.write(b'\x1b')  # ESC to exit YMODEM mode
    ser.flush()
    time.sleep(0.5)

    send_and_read(ser, 'mmc write 0x70000000 0x0 0x22')
    send_and_read(ser, 'mmc rescan')

    return ser


def main():
    if len(sys.argv) != 3:
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    port = sys.argv[1]
    switch_model = sys.argv[2]
    file_path = FIRMWARE_PATH
    config_txt = ""

    # check if switch model is valid
    if switch_model == "WS4-14":
        config_txt = WS4_14_CONFG_TXT 
    elif switch_model == "WS4-8":
        config_txt = WS4_8_CONFG_TXT 
    elif switch_model == "WS4-6":
        config_txt = WS4_6_CONFG_TXT 
    else:
        print(f"Error: invalid switch_model {switch_model}")
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    if not os.path.exists(file_path):
        print("File does not exist:", file_path)
        sys.exit(1)

    if not os.path.exists(config_txt):
        print("File does not exist:", config_txt)
        sys.exit(1)

    try:
        print(f"[*] Connecting to {port} at 115200...")
        ser = serial.Serial(port, 115200, timeout=1)
        time.sleep(2)
        ser.reset_input_buffer()

        send_and_read(ser, '')
        wait_for_prompt(ser)

        # use mbt extracted from simulated nand flash raw to fix broken partitions 
        #ser = write_partitions_corrected(ser)

        print("[*] Starting loady...")
        send_and_read(ser, f'loady $loadaddr 921600')

        print("[*] Switching baud to 921600...")
        ser = update_baud(ser, 921600)
        time.sleep(5)
        ser.write(b'\r\n')   # send <enter> to trigger transfer start on uboot side
        ser.flush()
        time.sleep(0.5)
        
        send_file_ymodem(ser, file_path)

        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')  # ESC to exit YMODEM mode
        ser.flush()
        time.sleep(0.5)
        
        # send <esc> to trigger transfer complete on uboot
        
        print("[*] flashing firmware...")
        send_and_read(ser, f'unzip $loadaddr {LOADADDR}')
        send_and_read(ser, 'setexpr blkcnt $filesize + 0x1FF')
        send_and_read(ser, 'setexpr blkcnt $blkcnt / 0x200')
        send_and_read(ser, 'mmc dev 0')
        #send_and_read(ser, f'mmc erase {PART_BOOT0} {ERASE_BLOCKS}') # perhaps too aggressive
        send_and_read(ser, f'mmc write {LOADADDR} {PART_BOOT0} $blkcnt')
        send_and_read(ser, f'mmc write {LOADADDR} {PART_BOOT1} $blkcnt')

        print("[*] Starting loady...")
        send_and_read(ser, f'loady $loadaddr 921600')

        print("[*] Switching baud to 921600...")
        ser = update_baud(ser, 921600)
        time.sleep(5)
        ser.write(b'\r\n')   # send <enter> to trigger transfer start on uboot side
        ser.flush()
        time.sleep(0.5)
        
        send_file_ymodem(ser, config_txt)

        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')  # ESC to exit YMODEM mode
        ser.flush()
        time.sleep(0.5)

        print("[*] flashing config...")
        send_and_read(ser, 'mmc dev 0')
        #send_and_read(ser, f'mmc erase {PART_BOOT0} {ERASE_BLOCKS}') # perhaps too aggressive
        send_and_read(ser, f'mmc write $loadaddr {PART_CONFIG} 1')

        print("[✔] Flash complete.")
        time.sleep(1)
        send_and_read(ser, f'reset')
        print("[✔] System reset, please reconnect you terminal program to interact with the device.")

    except KeyboardInterrupt:
        print("\n[!] Interrupted by user.")
    finally:
        if 'ser' in locals() and ser.is_open:
            ser.close()

def just_send_file():
    if len(sys.argv) != 3:
        print("Usage: python flashtool.py COM4 lan969x.ext4.gz")
        sys.exit(1)

    port = sys.argv[1]
    file_path = sys.argv[2]

    if not os.path.exists(file_path):
        print("File does not exist:", file_path)
        sys.exit(1)
    
    try:
        print(f"[*] Connecting to {port} at 921600...")
        ser = serial.Serial(port, 921600, timeout=1)
        time.sleep(2)
        ser.reset_input_buffer()

        ser.write(b'\r\n')   # send <enter> to trigger transfer start on uboot side
        ser.flush()
        time.sleep(0.5)

        
        send_file_ymodem(ser, file_path)

        print("[✔] Transfer complete.")
        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')  # ESC to exit YMODEM mode
        ser.flush()

    except KeyboardInterrupt:
        print("\n[!] Interrupted by user.")
    finally:
        if 'ser' in locals() and ser.is_open:
            ser.close()


def only_flash_config():
    if len(sys.argv) != 3:
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    port = sys.argv[1]
    switch_model = sys.argv[2]
    config_txt = ""

    # check if switch model is valid
    if switch_model == "WS4-14":
        config_txt = WS4_14_CONFG_TXT 
    elif switch_model == "WS4-8":
        config_txt = WS4_8_CONFG_TXT 
    elif switch_model == "WS4-6":
        config_txt = WS4_6_CONFG_TXT 
    else:
        print(f"Error: invalid switch_model {switch_model}")
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    if not os.path.exists(config_txt):
        print("File does not exist:", config_txt)
        sys.exit(1)

    print(f"[*] Sending flash config for {switch_model}")

    try:
        print(f"[*] Connecting to {port} at 115200...")
        ser = serial.Serial(port, 115200, timeout=1)
        time.sleep(2)
        ser.reset_input_buffer()

        send_and_read(ser, '')
        wait_for_prompt(ser)

        # use mbt extracted from simulated nand flash raw to fix broken partitions 
        #ser = write_partitions_corrected(ser)

        print("[*] Starting loady...")
        send_and_read(ser, f'loady $loadaddr 921600')

        print("[*] Switching baud to 921600...")
        ser = update_baud(ser, 921600)
        time.sleep(5)
        ser.write(b'\r\n')   # send <enter> to trigger transfer start on uboot side
        ser.flush()
        time.sleep(0.5)
        
        send_file_ymodem(ser, config_txt)

        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')  # ESC to exit YMODEM mode
        ser.flush()
        time.sleep(0.5)
        
        # send <esc> to trigger transfer complete on uboot
        
        print("[*] Finishing pushing config...")
        send_and_read(ser, 'setexpr blkcnt $filesize + 0x1FF')
        send_and_read(ser, 'setexpr blkcnt $blkcnt / 0x200')
        send_and_read(ser, 'mmc dev 0')
        #send_and_read(ser, f'mmc erase {PART_BOOT0} {ERASE_BLOCKS}') # perhaps too aggressive
        send_and_read(ser, f'mmc write {LOADADDR} {PART_CONFIG} $blkcnt')

        print("[✔] Flash complete.")
        time.sleep(1)
        send_and_read(ser, f'reset')
        print("[✔] System reset, please reconnect you terminal program to interact with the device.")

    except KeyboardInterrupt:
        print("\n[!] Interrupted by user.")
    finally:
        if 'ser' in locals() and ser.is_open:
            ser.close()



def flash_firmware():
    if len(sys.argv) != 3:
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    port = sys.argv[1]
    switch_model = sys.argv[2]  # not actually used for firmware but kept for symmetry

    file_path = FIRMWARE_PATH
    if not os.path.exists(file_path):
        print("Firmware file does not exist:", file_path)
        sys.exit(1)

    try:
        print(f"[*] Connecting to {port} at 115200...")
        ser = serial.Serial(port, 115200, timeout=1)
        time.sleep(2)
        ser.reset_input_buffer()

        send_and_read(ser, '')
        wait_for_prompt(ser)

        print("[*] Starting loady for firmware...")
        send_and_read(ser, f'loady $loadaddr 921600')

        print("[*] Switching baud to 921600...")
        ser = update_baud(ser, 921600)
        time.sleep(5)
        ser.write(b'\r\n')
        ser.flush()
        time.sleep(0.5)

        send_file_ymodem(ser, file_path)

        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')
        ser.flush()
        time.sleep(0.5)

        print("[*] Flashing firmware...")
        send_and_read(ser, f'unzip $loadaddr {LOADADDR}')
        send_and_read(ser, 'setexpr blkcnt $filesize + 0x1FF')
        send_and_read(ser, 'setexpr blkcnt $blkcnt / 0x200')
        send_and_read(ser, 'mmc dev 0')
        send_and_read(ser, f'mmc write {LOADADDR} {PART_BOOT0} $blkcnt')
        send_and_read(ser, f'mmc write {LOADADDR} {PART_BOOT1} $blkcnt')

        print("[✔] Firmware flash complete.")
        time.sleep(1)
        send_and_read(ser, 'reset')

    except KeyboardInterrupt:
        print("\n[!] Interrupted by user.")
    finally:
        if 'ser' in locals() and ser.is_open:
            ser.close()

def flash_config():
    if len(sys.argv) != 3:
        print("Usage: python flashtool.py COM4 <WS4-14/WS4-8/WS4-6>")
        sys.exit(1)

    port = sys.argv[1]
    switch_model = sys.argv[2]

    model_configs = {
        "WS4-14": WS4_14_CONFG_TXT,
        "WS4-8":  WS4_8_CONFG_TXT,
        "WS4-6":  WS4_6_CONFG_TXT,
    }

    if switch_model not in model_configs:
        print(f"[!] Invalid switch model '{switch_model}'")
        sys.exit(1)

    config_txt = model_configs[switch_model]
    if not os.path.exists(config_txt):
        print("Config file does not exist:", config_txt)
        sys.exit(1)

    try:
        print(f"[*] Connecting to {port} at 115200...")
        ser = serial.Serial(port, 115200, timeout=1)
        time.sleep(2)
        ser.reset_input_buffer()

        send_and_read(ser, '')
        wait_for_prompt(ser)

        print("[*] Starting loady for config...")
        send_and_read(ser, f'loady $loadaddr 921600')

        print("[*] Switching baud to 921600...")
        ser = update_baud(ser, 921600)
        time.sleep(5)
        ser.write(b'\r\n')
        ser.flush()
        time.sleep(0.5)

        send_file_ymodem(ser, config_txt)

        print("[*] Switching back to 115200...")
        ser = update_baud(ser, 115200)
        ser.write(b'\x1b')
        ser.flush()
        time.sleep(0.5)

        print("[*] Flashing config...")
        #send_and_read(ser, 'setexpr blkcnt $filesize + 0x1FF')
        #send_and_read(ser, 'setexpr blkcnt $blkcnt / 0x200')
        send_and_read(ser, 'mmc dev 0')
        send_and_read(ser, f'mmc write $loadaddr {PART_CONFIG} 1')

        print("[✔] Config flash complete.")
        time.sleep(1)
        send_and_read(ser, 'reset')

    except KeyboardInterrupt:
        print("\n[!] Interrupted by user.")
    finally:
        if 'ser' in locals() and ser.is_open:
            ser.close()


if __name__ == '__main__':
    #main()
    #just_send_file() # This is just for debugging or if you want to 'emulate' sz/sb lrzsz commands, don't use it normally (only to help supplement dumb windows limitations)
    #flash_config()
    flash_firmware()
