1 #include "analysis/type_check.h"
2 #include "ast/ast_base.h"
3 #include "ast/ast_node_type.h"
4 #include "common/ast_visitor.h"
9 #include "ast/intrinsic.h"
10 #include "ast/context.h"
12 #include "source_file/token.h"
18 void TypeCheck::run_impl(Package *p) {
22 auto [unresolved_symbols, error_node] = p->top_level_symbol_dependency.topological_sort();
23 if (error_node.has_value()) {
24 error(ErrorType::COMPILE_ERROR, error_node.value(),
"Cyclic dependency detected");
27 for (
auto *c : unresolved_symbols.value()) {
32 std::set<ASTBase *> s(unresolved_symbols.value().begin(), unresolved_symbols.value().end());
33 for (
auto *c : p->get_children()) {
41 FunctionDecl *TypeCheck::search_function_callee(FunctionCall *p) {
42 const str &name = p->get_name();
43 const vector<Expr *> &args = p->_args;
47 error(ErrorType::TYPE_ERROR, p, fmt::format(
"Unknown function call: {}", name));
50 size_t n = candidate->get_n_args();
51 if (n != args.size()) {
52 error(ErrorType::SEMANTIC_ERROR,
54 fmt::format(
"Incorrect number of arguments: expect {} but found {}", candidate->get_n_args(), n));
57 auto *func_type = pcast<FunctionType>(candidate->get_type());
61 for (
size_t i = 0; i < n; ++i) {
62 auto *actual_type = args[i]->get_type();
63 auto *expected_type = func_type->get_arg_types()[i];
65 if (actual_type != expected_type) {
66 if (!CanImplicitlyConvert(actual_type, expected_type)) {
67 error(ErrorType::TYPE_ERROR,
69 fmt::format(
"Cannot implicitly convert the type of argument {}: expect {} but found {}",
71 actual_type->get_typename(),
72 expected_type->get_typename()));
80 Type *TypeCheck::resolve_type_ref(Type *p, ASTBase *node) {
81 TAN_ASSERT(p->is_ref());
84 const str &referred_name = p->get_typename();
85 auto *decl = search_decl_in_scopes(referred_name);
86 if (decl && decl->is_type_decl()) {
87 ret = decl->get_type();
89 error(ErrorType::TYPE_ERROR, node, fmt::format(
"Unknown type {}", referred_name));
96 Type *TypeCheck::resolve_type(Type *p, ASTBase *node) {
101 ret = resolve_type_ref(p, node);
102 }
else if (p->is_pointer()) {
103 auto *pointee = pcast<PointerType>(p)->get_pointee();
106 if (pointee->is_ref()) {
107 pointee = resolve_type_ref(pointee, node);
108 ret = Type::GetPointerType(pointee);
116 void TypeCheck::analyze_func_decl_prototype(ASTBase *_p) {
117 auto *p = pcast<FunctionDecl>(_p);
122 auto *func_type = pcast<FunctionType>(p->get_type());
123 auto *ret_type = resolve_type(func_type->get_return_type(), p);
124 func_type->set_return_type(ret_type);
127 size_t n = p->get_n_args();
128 const auto &arg_decls = p->get_arg_decls();
129 vector<Type *> arg_types(n,
nullptr);
130 for (
size_t i = 0; i < n; ++i) {
132 arg_types[i] = arg_decls[i]->get_type();
133 TAN_ASSERT(arg_types[i]->is_canonical());
135 func_type->set_arg_types(arg_types);
140 void TypeCheck::analyze_func_body(ASTBase *_p) {
141 auto *p = pcast<FunctionDecl>(_p);
145 if (!p->is_external()) {
146 visit(p->get_body());
152 void TypeCheck::analyze_function_call(FunctionCall *p,
bool include_intrinsics) {
153 for (
const auto &a : p->_args) {
157 FunctionDecl *callee = search_function_callee(p);
158 if (include_intrinsics || !callee->is_intrinsic()) {
161 error(ErrorType::UNKNOWN_SYMBOL,
163 fmt::format(
"Unknown function call. Maybe use @{} if you want to call this intrinsic?", p->get_name()));
166 auto *func_type = pcast<FunctionType>(callee->get_type());
167 p->set_type(func_type->get_return_type());
170 void TypeCheck::analyze_intrinsic_func_call(Intrinsic *p, FunctionCall *func_call) {
171 auto *void_type = Type::GetVoidType();
172 switch (p->get_intrinsic_type()) {
173 case IntrinsicType::STACK_TRACE:
174 func_call->set_name(Intrinsic::STACK_TRACE_FUNCTION_REAL_NAME);
176 case IntrinsicType::ABORT:
177 analyze_function_call(func_call,
true);
178 p->set_type(void_type);
180 case IntrinsicType::GET_DECL: {
181 if (func_call->get_n_args() != 1) {
182 error(ErrorType::SEMANTIC_ERROR, func_call,
"Expect the number of args to be 1");
185 auto *target = func_call->get_arg(0);
186 if (target->get_node_type() != ASTNodeType::ID) {
187 error(ErrorType::TYPE_ERROR,
189 fmt::format(
"Expect an identifier as the operand, but got {}",
195 auto *
id = pcast<Identifier>(target);
196 if (id->get_id_type() != IdentifierType::ID_VAR_REF) {
197 error(ErrorType::TYPE_ERROR,
id, fmt::format(
"Expect a value but got type"));
200 auto *decl =
id->get_var_ref()->get_referred();
201 auto *source_str = Literal::CreateStringLiteral(p->src(), decl->src()->get_source_code(decl->start(), decl->end()));
204 p->set_sub(source_str);
205 p->set_type(source_str->get_type());
208 case IntrinsicType::COMP_PRINT: {
209 p->set_type(void_type);
212 auto args = func_call->_args;
213 if (args.size() != 1 || args[0]->get_node_type() != ASTNodeType::STRING_LITERAL) {
214 error(ErrorType::TYPE_ERROR, p,
"Invalid call to compprint, one argument with type 'str' required");
217 str msg = pcast<StringLiteral>(args[0])->get_value();
218 std::cout << fmt::format(
"Message ({}): {}\n", p->src()->get_src_location_str(p->start()), msg);
227 void TypeCheck::analyze_member_func_call(MemberAccess *p, Expr *lhs, FunctionCall *rhs) {
228 if (!lhs->is_lvalue() && !lhs->get_type()->is_pointer()) {
229 error(ErrorType::TYPE_ERROR, p,
"Invalid member function call");
233 if (lhs->is_lvalue() && !lhs->get_type()->is_pointer()) {
234 Expr *tmp = UnaryOperator::Create(UnaryOpKind::ADDRESS_OF, lhs->src(), lhs);
236 rhs->_args.insert(rhs->_args.begin(), tmp);
238 rhs->_args.insert(rhs->_args.begin(), lhs);
242 p->set_type(rhs->get_type());
246 void TypeCheck::analyze_bracket_access(MemberAccess *p, Expr *lhs, Expr *rhs) {
249 if (!lhs->is_lvalue()) {
250 error(ErrorType::TYPE_ERROR, p,
"Expect lhs to be an lvalue");
253 auto *lhs_type = lhs->get_type();
254 if (!(lhs_type->is_pointer() || lhs_type->is_array() || lhs_type->is_string())) {
255 error(ErrorType::TYPE_ERROR, p,
"Expect a type that supports bracket access");
257 if (!rhs->get_type()->is_int()) {
258 error(ErrorType::TYPE_ERROR, rhs,
"Expect an integer");
261 Type *sub_type =
nullptr;
262 if (lhs_type->is_pointer()) {
263 sub_type = pcast<PointerType>(lhs_type)->get_pointee();
264 }
else if (lhs_type->is_array()) {
265 auto *array_type = pcast<ArrayType>(lhs_type);
266 sub_type = array_type->get_element_type();
268 if (rhs->get_node_type() == ASTNodeType::INTEGER_LITERAL) {
269 uint64_t size = pcast<IntegerLiteral>(rhs)->get_value();
270 if (lhs->get_type()->is_array() && (
int)size >= array_type->array_size()) {
271 error(ErrorType::TYPE_ERROR,
273 fmt::format(
"Index {} out of bound, the array size is {}",
274 std::to_string(size),
275 std::to_string(array_type->array_size())));
278 }
else if (lhs_type->is_string()) {
279 sub_type = Type::GetCharType();
282 p->set_type(sub_type);
286 void TypeCheck::analyze_member_access_member_variable(MemberAccess *p, Expr *lhs, Expr *rhs) {
287 str m_name = pcast<Identifier>(rhs)->get_name();
291 if (lhs->get_type()->is_pointer()) {
292 t = pcast<PointerType>(lhs->get_type())->get_pointee();
298 t = resolve_type(t, lhs);
299 if (!t->is_struct()) {
300 error(ErrorType::TYPE_ERROR, lhs,
"Expect a struct type");
304 auto *struct_type = pcast<StructType>(t);
305 auto *struct_decl = struct_type->get_decl();
306 TAN_ASSERT(struct_decl);
308 p->_access_idx = struct_decl->get_struct_member_index(m_name);
309 if (p->_access_idx == -1) {
310 error(ErrorType::UNKNOWN_SYMBOL,
312 fmt::format(
"Cannot find member variable '{}' of struct '{}'", m_name, struct_decl->get_name()));
314 auto *mem_type = struct_decl->get_struct_member_ty(p->_access_idx);
315 p->set_type(resolve_type(mem_type, p));
318 DEFINE_AST_VISITOR_IMPL(TypeCheck, Identifier) {
319 auto *referred = search_decl_in_scopes(p->get_name());
321 if (referred->is_type_decl()) {
322 auto *ty = resolve_type_ref(referred->get_type(), p);
325 p->set_var_ref(
VarRef::Create(p->src(), p->get_name(), referred));
326 p->set_type(resolve_type(referred->get_type(), p));
329 error(ErrorType::UNKNOWN_SYMBOL, p,
"Unknown identifier");
333 DEFINE_AST_VISITOR_IMPL(TypeCheck, Parenthesis) {
335 p->set_type(p->get_sub()->get_type());
338 DEFINE_AST_VISITOR_IMPL(TypeCheck, If) {
339 size_t n = p->get_num_branches();
340 for (
size_t i = 0; i < n; ++i) {
341 auto *cond = p->get_predicate(i);
344 p->set_predicate(i, create_implicit_conversion(cond, Type::GetBoolType()));
347 visit(p->get_branch(i));
351 DEFINE_AST_VISITOR_IMPL(TypeCheck, VarDecl) {
352 Type *ty = p->get_type();
357 error(ErrorType::TYPE_ERROR, p,
"Cannot deduce the type of variable declaration");
359 p->set_type(resolve_type(ty, p));
362 DEFINE_AST_VISITOR_IMPL(TypeCheck, ArgDecl) { p->set_type(resolve_type(p->get_type(), p)); }
364 DEFINE_AST_VISITOR_IMPL(TypeCheck, Return) {
365 FunctionDecl *func = search_node_in_parent_scopes<FunctionDecl, ASTNodeType::FUNC_DECL>();
367 error(ErrorType::TYPE_ERROR, p,
"Return statement must be inside a function definition");
370 auto *rhs = p->get_rhs();
371 Type *ret_type = Type::GetVoidType();
374 ret_type = rhs->get_type();
377 if (!CanImplicitlyConvert(ret_type, pcast<FunctionType>(func->get_type())->get_return_type())) {
378 error(ErrorType::TYPE_ERROR, p,
"Returned type cannot be coerced to function return type");
382 DEFINE_AST_VISITOR_IMPL(TypeCheck, CompoundStmt) {
385 for (
const auto &c : p->get_children()) {
392 DEFINE_AST_VISITOR_IMPL(TypeCheck, BinaryOrUnary) { visit(p->get_expr_ptr()); }
394 DEFINE_AST_VISITOR_IMPL(TypeCheck, BinaryOperator) {
395 Expr *lhs = p->get_lhs();
396 Expr *rhs = p->get_rhs();
398 if (p->get_op() == BinaryOpKind::MEMBER_ACCESS) {
399 CALL_AST_VISITOR(MemberAccess, p);
406 switch (p->get_op()) {
407 case BinaryOpKind::SUM:
408 case BinaryOpKind::SUBTRACT:
409 case BinaryOpKind::MULTIPLY:
410 case BinaryOpKind::DIVIDE:
411 case BinaryOpKind::BAND:
412 case BinaryOpKind::BOR:
413 case BinaryOpKind::MOD: {
414 p->set_type(auto_promote_bop_operand_types(p));
417 case BinaryOpKind::LAND:
418 case BinaryOpKind::LOR:
419 case BinaryOpKind::XOR: {
421 auto *bool_type = PrimitiveType::GetBoolType();
422 p->set_lhs(create_implicit_conversion(lhs, bool_type));
423 p->set_rhs(create_implicit_conversion(rhs, bool_type));
424 p->set_type(bool_type);
427 case BinaryOpKind::GT:
428 case BinaryOpKind::GE:
429 case BinaryOpKind::LT:
430 case BinaryOpKind::LE:
431 case BinaryOpKind::EQ:
432 case BinaryOpKind::NE:
433 auto_promote_bop_operand_types(p);
434 p->set_type(PrimitiveType::GetBoolType());
441 DEFINE_AST_VISITOR_IMPL(TypeCheck, UnaryOperator) {
442 auto *rhs = p->get_rhs();
445 auto *rhs_type = rhs->get_type();
446 switch (p->get_op()) {
447 case UnaryOpKind::LNOT:
448 rhs = create_implicit_conversion(rhs, Type::GetBoolType());
450 p->set_type(PrimitiveType::GetBoolType());
452 case UnaryOpKind::BNOT:
453 if (!rhs_type->is_int()) {
454 error(ErrorType::TYPE_ERROR, rhs,
"Expect an integer type");
456 p->set_type(rhs_type);
458 case UnaryOpKind::ADDRESS_OF:
459 p->set_type(Type::GetPointerType(rhs_type));
461 case UnaryOpKind::PTR_DEREF:
462 if (!rhs_type->is_pointer()) {
463 error(ErrorType::TYPE_ERROR, rhs,
"Expect a pointer type");
465 TAN_ASSERT(rhs->is_lvalue());
467 p->set_type(pcast<PointerType>(rhs_type)->get_pointee());
469 case UnaryOpKind::PLUS:
470 case UnaryOpKind::MINUS:
471 if (!rhs_type->is_num()) {
472 error(ErrorType::TYPE_ERROR, rhs,
"Expect a numerical type");
474 p->set_type(rhs_type);
481 DEFINE_AST_VISITOR_IMPL(TypeCheck, Cast) {
482 Expr *lhs = p->get_lhs();
484 p->set_type(resolve_type(p->get_type(), p));
487 DEFINE_AST_VISITOR_IMPL(TypeCheck, Assignment) {
488 Expr *rhs = p->get_rhs();
491 auto *lhs = p->get_lhs();
492 Type *lhs_type =
nullptr;
493 if (lhs->get_node_type() == ASTNodeType::VAR_DECL) {
494 auto *var_decl = pcast<VarDecl>(lhs);
497 if (!var_decl->get_type()) {
498 var_decl->set_type(rhs->get_type());
502 lhs_type = var_decl->get_type();
506 switch (lhs->get_node_type()) {
507 case ASTNodeType::ID: {
508 auto *
id = pcast<Identifier>(lhs);
509 if (id->get_id_type() != IdentifierType::ID_VAR_REF) {
510 error(ErrorType::TYPE_ERROR, lhs,
"Can only assign value to a variable");
512 lhs_type =
id->get_type();
515 case ASTNodeType::ARG_DECL:
516 case ASTNodeType::BOP_OR_UOP:
517 case ASTNodeType::UOP:
518 case ASTNodeType::BOP:
519 lhs_type = pcast<Expr>(lhs)->get_type();
522 error(ErrorType::TYPE_ERROR, lhs,
"Invalid left-hand operand");
525 p->set_type(lhs_type);
527 rhs = create_implicit_conversion(rhs, lhs_type);
532 DEFINE_AST_VISITOR_IMPL(TypeCheck, FunctionCall) {
533 analyze_function_call(p,
false);
536 DEFINE_AST_VISITOR_IMPL(TypeCheck, FunctionDecl) {
537 analyze_func_decl_prototype(p);
538 analyze_func_body(p);
541 DEFINE_AST_VISITOR_IMPL(TypeCheck, Import) {
542 for (TypeDecl *t : p->_imported_types) {
547 DEFINE_AST_VISITOR_IMPL(TypeCheck, Intrinsic) {
548 switch (p->get_intrinsic_type()) {
549 case IntrinsicType::LINENO: {
550 auto sub = IntegerLiteral::Create(p->src(), p->src()->get_line(p->start()),
true);
551 auto type = PrimitiveType::GetIntegerType(32,
true);
557 case IntrinsicType::FILENAME: {
558 auto sub = StringLiteral::Create(p->src(), p->src()->get_filename());
559 auto type = Type::GetStringType();
565 case IntrinsicType::TEST_COMP_ERROR: {
566 auto *tce = pcast<TestCompError>(p->get_sub());
574 }
catch (
const CompileException &e) {
575 std::cerr << fmt::format(
"Caught expected compile error: {}\nContinue compilation...\n", e.what());
580 error(ErrorType::TYPE_ERROR, p,
"Expect a compile error");
585 case IntrinsicType::NOOP:
587 case IntrinsicType::INVALID:
591 auto *c = p->get_sub();
592 if (c->get_node_type() == ASTNodeType::FUNC_CALL) {
593 analyze_intrinsic_func_call(p, pcast<FunctionCall>(c));
602 DEFINE_AST_VISITOR_IMPL(TypeCheck, StringLiteral) {
603 TAN_ASSERT(!p->get_value().empty());
604 p->set_type(Type::GetStringType());
607 DEFINE_AST_VISITOR_IMPL(TypeCheck, CharLiteral) { p->set_type(Type::GetCharType()); }
609 DEFINE_AST_VISITOR_IMPL(TypeCheck, IntegerLiteral) {
611 if (p->is_unsigned()) {
612 ty = Type::GetIntegerType(32,
true);
614 ty = Type::GetIntegerType(32,
false);
619 DEFINE_AST_VISITOR_IMPL(TypeCheck, BoolLiteral) { p->set_type(Type::GetBoolType()); }
621 DEFINE_AST_VISITOR_IMPL(TypeCheck, FloatLiteral) { p->set_type(Type::GetFloatType(32)); }
623 DEFINE_AST_VISITOR_IMPL(TypeCheck, ArrayLiteral) {
626 auto elements = p->get_elements();
627 Type *element_type =
nullptr;
628 for (
auto *e : elements) {
631 element_type = e->get_type();
633 create_implicit_conversion(e, element_type);
636 TAN_ASSERT(element_type);
637 p->set_type(Type::GetArrayType(element_type, (
int)elements.size()));
640 DEFINE_AST_VISITOR_IMPL(TypeCheck, MemberAccess) {
641 Expr *lhs = p->get_lhs();
644 Expr *rhs = p->get_rhs();
646 if (rhs->get_node_type() == ASTNodeType::FUNC_CALL) {
647 p->_access_type = MemberAccess::MemberAccessMemberFunction;
648 auto func_call = pcast<FunctionCall>(rhs);
649 analyze_member_func_call(p, lhs, func_call);
650 }
else if (p->_access_type == MemberAccess::MemberAccessBracket) {
651 analyze_bracket_access(p, lhs, rhs);
652 }
else if (rhs->get_node_type() == ASTNodeType::ID) {
653 p->_access_type = MemberAccess::MemberAccessMemberVariable;
654 analyze_member_access_member_variable(p, lhs, rhs);
656 error(ErrorType::UNKNOWN_SYMBOL, p,
"Invalid right-hand operand");
660 DEFINE_AST_VISITOR_IMPL(TypeCheck, StructDecl) {
661 str struct_name = p->get_name();
662 auto *ty = pcast<StructType>(p->get_type());
663 TAN_ASSERT(ty && ty->is_struct());
664 auto members = p->get_member_decls();
668 size_t n = members.size();
669 for (
size_t i = 0; i < n; ++i) {
670 Expr *m = members[i];
672 if (m->get_node_type() == ASTNodeType::VAR_DECL) {
673 (*ty)[i] = resolve_type((*ty)[i], m);
675 }
else if (m->get_node_type() == ASTNodeType::ASSIGN) {
676 auto init_val = pcast<Assignment>(m)->get_rhs();
679 (*ty)[i] = resolve_type((*ty)[i], m);
681 if (!init_val->is_comptime_known()) {
682 error(ErrorType::TYPE_ERROR, p,
"Initial value of a member variable must be compile-time known");
685 }
else if (m->get_node_type() == ASTNodeType::FUNC_DECL) {
686 auto f = pcast<FunctionDecl>(m);
687 (*ty)[i] = f->get_type();
690 error(ErrorType::TYPE_ERROR, p,
"Invalid struct member");
697 DEFINE_AST_VISITOR_IMPL(TypeCheck, Loop) {
700 if (p->_loop_type == ASTLoopType::FOR) {
701 visit(p->_initialization);
704 visit(p->_predicate);
706 if (p->_loop_type == ASTLoopType::FOR) {
707 visit(p->_iteration);
715 DEFINE_AST_VISITOR_IMPL(TypeCheck, BreakContinue) {
716 Loop *loop = search_node_in_parent_scopes<Loop, ASTNodeType::LOOP>();
718 error(ErrorType::SEMANTIC_ERROR, p,
"Break or continue must be inside a loop");
720 p->set_parent_loop(pcast<Loop>(loop));
static umap< ASTNodeType, str > ASTTypeNames
string representation of ASTNodeType
FunctionDecl * get_func_decl(const str &name) const
Search for a function declaration by name.
static VarRef * Create(TokenizedSourceFile *src, const str &name, Decl *referred)