"""
PicoQuant Unified TTTR (PTU) - File Access Demo in Python

This is demo code. Use at your own risk. No warranties.

Tested with Python     3.6.4
            Numpy      1.14.0
            Matplotlib 2.1.2

Keno Goertz, PicoQuant GmbH, December 2018
"""

from enum import Enum
import struct
import sys
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np

class LinePos(Enum):
    """
    Position of photons
    """
    IN_LINE = 1    # After start and before stop marker
    IN_BETWEEN = 2 # other

class Tag:
    """
    Tag entry and some functions to read in the raw binary data
    """
    TAG_TYPES = {
        0xffff0008: "empty",
        0x00000008: "bool",
        0x10000008: "int",
        0x11000008: "bitset",
        0x12000008: "color",
        0x20000008: "float",
        0x21000008: "date_time",
        0x2001ffff: "float_array",
        0x4001ffff: "ansi_string",
        0x4002ffff: "wide_string",
        0xffffffff: "binary_blob"
    }

    def __init__(self, file_descriptor=None, value=None):
        self.ident = None
        self.idx = None
        self.typ = None
        self.value = value
        if file_descriptor:
            self.read_tag(file_descriptor)

    def read_tag(self, file_descriptor):
        """
        Read the tag directly from the PTU file
        """
        self.ident = struct.unpack(
            "32s", file_descriptor.read(32))[0].decode("utf-8").strip("\0")
        self.idx = struct.unpack("<i", file_descriptor.read(4))[0]
        tag_typ_value = struct.unpack("<I", file_descriptor.read(4))[0]
        self.typ = self.TAG_TYPES[tag_typ_value] \
                   if tag_typ_value in self.TAG_TYPES else None
        self.value = file_descriptor.read(8)
        self.handle_tag_types(file_descriptor)

    def handle_tag_types(self, file_descriptor):
        """
        Read the values in the proper type
        """
        if self.typ == "bool":
            value = struct.unpack("<q", self.value)[0]
            self.value = False if value == 0 else True
        elif self.typ == "float":
            self.value = struct.unpack("<d", self.value)[0]
        elif self.typ == "float_array":
            length = struct.unpack("<q", self.value)[0]
            self.value = list(
                struct.unpack("%d<d" % length, file_descriptor.read(length))
            )
        elif self.typ == "ansi_string":
            string_length = struct.unpack("<q", self.value)[0]
            self.value = struct.unpack(
                "%ds" % string_length, file_descriptor.read(string_length)
            )[0].decode("utf-8").strip("\0")
        elif self.typ == "wide_string":
            string_length = struct.unpack("<q", self.value)[0]
            self.value = struct.unpack(
                "%ds" % string_length, file_descriptor.read(string_length)
            )[0].decode("utf-16").strip("\0")
        elif self.typ == "binary_blob":
            length = struct.unpack("<q", self.value)[0]
            self.value = file_descriptor.read(length)
        else:
            self.value = struct.unpack("<q", self.value)[0]

class PTUFile:
    """
    PTU file reading class
    """
    def __init__(self, file_name: str):
        self.input_file = open(file_name, "rb")
        if not self.check_magic():
            print("[ERROR]: File is not a PTU file")
            raise Exception()
        self.image = None
        self.tag_dict = self.read_tags()

    def check_magic(self) -> bool:
        """
        Check whether the file is a PTU file
        """
        self.input_file.seek(0)
        magic = struct.unpack("8s", self.input_file.read(8))[0].decode("utf-8")
        if magic != "PQTTTR\0\0":
            return False
        return True

    def read_tags(self) -> Dict[str, Tag]:
        """
        Read all tags from file and store them in a list
        """
        self.input_file.seek(16)
        tag_dict = {}
        while True:
            tag = Tag(self.input_file)
            if tag.ident == "Header_End":
                break
            tag_dict[tag.ident] = tag
        # Special case for converted old SPT32 files, did not have this info
        if tag_dict["ImgHdr_Ident"].value == 1 and \
                ("ImgHdr_LineStart" not in tag_dict or
                 "ImgHdr_LineStop" not in tag_dict):
            tag_dict["ImgHdr_LineStart"] = Tag(value=3)
            tag_dict["ImgHdr_LineStop"] = Tag(value=2)
        self.image = np.zeros((tag_dict["ImgHdr_PixX"].value,
                               tag_dict["ImgHdr_PixY"].value))
        return tag_dict

