package client

import (
	"crypto/rand"
	"encoding/json"
	"fmt"
	"log"
	"net"
	"time"

	"github.com/gorilla/websocket"
	"golang.org/x/crypto/curve25519"
	"golang.zx2c4.com/wireguard/wgctrl"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
	"tailscale-clone/internal/types"
)

// Client represents a client that connects to the controller
type Client struct {
	controllerAddr string
	nodeName       string
	interfaceName  string
	nodeID         string
	privateKey     string
	publicKey      string
	ip             net.IP
	conn           *websocket.Conn
	wgClient       *wgctrl.Client
	stopChan       chan struct{}
	// Subnet routing fields
	isSubnetRouter bool
	subnets        []string
	subnetRoutes   []types.SubnetRoute
}

// New creates a new client
func New(controllerAddr, nodeName, interfaceName string) (*Client, error) {
	// Generate WireGuard keys
	privateKey, publicKey, err := generateKeyPair()
	if err != nil {
		return nil, fmt.Errorf("failed to generate keys: %v", err)
	}

	// Create WireGuard client
	wgClient, err := wgctrl.New()
	if err != nil {
		return nil, fmt.Errorf("failed to create WireGuard client: %v", err)
	}

	return &Client{
		controllerAddr: controllerAddr,
		nodeName:       nodeName,
		interfaceName:  interfaceName,
		privateKey:     privateKey,
		publicKey:      publicKey,
		wgClient:       wgClient,
		stopChan:       make(chan struct{}),
	}, nil
}

// SetSubnetRouter configures the client as a subnet router
func (c *Client) SetSubnetRouter(subnets []string) {
	c.isSubnetRouter = true
	c.subnets = subnets
}

// Start starts the client
func (c *Client) Start() error {
	// Connect to controller
	if err := c.connect(); err != nil {
		return fmt.Errorf("failed to connect to controller: %v", err)
	}

	// Start message handling
	go c.handleMessages()

	// Start periodic updates
	go c.updateRoutine()

	return nil
}

// Stop stops the client
func (c *Client) Stop() {
	close(c.stopChan)
	if c.conn != nil {
		c.conn.Close()
	}
	if c.wgClient != nil {
		c.wgClient.Close()
	}
}

func (c *Client) connect() error {
	url := fmt.Sprintf("ws://%s/ws", c.controllerAddr)
	conn, _, err := websocket.DefaultDialer.Dial(url, nil)
	if err != nil {
		return err
	}

	c.conn = conn

	// Register with controller
	return c.register()
}

func (c *Client) register() error {
	req := types.RegisterRequest{
		NodeName:       c.nodeName,
		PublicKey:      c.publicKey,
		IsSubnetRouter: c.isSubnetRouter,
		Subnets:        c.subnets,
	}

	msg := types.Message{
		Type:    types.MessageTypeRegister,
		Payload: req,
	}

	data, err := json.Marshal(msg)
	if err != nil {
		return err
	}

	return c.conn.WriteMessage(websocket.TextMessage, data)
}

func (c *Client) handleMessages() {
	for {
		select {
		case <-c.stopChan:
			return
		default:
			_, message, err := c.conn.ReadMessage()
			if err != nil {
				log.Printf("WebSocket read error: %v", err)
				// Try to reconnect
				time.Sleep(5 * time.Second)
				if err := c.connect(); err != nil {
					log.Printf("Failed to reconnect: %v", err)
					continue
				}
				continue
			}

			c.handleMessage(message)
		}
	}
}

func (c *Client) handleMessage(message []byte) {
	var msg types.Message
	if err := json.Unmarshal(message, &msg); err != nil {
		log.Printf("Failed to unmarshal message: %v", err)
		return
	}

	switch msg.Type {
	case types.MessageTypeRegister:
		c.handleRegisterResponse(msg.Payload)
	case types.MessageTypePeers:
		c.handlePeersUpdate(msg.Payload)
	case types.MessageTypeSubnetUpdate:
		c.handleSubnetUpdate(msg.Payload)
	case types.MessageTypePong:
		// Heartbeat response, do nothing
	default:
		log.Printf("Unknown message type: %s", msg.Type)
	}
}

func (c *Client) handleRegisterResponse(payload interface{}) {
	var resp types.RegisterResponse
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &resp)

	c.nodeID = resp.NodeID
	c.ip = net.ParseIP(resp.IP)
	c.subnetRoutes = resp.SubnetRoutes

	log.Printf("Registered with controller: %s at %s", c.nodeID, c.ip)
	if c.isSubnetRouter {
		log.Printf("Subnet router advertising: %v", c.subnets)
	}

	// Setup WireGuard interface
	if err := c.setupWireGuard(resp.Network, resp.Peers); err != nil {
		log.Printf("Failed to setup WireGuard: %v", err)
	}

	// Setup subnet routing if this is a subnet router
	if c.isSubnetRouter {
		if err := c.setupSubnetRouting(); err != nil {
			log.Printf("Failed to setup subnet routing: %v", err)
		}
	}
}

func (c *Client) handlePeersUpdate(payload interface{}) {
	var peers []types.Peer
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &peers)

	if err := c.updateWireGuardPeers(peers); err != nil {
		log.Printf("Failed to update WireGuard peers: %v", err)
	}
}

