Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 93 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ const (
RRTypePTR = dns.TypePTR
RRTypeSRV = dns.TypeSRV
RRTypeTXT = dns.TypeTXT

// Message Response Codes, see https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml
RcodeSuccess = dns.RcodeSuccess // No Error
RcodeFormatError = dns.RcodeFormatError // Format Error
RcodeServerFailure = dns.RcodeServerFailure // Server Failure
RcodeNameError = dns.RcodeNameError // Non-Existent Domain
RcodeNotImplemented = dns.RcodeNotImplemented // Not Implemented
RcodeRefused = dns.RcodeRefused // Query Refused
RcodeYXDomain = dns.RcodeYXDomain // Name Exists when it should not
RcodeYXRrset = dns.RcodeYXRrset // RR Set Exists when it should not
RcodeNXRrset = dns.RcodeNXRrset // RR Set that should exist does not
RcodeNotAuth = dns.RcodeNotAuth // Server Not Authoritative for zone
RcodeNotZone = dns.RcodeNotZone // Name not contained in zone
RcodeBadSig = dns.RcodeBadSig // TSIG Signature Failure
RcodeBadVers = dns.RcodeBadVers // Bad OPT Version
RcodeBadKey = dns.RcodeBadKey // Key not recognized
RcodeBadTime = dns.RcodeBadTime // Signature out of time window
RcodeBadMode = dns.RcodeBadMode // Bad TKEY Mode
RcodeBadName = dns.RcodeBadName // Duplicate key name
RcodeBadAlg = dns.RcodeBadAlg // Algorithm not supported
RcodeBadTrunc = dns.RcodeBadTrunc // Bad Truncation
RcodeBadCookie = dns.RcodeBadCookie // Bad/missing Server Cookie
)

