package rules

import (
	"fmt"
	"net/netip"
	"strconv"
	"strings"

	"github.com/AdguardTeam/golibs/errors"
	"github.com/AdguardTeam/golibs/netutil"
	"github.com/miekg/dns"
)

// RCode is a semantic alias for int when used as a DNS response code RCODE.
type RCode = int

// RRType is a semantic alias for uint16 when used as a DNS resource record (RR)
// type.
type RRType = uint16

// RRValue is the value of a resource record.  Depending on the [RRType], it
// will have different types:
//   - [netip.Addr] for [dns.TypeA] and [dns.TypeAAAA];
//   - non-nil [*DNSMX] for [dns.TypeMX];
//   - string for [dns.TypePTR] (it's also a valid FQDN);
//   - string for [dns.TypeTXT];
//   - non-nil [*DNSSVCB] for [dns.TypeHTTPS] and [dns.TypeSVCB];
//   - non-nil [*DNSSRV] for [dns.TypeSRV];
//   - nil otherwise, but new types may be added in the future.
type RRValue = any

// DNSRewrite is a DNS rewrite ($dnsrewrite) rule.
type DNSRewrite struct {
	// Value is the value for the record.  See [RRValue] documentation for more
	// details.
	Value RRValue

	// NewCNAME is the new CNAME.  If set, clients must ignore other fields,
	// resolve the CNAME, and set the new records accordingly.
	NewCNAME string

	// RCode is the new DNS RCODE.
	RCode RCode

	// RRType is the new DNS resource record (RR) type.  It is only non-zero
	// if RCode is dns.RCodeSuccess.
	RRType RRType
}

// parseDNSRewrite parses the $dnsrewrite modifier.
func parseDNSRewrite(s string) (rewrite *DNSRewrite, err error) {
	parts := strings.SplitN(s, ";", 3)
	switch len(parts) {
	case 1:
		return loadDNSRewriteShort(s)
	case 2:
		return nil, errors.Error("invalid dnsrewrite: expected zero or two delimiters")
	case 3:
		return loadDNSRewriteNormal(parts[0], parts[1], parts[2])
	default:
		// TODO(a.garipov): Use panic("unreachable") instead?
		return nil, fmt.Errorf("SplitN returned %d parts", len(parts))
	}
}

// allUppercaseASCII returns true if s is not empty and all characters in s are
// uppercase ASCII letters.
func allUppercaseASCII(s string) (ok bool) {
	if s == "" {
		return false
	}

	for _, r := range s {
		if r < 'A' || r > 'Z' {
			return false
		}
	}

	return true
}

// loadDNSRewritesShort loads the shorthand version of the $dnsrewrite modifier.
func loadDNSRewriteShort(s string) (rewrite *DNSRewrite, err error) {
	if s == "" {
		// Return an empty DNSRewrite, because an empty string most probably
		// means that this is a disabling allowlist case.
		return &DNSRewrite{}, nil
	}

	if allUppercaseASCII(s) {
		switch s {
		case
			"NOERROR",
			"SERVFAIL",
			"NXDOMAIN",
			"REFUSED":
			return &DNSRewrite{
				RCode: dns.StringToRcode[s],
			}, nil
		default:
			return nil, fmt.Errorf("unknown keyword: %q", s)
		}
	}

	if netutil.IsValidIPString(s) {
		ip := netip.MustParseAddr(s)
		if ip.Is4() {
			return &DNSRewrite{
				RCode:  dns.RcodeSuccess,
				RRType: dns.TypeA,
				Value:  ip,
			}, nil
		}

		return &DNSRewrite{
			RCode:  dns.RcodeSuccess,
			RRType: dns.TypeAAAA,
			Value:  ip,
		}, nil
	}

	err = netutil.ValidateHostname(s)
	if err != nil {
		return nil, fmt.Errorf("invalid shorthand hostname %q: %w", s, err)
	}

	return &DNSRewrite{
		NewCNAME: s,
	}, nil
}

// DNSMX is the type of RRValue values returned for MX records in DNS rewrites.
type DNSMX struct {
	Exchange   string
	Preference uint16
}

// DNSSRV is the type of RRValue values returned for SRV records in DNS rewrites.
type DNSSRV struct {
	Target   string
	Priority uint16
	Weight   uint16
	Port     uint16
}

// DNSSVCB is the type of RRValue values returned for HTTPS and SVCB records in
// dns rewrites.
//
// See https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-02.
type DNSSVCB struct {
	Params   map[string]string
	Target   string
	Priority uint16
}