class Record:
    """
    Interpret the recorded binary data
    """
    RECORD_TYPES = {
        0x00010303: "picoharp_t3",
        0x00010203: "picoharp_t2",
        0x00010304: "hydraharp_t3",
        0x00010204: "hydraharp_t2",
        0x01010304: "hydraharp2_t3",
        0x01010204: "hydraharp2_t2",
        0x00010305: "timeharp260n_t3",
        0x00010205: "timeharp260n_t2",
        0x00010306: "timeharp260p_t3",
        0x00010206: "timeharp260p_t2",
        0x00010307: "multiharp_t3",
        0x00010207: "multiharp_t2"
    }

    def __init__(self, ptu_file: PTUFile):
        self.ptu_file = ptu_file
        self.typ = self.RECORD_TYPES[
            self.ptu_file.tag_dict["TTResultFormat_TTTRRecType"].value
        ]
        self.overflow_correction = 0
        self.position = {
            "line_pos": LinePos.IN_BETWEEN,
            "start_pos": 0,
            "cur_line": 0
        }
        self.temp_cur = 0
        self.temp_photons_time_tags = np.zeros(10000)
        self.record = None

    def process_records(self):
        """
        Process all records in the parent PTU file
        """
        num_records = self.ptu_file.tag_dict["TTResult_NumberOfRecords"].value
        for num in range(num_records):
            if (num % (num_records // 1000)) == 0:
                sys.stdout.write("\rProgress: %3d%%"
                                 % (int(round(num / num_records * 100))))
                sys.stdout.flush()
            self.record = "{0:0{1}b}".format(
                struct.unpack("<I", self.ptu_file.input_file.read(4))[0], 32
            )
            self.process_record()

    def process_record(self):
        """
        Process the current record
        """
        if self.typ == "picoharp_t2":
            process_pht2(self)
        elif self.typ == "picoharp_t3":
            process_pht3(self)
        elif self.typ == "hydraharp_t2":
            process_hht2(self, 1)
        elif self.typ == "hydraharp_t3":
            process_hht3(self, 1)
        elif self.typ in ["multiharp_t2", "hydraharp2_t2", "timeharp260n_t2",
                          "timeharp260p_t2"]:
            process_hht2(self, 2)
        else:
            process_hht3(self, 2)

def got_marker(record, time_tag, markers):
    """
    Handle a marker event
    ATTENTION: Several markers can arrive in the same record, especially the
    stop line marker and frame marker often come together. The ordering of
    handling the markers is important.
    """

    # Check if the line stop arrived, only makes sense inside a line
    stop_marker = record.ptu_file.tag_dict["ImgHdr_LineStop"].value - 1
    if (((1 << stop_marker) & markers) != 0) and \
            (record.position["line_pos"] == LinePos.IN_LINE):
        # distance between start and end marker:
        dist = time_tag - record.position["start_pos"]
        pix_x = record.ptu_file.tag_dict["ImgHdr_PixX"].value
        pix_y = record.ptu_file.tag_dict["ImgHdr_PixY"].value
        scanner_ident = record.ptu_file.tag_dict["ImgHdr_Ident"].value
        # Marker is one pixel off for E710:
        size_pix = (dist / (pix_x + 1)) if (scanner_ident == 1) \
                   else (dist / pix_x)
        # Time tag of the next pixel:
        next_pix = record.position["start_pos"] + size_pix
        # Current pixel to fill (x-axis):
        pix_num = 0
        is_bidir = record.ptu_file.tag_dict["ImgHdr_BiDirect"].value

        for i in range(record.temp_cur):
            # Calculate the pixel position
            while record.temp_photons_time_tags[i] >= next_pix:
                pix_num += 1
                next_pix += size_pix

            # Make sure we're inside the image array
            if (record.position["cur_line"] >= pix_y) or (pix_num >= pix_x):
                break

            # Revert every 2nd line if bidirectional scan
            if (is_bidir and (record.position["cur_line"] % 2 == 1)):
                cur_pix = pix_x - pix_num - 1
            else:
                cur_pix = pix_num

            # We only count the photons in the pixel to get an intesity image
            # See got_pphoton docstring
            record.ptu_file.image[record.position["cur_line"], cur_pix] += 1

        # Line is finished -> Next line
        record.position["cur_line"] += 1
        # Line is ended -> We are outside the line
        record.position["line_pos"] = LinePos.IN_BETWEEN

    # Check if frame marker arrived
    frame_marker = record.ptu_file.tag_dict["ImgHdr_Frame"].value - 1 \
                   if "ImgHdr_Frame" in record.ptu_file.tag_dict else -1
    if (frame_marker >= 0) and (((1 << frame_marker) & markers) != 0):
        # A frame marker just restarts from the beginning
        record.position["cur_line"] = 0
        record.position["line_pos"] = LinePos.IN_BETWEEN

    # Check if start marker arrived, only makes sense outside a line
    start_marker = record.ptu_file.tag_dict["ImgHdr_LineStart"].value - 1
    if (((1 << start_marker) & markers) != 0) and \
            (record.position["line_pos"] == LinePos.IN_BETWEEN):
        record.position["line_pos"] = LinePos.IN_LINE
        record.temp_cur = 0
        record.position["start_pos"] = time_tag

def got_photon(record, time_tag):
    """
    Handle a photon event. For this example, we'll only be doing an intensity
    plot, so we only need to pass on the time tag of the photon. If you want
    to interpret other data of the record, you'd need to pass it on to this
    function and then do your calculations.
    """

    # Only care about photons after start and before the stop marker
    if record.position["line_pos"] == LinePos.IN_LINE:

        # Not enough space in the temp array
        if record.temp_cur >= len(record.temp_photons_time_tags):
            record.temp_photons_time_tags = np.append(
                record.temp_photons_time_tags, np.zeros(10000)
            )

        record.temp_photons_time_tags[record.temp_cur] = time_tag
        record.temp_cur += 1

def process_pht2(record: Record):
    """
    Process PicoHarp T2 record
    """
    t2_wraparound = 210698240
    channel = int(record.record[:4], base=2)
    time = int(record.record[4:], base=2)

    # Special record
    if channel == 0xf:
        # In special case, the lower four bits are marker bits
        markers = time & 0xf

        # Overflow record
        if markers == 0:
            record.overflow_correction += t2_wraparound
        else:
            true_time = record.overflow_correction + time
            got_marker(record, true_time, markers)
    else:
        true_time = record.overflow_correction + time
        got_photon(record, true_time)

def process_pht3(record: Record):
    """
    Process PicoHarp T3 record
    """
    t3_wraparound = 65536
    channel = int(record.record[:4], base=2)
    dtime = int(record.record[4:16], base=2)
    num_sync = int(record.record[16:], base=2)

    # Special record
    if channel == 0xf:
        markers = dtime

        # Not a marker means overflow
        if markers == 0:
            record.overflow_correction += t3_wraparound
        else:
            true_num_sync = record.overflow_correction + num_sync
            got_marker(record, true_num_sync, markers)
    else:
        true_num_sync = record.overflow_correction + num_sync
        got_photon(record, true_num_sync)

def process_hht2(record: Record, version: int):
    """
    Process HydraHarp T2 record
    """
    t2_wraparound_v1 = 33552000
    t2_wraparound_v2 = 33554432
    special = int(record.record[:1], base=2)
    channel = int(record.record[1:7], base=2)
    time_tag = int(record.record[7:], base=2)

    if special == 1:

        # Overflow record
        if channel == 0x3f:
            if version == 1:
                record.overflow_correction += t2_wraparound_v1
            # Version 2, but old style single overflow, shouldn't happen with
            # new firmware, for backwards-compatibility only
            elif time_tag == 0:
                record.overflow_correction += t2_wraparound_v2
            else:
                record.overflow_correction += t2_wraparound_v2 * time_tag

        # Markers
        if 1 <= channel <= 15:
            true_time = record.overflow_correction + time_tag
            got_marker(record, true_time, channel)

        # Sync
        if channel == 0:
            true_time = record.overflow_correction + time_tag
            got_photon(record, true_time)
    else:
        true_time = record.overflow_correction + time_tag
        got_photon(record, true_time)

def process_hht3(record: Record, version: int):
    """
    Process HydraHarp T3 record
    """
    t3_wraparound = 1024
    special = int(record.record[:1], base=2)
    channel = int(record.record[1:7], base=2)
    # Not needed for this demonstration, you could pass on dtime to your custom
    # got_photon() function (see got_photon docstring)
    #
    # dtime = int(record.record[7:22], base=2)
    num_sync = int(record.record[22:], base=2)

    if special == 1:

        # Overflow record
        if channel == 0x3f:

            # Old style single overflow
            if num_sync == 0 or version == 1:
                record.overflow_correction += t3_wraparound
            else:
                record.overflow_correction += t3_wraparound * num_sync

        # Markers
        if 1 <= channel <= 15:
            true_num_sync = record.overflow_correction + num_sync
            got_marker(record, true_num_sync, channel)
    else:
        true_num_sync = record.overflow_correction + num_sync
        got_photon(record, true_num_sync)

def main():
    """
    PicoQuant Unified TTTR (PTU) Image File Demo
    """
    if len(sys.argv) != 2:
        print("Usage: ./ptuimagedemo.py file.ptu")
        exit(0)
    print("PicoQuant Unified TTTR (PTU) Image File Demo\n"
          "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
    ptu_file = PTUFile(sys.argv[1])
    record = Record(ptu_file)
    record.process_records()
    plt.imshow(ptu_file.image)
    plt.show()

if __name__ == "__main__":
    main()
