diff --git a/README.md b/README.md index 121f7f9..9f851d9 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,10 @@ Given the following configuration: { "host": "127.0.0.1", "port": 53, - "default_upstream": "1.1.1.1", + "default_upstream": [ + "1.1.1.1", + "1.0.0.1" + ], "internal": [ { "regex": "mail.example.com", diff --git a/config.json.example b/config.json.example index 8c74396..91bca6c 100644 --- a/config.json.example +++ b/config.json.example @@ -7,7 +7,10 @@ { "host": "127.0.0.1", "port": 53, - "default_upstream": "1.1.1.1", + "default_upstream": [ + "1.1.1.1", + "1.0.0.1" + ], "internal": [ { "regex": "mail.example.com", diff --git a/internal/config/router_config.go b/internal/config/router_config.go index 5f9e6c8..664415d 100644 --- a/internal/config/router_config.go +++ b/internal/config/router_config.go @@ -15,7 +15,7 @@ import ( ) // DefaultUpstream is a system wide default -const DefaultUpstream = "1.1.1.1" +var DefaultUpstream = []string{"1.1.1.1"} // ServerConfig ... type ServerConfig struct { @@ -37,7 +37,7 @@ type RouterConfig struct { Port int `json:"port"` Upstreams []UpstreamConfig `json:"upstreams"` InternalRecords []InternalRecordConfig `json:"internal"` - DefaultUpstream string `json:"default_upstream"` + DefaultUpstream []string `json:"default_upstream"` } // LogConfig is self explanatatory diff --git a/internal/server/handler.go b/internal/server/handler.go index 09ce824..a7db552 100644 --- a/internal/server/handler.go +++ b/internal/server/handler.go @@ -78,31 +78,38 @@ func (h *DNSHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { msg.Answer = internalAnswer } else { // use upstream next - upstreamHost := getDNSServerFromLookup(h.RouterConf, domain) - logger.Debug("[%d] DNSLookup %s %s -> %s", h.ServerIndex, domain, getRecordTypeString(msg.Question[0].Qtype), upstreamHost) - - if upstreamHost == "nxdomain" { - // Return nxdomain asap - msg.SetRcode(r, dns.RcodeNameError) - } else { - // Forward to the determined upstream dns server - m := new(dns.Msg) - m.SetQuestion(dns.Fqdn(domain), msg.Question[0].Qtype) - m.RecursionDesired = true - - upstreamResponse, _, err := c.Exchange(m, net.JoinHostPort(upstreamHost, "53")) - if upstreamResponse == nil { - logger.Error("UpstreamError", err) - return - } + upstreamHosts := getDNSServerFromLookup(h.RouterConf, domain) + + for upstreamHostIndex, upstreamHost := range upstreamHosts { + logger.Debug("[%d] DNSLookup %s %s -> %s", h.ServerIndex, domain, getRecordTypeString(msg.Question[0].Qtype), upstreamHost) - if upstreamResponse.Rcode != dns.RcodeSuccess { - msg.SetRcode(r, upstreamResponse.Rcode) + if upstreamHost == "nxdomain" { + // Return nxdomain asap + msg.SetRcode(r, dns.RcodeNameError) } else { - msg.Answer = upstreamResponse.Answer - // Cache it - if memCache != nil { - memCache.Set(cacheKey, upstreamResponse.Answer, cache.DefaultExpiration) + // Forward to the determined upstream dns server + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), msg.Question[0].Qtype) + m.RecursionDesired = true + + upstreamResponse, _, err := c.Exchange(m, net.JoinHostPort(upstreamHost, "53")) + if upstreamResponse == nil { + logger.Error("UpstreamError", err) + if (len(upstreamHosts) - 1) > upstreamHostIndex { + logger.Debug("[%d] DNSLookupRetry %s -> %s", upstreamHost, upstreamHosts[upstreamHostIndex+1]) + continue + } + return + } + + if upstreamResponse.Rcode != dns.RcodeSuccess { + msg.SetRcode(r, upstreamResponse.Rcode) + } else { + msg.Answer = upstreamResponse.Answer + // Cache it + if memCache != nil { + memCache.Set(cacheKey, upstreamResponse.Answer, cache.DefaultExpiration) + } } } } @@ -114,20 +121,22 @@ func (h *DNSHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } } -func getDNSServerFromLookup(conf config.RouterConfig, domain string) string { - dnsServer := conf.DefaultUpstream +func getDNSServerFromLookup(conf config.RouterConfig, domain string) []string { + var dnsServer []string if len(conf.Upstreams) > 0 { for _, upstream := range conf.Upstreams { if found := upstream.CompiledRegex.MatchString(domain); found { if upstream.NXDomain { - dnsServer = "nxdomain" + dnsServer = []string{"nxdomain"} } else { - dnsServer = upstream.DNSServer + dnsServer = []string{upstream.DNSServer} } break } } + } else { + dnsServer = conf.DefaultUpstream } return dnsServer diff --git a/scripts/lint.sh b/scripts/lint.sh index 0ddcf3b..27d599e 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -7,7 +7,7 @@ cd "$PROJECT_DIR" if ! command -v golangci-lint &>/dev/null; then echo -e "${YELLOW}Installing golangci-lint ...${RESET}" - go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest fi if ! command -v govulncheck &>/dev/null; then