Skip to content

Commit fe76603

Browse files
committed
Add type cast support to the type checker
1 parent 0799bf7 commit fe76603

File tree

4 files changed

+256
-11
lines changed

4 files changed

+256
-11
lines changed

src/ast/__tests__/expression-extractor.test.ts

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,3 +1035,142 @@ describe("extract ClassInstanceCreationExpression correctly", () => {
10351035
expect(ast).toEqual(expectedAst);
10361036
});
10371037
});
1038+
1039+
describe("extract CastExpression correctly", () => {
1040+
it("extract CastExpression int to char correctly", () => {
1041+
const programStr = `
1042+
class Test {
1043+
void test() {
1044+
char c = (char) 65;
1045+
}
1046+
}
1047+
`;
1048+
1049+
const expectedAst: AST = {
1050+
kind: "CompilationUnit",
1051+
importDeclarations: [],
1052+
topLevelClassOrInterfaceDeclarations: [
1053+
{
1054+
kind: "NormalClassDeclaration",
1055+
classModifier: [],
1056+
typeIdentifier: "Test",
1057+
classBody: [
1058+
{
1059+
kind: "MethodDeclaration",
1060+
methodModifier: [],
1061+
methodHeader: {
1062+
result: "void",
1063+
identifier: "test",
1064+
formalParameterList: [],
1065+
},
1066+
methodBody: {
1067+
kind: "Block",
1068+
blockStatements: [
1069+
{
1070+
kind: "LocalVariableDeclarationStatement",
1071+
localVariableType: "char",
1072+
variableDeclaratorList: [
1073+
{
1074+
kind: "VariableDeclarator",
1075+
variableDeclaratorId: "c",
1076+
variableInitializer: {
1077+
kind: "CastExpression",
1078+
type: "char",
1079+
expression: {
1080+
kind: "Literal",
1081+
literalType: {
1082+
kind: "DecimalIntegerLiteral",
1083+
value: "65",
1084+
},
1085+
location: expect.anything(),
1086+
},
1087+
location: expect.anything(),
1088+
},
1089+
},
1090+
],
1091+
location: expect.anything(),
1092+
},
1093+
],
1094+
location: expect.anything(),
1095+
},
1096+
location: expect.anything(),
1097+
},
1098+
],
1099+
location: expect.anything(),
1100+
},
1101+
],
1102+
location: expect.anything(),
1103+
};
1104+
1105+
const ast = parse(programStr);
1106+
expect(ast).toEqual(expectedAst);
1107+
});
1108+
1109+
it("extract CastExpression double to int correctly", () => {
1110+
const programStr = `
1111+
class Test {
1112+
void test() {
1113+
int x = (int) 3.14;
1114+
}
1115+
}
1116+
`;
1117+
1118+
const expectedAst: AST = {
1119+
kind: "CompilationUnit",
1120+
importDeclarations: [],
1121+
topLevelClassOrInterfaceDeclarations: [
1122+
{
1123+
kind: "NormalClassDeclaration",
1124+
classModifier: [],
1125+
typeIdentifier: "Test",
1126+
classBody: [
1127+
{
1128+
kind: "MethodDeclaration",
1129+
methodModifier: [],
1130+
methodHeader: {
1131+
result: "void",
1132+
identifier: "test",
1133+
formalParameterList: [],
1134+
},
1135+
methodBody: {
1136+
kind: "Block",
1137+
blockStatements: [
1138+
{
1139+
kind: "LocalVariableDeclarationStatement",
1140+
localVariableType: "int",
1141+
variableDeclaratorList: [
1142+
{
1143+
kind: "VariableDeclarator",
1144+
variableDeclaratorId: "x",
1145+
variableInitializer: {
1146+
kind: "CastExpression",
1147+
type: "int",
1148+
expression: {
1149+
kind: "Literal",
1150+
literalType: {
1151+
kind: "DecimalFloatingPointLiteral",
1152+
value: "3.14",
1153+
}
1154+
},
1155+
location: expect.anything(),
1156+
},
1157+
},
1158+
],
1159+
location: expect.anything(),
1160+
},
1161+
],
1162+
location: expect.anything(),
1163+
},
1164+
location: expect.anything(),
1165+
},
1166+
],
1167+
location: expect.anything(),
1168+
},
1169+
],
1170+
location: expect.anything(),
1171+
};
1172+
1173+
const ast = parse(programStr);
1174+
expect(ast).toEqual(expectedAst);
1175+
});
1176+
});

