From b6ce5053fc9a292cd024dd1d53ef50e385873a14 Mon Sep 17 00:00:00 2001 From: Maxim Vladimirskiy Date: Fri, 27 Feb 2026 20:14:14 +0300 Subject: [PATCH] Allow setting response code --- server.go | 102 ++++++++++++++++++++++++++++++++++++++++++++----- server_test.go | 74 +++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 9 deletions(-) diff --git a/server.go b/server.go index 791bb86..86c289a 100644 --- a/server.go +++ b/server.go @@ -26,6 +26,28 @@ const ( RRTypePTR = dns.TypePTR RRTypeSRV = dns.TypeSRV RRTypeTXT = dns.TypeTXT + + // Message Response Codes, see https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml + RcodeSuccess = dns.RcodeSuccess // No Error + RcodeFormatError = dns.RcodeFormatError // Format Error + RcodeServerFailure = dns.RcodeServerFailure // Server Failure + RcodeNameError = dns.RcodeNameError // Non-Existent Domain + RcodeNotImplemented = dns.RcodeNotImplemented // Not Implemented + RcodeRefused = dns.RcodeRefused // Query Refused + RcodeYXDomain = dns.RcodeYXDomain // Name Exists when it should not + RcodeYXRrset = dns.RcodeYXRrset // RR Set Exists when it should not + RcodeNXRrset = dns.RcodeNXRrset // RR Set that should exist does not + RcodeNotAuth = dns.RcodeNotAuth // Server Not Authoritative for zone + RcodeNotZone = dns.RcodeNotZone // Name not contained in zone + RcodeBadSig = dns.RcodeBadSig // TSIG Signature Failure + RcodeBadVers = dns.RcodeBadVers // Bad OPT Version + RcodeBadKey = dns.RcodeBadKey // Key not recognized + RcodeBadTime = dns.RcodeBadTime // Signature out of time window + RcodeBadMode = dns.RcodeBadMode // Bad TKEY Mode + RcodeBadName = dns.RcodeBadName // Duplicate key name + RcodeBadAlg = dns.RcodeBadAlg // Algorithm not supported + RcodeBadTrunc = dns.RcodeBadTrunc // Bad Truncation + RcodeBadCookie = dns.RcodeBadCookie // Bad/missing Server Cookie ) // Server is the wrapper that binds Resolver to the DNS server implementation @@ -37,6 +59,7 @@ type Server struct { stopped bool tcpServ dns.Server udpServ dns.Server + rcodes map[string]int Log Logger Authoritative bool @@ -91,15 +114,17 @@ func NewServerWithLogger(zones map[string]Zone, l Logger, authoritative bool) (* } func (s *Server) writeErr(w dns.ResponseWriter, reply *dns.Msg, err error) { - reply.Rcode = dns.RcodeServerFailure reply.RecursionAvailable = false // A not found response may still include answers (e.g. CNAME chain). //reply.Answer = nil reply.Extra = nil - var dnsErr *net.DNSError + var ( + rcode int + dnsErr *net.DNSError + ) if errors.As(err, &dnsErr) && dnsErr.IsNotFound { - reply.Rcode = dns.RcodeNameError + rcode = dns.RcodeNameError reply.RecursionAvailable = true reply.Ns = []dns.RR{ &dns.SOA{ @@ -119,9 +144,16 @@ func (s *Server) writeErr(w dns.ResponseWriter, reply *dns.Msg, err error) { }, } } else { + rcode = dns.RcodeServerFailure s.Log.Printf("lookup error: %v", err) } + // If the rcode has not been explicitly set to something other than + // Success, then set it to the error code. + if reply.Rcode == dns.RcodeSuccess { + reply.Rcode = rcode + } + w.WriteMsg(reply) } @@ -191,6 +223,16 @@ func (s *Server) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { s.mu.RLock() defer s.mu.RUnlock() + // Check if there is an explicit response code set for this query. If so, + // use it. If a response code is set for the name and type, it takes + // precedence over a response code set for the name only. This allows tests + // to simulate various error conditions. + if rcode := s.rcodes[rcodeKey(qname, q.Qtype)]; rcode != dns.RcodeSuccess { + reply.Rcode = rcode + } else if rcode := s.rcodes[qname]; rcode != dns.RcodeSuccess { + reply.Rcode = rcode + } + switch q.Qtype { case dns.TypeA: ad, cnames, rname, addrs, err := s.r.lookupA(context.Background(), qname) @@ -470,9 +512,7 @@ func (s *Server) AppendRR(name string, rrType uint16, rrData string) error { s.mu.Lock() defer s.mu.Unlock() - if !strings.HasSuffix(name, ".") { - name = name + "." - } + name = dotifyName(name) zone := s.r.Zones[name] switch rrType { case dns.TypeA: @@ -540,15 +580,52 @@ func (s *Server) AppendRR(name string, rrType uint16, rrData string) error { return nil } +// SetResponseCode sets the response code for all queries to the given name. +// It is overridden by a response code set for the name and type. By setting +// Success, any previously set response code for the name is removed. +func (s *Server) SetResponseCode(name string, rcode int) { + s.mu.Lock() + defer s.mu.Unlock() + + key := dotifyName(name) + s.setResponseCodeForKey(key, rcode) +} + +// SetResponseCodeForType sets the response code for queries to the given name +// and type. It takes precedence over a response code set for the name only. By +// setting Success, any previously set response code for the name and type is +// removed. +func (s *Server) SetResponseCodeForType(name string, rrType uint16, rcode int) { + s.mu.Lock() + defer s.mu.Unlock() + + key := rcodeKey(dotifyName(name), rrType) + s.setResponseCodeForKey(key, rcode) +} + +func (s *Server) setResponseCodeForKey(key string, rcode int) { + if s.rcodes == nil { + s.rcodes = make(map[string]int) + } + if rcode == dns.RcodeSuccess { + delete(s.rcodes, key) + return + } + s.rcodes[key] = rcode + return +} + +func rcodeKey(name string, rrType uint16) string { + return name + dns.Type(rrType).String() +} + // RemoveRR removes all records of the given type from the zone for the given // name. If the zone becomes empty, it is removed. The function is thread-safe. func (s *Server) RemoveRR(name string, rrType uint16) { s.mu.Lock() defer s.mu.Unlock() - if !strings.HasSuffix(name, ".") { - name = name + "." - } + name = dotifyName(name) zone := s.r.Zones[name] switch rrType { case dns.TypeA: @@ -583,6 +660,13 @@ func (s *Server) RemoveRR(name string, rrType uint16) { s.r.Zones[name] = zone } +func dotifyName(name string) string { + if !strings.HasSuffix(name, ".") { + return name + "." + } + return name +} + func isZoneEmpty(zone Zone) bool { return len(zone.A) == 0 && len(zone.AAAA) == 0 && diff --git a/server_test.go b/server_test.go index f47d777..af8b5de 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package mockdns import ( "context" + "errors" "net" "reflect" "sort" @@ -419,6 +420,79 @@ func TestServer_MutateRR(t *testing.T) { }) } +func TestServer_ExplicitResponseCodes(t *testing.T) { + srv, err := NewServer(nil, false) + assertNoError(t, err) + defer srv.Close() + + srv.AppendRR("foo.io.", RRTypeTXT, "one") + srv.AppendRR("foo.io.", RRTypeA, "1.2.3.4") + srv.AppendRR("bar.io.", RRTypeTXT, "two") + + r := srv.NewResolver() + // Force server to return on error immediately instead of retrying with + // other names. + r.StrictErrors = true + + // Make sure that all requests succeed. + _, err = r.LookupTXT(context.Background(), "foo.io") + assertNoError(t, err) + + _, err = r.LookupHost(context.Background(), "foo.io") + assertNoError(t, err) + + _, err = r.LookupTXT(context.Background(), "bar.io") + assertNoError(t, err) + + // Set explicit response code for A records of foo.io. Make sure that it + // does not affect TXT records of foo.io. and any record of bar.io. + srv.SetResponseCodeForType("foo.io.", RRTypeA, RcodeServerFailure) + + _, err = r.LookupTXT(context.Background(), "foo.io") + assertNoError(t, err) + + _, err = r.LookupHost(context.Background(), "foo.io") + var dnsErr1 *net.DNSError + if !errors.As(err, &dnsErr1) || dnsErr1.Err != "server misbehaving" { + t.Fatalf("Expected error, got=%s", dnsErr1.Err) + } + + _, err = r.LookupTXT(context.Background(), "bar.io") + assertNoError(t, err) + + // Set explicit response code for A records of foo.io. back to success. + // Make sure that all requests succeed again. + srv.SetResponseCodeForType("foo.io.", RRTypeA, RcodeSuccess) + + _, err = r.LookupTXT(context.Background(), "foo.io") + assertNoError(t, err) + + _, err = r.LookupHost(context.Background(), "foo.io") + assertNoError(t, err) + + _, err = r.LookupTXT(context.Background(), "bar.io") + assertNoError(t, err) + + // Set explicit response code for the whole zone. Make sure that it affects + // all requests for foo.io. but not for bar.io. + srv.SetResponseCode("foo.io.", RcodeServerFailure) + + _, err = r.LookupTXT(context.Background(), "foo.io") + var dnsErr2 *net.DNSError + if !errors.As(err, &dnsErr2) || dnsErr2.Err != "server misbehaving" { + t.Fatalf("Expected error, got=%s", dnsErr2.Err) + } + + _, err = r.LookupHost(context.Background(), "foo.io") + var dnsErr3 *net.DNSError + if !errors.As(err, &dnsErr3) || dnsErr3.Err != "server misbehaving" { + t.Fatalf("Expected error, got=%s", dnsErr3.Err) + } + + _, err = r.LookupTXT(context.Background(), "bar.io") + assertNoError(t, err) +} + func assertNoError(t *testing.T, err error) { if err != nil { t.Fatalf("Unexpected error: %v", err)