using System;
using System.Collections.Generic;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Text;
using UnityEngine;

namespace WebViewStream
{
    public class ServiceDiscovery : MonoBehaviour
    {
        private UdpClient udpClient;
        private Action<MdnsService> action;

        private string receivedIp;
        private string receivedPort;
        private string receivedPath;
        private string receivedHost;

        private IPAddress defaultIP;

        private const string multicastAddress = "224.0.0.251";
        private const int multicastPort = 5353;

        private Queue<MdnsService> serviceQueue = new Queue<MdnsService>();

        private IPAddress GetDefaultInterfaceIP()
        {
            foreach (NetworkInterface ni in NetworkInterface.GetAllNetworkInterfaces())
            {
                if (ni.OperationalStatus == OperationalStatus.Up)
                {
                    var ipProps = ni.GetIPProperties();
                    if (ipProps.GatewayAddresses.Count > 0)
                    {
                        foreach (UnicastIPAddressInformation ip in ipProps.UnicastAddresses)
                        {
                            if (ip.Address.AddressFamily == AddressFamily.InterNetwork)
                            {
                                return ip.Address;
                            }
                        }
                    }
                }
            }

            return null;
        }

        /// <summary>
        /// Starts listening for mDNS service announcements.
        /// </summary>
        /// <param name="action"></param>
        public void StartListening(Action<MdnsService> action)
        {
            try
            {
                defaultIP = GetDefaultInterfaceIP();
                if (defaultIP == null)
                {
                    Debug.LogError("No default interface found. Cannot start multicast listener.");
                    return;
                }

                udpClient = new UdpClient();
                udpClient.Client.SetSocketOption(
                    SocketOptionLevel.Socket,
                    SocketOptionName.ReuseAddress,
                    true
                );
                udpClient.Client.Bind(new IPEndPoint(defaultIP, multicastPort));
                udpClient.JoinMulticastGroup(IPAddress.Parse(multicastAddress), defaultIP);

                this.action = action;

                Debug.Log("Listening for service announcements...");

                SendMdnsQuery("_http._tcp.local");

                udpClient.BeginReceive(OnReceive, null);
            }
            catch (Exception ex)
            {
                Debug.LogError($"Error starting UDP listener: {ex.Message}");
            }
        }

        private void SendMdnsQuery(string serviceName)
        {
            byte[] query = CreateMdnsQuery(serviceName);
            Debug.Log($"Sending mDNS query for {serviceName}");

            udpClient.Send(
                query,
                query.Length,
                new IPEndPoint(IPAddress.Parse(multicastAddress), multicastPort)
            );
        }

        private byte[] CreateMdnsQuery(string serviceName)
        {
            ushort transactionId = 0;
            ushort flags = 0x0100;
            ushort questions = 1;
            byte[] header = new byte[12];
            Array.Copy(
                BitConverter.GetBytes((ushort)IPAddress.HostToNetworkOrder((short)transactionId)),
                0,
                header,
                0,
                2
            );
            Array.Copy(
                BitConverter.GetBytes((ushort)IPAddress.HostToNetworkOrder((short)flags)),
                0,
                header,
                2,
                2
            );
            Array.Copy(
                BitConverter.GetBytes((ushort)IPAddress.HostToNetworkOrder((short)questions)),
                0,
                header,
                4,
                2
            );

            byte[] name = EncodeName(serviceName);
            byte[] query = new byte[header.Length + name.Length + 4];
            Array.Copy(header, query, header.Length);
            Array.Copy(name, 0, query, header.Length, name.Length);

            query[query.Length - 4] = 0x00;
            query[query.Length - 3] = 0x0C;
            query[query.Length - 2] = 0x00;
            query[query.Length - 1] = 0x01;

            return query;
        }

        private byte[] EncodeName(string name)
        {
            string[] parts = name.Split('.');
            byte[] result = new byte[name.Length + 2];
            int offset = 0;

            foreach (string part in parts)
            {
                result[offset++] = (byte)part.Length;
                Array.Copy(Encoding.UTF8.GetBytes(part), 0, result, offset, part.Length);
                offset += part.Length;
            }

            result[offset] = 0;
            return result;
        }

        private void OnReceive(IAsyncResult result)
        {
            if (udpClient == null)
            {
                return;
            }

            try
            {
                IPEndPoint remoteEndPoint = new IPEndPoint(IPAddress.Any, multicastPort);
                byte[] receivedBytes = udpClient.EndReceive(result, ref remoteEndPoint);

                ushort flags = BitConverter.ToUInt16(new byte[] { receivedBytes[3], receivedBytes[2] }, 0);
                if (flags == 0x0100)
                {
                    udpClient?.BeginReceive(OnReceive, null);
                    return;
                }

                ParseMdnsResponse(receivedBytes);

                udpClient?.BeginReceive(OnReceive, null);
            }
            catch (Exception ex)
            {
                Debug.LogError($"Error receiving UDP message: {ex.Message}");
            }
        }

