#include "cwCommon.h" #include "cwLog.h" #include "cwCommonImpl.h" #include "cwMem.h" #include "cwThread.h" #include "cwTcpSocket.h" #include "cwTcpSocketSrv.h" #include "cwMdns.h" namespace cw { namespace net { namespace mdns { typedef struct msg_str { uint16_t transactionId; uint16_t flags; uint16_t questionN; uint16_t answerN; uint16_t nameServerN; uint16_t additionalN; } msg_t; typedef struct question_str { char* name; uint16_t type; uint16_t clss; struct question_str* link; } question_t; typedef struct srv_rsrc_str { uint16_t priority; uint16_t weight; uint16_t port; char* target; } srv_rsrc_t; typedef struct rsrc_str { char* name; uint16_t type; uint16_t clss; uint32_t ttl; uint16_t dataByteN; union { char* text; srv_rsrc_t srv; uint32_t addr; }; struct rsrc_str* link; } rsrc_t; typedef struct mdns_str { rsrc_t* rsrcL; } mdns_t; typedef struct mdns_app_str { srv::handle_t mdnsH; srv::handle_t tcpH; unsigned cbN; mdns_t mdns; } mdns_app_t; void errorv( mdns_t* p, const char* fmt, va_list vl ) { printf("Error: "); vprintf(fmt,vl); } void logv( mdns_t* p, const char* fmt, va_list vl ) { vprintf(fmt,vl); fflush(stdout); } void error( mdns_t* p, const char* fmt, ... ) { va_list vl; va_start(vl,fmt); errorv(p,fmt,vl); va_end(vl); } void log( mdns_t* p, const char* fmt, ... ) { va_list vl; va_start(vl,fmt); logv(p,fmt,vl); va_end(vl); } enum { kInvalidRecdTId, kQuestionRecdTId, kAnswerRecdTId, kNameServerRecdTId, kAdditionalRecdTId }; enum { kA_DnsTId = 1, kPTR_DnsTId = 12, kTXT_DnsTId = 16, kAAAA_DnsTId = 28, kSRV_DnsTId = 33, kOPT_DnsTId = 41, kANY_DnsTId = 255 // REMEMBER: Add new type id's to dnsTypeIdToString() }; enum { kHdrBodyByteN = 12, kQuestionBodyByteN = 4, kRsrcBodyByteN = 10, kABodyByteN = 4, kSrvBodyByteN = 6, kOptBodyByteN = 4, }; enum { kReplyHdrFl = 0x8000, kAuthoritativeHdrFl = 0x0400, kFlushClassFl = 0x8000, kInClassFl = 0x0001 }; const char* dnsTypeIdToString( uint16_t id ) { switch( id ) { case kA_DnsTId: return "A"; case kPTR_DnsTId: return "PTR"; case kTXT_DnsTId: return "TXT"; case kAAAA_DnsTId: return "AAAA"; case kSRV_DnsTId: return "SRV"; case kOPT_DnsTId: return "OPT"; case kANY_DnsTId: return "ANY"; } return ""; } unsigned calc_msg_buf_byte_count( unsigned recdTId, const char* name, unsigned dnsTId, unsigned clss, unsigned ttl, unsigned numb0, const char* text, unsigned nextRecdTId, va_list vl ) { unsigned msgByteN = kHdrBodyByteN; // msg header bytes unsigned recdN = 0; while( true ) { // unsigned n0 = msgByteN; // record name bytes msgByteN += strlen(name) + 2; // add 1 for initial segment length and 1 for terminating zero if( recdTId == kQuestionRecdTId ) { msgByteN += kQuestionBodyByteN; } else { // resource record bytes msgByteN += kRsrcBodyByteN; switch( dnsTId ) { case kA_DnsTId: msgByteN += kABodyByteN; break; case kPTR_DnsTId: msgByteN += strlen(text) + 1; break; case kTXT_DnsTId: msgByteN += strlen(text) + 1; break; case kSRV_DnsTId: msgByteN += kSrvBodyByteN + strlen(text) + 1; break; default: assert(0); } } //printf("SIZE: %i %i\n", dnsTId, msgByteN-n0 ); recdTId = recdN==0 ? nextRecdTId : va_arg(vl,unsigned); if( recdTId == kInvalidRecdTId ) break; name = va_arg(vl,const char*); dnsTId = va_arg(vl,unsigned); clss = va_arg(vl,unsigned); ttl = va_arg(vl,unsigned); // not used numb0 = va_arg(vl,unsigned); // not used text = va_arg(vl,const char*); recdN += 1; } return msgByteN; } char* format_name( char* b, unsigned bN, const char* name, bool zeroTermFl=true, const char sepChar='.' ) { unsigned n = 0; unsigned j = 0; // for each input character for(unsigned i=0; true; ++i) { // if this char is a '.' or a '\0' then it is the end of a name segment if( name[i] == sepChar || name[i]==0 ) { assert( j < bN); b[j] = n; // write the length of the previous segment j = i+1; // advance j to the length cell of the next segments n = 0; // if this char is a '\0' then we are at the end of the input if( name[i] == 0 ) break; } else { n += 1; // advance the segment length counter assert( j+n < bN ); b[j+n] = name[i]; // write the current char to the output } } // terminate the output string if( zeroTermFl ) { assert( j < bN ); b[j] = 0; j += 1; } return b + j; // return a pointer just past the end of the output string } char* format_question( char* b, unsigned bN, const char* name, unsigned dnsTypeId ) { b = format_name(b,bN,name); uint16_t* u = (uint16_t*)b; u[0] = htons(dnsTypeId); u[1] = htons(kInClassFl); return b + kQuestionBodyByteN; } char* format_rsrc( char* b, unsigned bN, const char* name, unsigned typeId, unsigned clss, unsigned ttl, unsigned dataByteN ) { // u[0] u[1] u[2-3] u[4] // type class TTL dlen char* b1 = format_name(b,bN,name); uint16_t* u = (uint16_t*)b1; uint32_t* l = (uint32_t*)(u + 2); u[0] = htons(typeId); u[1] = htons(clss); l[0] = htonl(ttl); u[4] = htons(dataByteN); return b1 + kRsrcBodyByteN; } char* format_A_rsrc( char* b, unsigned bN, const char* name, unsigned clss, unsigned ttl, unsigned addr ) { char* b1 = format_rsrc( b, bN, name, kA_DnsTId, clss, ttl, kABodyByteN ); uint32_t* l = (uint32_t*)b1; l[0] = htonl(addr); return b1 + kABodyByteN; } char* format_PTR_rsrc( char* b, unsigned bN, const char* name, unsigned clss, unsigned ttl, const char* text ) { // u[0] u[1] u[2-3] u[4] u[5 ... ] // type class TTL dlen text unsigned dataByteN = strlen(text)+1; char* b1 = format_rsrc( b, bN, name, kPTR_DnsTId, clss, ttl, dataByteN ); b1 = format_name(b1,bN-(b1-b),text,false); return b1; } char* format_TXT_rsrc( char* b, unsigned bN, const char* name, unsigned clss, unsigned ttl, const char* text ) { // u[0] u[1] u[2-3] u[4] u[5 ... ] // type class TTL dlen text unsigned dataByteN = strlen(text)+1; char* b1 = format_rsrc( b, bN, name, kTXT_DnsTId, clss, ttl, dataByteN ); b1 = format_name(b1,bN-(b1-b),text,false,'\n'); return b1; } char* format_SRV_rsrc( char* b, unsigned bN, const char* name, unsigned clss, unsigned ttl, const char* text, unsigned port, unsigned priority=0, unsigned weight=0 ) { // u[0] u[1] u[2-3] u[4] u[5] u[6] u[7] u[8 ...] // type class TTL dlen pri weight port target unsigned dataByteN = kSrvBodyByteN + strlen(text)+1; char* b1 = format_rsrc( b, bN, name, kSRV_DnsTId, clss, ttl, dataByteN ); uint16_t* u = (uint16_t*)b1; u[0] = htons(priority); u[1] = htons(weight); u[2] = htons(port); b1 = format_name(b1 + kSrvBodyByteN,bN-((b1-b)+kSrvBodyByteN),text,false); return b1; } char* alloc_msgv( unsigned* msgByteNRef, uint16_t transactionId, uint16_t flags, unsigned recdTId, const char* name, unsigned dnsTId, unsigned clss, unsigned ttl, unsigned numb0, const char* text, unsigned nextRecdTId, va_list vl0 ) { va_list vl1; va_copy(vl1,vl0); unsigned byteN = calc_msg_buf_byte_count(recdTId,name,dnsTId,clss,ttl,numb0,text,nextRecdTId,vl1); va_end(vl1); if( msgByteNRef != nullptr ) *msgByteNRef = 0; char* buf = (char*)calloc(1,byteN); char* b0 = buf + kHdrBodyByteN; char* b1 = nullptr; int bN = byteN; uint16_t* u = (uint16_t*)buf; unsigned recdN = 0; u[0] = transactionId; u[1] = flags; // for each specified record while( true ) { // track the type of record switch( recdTId ) { case kQuestionRecdTId: u[2] += 1; break; case kAnswerRecdTId: u[3] += 1; break; case kNameServerRecdTId: u[4] += 1; break; case kAdditionalRecdTId: u[5] += 1; break; } // if this is a question record if( recdTId == kQuestionRecdTId ) { b1 = format_question( b0, bN, name, dnsTId ); } else { // select the resource record type to generate switch( dnsTId ) { case kA_DnsTId: b1 = format_A_rsrc( b0, bN, name, clss, ttl, numb0 ); break; case kPTR_DnsTId: b1 = format_PTR_rsrc(b0, bN, name, clss, ttl, text ); break; case kTXT_DnsTId: b1 = format_TXT_rsrc(b0, bN, name, clss, ttl, text ); break; case kSRV_DnsTId: b1 = format_SRV_rsrc(b0, bN, name, clss, ttl, text, numb0 ); break; default: assert(0); } } //printf("FRMT: %i %li\n", dnsTId, b1-b0 ); bN -= (b1 - b0); // track the count of remaing bytes in the buffer assert(bN >= 0); // assert the buffer is not already full b0 = b1; // update the current buffer output pointer // get the next record type recdTId = recdN==0 ? nextRecdTId : va_arg(vl0,unsigned); // detect the end of records sentinel if( recdTId == kInvalidRecdTId ) break; // get the arguments for the next record name = va_arg(vl0,const char*); dnsTId = va_arg(vl0,unsigned); clss = va_arg(vl0,unsigned); ttl = va_arg(vl0,unsigned); numb0 = va_arg(vl0,unsigned); // not used text = va_arg(vl0,const char*); recdN += 1; } // Note that the buffer should be exactly full when all data is written. // If this is not true then either the buffer size calculation or // the buffer serialization code is incorrect. assert( bN - kHdrBodyByteN == 0 ); if( msgByteNRef != nullptr ) *msgByteNRef = byteN; // convert the record counts to the network endianess u[2] = htons(u[2]); u[3] = htons(u[3]); u[4] = htons(u[4]); u[5] = htons(u[5]); return buf; } char* alloc_msg( unsigned* msgByteNRef, uint16_t transactionId, uint16_t flags, unsigned recdTId, const char* name, unsigned dnsTId, unsigned clss, unsigned ttl, unsigned numb0, const char* text, unsigned nextRecdTId, ... ) { va_list vl; va_start(vl,nextRecdTId); char* b = alloc_msgv( msgByteNRef, transactionId, flags, recdTId, name, dnsTId, clss, ttl, numb0, text, nextRecdTId, vl ); va_end(vl); return b; } unsigned calc_ptr_string_byte_count( const char* b ) { unsigned n = 0; unsigned i = 0; // terminate when a zero or another ptr string is encountered while( b[i] != 0 && (b[i] & 0xc0) != 0xc0) { // TODO: what if this is a 'ptr' ... getting the length of a pointer string may require a recursive function? n += b[i] + 1; i += b[i] + 1; } return n; } unsigned calc_name_byte_count( mdns_t* p, const char* base, const char* b, unsigned maxSrcByteN, unsigned* strLenRef=nullptr, bool logFl=true ) { if( strLenRef != nullptr ) *strLenRef = 0; // Number of bytes required to represent the uncompressed string // (including the segment size bytes but not the terminating zero) unsigned strByteN = 0; unsigned segN = 0; // count of segments the name is formed from unsigned i = 0; while( maxSrcByteN ==0 || i < maxSrcByteN ) { // if this a pointer if( (b[i] & 0xc0) == 0xc0 ) { // TODO check for going past buffer before add 1 to index unsigned short offset = b[i] & 0x3f; offset = (offset<<8) + ((unsigned char)b[i+1]); strByteN += calc_ptr_string_byte_count( base + offset) + 1; if( logFl ) log(p,"%.*s | ", base[offset], base + offset + 1 ); i += 2; segN += 1; break; // ptr terminates the name } else { if( b[i] == 0 ) { ++i; break; // zero terminates the name } if( logFl ) log(p,"%.*s | ", b[i], b+i+1 ); strByteN += b[i] + 1; i += b[i] + 1; segN += 1; } } if( maxSrcByteN != 0 and i > maxSrcByteN ) { // we came to the end of a name without a zero or ptr this is a malformed packet error(p,"Malformed name."); return -1; } if( strLenRef != nullptr ) *strLenRef = strByteN; // add one for terminating zero return i; // i is the count of byte used by the name in the packet buffer } unsigned resource_recd_byte_count( mdns_t* p, const char* base, const char* b, unsigned bN ) { unsigned nameN = calc_name_byte_count( p, base, b, bN, nullptr, false ); uint16_t* u = (uint16_t*)(b + nameN); return nameN + 10 + ntohs(u[4]); } const char* parse_A_recd( mdns_t* p, const char* base, const char* b, unsigned byteN ) { assert( byteN >= kABodyByteN ); unsigned addr = ntohl( *(unsigned *)b ); log(p,"0x%04x inet addr", addr ); return b + 4; } const char* parse_PTR_recd( mdns_t* p, const char* base, const char* b, unsigned byteN ) { unsigned nameN = calc_name_byte_count( p, base, b, byteN ); return b + nameN; } const char* parse_TXT_recd( mdns_t* p, const char* base, const char* b, unsigned byteN ) { unsigned i =0; while( i(buf); const uint16_t* hdr = static_cast(buf); uint16_t transId = ntohs(hdr[0]); uint16_t flags = ntohs(hdr[1]); uint16_t questionN = ntohs(hdr[2]); uint16_t answerN = ntohs(hdr[3]); uint16_t nameSrvN = ntohs(hdr[4]); uint16_t addN = ntohs(hdr[5]); log(p,"*** Msg: id:0x%04x flags:0x%04x qN:%i aN:%i nsN:%i addN:%i\n", transId, flags, questionN, answerN, nameSrvN,addN); const char* b0 = (const char*)(hdr + 6); const char* b1 = nullptr; int bN = byteN - (b0 - base); if((b1 = parse_msg_segment( p, parse_question, questionN, base, b0, bN )) == nullptr ) goto errLabel; bN -= b1 - b0; b0 = b1; if((b1 = parse_msg_segment( p, parse_answer, answerN, base, b0, bN )) == nullptr ) goto errLabel; bN -= b1 - b0; b0 = b1; if((b1 = parse_msg_segment( p, parse_name_server, nameSrvN, base, b0, bN )) == nullptr ) goto errLabel; bN -= b1 - b0; b0 = b1; if((b1 = parse_msg_segment( p, parse_additional, addN, base, b0, bN )) == nullptr ) goto errLabel; errLabel: return 0; } void print_hex( const char* buf, unsigned dataByteCnt ) { unsigned char* data = (unsigned char*)buf; const unsigned colN = 8; unsigned ci = 0; for(unsigned i=0; i(arg); char addrBuf[ INET_ADDRSTRLEN ]; socket::addrToString( fromAddr, addrBuf, INET_ADDRSTRLEN ); p->cbN += 1; printf("%i bytes:%i %s\n", p->cbN, dataByteCnt, addrBuf ); print_hex( (const char*)data, dataByteCnt ); parse_msg(&p->mdns,data,dataByteCnt); } void testAllocMsg( const char* tag ) { unsigned bufByteN = 0; unsigned transId = 0; unsigned flags = 0; unsigned ttl = 120; char* buf = alloc_msg( &bufByteN, transId, flags, kQuestionRecdTId, "80.0.168.192.in-addr.arpa", kANY_DnsTId, kInClassFl, 0, 0, nullptr, kQuestionRecdTId, "Euphonix-MC-0090D580F4DE.local", kANY_DnsTId, kInClassFl, 0, 0, nullptr, kNameServerRecdTId, "Euphonix-MC-0090D580F4DE.local", kA_DnsTId, kInClassFl, ttl, 49168, nullptr, kNameServerRecdTId, "80.0.168.192.in-addr.arpa", kPTR_DnsTId, kInClassFl, ttl, 0, "in-addr.arpa", kInvalidRecdTId ); print_hex( buf, bufByteN ); parse_msg( nullptr, buf, bufByteN ); free(buf); buf = alloc_msg( &bufByteN, transId, flags, kQuestionRecdTId, "MC Mix._EuConProxy._tcp.local", kANY_DnsTId, kInClassFl, 0, 0, nullptr, kNameServerRecdTId,"Euphonix-MC-0090D580F4DE.local", kSRV_DnsTId, kInClassFl, 120, 49168, "local", kInvalidRecdTId ); print_hex( buf, bufByteN ); parse_msg( nullptr, buf, bufByteN ); free(buf); buf = alloc_msg( &bufByteN, transId, kReplyHdrFl | kAuthoritativeHdrFl, kAnswerRecdTId, "MC Mix - 1._EuConProxy._tcp.local", kTXT_DnsTId, kFlushClassFl | kInClassFl, 0, 0, "lmac=00-90-D5-80-F4-DE\ndummy=0", kInvalidRecdTId ); print_hex( buf, bufByteN ); parse_msg( nullptr, buf, bufByteN ); free(buf); } } } } cw::rc_t cw::net::mdns::test() { rc_t rc; socket::portNumber_t mdnsPort = 5353; unsigned recvBufByteCnt = 4096; unsigned timeOutMs = 0; // if timeOutMs==0 server uses recv_from() const unsigned sbufN = 31; char sbuf[ sbufN+1 ]; mdns_app_t app; app.cbN = 0; // create the mDNS UDP socket server if((rc = srv::create( app.mdnsH, mdnsPort, socket::kNonBlockingFl | socket::kReuseAddrFl | socket::kReusePortFl | socket::kMultiCastTtlFl | socket::kMultiCastLoopFl, mdnsRecieveCallback, &app, recvBufByteCnt, timeOutMs, NULL, socket::kInvalidPortNumber )) != kOkRC ) { return rc; } // add the mDNS socket to the multicast group if((rc = join_multicast_group( socketHandle(app.mdnsH), "224.0.0.251" )) != kOkRC ) goto errLabel; // start the mDNS socket server if((rc = srv::start( app.mdnsH )) != kOkRC ) goto errLabel; while( true ) { printf("? "); if( std::fgets(sbuf,sbufN,stdin) == sbuf ) { if( strcmp(sbuf,"msg0\n") == 0 ) { testAllocMsg(sbuf); break; } if( strcmp(sbuf,"quit\n") == 0) break; } } errLabel: // close the mDNS server rc_t rc0 = destroy(app.mdnsH); return rcSelect(rc,rc0); }