// dnsRewriteRRHandler is a function that parses values for specific resource
// record types.
//
// TODO(e.burkov):  Since these functions are used here only as values of
// [dnsRewriteRRHandlers] map, which has a key of [RRType], the rr parameter
// appears useless in most cases except for [svcbDNSRewriteRRHandler].
type dnsRewriteRRHandler func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error)

// cnameDNSRewriteRRHandler is a DNS rewrite handler that parses full-form CNAME
// rewrites.
func cnameDNSRewriteRRHandler(_ RCode, _ RRType, valStr string) (dnsr *DNSRewrite, err error) {
	err = netutil.ValidateHostname(valStr)
	if err != nil {
		return nil, fmt.Errorf("invalid cname host: %w", err)
	}

	return &DNSRewrite{
		NewCNAME: valStr,
	}, nil
}

// ptrDNSRewriteRRHandler is a DNS rewrite handler that parses PTR rewrites.
func ptrDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
	// Accept both vanilla domain names and FQDNs.
	var fqdn string
	if l := len(valStr); l > 0 && valStr[l-1] == '.' {
		fqdn = valStr
		valStr = valStr[:l-1]
	} else {
		fqdn = dns.Fqdn(valStr)
	}

	err = netutil.ValidateHostname(valStr)
	if err != nil {
		return nil, fmt.Errorf("invalid ptr host: %w", err)
	}

	return &DNSRewrite{
		RCode:  rcode,
		RRType: rr,
		Value:  fqdn,
	}, nil
}

// strDNSRewriteRRHandler is a simple DNS rewrite handler that returns
// a *DNSRewrite with Value st to valStr.
func strDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
	return &DNSRewrite{
		RCode:  rcode,
		RRType: rr,
		Value:  valStr,
	}, nil
}

// srvDNSRewriteRRHandler is a DNS rewrite handler that parses SRV rewrites.
func srvDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
	fields := strings.Split(valStr, " ")
	if len(fields) < 4 {
		return nil, fmt.Errorf("invalid srv %q: need four fields", valStr)
	}

	var prio64 uint64
	prio64, err = strconv.ParseUint(fields[0], 10, 16)
	if err != nil {
		return nil, fmt.Errorf("invalid srv priority: %w", err)
	}

	var weight64 uint64
	weight64, err = strconv.ParseUint(fields[1], 10, 16)
	if err != nil {
		return nil, fmt.Errorf("invalid srv weight: %w", err)
	}

	var port64 uint64
	port64, err = strconv.ParseUint(fields[2], 10, 16)
	if err != nil {
		return nil, fmt.Errorf("invalid srv port: %w", err)
	}

	target := fields[3]

	// From RFC 2782:
	//
	//   A Target of "." means that the service is decidedly not available
	//   at this domain.
	//
	if target != "." {
		err = netutil.ValidateHostname(target)
		if err != nil {
			return nil, fmt.Errorf("invalid srv target: %w", err)
		}
	}

	v := &DNSSRV{
		Target:   target,
		Priority: uint16(prio64),
		Weight:   uint16(weight64),
		Port:     uint16(port64),
	}

	dnsr = &DNSRewrite{
		RCode:  rcode,
		RRType: rr,
		Value:  v,
	}

	return dnsr, nil
}

// svcbDNSRewriteRRHandler is a DNS rewrite handler that parses SVCB and HTTPS
// rewrites.
//
// See https://tools.ietf.org/html/draft-ietf-dnsop-svcb-https-02.
//
// TODO(a.garipov): Currently, we only support the contiguous type of
// char-string values from the RFC.
func svcbDNSRewriteRRHandler(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
	var name string
	switch rr {
	case dns.TypeHTTPS:
		name = "https"
	case dns.TypeSVCB:
		name = "svcb"
	default:
		return nil, fmt.Errorf("unsupported svcb-like rr type: %d", rr)
	}

	fields := strings.Split(valStr, " ")
	if len(fields) < 2 {
		return nil, fmt.Errorf("invalid %s %q: need at least two fields", name, valStr)
	}

	var prio64 uint64
	prio64, err = strconv.ParseUint(fields[0], 10, 16)
	if err != nil {
		return nil, fmt.Errorf("invalid %s priority: %w", name, err)
	}

	target := fields[1]

	// From the IETF draft:
	//
	//   If TargetName has the value "." (represented in the wire format as
	//   a zero-length label), special rules apply.
	//
	if target != "." {
		err = netutil.ValidateHostname(target)
		if err != nil {
			return nil, fmt.Errorf("invalid %s target: %w", name, err)
		}
	}

	if len(fields) == 2 {
		v := &DNSSVCB{
			Priority: uint16(prio64),
			Target:   target,
		}

		return &DNSRewrite{
			RCode:  rcode,
			RRType: rr,
			Value:  v,
		}, nil
	}

	params := make(map[string]string, len(fields)-2)
	for i, pair := range fields[2:] {
		kv := strings.Split(pair, "=")
		if l := len(kv); l != 2 {
			err = fmt.Errorf("invalid %s param at index %d: got %d fields", name, i, l)

			return nil, err
		}

		// TODO(a.garipov): Validate for uniqueness?  Validate against
		// the currently specified list of params from the RFC?
		params[kv[0]] = kv[1]
	}

	v := &DNSSVCB{
		Priority: uint16(prio64),
		Target:   target,
		Params:   params,
	}

	return &DNSRewrite{
		RCode:  rcode,
		RRType: rr,
		Value:  v,
	}, nil
}