        private void AddMdnsService()
        {
            if (receivedIp != null && receivedPort != null && receivedHost != null && receivedPath != null)
            {
                MdnsService currentService = new MdnsService(
                    receivedIp,
                    int.Parse(receivedPort),
                    receivedPath,
                    receivedHost
                );
                serviceQueue.Enqueue(currentService);
                Debug.Log($"Added service: {currentService}");
                receivedIp = null;
                receivedPort = null;
                receivedPath = null;
                receivedHost = null;
            }
        }

        private void ParseMdnsResponse(byte[] data)
        {
            int offset = 12;
            ushort questions = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, 4));
            ushort answerRRs = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, 6));
            ushort additionalRRs = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, 10));

            for (int i = 0; i < questions; i++)
            {
                offset = SkipName(data, offset);
                offset += 4;
            }

            for (int i = 0; i < answerRRs; i++)
            {
                offset = ParseRecord(data, offset);
            }

            for (int i = 0; i < additionalRRs; i++)
            {
                offset = ParseRecord(data, offset);
                AddMdnsService();
            }
        }

        private int ParseRecord(byte[] data, int offset)
        {
            string name;
            (name, offset) = ReadName(data, offset);

            ushort recordType = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset));
            ushort recordClass = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset + 2));
            uint ttl = (uint)IPAddress.NetworkToHostOrder(BitConverter.ToInt32(data, offset + 4));
            ushort dataLength = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset + 8));
            offset += 10;

            if (ttl == 0)
            {
                Debug.LogWarning($"Zero TTL for {name}");
                return offset + dataLength;
            }

            if (recordType == 1) // A Record
            {
                IPAddress ipAddress = new IPAddress(
                    new ArraySegment<byte>(data, offset, dataLength).ToArray()
                );
                receivedIp = ipAddress.ToString();
                receivedHost = name;
            }
            else if (recordType == 12) // PTR Record
            {
                string target;
                (target, _) = ReadName(data, offset);
            }
            else if (recordType == 33) // SRV Record
            {
                ushort priority = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset));
                ushort weight = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset + 2));
                ushort port = (ushort)IPAddress.NetworkToHostOrder(BitConverter.ToInt16(data, offset + 4));
                string target;
                (target, _) = ReadName(data, offset + 6);
                receivedPort = port.ToString();
            }
            else if (recordType == 16) // TXT Record
            {
                string txtData = Encoding.UTF8.GetString(data, offset, dataLength);
                if (txtData.Contains("path"))
                {
                    receivedPath = txtData.Split('=')[1];
                }
            }
            else if (recordType == 47) // NSEC Record
            {
                // Debug.Log($"NSEC Record: {name}");
            }
            else
            {
                Debug.Log($"Unknown Record Type {recordType} for {name}");
            }

            return offset + dataLength;
        }

        private (string, int) ReadName(byte[] data, int offset)
        {
            StringBuilder name = new StringBuilder();
            int originalOffset = offset;
            bool jumped = false;

            while (data[offset] != 0)
            {
                if ((data[offset] & 0xC0) == 0xC0)
                {
                    if (!jumped)
                    {
                        originalOffset = offset + 2;
                    }
                    offset = (data[offset] & 0x3F) << 8 | data[offset + 1];
                    jumped = true;
                }
                else
                {
                    int length = data[offset++];
                    name.Append(Encoding.UTF8.GetString(data, offset, length) + ".");
                    offset += length;
                }
            }

            return (name.ToString().TrimEnd('.'), jumped ? originalOffset : offset + 1);
        }

        private int SkipName(byte[] data, int offset)
        {
            while (data[offset] != 0)
            {
                if ((data[offset] & 0xC0) == 0xC0)
                {
                    return offset + 2;
                }
                offset += data[offset] + 1;
            }
            return offset + 1;
        }

        private void Update()
        {
            if (serviceQueue.Count > 0)
            {
                Debug.Log($"Queued services: {serviceQueue.Count}");

                while (serviceQueue.Count > 0)
                {
                    MdnsService service = serviceQueue.Dequeue();
                    if (service == null)
                    {
                        continue;
                    }
                    Debug.Log($"Invoking action with: {service}");
                    action?.Invoke(service);
                }
            }
        }

        private void OnDestroy()
        {
            StopListening();
        }

        private void StopListening()
        {
            udpClient?.DropMulticastGroup(IPAddress.Parse(multicastAddress));
            udpClient?.Close();
            udpClient = null;
        }
    }
}