func (c *Client) handleSubnetUpdate(payload interface{}) {
	var update types.SubnetUpdateRequest
	data, _ := json.Marshal(payload)
	json.Unmarshal(data, &update)

	c.subnetRoutes = update.SubnetRoutes
	log.Printf("Received subnet update: %d routes", len(update.SubnetRoutes))

	// Update routing table for subnet routes
	if err := c.updateSubnetRouting(); err != nil {
		log.Printf("Failed to update subnet routing: %v", err)
	}
}

func (c *Client) setupWireGuard(networkCIDR string, peers []types.Peer) error {
	// Create WireGuard interface
	if err := c.createInterface(); err != nil {
		return fmt.Errorf("failed to create interface: %v", err)
	}

	// Configure WireGuard interface
	if err := c.configureInterface(networkCIDR, peers); err != nil {
		return fmt.Errorf("failed to configure interface: %v", err)
	}

	log.Printf("WireGuard interface %s configured", c.interfaceName)
	return nil
}



func (c *Client) configureInterface(networkCIDR string, peers []types.Peer) error {
	return c.configureInterfacePlatform(networkCIDR, peers)
}

func (c *Client) updateWireGuardPeers(peers []types.Peer) error {
	var wgPeers []wgtypes.PeerConfig

	for _, peer := range peers {
		// Parse public key
		publicKeyBytes, err := hexToBytes(peer.PublicKey)
		if err != nil {
			log.Printf("Invalid peer public key: %v", err)
			continue
		}

		key, err := wgtypes.NewKey(publicKeyBytes)
		if err != nil {
			log.Printf("Failed to create peer key: %v", err)
			continue
		}

		// Parse allowed IPs
		var allowedIPs []net.IPNet
		for _, allowedIP := range peer.AllowedIPs {
			_, ipNet, err := net.ParseCIDR(allowedIP)
			if err != nil {
				log.Printf("Invalid allowed IP: %v", err)
				continue
			}
			allowedIPs = append(allowedIPs, *ipNet)
		}

		// Parse endpoint if provided
		var endpoint *net.UDPAddr
		if peer.Endpoint != "" {
			addr, err := net.ResolveUDPAddr("udp", peer.Endpoint)
			if err != nil {
				log.Printf("Invalid endpoint: %v", err)
			} else {
				endpoint = addr
			}
		}

		keepAlive := time.Duration(peer.KeepAlive) * time.Second

		wgPeer := wgtypes.PeerConfig{
			PublicKey:         key,
			AllowedIPs:        allowedIPs,
			Endpoint:          endpoint,
			PersistentKeepaliveInterval: &keepAlive,
		}

		wgPeers = append(wgPeers, wgPeer)
	}

	// Update device configuration
	cfg := wgtypes.Config{
		Peers: wgPeers,
	}

	if err := c.wgClient.ConfigureDevice(c.interfaceName, cfg); err != nil {
		return fmt.Errorf("failed to update peers: %v", err)
	}

	log.Printf("Updated %d peers", len(wgPeers))
	return nil
}

func (c *Client) updateRoutine() {
	ticker := time.NewTicker(30 * time.Second)
	defer ticker.Stop()

	for {
		select {
		case <-c.stopChan:
			return
		case <-ticker.C:
			c.sendUpdate()
		}
	}
}



func (c *Client) sendUpdate() {
	if c.nodeID == "" {
		return
	}

	req := types.UpdateRequest{
		NodeID: c.nodeID,
	}

	msg := types.Message{
		Type:    types.MessageTypeUpdate,
		Payload: req,
	}

	data, err := json.Marshal(msg)
	if err != nil {
		log.Printf("Failed to marshal update message: %v", err)
		return
	}

	if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
		log.Printf("Failed to send update: %v", err)
	}
}

func generateKeyPair() (string, string, error) {
	privateKey := make([]byte, curve25519.ScalarSize)
	if _, err := rand.Read(privateKey); err != nil {
		return "", "", err
	}

	publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
	if err != nil {
		return "", "", err
	}

	return fmt.Sprintf("%x", privateKey), fmt.Sprintf("%x", publicKey), nil
}

func hexToBytes(hexStr string) ([]byte, error) {
	if len(hexStr)%2 != 0 {
		return nil, fmt.Errorf("hex string must have even length")
	}

	bytes := make([]byte, len(hexStr)/2)
	for i := 0; i < len(hexStr); i += 2 {
		b1, err1 := hexToByte(hexStr[i])
		if err1 != nil {
			return nil, err1
		}
		b2, err2 := hexToByte(hexStr[i+1])
		if err2 != nil {
			return nil, err2
		}
		bytes[i/2] = b1<<4 | b2
	}
	return bytes, nil
}

func hexToByte(c byte) (byte, error) {
	switch {
	case '0' <= c && c <= '9':
		return c - '0', nil
	case 'a' <= c && c <= 'f':
		return c - 'a' + 10, nil
	case 'A' <= c && c <= 'F':
		return c - 'A' + 10, nil
	default:
		return 0, fmt.Errorf("invalid hex character: %c", c)
	}
} 