#!/usr/bin/env python3

import sys
import signal
import re
import time

PRINT_BUF = 0x20000 / 4

def translator(line):
    if 'core-req-wr' in line:
        # Check rs1_data's last element condition
        rs1_data_start = line.find('addr={') + len('addr={')
        rs1_data_end = line.find('}', rs1_data_start)
        rs1_data_elts = line[rs1_data_start:rs1_data_end].split(', ')

        byteen_start = line.find('byteen={') + len('byteen={')
        byteen_end = line.find('}', byteen_start)
        byteen_elts = line[byteen_start:byteen_end].split(', ')

        rs2_data_start = line.find('data={') + len('data={')
        rs2_data_end = line.find('}', rs2_data_start)
        rs2_data_elts = line[rs2_data_start:rs2_data_end].split(', ')

        # print(rs1_data_last_element)
        for rs1, rs2, byteen in zip(rs1_data_elts, rs2_data_elts, byteen_elts):
            if int(rs1, 16) >> 18 == 0xff0:
                offset = (int(rs1, 16) - PRINT_BUF) % 65536
                if offset < 0 or offset >= 1024:
                    continue
                else:
                    offset = offset % 16384
                # Extract rs2_data's last element

                hex_value = rs2[2:]  # Remove the '0x' prefix
                if "x" in hex_value:
                    continue
                byteen_int = int(byteen, 16)
                hex_value = "0" * (8 - len(hex_value)) + hex_value
                bytes_object = bytes.fromhex(hex_value) # .replace(b"\x00", b"")

                masked_bytes_list = []
                    
                assert(len(bytes_object) == 4)
                
                for i, byte in enumerate(bytes_object[::-1]):
                    if byteen_int & (1 << i):
                        masked_bytes_list.append(byte)

                reversed_bytes = bytes(masked_bytes_list)
                # print(reversed_bytes.decode('utf-8', errors="ignore"))
                try:
                    return reversed_bytes.decode('ascii', errors="ignore")
                except UnicodeDecodeError:
                    return ""

def timestamp_parser(line):
    match = re.match(r"^\s*(\d+):", line)
    if match:
        return match.group(1)
    else:
        return ""

sim_started = False
sim_ended = False

def signal_handler(sig, frame):
    if sim_started:
        print("\033[B")
    sys.exit(0)

def main():
    signal.signal(signal.SIGINT, signal_handler)

    curr_timestamp = -1
    prev_timestamp = -1
    curr_clock = time.time()
    prev_clock = time.time()
    re_start_num = re.compile(r"^\s*[0-9]+:")

    ts_countdown = 100
    global sim_started, sim_ended
    perf_counters = False
    hang_detector = 0

    # if (len(sys.argv) > 1) and (sys.argv[1] == "started"):
    #     sim_started = True
    run_label = sys.argv[1]
    lineno = int(sys.argv[2])

    print("\033[2J\033[H")

    for line in sys.stdin:
        line = line.rstrip('\n')

        if "Chronologic VCS simulator" in line:
            sim_started = True

        sim_nontrace = re.match(re_start_num, line) is None

        if "====================CORE" in line:
            perf_counters = True
            if hang_detector >= 8:
                pass
                # print("\n\033[3mpossible hang detected\033[0m\n")

        if "has no more active warps" in line:
            sim_ended = True

        if (not sim_started):
            print(line)
            continue
        
        if sim_ended:
            if "has no more active warps" not in line:
                sim_ended = False
            else:
                continue
        
        if sim_started and (not sim_ended):
            if sim_nontrace:
                if not perf_counters:
                    print(line)
                elif line.startswith("dcache stores:"):
                    perf_counters = False
                    hang_detector += 1
                continue
            else:
                hang_detector = 0

        if ts_countdown == 0:
            match = re.match(r"^\s*(\d+):", line)
            if match:
                prev_clock = curr_clock
                prev_timestamp = curr_timestamp
                curr_clock = time.time()
                curr_timestamp = match.group(1)
                speed_in_hz = (int(curr_timestamp) - int(prev_timestamp)) / (curr_clock - prev_clock) if curr_clock - prev_clock > 0 else 0

                # Save cursor position
                print("\033[s", end='')
                # Move cursor to top-left corner, clear line, bold, timestamp, unbold
                print("\033[" + str(lineno) + "H\033[2K\033[1m" + run_label,
                      "[TIME]", curr_timestamp, "[SPEED]", int(speed_in_hz), "\033[0m", end='')
                # Restore cursor position
                print("\033[u", end='', flush=True)

                ts_countdown = 200
            else:
                ts_countdown = 10

        ts_countdown -= 1
        translated_line = translator(line)

        if translated_line:
            print(translated_line, end='')
            sys.stdout.flush()

    print("")

if __name__ == '__main__':
    main()

