Skip to content
2 changes: 2 additions & 0 deletions end2end/server/src/zen/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type AppConfig = {
domains: any[];
failureRate?: number;
timeout?: number;
promptProtectionMode: string;
};

const configs: AppConfig[] = [];
Expand All @@ -26,6 +27,7 @@ export function generateConfig(app: App): AppConfig {
blockedUserIds: [],
allowedIPAddresses: [],
blockNewOutgoingRequests: false,
promptProtectionMode: "disabled",
domains: [],
};
}
Expand Down
34 changes: 34 additions & 0 deletions library/agent/Agent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ t.test(
allowedIPAddresses: [],
block: true,
blockNewOutgoingRequests: false,
promptProtectionMode: "disabled",
});
const agent = createTestAgent({
api,
Expand Down Expand Up @@ -1083,6 +1084,7 @@ t.test("it fetches blocked lists", async () => {

await setTimeout(0);

t.same(agent.getConfig().getPromptProtectionMode(), "disabled");
t.same(agent.getConfig().isIPAddressBlocked("1.3.2.4"), {
blocked: true,
reason: "Description",
Expand Down Expand Up @@ -1354,3 +1356,35 @@ t.test(
clock.uninstall();
}
);

t.test("it fetches prompt protection status", async () => {
const clock = FakeTimers.install();

const logger = new LoggerNoop();
const api = new ReportingAPIForTesting({
success: true,
endpoints: [],
configUpdatedAt: 0,
heartbeatIntervalInMS: 10 * 60 * 1000,
blockedUserIds: [],
allowedIPAddresses: [],
block: true,
blockNewOutgoingRequests: false,
promptProtectionMode: "monitor",
});
const agent = createTestAgent({
api,
logger,
token: new Token("123"),
suppressConsoleLog: false,
});
agent.start([]);

t.same(agent.getConfig().getPromptProtectionMode(), "disabled");

await agent.flushStats(1000);

t.same(agent.getConfig().getPromptProtectionMode(), "monitor");

clock.uninstall();
});
19 changes: 18 additions & 1 deletion library/agent/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ import { isNewInstrumentationUnitTest } from "../helpers/isNewInstrumentationUni
import { AttackWaveDetector } from "../vulnerabilities/attack-wave-detection/AttackWaveDetector";
import type { FetchListsAPI } from "./api/FetchListsAPI";
import { PendingEvents } from "./PendingEvents";
import type { PromptProtectionApi } from "./api/PromptProtectionAPI";
import { PromptProtectionAPINodeHTTP } from "./api/PromptProtectionAPINodeHTTP";
import type { AiMessage } from "../vulnerabilities/prompt-injection/messages";
import type { IdorProtectionConfig } from "./IdorProtectionConfig";
import { warnIfTsxIsUsed } from "../helpers/warnIfTsxIsUsed";

Expand Down Expand Up @@ -75,7 +78,8 @@ export class Agent {
private readonly token: Token | undefined,
private readonly serverless: string | undefined,
private readonly newInstrumentation: boolean = false,
private readonly fetchListsAPI: FetchListsAPI
private readonly fetchListsAPI: FetchListsAPI,
private readonly promptProtectionAPI: PromptProtectionApi = new PromptProtectionAPINodeHTTP()
) {
if (typeof this.serverless === "string" && this.serverless.length === 0) {
throw new Error("Serverless cannot be an empty string");
Expand Down Expand Up @@ -338,6 +342,12 @@ export class Agent {
);
this.serviceConfig.updateDomains(response.domains);
}

if (typeof response.promptProtectionMode === "string") {
this.serviceConfig.setPromptProtectionMode(
response.promptProtectionMode
);
}
}
}

Expand Down Expand Up @@ -712,4 +722,11 @@ export class Agent {
this.pendingEvents.onAPICall(promise);
}
}

checkForPromptInjection(input: AiMessage[]) {
if (!this.token) {
return Promise.resolve({ success: false, block: false });
}
return this.promptProtectionAPI.checkForInjection(this.token, input);
}
}
5 changes: 4 additions & 1 deletion library/agent/Attack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ export type Kind =
| "path_traversal"
| "ssrf"
| "stored_ssrf"
| "code_injection";
| "code_injection"
| "prompt_injection";

export function attackKindHumanName(kind: Kind) {
switch (kind) {
Expand All @@ -23,5 +24,7 @@ export function attackKindHumanName(kind: Kind) {
return "a stored server-side request forgery";
case "code_injection":
return "a JavaScript injection";
case "prompt_injection":
return "a prompt injection";
}
}
3 changes: 3 additions & 0 deletions library/agent/Config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ export type Endpoint = Omit<EndpointConfig, "allowedIPAddresses"> & {

export type Domain = { hostname: string; mode: "allow" | "block" };

export type PromptProtectionMode = "disabled" | "monitor" | "block";

export type Config = {
endpoints: EndpointConfig[];
heartbeatIntervalInMS: number;
Expand All @@ -31,4 +33,5 @@ export type Config = {
block?: boolean;
blockNewOutgoingRequests?: boolean;
domains?: Domain[];
promptProtectionMode?: PromptProtectionMode;
};
15 changes: 15 additions & 0 deletions library/agent/ServiceConfig.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,18 @@ t.test("outbound request blocking", async (t) => {
t.same(config.shouldBlockOutgoingRequest("aikido.dev"), false);
t.same(config.shouldBlockOutgoingRequest("unknown.com"), false);
});

t.test("prompt protection", async (t) => {
const config = new ServiceConfig([], 0, [], [], [], []);

t.same(config.getPromptProtectionMode(), "disabled");

config.setPromptProtectionMode("block");
t.same(config.getPromptProtectionMode(), "block");

config.setPromptProtectionMode("monitor");
t.same(config.getPromptProtectionMode(), "monitor");

config.setPromptProtectionMode("disabled");
t.same(config.getPromptProtectionMode(), "disabled");
});
17 changes: 16 additions & 1 deletion library/agent/ServiceConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ import { addIPv4MappedAddresses } from "../helpers/addIPv4MappedAddresses";
import { IPMatcher } from "../helpers/ip-matcher/IPMatcher";
import { LimitedContext, matchEndpoints } from "../helpers/matchEndpoints";
import { isPrivateIP } from "../vulnerabilities/ssrf/isPrivateIP";
import type { Endpoint, EndpointConfig, Domain } from "./Config";
import type {
Endpoint,
EndpointConfig,
Domain,
PromptProtectionMode,
} from "./Config";
import type { IPList, UserAgentDetails } from "./api/FetchListsAPI";
import { safeCreateRegExp } from "./safeCreateRegExp";

Expand Down Expand Up @@ -31,6 +36,8 @@ export class ServiceConfig {
private blockNewOutgoingRequests = false;
private domains = new Map<string, Domain["mode"]>();

private promptProtectionMode: PromptProtectionMode = "disabled";

constructor(
endpoints: EndpointConfig[],
private lastUpdatedAt: number,
Expand Down Expand Up @@ -305,4 +312,12 @@ export class ServiceConfig {
// Only block outgoing requests if the mode is "block"
return mode === "block";
}

setPromptProtectionMode(mode: PromptProtectionMode) {
this.promptProtectionMode = mode;
}

getPromptProtectionMode(): PromptProtectionMode {
return this.promptProtectionMode;
}
}
14 changes: 14 additions & 0 deletions library/agent/api/PromptProtectionAPI.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type { Token } from "./Token";

export type PromptProtectionApiResponse = {
success: boolean;
block: boolean;
};

export interface PromptProtectionApi {
checkForInjection(
token: Token,
messages: AiMessage[]
): Promise<PromptProtectionApiResponse>;
}
34 changes: 34 additions & 0 deletions library/agent/api/PromptProtectionAPIForTesting.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type {
PromptProtectionApi,
PromptProtectionApiResponse,
} from "./PromptProtectionAPI";
import type { Token } from "./Token";

export class PromptProtectionAPIForTesting implements PromptProtectionApi {
constructor(
private response: PromptProtectionApiResponse = {
success: true,
block: false,
}
) {}

// oxlint-disable-next-line require-await
async checkForInjection(
_token: Token,
_messages: AiMessage[]
): Promise<PromptProtectionApiResponse> {
if (
_messages.some((msg) =>
msg.content.includes("!prompt-injection-block-me!")
)
) {
return {
success: true,
block: true,
};
}

return this.response;
}
}
48 changes: 48 additions & 0 deletions library/agent/api/PromptProtectionAPINodeHTTP.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import { fetch } from "../../helpers/fetch";
import { getPromptInjectionServiceURL } from "../../helpers/getPromptInjectionServiceURL";
import type { AiMessage } from "../../vulnerabilities/prompt-injection/messages";
import type {
PromptProtectionApi,
PromptProtectionApiResponse,
} from "./PromptProtectionAPI";
import type { Token } from "./Token";

export class PromptProtectionAPINodeHTTP implements PromptProtectionApi {
constructor(private baseUrl = getPromptInjectionServiceURL()) {}

async checkForInjection(
token: Token,
messages: AiMessage[]
): Promise<PromptProtectionApiResponse> {
const { body, statusCode } = await fetch({
url: new URL("/api/v1/analyze", this.baseUrl.toString()),
method: "POST",
headers: {
Accept: "application/json",
Authorization: token.asString(),
},
body: JSON.stringify({ input: messages }),
timeoutInMS: 15 * 1000,
});

if (statusCode !== 200) {
if (statusCode === 401) {
throw new Error(
`Unable to access the Prompt Protection service, please check your token.`
);
}
throw new Error(`Failed to fetch prompt analysis: ${statusCode}`);
}

return this.toAPIResponse(body);
}

private toAPIResponse(data: string): PromptProtectionApiResponse {
const result = JSON.parse(data);

return {
success: result.success === true,
block: result.block === true,
};
}
}
6 changes: 5 additions & 1 deletion library/helpers/createTestAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import { Agent } from "../agent/Agent";
import { setInstance } from "../agent/AgentSingleton";
import type { FetchListsAPI } from "../agent/api/FetchListsAPI";
import { FetchListsAPIForTesting } from "../agent/api/FetchListsAPIForTesting";
import type { PromptProtectionApi } from "../agent/api/PromptProtectionAPI";
import { PromptProtectionAPIForTesting } from "../agent/api/PromptProtectionAPIForTesting";
import type { ReportingAPI } from "../agent/api/ReportingAPI";
import { ReportingAPIForTesting } from "../agent/api/ReportingAPIForTesting";
import type { Token } from "../agent/api/Token";
Expand All @@ -20,6 +22,7 @@ export function createTestAgent(opts?: {
serverless?: string;
suppressConsoleLog?: boolean;
fetchListsAPI?: FetchListsAPI;
promptProtectionAPI?: PromptProtectionApi;
}) {
if (opts?.suppressConsoleLog ?? true) {
wrap(console, "log", function log() {
Expand All @@ -34,7 +37,8 @@ export function createTestAgent(opts?: {
opts?.token, // Defaults to undefined
opts?.serverless, // Defaults to undefined
false, // During tests this is controlled by the AIKIDO_TEST_NEW_INSTRUMENTATION env var
opts?.fetchListsAPI ?? new FetchListsAPIForTesting()
opts?.fetchListsAPI ?? new FetchListsAPIForTesting(),
opts?.promptProtectionAPI ?? new PromptProtectionAPIForTesting()
);

setInstance(agent);
Expand Down
8 changes: 8 additions & 0 deletions library/helpers/getPromptInjectionServiceURL.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function getPromptInjectionServiceURL(): URL {
if (process.env.PROMPT_INJECTION_SERVICE_URL) {
return new URL(process.env.PROMPT_INJECTION_SERVICE_URL);
}

// Todo add default URL when deployed
return new URL("http://localhost:8123");
}
2 changes: 2 additions & 0 deletions library/helpers/startTestAgent.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { PromptProtectionApi } from "../agent/api/PromptProtectionAPI";
import type { ReportingAPI } from "../agent/api/ReportingAPI";
import type { Token } from "../agent/api/Token";
import { __internalRewritePackageNamesForTesting } from "../agent/hooks/instrumentation/instructions";
Expand All @@ -20,6 +21,7 @@ export function startTestAgent(opts: {
serverless?: string;
wrappers: Wrapper[];
rewrite: Record<PackageName, AliasToRequire>;
promptProtectionAPI?: PromptProtectionApi;
}) {
const agent = createTestAgent(opts);

Expand Down
Loading
Loading