Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/compiler/src/ast/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ export type {
CallExpressionNode,
ConstStatementNode,
DeclarationNode,
DecoratedExpressionNode,
DecoratorDeclarationStatementNode,
DecoratorExpressionNode,
DirectiveExpressionNode,
Expand Down
79 changes: 68 additions & 11 deletions packages/compiler/src/core/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import {
DecoratorContext,
DecoratorDeclarationStatementNode,
DecoratorExpressionNode,
DecoratedExpressionNode,
DecoratorValidatorCallbacks,
Diagnostic,
DiagnosticTarget,
Expand Down Expand Up @@ -1037,6 +1038,8 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
return checkCallExpression(ctx, node);
case SyntaxKind.TypeOfExpression:
return checkTypeOfExpression(ctx, node);
case SyntaxKind.DecoratedExpression:
return checkDecoratedExpression(ctx, node);
case SyntaxKind.AugmentDecoratorStatement:
return checkAugmentDecorator(ctx, node);
case SyntaxKind.UsingStatement:
Expand Down Expand Up @@ -3927,6 +3930,32 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
return type;
}

function checkDecoratedExpression(
ctx: CheckContext,
node: DecoratedExpressionNode,
): Type | Value | IndeterminateEntity | null {
const targetResult = checkNode(ctx, node.target);
if (targetResult === null) {
return null;
}

// Apply decorators to the resolved type
if (typeof targetResult === "object" && "entityKind" in targetResult) {
if (targetResult.entityKind === "Type" && "decorators" in targetResult) {
const type = targetResult as Type & { decorators: DecoratorApplication[] };
for (const decNode of node.decorators) {
const decorator = checkDecoratorApplication(ctx, type, decNode);
if (decorator) {
type.decorators.unshift(decorator);
applyDecoratorToType(program, decorator, type);
}
}
}
}

return targetResult;
}

/** Find the indexer that applies to this model. Either defined on itself or from a base model */
function findIndexer(model: Model): ModelIndexer | undefined {
let current: Model | undefined = model;
Expand Down Expand Up @@ -4748,7 +4777,10 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
model: ModelStatementNode,
heritageRef: Expression,
): Model | undefined {
if (heritageRef.kind === SyntaxKind.ModelExpression) {
// Unwrap decorated expression to check the target
const innerRef =
heritageRef.kind === SyntaxKind.DecoratedExpression ? heritageRef.target : heritageRef;
if (innerRef.kind === SyntaxKind.ModelExpression) {
reportCheckerDiagnostic(
createDiagnostic({
code: "extend-model",
Expand All @@ -4759,8 +4791,8 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
return undefined;
}
if (
heritageRef.kind !== SyntaxKind.TypeReference &&
heritageRef.kind !== SyntaxKind.ArrayExpression
innerRef.kind !== SyntaxKind.TypeReference &&
innerRef.kind !== SyntaxKind.ArrayExpression
) {
reportCheckerDiagnostic(
createDiagnostic({
Expand All @@ -4773,7 +4805,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
const modelSymId = getNodeSym(model);
pendingResolutions.start(modelSymId, ResolutionKind.BaseType);

const target = resolver.getNodeLinks(heritageRef).resolvedSymbol;
const target = resolver.getNodeLinks(innerRef).resolvedSymbol;
if (target && pendingResolutions.has(target, ResolutionKind.BaseType)) {
if (ctx.mapper === undefined) {
reportCheckerDiagnostic(
Expand All @@ -4786,7 +4818,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
}
return undefined;
}
const heritageType = getTypeForNode(heritageRef, ctx);
const heritageType = getTypeForNode(innerRef, ctx);
pendingResolutions.finish(modelSymId, ResolutionKind.BaseType);
if (isErrorType(heritageType)) {
compilerAssert(program.hasError(), "Should already have reported an error.", heritageRef);
Expand All @@ -4808,6 +4840,17 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
);
}

// Apply decorators from the decorated expression to the resolved type
if (heritageRef.kind === SyntaxKind.DecoratedExpression) {
for (const decNode of heritageRef.decorators) {
const decorator = checkDecoratorApplication(ctx, heritageType, decNode);
if (decorator) {
heritageType.decorators.unshift(decorator);
applyDecoratorToType(program, decorator, heritageType);
}
}
}

return heritageType;
}

Expand All @@ -4821,7 +4864,10 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
const modelSymId = getNodeSym(model);
pendingResolutions.start(modelSymId, ResolutionKind.BaseType);
let isType;
if (isExpr.kind === SyntaxKind.ModelExpression) {
// Unwrap decorated expression to check the target
const innerExpr =
isExpr.kind === SyntaxKind.DecoratedExpression ? isExpr.target : isExpr;
if (innerExpr.kind === SyntaxKind.ModelExpression) {
reportCheckerDiagnostic(
createDiagnostic({
code: "is-model",
Expand All @@ -4830,10 +4876,10 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
}),
);
return undefined;
} else if (isExpr.kind === SyntaxKind.ArrayExpression) {
isType = checkArrayExpression(ctx, isExpr);
} else if (isExpr.kind === SyntaxKind.TypeReference) {
const target = resolver.getNodeLinks(isExpr).resolvedSymbol;
} else if (innerExpr.kind === SyntaxKind.ArrayExpression) {
isType = checkArrayExpression(ctx, innerExpr);
} else if (innerExpr.kind === SyntaxKind.TypeReference) {
const target = resolver.getNodeLinks(innerExpr).resolvedSymbol;
if (target && pendingResolutions.has(target, ResolutionKind.BaseType)) {
if (ctx.mapper === undefined) {
reportCheckerDiagnostic(
Expand All @@ -4846,7 +4892,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
}
return undefined;
}
isType = getTypeForNode(isExpr, ctx);
isType = getTypeForNode(innerExpr, ctx);
} else {
reportCheckerDiagnostic(createDiagnostic({ code: "is-model", target: isExpr }));
return undefined;
Expand All @@ -4866,6 +4912,17 @@ export function createChecker(program: Program, resolver: NameResolver): Checker
return undefined;
}

// Apply decorators from the decorated expression to the resolved type
if (isExpr.kind === SyntaxKind.DecoratedExpression) {
for (const decNode of isExpr.decorators) {
const decorator = checkDecoratorApplication(ctx, isType, decNode);
if (decorator) {
isType.decorators.unshift(decorator);
applyDecoratorToType(program, decorator, isType);
}
}
}

return isType;
}

Expand Down
19 changes: 16 additions & 3 deletions packages/compiler/src/core/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
Comment,
ConstStatementNode,
DeclarationNode,
DecoratedExpressionNode,
DecoratorDeclarationStatementNode,
DecoratorExpressionNode,
Diagnostic,
Expand Down Expand Up @@ -1684,9 +1685,7 @@ function createParser(code: string | SourceFile, options: ParseOptions = {}): Pa
case Token.OpenParen:
return parseParenthesizedExpression();
case Token.At:
const decorators = parseDecoratorList();
reportInvalidDecorators(decorators, "expression");
continue;
return parseDecoratedExpression();
case Token.Hash:
const directives = parseDirectiveList();
reportInvalidDirective(directives, "expression");
Expand All @@ -1707,6 +1706,18 @@ function createParser(code: string | SourceFile, options: ParseOptions = {}): Pa
}
}

function parseDecoratedExpression(): DecoratedExpressionNode {
const pos = tokenPos();
const decorators = parseDecoratorList();
const target = parseExpression();
return {
kind: SyntaxKind.DecoratedExpression,
decorators,
target,
...finishNode(pos),
};
}

function parseExternKeyword(): ExternKeywordNode {
const pos = tokenPos();
parseExpected(Token.ExternKeyword);
Expand Down Expand Up @@ -2951,6 +2962,8 @@ export function visitChildren<T>(node: Node, cb: NodeCallback<T>): T | undefined
return visitNode(cb, node.base) || visitNode(cb, node.id);
case SyntaxKind.ModelExpression:
return visitEach(cb, node.properties);
case SyntaxKind.DecoratedExpression:
return visitEach(cb, node.decorators) || visitNode(cb, node.target);
case SyntaxKind.ModelProperty:
return (
visitEach(cb, node.decorators) ||
Expand Down
10 changes: 9 additions & 1 deletion packages/compiler/src/core/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ export enum SyntaxKind {
ConstStatement,
CallExpression,
ScalarConstructor,
DecoratedExpression,
}

export const enum NodeFlags {
Expand Down Expand Up @@ -1357,7 +1358,8 @@ export type Expression =
| StringTemplateExpressionNode
| VoidKeywordNode
| NeverKeywordNode
| AnyKeywordNode;
| AnyKeywordNode
| DecoratedExpressionNode;

export type ReferenceExpression =
| TypeReferenceNode
Expand Down Expand Up @@ -1512,6 +1514,12 @@ export interface ModelExpressionNode extends BaseNode {
readonly bodyRange: TextRange;
}

export interface DecoratedExpressionNode extends BaseNode {
readonly kind: SyntaxKind.DecoratedExpression;
readonly decorators: readonly DecoratorExpressionNode[];
readonly target: Expression;
}

export interface ArrayExpressionNode extends BaseNode {
readonly kind: SyntaxKind.ArrayExpression;
readonly elementType: Expression;
Expand Down
12 changes: 12 additions & 0 deletions packages/compiler/src/formatter/print/printer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
CallExpressionNode,
Comment,
ConstStatementNode,
DecoratedExpressionNode,
DecoratorDeclarationStatementNode,
DecoratorExpressionNode,
DirectiveExpressionNode,
Expand Down Expand Up @@ -189,6 +190,8 @@ export function printNode(
return printBooleanLiteral(path as AstPath<BooleanLiteralNode>, options);
case SyntaxKind.ModelExpression:
return printModelExpression(path as AstPath<ModelExpressionNode>, options, print);
case SyntaxKind.DecoratedExpression:
return printDecoratedExpression(path as AstPath<DecoratedExpressionNode>, options, print);
case SyntaxKind.ModelProperty:
return printModelProperty(path as AstPath<ModelPropertyNode>, options, print);
case SyntaxKind.DecoratorExpression:
Expand Down Expand Up @@ -937,6 +940,15 @@ export function printModelExpression(
}
}

export function printDecoratedExpression(
path: AstPath<DecoratedExpressionNode>,
options: TypeSpecPrettierOptions,
print: PrettierChildPrint,
) {
const decorators = path.map((x) => [print(x as any), " "], "decorators");
return group([...decorators, path.call(print, "target")]);
}

export function printObjectLiteral(
path: AstPath<ObjectLiteralNode>,
options: TypeSpecPrettierOptions,
Expand Down
Loading
Loading