Skip to content

Commit 76e13ac

Browse files
committed
fix: add validation for outputSchema in StructuredGenerationTask, changes to tool use schema
Enhances the StructuredGenerationTask by adding a check to ensure that outputSchema is a valid JSON Schema. Updates test cases to reflect changes in input handling for better type safety.
1 parent 0c529fe commit 76e13ac

3 files changed

Lines changed: 47 additions & 19 deletions

File tree

packages/ai/src/task/ChatMessage.ts

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,22 @@ export type ContentBlockToolUse = {
2828
readonly input: Record<string, unknown>;
2929
};
3030

31+
/**
32+
* Blocks that may appear in a `tool_result`'s `content` array. Provider payloads
33+
* typically use text, image, and tool_use; nested `tool_result` is not modeled
34+
* here so the JSON schema can be embedded in parent task schemas without a
35+
* recursive `$ref` (which fails to resolve when `ContentBlockSchema` is nested
36+
* under a larger document such as `ToolCallingInputSchema`).
37+
*/
38+
export type ContentBlockInToolResultBody =
39+
| ContentBlockText
40+
| ContentBlockImage
41+
| ContentBlockToolUse;
42+
3143
export type ContentBlockToolResult = {
3244
readonly type: "tool_result";
3345
readonly tool_use_id: string;
34-
readonly content: ReadonlyArray<ContentBlock>;
46+
readonly content: ReadonlyArray<ContentBlockInToolResultBody>;
3547
readonly is_error: boolean | undefined;
3648
};
3749

@@ -89,16 +101,19 @@ const ContentBlockToolUseSchema = {
89101
additionalProperties: false,
90102
} as const;
91103

92-
// tool_result is recursive — its `content` is an array of ContentBlock.
93-
// The $ref resolves against ContentBlockSchema.definitions.ContentBlock below.
104+
/** `tool_result.content` — text, image, and tool_use only (no nested `tool_result`). */
105+
const ContentBlockInToolResultBodySchema = {
106+
oneOf: [ContentBlockTextSchema, ContentBlockImageSchema, ContentBlockToolUseSchema],
107+
} as const;
108+
94109
const ContentBlockToolResultSchema = {
95110
type: "object",
96111
properties: {
97112
type: { type: "string", enum: ["tool_result"] },
98113
tool_use_id: { type: "string" },
99114
content: {
100115
type: "array",
101-
items: { $ref: "#/definitions/ContentBlock" },
116+
items: ContentBlockInToolResultBodySchema,
102117
},
103118
is_error: { type: "boolean" },
104119
},
@@ -115,16 +130,6 @@ export const ContentBlockSchema = {
115130
ContentBlockToolUseSchema,
116131
ContentBlockToolResultSchema,
117132
],
118-
definitions: {
119-
ContentBlock: {
120-
oneOf: [
121-
ContentBlockTextSchema,
122-
ContentBlockImageSchema,
123-
ContentBlockToolUseSchema,
124-
ContentBlockToolResultSchema,
125-
],
126-
},
127-
},
128133
title: "ContentBlock",
129134
description: "A single content block within a chat message",
130135
} as const;
@@ -148,6 +153,28 @@ export const ChatMessageSchema = {
148153
// Runtime type guards
149154
// ========================================================================
150155

156+
export function isContentBlockInToolResultBody(
157+
value: unknown
158+
): value is ContentBlockInToolResultBody {
159+
if (!value || typeof value !== "object") return false;
160+
const v = value as Record<string, unknown>;
161+
switch (v.type) {
162+
case "text":
163+
return typeof v.text === "string";
164+
case "image":
165+
return typeof v.mimeType === "string" && typeof v.data === "string";
166+
case "tool_use":
167+
return (
168+
typeof v.id === "string" &&
169+
typeof v.name === "string" &&
170+
v.input !== null &&
171+
typeof v.input === "object"
172+
);
173+
default:
174+
return false;
175+
}
176+
}
177+
151178
export function isContentBlock(value: unknown): value is ContentBlock {
152179
if (!value || typeof value !== "object") return false;
153180
const v = value as Record<string, unknown>;
@@ -167,7 +194,7 @@ export function isContentBlock(value: unknown): value is ContentBlock {
167194
return (
168195
typeof v.tool_use_id === "string" &&
169196
Array.isArray(v.content) &&
170-
v.content.every(isContentBlock)
197+
v.content.every(isContentBlockInToolResultBody)
171198
);
172199
default:
173200
return false;

packages/ai/src/task/StructuredGenerationTask.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
* SPDX-License-Identifier: Apache-2.0
55
*/
66

7-
import { CreateWorkflow, TaskConfigurationError, TaskError, Workflow } from "@workglow/task-graph";
87
import type { IExecuteContext, StreamEvent, TaskConfig } from "@workglow/task-graph";
9-
import { compileSchema } from "@workglow/util/schema";
8+
import { CreateWorkflow, TaskConfigurationError, TaskError, Workflow } from "@workglow/task-graph";
109
import type { DataPortSchema, FromSchema, SchemaNode } from "@workglow/util/schema";
10+
import { compileSchema } from "@workglow/util/schema";
1111
import { TypeModel } from "./base/AiTaskSchemas";
1212
import { StreamingAiTask } from "./base/StreamingAiTask";
1313

@@ -163,6 +163,7 @@ export class StructuredGenerationTask extends StreamingAiTask<
163163
let validator: SchemaNode;
164164
try {
165165
validator = compileSchema(input.outputSchema);
166+
if (!input.outputSchema) throw new Error("outputSchema is not a valid JSON Schema");
166167
} catch (err) {
167168
const msg = err instanceof Error ? err.message : String(err);
168169
const configErr = new TaskConfigurationError(

packages/test/src/test/task/StructuredGenerationTask.test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,10 @@ describe("StructuredGenerationTask — schema compile errors", () => {
323323
outputSchema: null as unknown as Record<string, unknown>,
324324
maxRetries: 0,
325325
};
326-
const task = new StructuredGenerationTask({ defaults: input } as any);
326+
const task = new StructuredGenerationTask({ defaults: input });
327327
let caught: unknown;
328328
try {
329-
await drain(task.executeStream(input as any, mkContext()));
329+
await drain(task.executeStream(input, mkContext()));
330330
} catch (e) {
331331
caught = e;
332332
}

0 commit comments

Comments
 (0)