Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions packages/common/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,10 @@ export function assertDefined<T>(
throw new Error(msg ?? "Value is undefined");
}
}

export function takeFirst<T>(arr: T[]): T {
if (arr.length === 0) {
throw new Error("takeFirst called with empty array");
}
return arr[0];
}
27 changes: 27 additions & 0 deletions packages/common/src/zkProgrammable/FeatureFlagsExtension.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { FeatureFlags } from "o1js";

function combineFeatureFlag(a: boolean | undefined, b: boolean | undefined) {
if (a === true || b === true) {
return true;
} else if (a === undefined || b === undefined) {
return undefined;
} else {
return false;
}
}

export function combineFeatureFlags(
a: FeatureFlags,
b: FeatureFlags
): FeatureFlags {
return {
xor: combineFeatureFlag(a.xor, b.xor),
rot: combineFeatureFlag(a.rot, b.rot),
lookup: combineFeatureFlag(a.lookup, b.lookup),
foreignFieldAdd: combineFeatureFlag(a.foreignFieldAdd, b.foreignFieldAdd),
foreignFieldMul: combineFeatureFlag(a.foreignFieldMul, b.foreignFieldMul),
rangeCheck0: combineFeatureFlag(a.rangeCheck0, b.rangeCheck0),
rangeCheck1: combineFeatureFlag(a.rangeCheck1, b.rangeCheck1),
runtimeTables: combineFeatureFlag(a.runtimeTables, b.runtimeTables),
};
}
99 changes: 90 additions & 9 deletions packages/common/src/zkProgrammable/ZkProgrammable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@ import {
Field,
Provable,
Cache as O1Cache,
DynamicProof,
FlexibleProvable,
FeatureFlags,
} from "o1js";
import { Memoize } from "typescript-memoize";

import { log } from "../log";
import { dummyVerificationKey } from "../dummyVerificationKey";
import { reduceSequential } from "../utils";
import { mapSequential, reduceSequential } from "../utils";
import type { CompileRegistry } from "../compiling/CompileRegistry";

import { MOCK_PROOF } from "./provableMethod";
import { combineFeatureFlags } from "./FeatureFlagsExtension";

const errors = {
areProofsEnabledNotSet: (name: string) =>
Expand Down Expand Up @@ -52,6 +56,8 @@ export interface PlainZkProgram<
PublicOutput = undefined,
> {
name: string;
publicInputType: FlexibleProvable<PublicInput>;
publicOutputType: FlexibleProvable<PublicOutput>;
compile: Compile;
verify: Verify<PublicInput, PublicOutput>;
Proof: ReturnType<
Expand All @@ -75,8 +81,15 @@ export interface PlainZkProgram<
}>)
>;
analyzeMethods: () => Promise<
Record<string, Awaited<ReturnType<typeof Provable.constraintSystem>>>
Record<
string,
Awaited<ReturnType<typeof Provable.constraintSystem>> & {
// TODO Properly model ProofClass here
proofs: any[];
}
>
>;
maxProofsVerified: () => Promise<0 | 1 | 2>;
}

