Skip to content

Commit 6eb1958

Browse files
authored
feat: add support for srvMaxHosts/srvServiceName COMPASS-5324 (#2)
1 parent 45dfeb8 commit 6eb1958

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed

src/index.ts

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ function matchesParentDomain (srvAddress: string, parentDomain: string): boolean
2020
return srv.endsWith(parent);
2121
}
2222

23-
async function resolveDnsSrvRecord (dns: NonNullable<Options['dns']>, lookupAddress: string): Promise<string[]> {
24-
const addresses = await promisify(dns.resolveSrv)(`_mongodb._tcp.${lookupAddress}`);
23+
async function resolveDnsSrvRecord (dns: NonNullable<Options['dns']>, lookupAddress: string, srvServiceName: string): Promise<string[]> {
24+
const addresses = await promisify(dns.resolveSrv)(`_${srvServiceName}._tcp.${lookupAddress}`);
2525
if (!addresses?.length) {
2626
throw new MongoParseError('No addresses found at host');
2727
}
@@ -39,8 +39,8 @@ async function resolveDnsTxtRecord (dns: NonNullable<Options['dns']>, lookupAddr
3939
let records: string[][] | undefined;
4040
try {
4141
records = await promisify(dns.resolveTxt)(lookupAddress);
42-
} catch (err) {
43-
if (err.code && (err.code !== 'ENODATA' && err.code !== 'ENOTFOUND')) {
42+
} catch (err: any) {
43+
if (err?.code && (err.code !== 'ENODATA' && err.code !== 'ENOTFOUND')) {
4444
throw err;
4545
}
4646
}
@@ -89,11 +89,19 @@ async function resolveMongodbSrv (input: string, options?: Options): Promise<str
8989
}
9090

9191
const lookupAddress = url.hostname;
92+
const srvServiceName = url.searchParams.get('srvServiceName') || 'mongodb';
93+
const srvMaxHosts = +(url.searchParams.get('srvMaxHosts') || '0');
94+
9295
const [srvResult, txtResult] = await Promise.all([
93-
resolveDnsSrvRecord(dns, lookupAddress),
96+
resolveDnsSrvRecord(dns, lookupAddress, srvServiceName),
9497
resolveDnsTxtRecord(dns, lookupAddress)
9598
]);
9699

100+
if (srvMaxHosts && srvMaxHosts < srvResult.length) {
101+
// Replace srvResult with shuffled + limited srvResult
102+
srvResult.splice(0, srvResult.length, ...shuffle(srvResult, srvMaxHosts));
103+
}
104+
97105
url.protocol = 'mongodb:';
98106
url.hostname = '__DUMMY_HOSTNAME__';
99107
if (!url.pathname) {
@@ -107,8 +115,39 @@ async function resolveMongodbSrv (input: string, options?: Options): Promise<str
107115
if (!url.searchParams.has('tls') && !url.searchParams.has('ssl')) {
108116
url.searchParams.set('tls', 'true');
109117
}
118+
url.searchParams.delete('srvServiceName');
119+
url.searchParams.delete('srvMaxHosts');
110120

111121
return url.toString().replace('__DUMMY_HOSTNAME__', srvResult.join(','));
112122
}
113123

124+
/**
125+
* Fisher–Yates Shuffle
126+
* (shamelessly copied from https://github.com/mongodb/node-mongodb-native/blob/1f8b539cd3d60dd9f36baa22fd287241b5c65380/src/utils.ts#L1423-L1451)
127+
*
128+
* Reference: https://bost.ocks.org/mike/shuffle/
129+
* @param sequence - items to be shuffled
130+
* @param limit - Defaults to `0`. If nonzero shuffle will slice the randomized array e.g, `.slice(0, limit)` otherwise will return the entire randomized array.
131+
*/
132+
function shuffle<T> (sequence: Iterable<T>, limit = 0): Array<T> {
133+
const items = Array.from(sequence); // shallow copy in order to never shuffle the input
134+
135+
limit = Math.min(limit, items.length);
136+
137+
let remainingItemsToShuffle = items.length;
138+
const lowerBound = limit % items.length === 0 ? 1 : items.length - limit;
139+
while (remainingItemsToShuffle > lowerBound) {
140+
// Pick a remaining element
141+
const randomIndex = Math.floor(Math.random() * remainingItemsToShuffle);
142+
remainingItemsToShuffle -= 1;
143+
144+
// And swap it with the current element
145+
const swapHold = items[remainingItemsToShuffle];
146+
items[remainingItemsToShuffle] = items[randomIndex];
147+
items[randomIndex] = swapHold;
148+
}
149+
150+
return limit % items.length === 0 ? items : items.slice(lowerBound);
151+
}
152+
114153
export = resolveMongodbSrv;

test/index.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,41 @@ describe('resolveMongodbSrv', () => {
190190
await resolveMongodbSrv('mongodb+srv://server.example.com/?loadBalanced=true', { dns }),
191191
'mongodb://asdf.example.com/?loadBalanced=true&tls=true');
192192
});
193+
194+
it('allows specifying a custom SRV service name', async () => {
195+
srvResult = [{ name: 'asdf.example.com', port: 27017 }];
196+
txtResult = [['loadBalanced=false']];
197+
assert.strictEqual(
198+
await resolveMongodbSrv('mongodb+srv://server.example.com/?loadBalanced=true&srvServiceName=custom', { dns }),
199+
'mongodb://asdf.example.com/?loadBalanced=true&tls=true');
200+
assert.deepStrictEqual(srvQueries, ['_custom._tcp.server.example.com']);
201+
});
202+
203+
it('defaults to _mongodb._tcp as a SRV service name', async () => {
204+
srvResult = [{ name: 'asdf.example.com', port: 27017 }];
205+
txtResult = [['loadBalanced=false']];
206+
assert.strictEqual(
207+
await resolveMongodbSrv('mongodb+srv://server.example.com/?loadBalanced=true', { dns }),
208+
'mongodb://asdf.example.com/?loadBalanced=true&tls=true');
209+
assert.deepStrictEqual(srvQueries, ['_mongodb._tcp.server.example.com']);
210+
});
211+
212+
it('allows limiting the SRV result to a specific number of hosts', async () => {
213+
srvResult = ['host1', 'host2', 'host3'].map(name => ({ name: `${name}.example.com`, port: 27017 }));
214+
txtResult = [];
215+
assert.strictEqual(
216+
await resolveMongodbSrv('mongodb+srv://server.example.com/?srvMaxHosts=0', { dns }),
217+
'mongodb://host1.example.com,host2.example.com,host3.example.com/?tls=true');
218+
assert.strictEqual(
219+
await resolveMongodbSrv('mongodb+srv://server.example.com/?srvMaxHosts=3', { dns }),
220+
'mongodb://host1.example.com,host2.example.com,host3.example.com/?tls=true');
221+
assert.strictEqual(
222+
await resolveMongodbSrv('mongodb+srv://server.example.com/?srvMaxHosts=6', { dns }),
223+
'mongodb://host1.example.com,host2.example.com,host3.example.com/?tls=true');
224+
assert.match(
225+
await resolveMongodbSrv('mongodb+srv://server.example.com/?srvMaxHosts=1', { dns }),
226+
/^mongodb:\/\/host[1-3]\.example\.com\/\?tls=true$/);
227+
});
193228
});
194229

195230
for (const [name, dnsProvider] of [

0 commit comments

Comments
 (0)