// dnsRewriteRRHandlers are the supported resource record types' rewrite
// handlers.
var dnsRewriteRRHandlers = map[RRType]dnsRewriteRRHandler{
	dns.TypeA: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
		var ip netip.Addr
		if !netutil.IsValidIPString(valStr) {
			return nil, fmt.Errorf("%q is not a valid ipv4", valStr)
		} else if ip, err = netip.ParseAddr(valStr); err != nil {
			// Don't wrap the error since it's informative enough as is.
			return nil, err
		} else if !ip.Is4() {
			return nil, fmt.Errorf("%q is not a valid ipv4", valStr)
		}

		return &DNSRewrite{
			RCode:  rcode,
			RRType: rr,
			Value:  ip,
		}, nil
	},

	dns.TypeAAAA: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
		var ip netip.Addr
		if !netutil.IsValidIPString(valStr) {
			return nil, fmt.Errorf("%q is not a valid ipv6", valStr)
		} else if ip, err = netip.ParseAddr(valStr); err != nil {
			// Don't wrap the error since it's informative enough as is.
			return nil, err
		} else if !ip.Is6() {
			return nil, fmt.Errorf("%q is an ipv4, not an ipv6", valStr)
		}

		return &DNSRewrite{
			RCode:  rcode,
			RRType: rr,
			Value:  ip,
		}, nil
	},

	dns.TypeCNAME: cnameDNSRewriteRRHandler,

	dns.TypeMX: func(rcode RCode, rr RRType, valStr string) (dnsr *DNSRewrite, err error) {
		parts := strings.SplitN(valStr, " ", 2)
		if len(parts) != 2 {
			return nil, fmt.Errorf("invalid mx: %q", valStr)
		}

		var pref64 uint64
		pref64, err = strconv.ParseUint(parts[0], 10, 16)
		if err != nil {
			return nil, fmt.Errorf("invalid mx preference: %w", err)
		}

		exch := parts[1]
		err = netutil.ValidateHostname(exch)
		if err != nil {
			return nil, fmt.Errorf("invalid mx exchange: %w", err)
		}

		v := &DNSMX{
			Exchange:   exch,
			Preference: uint16(pref64),
		}

		return &DNSRewrite{
			RCode:  rcode,
			RRType: rr,
			Value:  v,
		}, nil
	},

	dns.TypePTR: ptrDNSRewriteRRHandler,

	dns.TypeTXT: strDNSRewriteRRHandler,

	dns.TypeHTTPS: svcbDNSRewriteRRHandler,
	dns.TypeSVCB:  svcbDNSRewriteRRHandler,

	dns.TypeSRV: srvDNSRewriteRRHandler,
}

// loadDNSRewritesNormal loads the normal version for of the $dnsrewrite
// modifier.
func loadDNSRewriteNormal(rcodeStr, rrStr, valStr string) (rewrite *DNSRewrite, err error) {
	rcode, ok := dns.StringToRcode[strings.ToUpper(rcodeStr)]
	if !ok {
		return nil, fmt.Errorf("unknown rcode: %q", rcodeStr)
	}

	if rcode != dns.RcodeSuccess || (rrStr == "" && valStr == "") {
		return &DNSRewrite{
			RCode: rcode,
		}, nil
	}

	var rr RRType
	rr, err = strToRRType(rrStr)
	if err != nil {
		return nil, err
	}

	var handler dnsRewriteRRHandler
	handler, ok = dnsRewriteRRHandlers[rr]
	if !ok {
		return &DNSRewrite{
			RCode:  rcode,
			RRType: rr,
		}, nil
	}

	return handler(rcode, rr, valStr)
}