export function verifyToMockable<PublicInput, PublicOutput>(
Expand Down Expand Up @@ -125,17 +138,18 @@ export abstract class ZkProgrammable<
> {
public abstract get areProofsEnabled(): AreProofsEnabled | undefined;

public abstract zkProgramFactory(): PlainZkProgram<
PublicInput,
PublicOutput
>[];
public abstract zkProgramFactory(): Promise<
PlainZkProgram<PublicInput, PublicOutput>[]
>;

private zkProgramSingleton?: PlainZkProgram<PublicInput, PublicOutput>[];

@Memoize()
public get zkProgram(): PlainZkProgram<PublicInput, PublicOutput>[] {
public async zkProgram(): Promise<
PlainZkProgram<PublicInput, PublicOutput>[]
> {
if (this.zkProgramSingleton === undefined) {
this.zkProgramSingleton = this.zkProgramFactory();
this.zkProgramSingleton = await this.zkProgramFactory();
}

return this.zkProgramSingleton.map((bucket) => {
Expand All @@ -150,9 +164,76 @@ export abstract class ZkProgrammable<
});
}

@Memoize()
public async proofType(): Promise<typeof Proof<PublicInput, PublicOutput>> {
const programs = await this.zkProgram();

const Template = programs[0].Proof;
const maxProofsVerifeds = await mapSequential(programs, (p) =>
p.maxProofsVerified()
);
// eslint-disable-next-line @typescript-eslint/consistent-type-assertions
const maxProofsVerified = Math.max(...maxProofsVerifeds) as 0 | 1 | 2;

return class ZkProgrammableProofType extends Proof<
PublicInput,
PublicOutput
> {
static publicInputType = Template.publicInputType;

static publicOutputType = Template.publicOutputType;

static maxProofsVerified = maxProofsVerified;
};
}

@Memoize()
public async dynamicProofType(): Promise<
typeof DynamicProof<PublicInput, PublicOutput>
> {
const programs = await this.zkProgram();

let maxProofsVerified: 0 | 1 | 2;
let featureFlags: FeatureFlags;

// We actually only need to compute maxProofsVerified and featuresflags if proofs
// are enabled, otherwise o1js will ignore it anyways. This way startup is a bit
// faster for non-proof environments
if (this.areProofsEnabled?.areProofsEnabled === true) {
const maxProofsVerifieds = await mapSequential(
programs,
async (zkProgram) => await zkProgram.maxProofsVerified()
);
// eslint-disable-next-line @typescript-eslint/consistent-type-assertions
maxProofsVerified = Math.max(...maxProofsVerifieds) as 0 | 1 | 2;
const featureFlagsSet = await mapSequential(
programs,
async (zkProgram) => await FeatureFlags.fromZkProgram(zkProgram)
);
featureFlags = featureFlagsSet.reduce(combineFeatureFlags);
} else {
featureFlags = FeatureFlags.allNone;
maxProofsVerified = 0;
}

return class DynamicProofType extends DynamicProof<
PublicInput,
PublicOutput
> {
static publicInputType = programs[0].publicInputType;

static publicOutputType = programs[0].publicOutputType;

static maxProofsVerified = maxProofsVerified;

static featureFlags = featureFlags;
};
}

public async compile(registry: CompileRegistry) {
const programs = await this.zkProgram();
return await reduceSequential(
this.zkProgram,
programs,
async (acc, program) => {
const result = await registry.compile(program);
return {
Expand Down
2 changes: 1 addition & 1 deletion packages/common/src/zkProgrammable/provableMethod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export function toProver(
return async function prover(this: ZkProgrammable<any, any>) {
const { areProofsEnabled } = this.areProofsEnabled!;

const zkProgram = this.zkProgram.find((prog) =>
const zkProgram = (await this.zkProgram()).find((prog) =>
Object.keys(prog.methods).includes(methodName)
);

Expand Down
48 changes: 35 additions & 13 deletions packages/common/test/zkProgrammable/ZkProgrammable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
MOCK_VERIFICATION_KEY,
ZkProgrammable,
ProvableMethodExecutionContext,
takeFirst,
} from "../../src";

const appChainMock: AreProofsEnabled = {
Expand Down Expand Up @@ -60,7 +61,7 @@ class TestProgrammable extends ZkProgrammable<
};
}

public zkProgramFactory() {
public async zkProgramFactory() {
const program = ZkProgram({
name: "testprogram",
publicInput: TestPublicInput,
Expand Down Expand Up @@ -89,9 +90,12 @@ class TestProgrammable extends ZkProgrammable<
return [
{
name: program.name,
publicInputType: program.publicInputType,
publicOutputType: program.publicOutputType,
compile: program.compile.bind(program),
verify: program.verify.bind(program),
analyzeMethods: program.analyzeMethods.bind(program),
maxProofsVerified: program.maxProofsVerified.bind(program),
Proof: SelfProof,
methods,
},
Expand All @@ -106,19 +110,21 @@ class OtherTestProgrammable extends ZkProgrammable<undefined, void> {
super();
}

proofType = this.testProgrammable.zkProgram[0].Proof;

@provableMethod()
public async bar(testProgrammableProof: InstanceType<typeof this.proofType>) {
public async bar(
testProgrammableProof: InstanceType<
Awaited<ReturnType<typeof this.testProgrammable.proofType>>
>
) {
testProgrammableProof.verify();
}

public zkProgramFactory() {
public async zkProgramFactory() {
const program = ZkProgram({
name: "testprogram2",
methods: {
bar: {
privateInputs: [this.testProgrammable.zkProgram[0].Proof],
privateInputs: [await this.testProgrammable.proofType()],
method: this.bar.bind(this),
},
},
Expand All @@ -133,9 +139,12 @@ class OtherTestProgrammable extends ZkProgrammable<undefined, void> {
return [
{
name: program.name,
publicInputType: program.publicInputType,
publicOutputType: program.publicOutputType,
compile: program.compile.bind(program),
verify: program.verify.bind(program),
analyzeMethods: program.analyzeMethods.bind(program),
maxProofsVerified: program.maxProofsVerified.bind(program),
Proof: SelfProof,
methods,
},
Expand Down Expand Up @@ -189,7 +198,11 @@ describe("zkProgrammable", () => {
testProgrammable = new TestProgrammable();
testProgrammable.areProofsEnabled.setProofsEnabled(areProofsEnabled);
zkProgramFactorySpy = jest.spyOn(testProgrammable, "zkProgramFactory");
artifact = await testProgrammable.zkProgram[0].compile();

artifact = await testProgrammable
.zkProgram()
.then((p) => takeFirst(p))
.then((p) => p.compile());
}, 500_000);

describe("zkProgramFactory", () => {
Expand All @@ -216,7 +229,8 @@ describe("zkProgrammable", () => {
it("if proofs are disabled, it should successfully verify mock proofs", async () => {
expect.assertions(1);

const proof = new testProgrammable.zkProgram[0].Proof({
const program = await testProgrammable.zkProgram().then(takeFirst);
const proof = new program.Proof({
proof: MOCK_PROOF,

publicInput: new TestPublicInput({
Expand All @@ -230,7 +244,7 @@ describe("zkProgrammable", () => {
maxProofsVerified: 0,
});

const verified = await testProgrammable.zkProgram[0].verify(proof);
const verified = await program.verify(proof);

expect(verified).toBe(shouldVerifyMockProofs);

Expand All @@ -254,7 +268,10 @@ describe("zkProgrammable", () => {
describe("zkProgram interoperability", () => {
beforeAll(async () => {
otherTestProgrammable = new OtherTestProgrammable(testProgrammable);
await otherTestProgrammable.zkProgram[0].compile();
await otherTestProgrammable
.zkProgram()
.then(takeFirst)
.then((p) => p.compile());
}, 500_000);

it("should successfully pass proof of one zkProgram as input to another zkProgram", async () => {
Expand All @@ -267,8 +284,10 @@ describe("zkProgrammable", () => {
const testProof = await executionContext
.current()
.result.prove<Proof<TestPublicInput, TestPublicOutput>>();
const testProofVerified =
await testProgrammable.zkProgram[0].verify(testProof);
const zkProgram = await testProgrammable
.zkProgram()
.then(takeFirst);
const testProofVerified = await zkProgram.verify(testProof);

// execute bar
await otherTestProgrammable.bar(testProof);
Expand All @@ -277,8 +296,11 @@ describe("zkProgrammable", () => {
const otherTestProof = await executionContext
.current()
.result.prove<Proof<undefined, undefined>>();
const otherZkProgram = await otherTestProgrammable
.zkProgram()
.then(takeFirst);
const otherTestProofVerified =
await otherTestProgrammable.zkProgram[0].verify(otherTestProof);
await otherZkProgram.verify(otherTestProof);

expect(testProof.publicOutput.bar.toString()).toBe(
testPublicInput.foo.toString()
Expand Down
Loading
Loading