src/ast/astExtractor/expression-extractor.ts

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,9 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults {
9393
const primitiveCast = ctx.primitiveCastExpression[0];
9494
const type = this.extractType(primitiveCast.children.primitiveType[0]);
9595
const expression = this.visit(primitiveCast.children.unaryExpression[0]);
96-
console.debug({primitiveCast, type, expression});
9796
return {
9897
kind: "CastExpression",
99-
castType: type,
98+
type: type,
10099
expression: expression,
101100
location: this.location,
102101
};
@@ -105,13 +104,41 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults {
105104
throw new Error("Invalid CastExpression format.");
106105
}
107106

108-
private extractType(typeCtx: any) {
109-
if (typeCtx.Identifier) {
110-
return typeCtx.Identifier[0].image;
111-
}
112-
if (typeCtx.unannPrimitiveType) {
113-
return this.visit(typeCtx.unannPrimitiveType);
107+
private extractType(typeCtx: any): string {
108+
// Check for the 'primitiveType' node
109+
if (typeCtx.name === "primitiveType" && typeCtx.children) {
110+
const { children } = typeCtx;
111+
112+
// Handle 'numericType' (e.g., int, char, float, double)
113+
if (children.numericType) {
114+
const numericTypeCtx = children.numericType[0];
115+
116+
if (numericTypeCtx.children.integralType) {
117+
// Handle integral types (e.g., char, int)
118+
const integralTypeCtx = numericTypeCtx.children.integralType[0];
119+
120+
// Extract the specific type (e.g., 'char', 'int')
121+
for (const key in integralTypeCtx.children) {
122+
if (integralTypeCtx.children[key][0].image) {
123+
return integralTypeCtx.children[key][0].image;
124+
}
125+
}
126+
}
127+
128+
if (numericTypeCtx.children.floatingPointType) {
129+
// Handle floating-point types (e.g., float, double)
130+
const floatingPointTypeCtx = numericTypeCtx.children.floatingPointType[0];
131+
132+
// Extract the specific type (e.g., 'float', 'double')
133+
for (const key in floatingPointTypeCtx.children) {
134+
if (floatingPointTypeCtx.children[key][0].image) {
135+
return floatingPointTypeCtx.children[key][0].image;
136+
}
137+
}
138+
}
139+
}
114140
}
141+
115142
throw new Error("Invalid type context in cast expression.");
116143
}
117144

@@ -204,6 +231,10 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults {
204231
}
205232

206233
unaryExpression(ctx: UnaryExpressionCtx) {
234+
if (ctx.primary[0].children.primaryPrefix[0].children.castExpression) {
235+
return this.visit(ctx.primary[0].children.primaryPrefix[0].children.castExpression);
236+
}
237+
207238
const node = this.visit(ctx.primary);
208239
if (ctx.UnaryPrefixOperator) {
209240
return {

src/ast/types/blocks-and-statements.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ export type Expression =
119119
| BinaryExpression
120120
| UnaryExpression
121121
| TernaryExpression
122-
| CastExpression
123122
| Void;
124123

125124
export interface Void extends BaseNode {
@@ -260,7 +259,7 @@ export interface Assignment extends BaseNode {
260259
}
261260

262261
export type LeftHandSide = ExpressionName | ArrayAccess;
263-
export type UnaryExpression = PrefixExpression | PostfixExpression;
262+
export type UnaryExpression = PrefixExpression | PostfixExpression | CastExpression;
264263

265264
export interface PrefixExpression extends BaseNode {
266265
kind: "PrefixExpression";

src/types/checker/index.ts

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { Array as ArrayType } from '../types/arrays'
22
import { Integer, String, Throwable, Void } from '../types/references'
33
import { CaseConstant, Node } from '../ast/specificationTypes'
4-
import { Type } from '../types/type'
4+
import { PrimitiveType, Type } from '../types/type'
55
import {
66
ArrayRequiredError,
77
BadOperandTypesError,
@@ -63,6 +63,33 @@ export const check = (node: Node, frame: Frame = Frame.globalFrame()): Result =>
6363
return typeCheckBody(node, typeCheckingFrame)
6464
}
6565

66+
const isCastCompatible = (fromType: Type, toType: Type): boolean => {
67+
// Handle primitive type compatibility
68+
if (fromType instanceof PrimitiveType && toType instanceof PrimitiveType) {
69+
const fromName = fromType.constructor.name;
70+
const toName = toType.constructor.name;
71+
72+
console.log(fromName, toName);
73+
74+
return !(fromName === 'char' && toName !== 'int');
75+
}
76+
77+
// Handle class type compatibility
78+
if (fromType instanceof ClassType && toType instanceof ClassType) {
79+
// Allow upcasts (base class to derived class) or downcasts (derived class to base class)
80+
return fromType.canBeAssigned(toType) || toType.canBeAssigned(fromType);
81+
}
82+
83+
// Handle array type compatibility
84+
if (fromType instanceof ArrayType && toType instanceof ArrayType) {
85+
// Ensure the content types are compatible
86+
return isCastCompatible(fromType.getContentType(), toType.getContentType());
87+
}
88+
89+
// Disallow other cases by default
90+
return false;
91+
};
92+
6693
export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): Result => {
6794
switch (node.kind) {
6895
case 'ArrayAccess': {
@@ -192,6 +219,55 @@ export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): R
192219
case 'BreakStatement': {
193220
return OK_RESULT
194221
}
222+
223+
case 'CastExpression': {
224+
let castType: Type | TypeCheckerError;
225+
let expressionType: Type | null = null;
226+
let expressionResult: Result;
227+
228+
if ('primitiveType' in node) {
229+
castType = frame.getType(unannTypeToString(node.primitiveType), node.primitiveType.location);
230+
} else {
231+
throw new Error('Invalid CastExpression: Missing type information.');
232+
}
233+
234+
if (castType instanceof TypeCheckerError) {
235+
return newResult(null, [castType]);
236+
}
237+
238+
if ('unaryExpression' in node) {
239+
expressionResult = typeCheckBody(node.unaryExpression, frame);
240+
} else {
241+
throw new Error('Invalid CastExpression: Missing expression.');
242+
}
243+
244+
if (expressionResult.hasErrors) {
245+
return expressionResult;
246+
}
247+
248+
expressionType = expressionResult.currentType;
249+
if (!expressionType) {
250+
throw new Error('Expression in cast should have a type.');
251+
}
252+
253+
if (
254+
(castType instanceof PrimitiveType && expressionType instanceof PrimitiveType)
255+
) {
256+
if (!isCastCompatible(expressionType, castType)) {
257+
return newResult(null, [
258+
new IncompatibleTypesError(node.location),
259+
]);
260+
}
261+
} else {
262+
return newResult(null, [
263+
new IncompatibleTypesError(node.location),
264+
]);
265+
}
266+
267+
// If the cast is valid, return the target type
268+
return newResult(castType);
269+
}
270+
195271
case 'ClassInstanceCreationExpression': {
196272
const classIdentifier =
197273
node.unqualifiedClassInstanceCreationExpression.classOrInterfaceTypeToInstantiate

0 commit comments

Comments
 (0)