Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions library/agent/hooks/wrapExport.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ t.test("With callback", async (t) => {
{ name: "test", type: "external" },
{
kind: "outgoing_http_op",
bindContext: false,
inspectArgs: (args) => {
t.same(args, ["input", () => {}]);
},
}
);

toWrap.test("input", () => {});
});

t.test("With callback with bindContext", async (t) => {
const toWrap = {
test(input: string, callback: (input: string) => void) {
callback(input);
},
};

wrapExport(
toWrap,
"test",
{ name: "test", type: "external" },
{
kind: "outgoing_http_op",
bindContext: true,
inspectArgs: (args) => {
t.same(args, ["input", bindContext(() => {})]);
},
Expand Down
16 changes: 12 additions & 4 deletions library/agent/hooks/wrapExport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ export type InterceptorObject = {
// This will be used to collect stats
// For sources, this will often be undefined
kind: OperationKind | undefined;
// Whether to bind the async resource execution context to callback functions passed as arguments
// Only applies to inspectArgs right now
// If the called function uses code where the context would be lost, like calling the callback in a setTimeout
// or an event listener, this should be true
// In other cases this can be false to avoid unnecessary overhead
bindContext?: boolean;
};

/**
Expand Down Expand Up @@ -60,10 +66,12 @@ export function wrapExport(

// Run inspectArgs interceptor if provided
if (typeof interceptors.inspectArgs === "function") {
// Bind context to functions in arguments
for (let i = 0; i < args.length; i++) {
if (typeof args[i] === "function") {
args[i] = bindContext(args[i]);
if (interceptors.bindContext) {
// Bind context to functions in arguments
for (let i = 0; i < args.length; i++) {
if (typeof args[i] === "function") {
args[i] = bindContext(args[i]);
}
}
}

Expand Down
11 changes: 10 additions & 1 deletion library/sinks/ChildProcess.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { Context, runWithContext } from "../agent/Context";
import { Context, getContext, runWithContext } from "../agent/Context";
import { ChildProcess } from "./ChildProcess";
import { execFile, execFileSync } from "child_process";
import { createTestAgent } from "../helpers/createTestAgent";
Expand Down Expand Up @@ -219,4 +219,13 @@ t.test("it works", async (t) => {
);
}
);

await new Promise<void>((resolve) => {
runWithContext(unsafeContext, () => {
exec("ls", () => {
t.same(getContext(), unsafeContext);
resolve();
}).unref();
});
});
});
16 changes: 14 additions & 2 deletions library/sinks/FileSystem.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { Context, runWithContext } from "../agent/Context";
import { Context, getContext, runWithContext } from "../agent/Context";
import { FileSystem } from "./FileSystem";
import { createTestAgent } from "../helpers/createTestAgent";

Expand Down Expand Up @@ -52,19 +52,22 @@ t.test("it works", async (t) => {
const {
writeFile,
writeFileSync,
readFile,
rename,
realpath,
promises: fsDotPromise,
realpathSync,
} = require("fs");
} = require("fs") as typeof import("fs");
const { writeFile: writeFilePromise } =
require("fs/promises") as typeof import("fs/promises");

t.ok(typeof realpath.native === "function");
t.ok(typeof realpathSync.native === "function");

const runCommandsWithInvalidArgs = () => {
// @ts-expect-error Invalid args test
throws(() => writeFile(), /Received undefined/);
// @ts-expect-error Invalid args test
throws(() => writeFileSync(), /Received undefined/);
};

Expand Down Expand Up @@ -308,4 +311,13 @@ t.test("it works", async (t) => {
rename(new URL("file:///../../test.txt"), "../test2.txt", () => {});
}
);

await new Promise<void>((resolve) => {
runWithContext(unsafeContext, () => {
readFile("./test.txt", "utf-8", (err, data) => {
t.same(getContext(), unsafeContext);
resolve();
});
});
});
});
11 changes: 10 additions & 1 deletion library/sinks/HTTPRequest.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import * as dns from "dns";
import * as t from "tap";
import { Token } from "../agent/api/Token";
import { Context, runWithContext } from "../agent/Context";
import { Context, getContext, runWithContext } from "../agent/Context";
import { wrap } from "../helpers/wrap";
import { HTTPRequest } from "./HTTPRequest";
import { createTestAgent } from "../helpers/createTestAgent";
Expand Down Expand Up @@ -358,6 +358,15 @@ t.test("it works", (t) => {
}
);

runWithContext(createContext(), () => {
const req = https.get("https://app.aikido.dev", (res) => {
t.same(getContext(), createContext());
res.on("data", () => {});
res.on("end", () => {});
});
req.end();
});

setTimeout(() => {
t.end();
}, 3000);
Expand Down
14 changes: 10 additions & 4 deletions library/sinks/MariaDB.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { runWithContext, type Context } from "../agent/Context";
import { getContext, runWithContext, type Context } from "../agent/Context";
import { createTestAgent } from "../helpers/createTestAgent";
import { MariaDB } from "./MariaDB";

Expand Down Expand Up @@ -267,9 +267,15 @@ t.test("it detects SQL injections using callbacks", (t) => {
}
}

connection.end();
pool.end();
t.end();
runWithContext(dangerousContext, () => {
connection.query("SELECT 1;", () => {
t.same(getContext(), dangerousContext);

connection.end();
pool.end();
t.end();
});
});
}
);
} catch (error: any) {
Expand Down
2 changes: 2 additions & 0 deletions library/sinks/MariaDB.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export class MariaDB implements Wrapper {
for (const fn of functions) {
wrapExport(exports.prototype, fn, pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery(args, fn),
});
}
Expand All @@ -66,6 +67,7 @@ export class MariaDB implements Wrapper {
for (const fn of functions) {
wrapExport(exports.prototype, fn, pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery(args, fn),
});
}
Expand Down
1 change: 1 addition & 0 deletions library/sinks/MySQL.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export class MySQL implements Wrapper {
.onFileRequire("lib/Connection.js", (exports, pkgInfo) => {
wrapExport(exports.prototype, "query", pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery(args),
});
});
Expand Down
11 changes: 10 additions & 1 deletion library/sinks/MySQL2.tests.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { runWithContext, type Context } from "../agent/Context";
import { getContext, runWithContext, type Context } from "../agent/Context";
import { MySQL2 } from "./MySQL2";
import { startTestAgent } from "../helpers/startTestAgent";

Expand Down Expand Up @@ -159,6 +159,15 @@ export function createMySQL2Tests(versionPkgName: string) {
runWithContext(safeContext, () => {
connection2!.query("-- This is a comment");
});

await runWithContext(dangerousContext, () => {
return new Promise<void>((resolve) => {
connection2!.query("SELECT petname FROM cats;", () => {
t.same(getContext(), dangerousContext);
resolve();
});
});
});
} catch (error: any) {
t.fail(error);
} finally {
Expand Down
2 changes: 2 additions & 0 deletions library/sinks/MySQL2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export class MySQL2 implements Wrapper {
// Wrap connection.query
wrapExport(connectionPrototype, "query", pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery("mysql2.query", args),
});
}
Expand All @@ -94,6 +95,7 @@ export class MySQL2 implements Wrapper {
// Wrap connection.execute
wrapExport(connectionPrototype, "execute", pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery("mysql2.execute", args),
});
}
Expand Down
1 change: 1 addition & 0 deletions library/sinks/Postgres.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ export class Postgres implements Wrapper {
.onRequire((exports, pkgInfo) => {
wrapExport(exports.Client.prototype, "query", pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => this.inspectQuery(args),
});
});
Expand Down
21 changes: 20 additions & 1 deletion library/sinks/SQLite3.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { runWithContext, type Context } from "../agent/Context";
import { getContext, runWithContext, type Context } from "../agent/Context";
import { SQLite3 } from "./SQLite3";
import { promisify } from "util";
import { createTestAgent } from "../helpers/createTestAgent";
Expand Down Expand Up @@ -124,6 +124,25 @@ t.test("it detects SQL injections", async () => {
'SQLITE_ERROR: unrecognized token: "\' SELECT * FROM test"'
);
}

await new Promise<void>((resolve) => {
runWithContext(dangerousContext, () => {
db.get("SELECT petname FROM cats;", () => {
t.match(getContext(), dangerousContext);

try {
db.get("-- should be blocked", () => {});
} catch (error: any) {
t.match(
error.message,
/Zen has blocked an SQL injection: sqlite3\.get\(\.\.\.\) originating from body\.myTitle/
);
}

resolve();
});
});
});
} catch (error: any) {
t.fail(error);
} finally {
Expand Down
1 change: 1 addition & 0 deletions library/sinks/SQLite3.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export class SQLite3 implements Wrapper {
for (const func of sqlFunctions) {
wrapExport(db, func, pkgInfo, {
kind: "sql_op",
bindContext: true,
inspectArgs: (args) => {
return this.inspectQuery(`sqlite3.${func}`, args);
},
Expand Down
15 changes: 14 additions & 1 deletion library/sinks/Shelljs.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as t from "tap";
import { runWithContext, type Context } from "../agent/Context";
import { getContext, runWithContext, type Context } from "../agent/Context";
import { Shelljs } from "./Shelljs";
import { ChildProcess } from "./ChildProcess";
import { FileSystem } from "./FileSystem";
Expand Down Expand Up @@ -201,3 +201,16 @@ t.test("invalid arguments are passed to shelljs", async () => {
t.same(result.code, 1);
});
});

t.test("context is available in callbacks", async (t) => {
const shell = require("shelljs");

await new Promise<void>((resolve) => {
runWithContext(safeContext, () => {
shell.exec("ls", { silent: true }, () => {
t.match(getContext(), safeContext);
resolve();
});
});
});
});
Loading