// Server is the wrapper that binds Resolver to the DNS server implementation
Expand All @@ -37,6 +59,7 @@ type Server struct {
stopped bool
tcpServ dns.Server
udpServ dns.Server
rcodes map[string]int

Log Logger
Authoritative bool
Expand Down Expand Up @@ -91,15 +114,17 @@ func NewServerWithLogger(zones map[string]Zone, l Logger, authoritative bool) (*
}

func (s *Server) writeErr(w dns.ResponseWriter, reply *dns.Msg, err error) {
reply.Rcode = dns.RcodeServerFailure
reply.RecursionAvailable = false
// A not found response may still include answers (e.g. CNAME chain).
//reply.Answer = nil
reply.Extra = nil

var dnsErr *net.DNSError
var (
rcode int
dnsErr *net.DNSError
)
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
reply.Rcode = dns.RcodeNameError
rcode = dns.RcodeNameError
reply.RecursionAvailable = true
reply.Ns = []dns.RR{
&dns.SOA{
Expand All @@ -119,9 +144,16 @@ func (s *Server) writeErr(w dns.ResponseWriter, reply *dns.Msg, err error) {
},
}
} else {
rcode = dns.RcodeServerFailure
s.Log.Printf("lookup error: %v", err)
}

// If the rcode has not been explicitly set to something other than
// Success, then set it to the error code.
if reply.Rcode == dns.RcodeSuccess {
reply.Rcode = rcode
}

w.WriteMsg(reply)
}

Expand Down Expand Up @@ -191,6 +223,16 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, m *dns.Msg) {
s.mu.RLock()
defer s.mu.RUnlock()

// Check if there is an explicit response code set for this query. If so,
// use it. If a response code is set for the name and type, it takes
// precedence over a response code set for the name only. This allows tests
// to simulate various error conditions.
if rcode := s.rcodes[rcodeKey(qname, q.Qtype)]; rcode != dns.RcodeSuccess {
reply.Rcode = rcode
} else if rcode := s.rcodes[qname]; rcode != dns.RcodeSuccess {
reply.Rcode = rcode
}

switch q.Qtype {
case dns.TypeA:
ad, cnames, rname, addrs, err := s.r.lookupA(context.Background(), qname)
Expand Down Expand Up @@ -470,9 +512,7 @@ func (s *Server) AppendRR(name string, rrType uint16, rrData string) error {
s.mu.Lock()
defer s.mu.Unlock()

if !strings.HasSuffix(name, ".") {
name = name + "."
}
name = dotifyName(name)
zone := s.r.Zones[name]
switch rrType {
case dns.TypeA:
Expand Down Expand Up @@ -540,15 +580,52 @@ func (s *Server) AppendRR(name string, rrType uint16, rrData string) error {
return nil
}

// SetResponseCode sets the response code for all queries to the given name.
// It is overridden by a response code set for the name and type. By setting
// Success, any previously set response code for the name is removed.
func (s *Server) SetResponseCode(name string, rcode int) {
s.mu.Lock()
defer s.mu.Unlock()

key := dotifyName(name)
s.setResponseCodeForKey(key, rcode)
}

// SetResponseCodeForType sets the response code for queries to the given name
// and type. It takes precedence over a response code set for the name only. By
// setting Success, any previously set response code for the name and type is
// removed.
func (s *Server) SetResponseCodeForType(name string, rrType uint16, rcode int) {
s.mu.Lock()
defer s.mu.Unlock()

key := rcodeKey(dotifyName(name), rrType)
s.setResponseCodeForKey(key, rcode)
}

func (s *Server) setResponseCodeForKey(key string, rcode int) {
if s.rcodes == nil {
s.rcodes = make(map[string]int)
}
if rcode == dns.RcodeSuccess {
delete(s.rcodes, key)
return
}
s.rcodes[key] = rcode
return
}

func rcodeKey(name string, rrType uint16) string {
return name + dns.Type(rrType).String()
}

// RemoveRR removes all records of the given type from the zone for the given
// name. If the zone becomes empty, it is removed. The function is thread-safe.
func (s *Server) RemoveRR(name string, rrType uint16) {
s.mu.Lock()
defer s.mu.Unlock()

if !strings.HasSuffix(name, ".") {
name = name + "."
}
name = dotifyName(name)
zone := s.r.Zones[name]
switch rrType {
case dns.TypeA:
Expand Down Expand Up @@ -583,6 +660,13 @@ func (s *Server) RemoveRR(name string, rrType uint16) {
s.r.Zones[name] = zone
}

func dotifyName(name string) string {
if !strings.HasSuffix(name, ".") {
return name + "."
}
return name
}

func isZoneEmpty(zone Zone) bool {
return len(zone.A) == 0 &&
len(zone.AAAA) == 0 &&
Expand Down
74 changes: 74 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package mockdns

import (
"context"
"errors"
"net"
"reflect"
"sort"
Expand Down Expand Up @@ -419,6 +420,79 @@ func TestServer_MutateRR(t *testing.T) {
})
}

func TestServer_ExplicitResponseCodes(t *testing.T) {
srv, err := NewServer(nil, false)
assertNoError(t, err)
defer srv.Close()

srv.AppendRR("foo.io.", RRTypeTXT, "one")
srv.AppendRR("foo.io.", RRTypeA, "1.2.3.4")
srv.AppendRR("bar.io.", RRTypeTXT, "two")

r := srv.NewResolver()
// Force server to return on error immediately instead of retrying with
// other names.
r.StrictErrors = true

// Make sure that all requests succeed.
_, err = r.LookupTXT(context.Background(), "foo.io")
assertNoError(t, err)

_, err = r.LookupHost(context.Background(), "foo.io")
assertNoError(t, err)

_, err = r.LookupTXT(context.Background(), "bar.io")
assertNoError(t, err)

// Set explicit response code for A records of foo.io. Make sure that it
// does not affect TXT records of foo.io. and any record of bar.io.
srv.SetResponseCodeForType("foo.io.", RRTypeA, RcodeServerFailure)

_, err = r.LookupTXT(context.Background(), "foo.io")
assertNoError(t, err)

_, err = r.LookupHost(context.Background(), "foo.io")
var dnsErr1 *net.DNSError
if !errors.As(err, &dnsErr1) || dnsErr1.Err != "server misbehaving" {
t.Fatalf("Expected error, got=%s", dnsErr1.Err)
}

_, err = r.LookupTXT(context.Background(), "bar.io")
assertNoError(t, err)

// Set explicit response code for A records of foo.io. back to success.
// Make sure that all requests succeed again.
srv.SetResponseCodeForType("foo.io.", RRTypeA, RcodeSuccess)

_, err = r.LookupTXT(context.Background(), "foo.io")
assertNoError(t, err)

_, err = r.LookupHost(context.Background(), "foo.io")
assertNoError(t, err)

_, err = r.LookupTXT(context.Background(), "bar.io")
assertNoError(t, err)

// Set explicit response code for the whole zone. Make sure that it affects
// all requests for foo.io. but not for bar.io.
srv.SetResponseCode("foo.io.", RcodeServerFailure)

_, err = r.LookupTXT(context.Background(), "foo.io")
var dnsErr2 *net.DNSError
if !errors.As(err, &dnsErr2) || dnsErr2.Err != "server misbehaving" {
t.Fatalf("Expected error, got=%s", dnsErr2.Err)
}

_, err = r.LookupHost(context.Background(), "foo.io")
var dnsErr3 *net.DNSError
if !errors.As(err, &dnsErr3) || dnsErr3.Err != "server misbehaving" {
t.Fatalf("Expected error, got=%s", dnsErr3.Err)
}

_, err = r.LookupTXT(context.Background(), "bar.io")
assertNoError(t, err)
}

func assertNoError(t *testing.T, err error) {
if err != nil {
t.Fatalf("Unexpected error: %v", err)
Expand Down