tan  0.0.1
type_check.cpp
1 #include "analysis/type_check.h"
2 #include "ast/ast_base.h"
3 #include "ast/ast_node_type.h"
4 #include "common/ast_visitor.h"
5 #include "ast/type.h"
6 #include "ast/expr.h"
7 #include "ast/stmt.h"
8 #include "ast/decl.h"
9 #include "ast/intrinsic.h"
10 #include "ast/context.h"
11 #include "fmt/core.h"
12 #include "source_file/token.h"
13 #include <iostream>
14 #include <set>
15 
16 namespace tanlang {
17 
18 void TypeCheck::run_impl(Package *p) {
19  push_scope(p);
20 
21  // check unresolved symbols
22  auto [unresolved_symbols, error_node] = p->top_level_symbol_dependency.topological_sort();
23  if (error_node.has_value()) {
24  error(ErrorType::COMPILE_ERROR, error_node.value(), "Cyclic dependency detected");
25  }
26 
27  for (auto *c : unresolved_symbols.value()) {
28  visit(c);
29  }
30 
31  // check remaining ones
32  std::set<ASTBase *> s(unresolved_symbols.value().begin(), unresolved_symbols.value().end());
33  for (auto *c : p->get_children()) {
34  if (!s.contains(c))
35  visit(c);
36  }
37 
38  pop_scope();
39 }
40 
41 FunctionDecl *TypeCheck::search_function_callee(FunctionCall *p) {
42  const str &name = p->get_name();
43  const vector<Expr *> &args = p->_args;
44 
45  FunctionDecl *candidate = top_ctx()->get_func_decl(name);
46  if (!candidate) {
47  error(ErrorType::TYPE_ERROR, p, fmt::format("Unknown function call: {}", name));
48  }
49 
50  size_t n = candidate->get_n_args();
51  if (n != args.size()) {
52  error(ErrorType::SEMANTIC_ERROR,
53  p,
54  fmt::format("Incorrect number of arguments: expect {} but found {}", candidate->get_n_args(), n));
55  }
56 
57  auto *func_type = pcast<FunctionType>(candidate->get_type());
58 
59  // Check if argument types match (return type not checked)
60  // Allow implicit cast from actual arguments to expected arguments
61  for (size_t i = 0; i < n; ++i) {
62  auto *actual_type = args[i]->get_type();
63  auto *expected_type = func_type->get_arg_types()[i];
64 
65  if (actual_type != expected_type) {
66  if (!CanImplicitlyConvert(actual_type, expected_type)) {
67  error(ErrorType::TYPE_ERROR,
68  p,
69  fmt::format("Cannot implicitly convert the type of argument {}: expect {} but found {}",
70  i + 1,
71  actual_type->get_typename(),
72  expected_type->get_typename()));
73  }
74  }
75  }
76 
77  return candidate;
78 }
79 
80 Type *TypeCheck::resolve_type_ref(Type *p, ASTBase *node) {
81  TAN_ASSERT(p->is_ref());
82  Type *ret = p;
83 
84  const str &referred_name = p->get_typename();
85  auto *decl = search_decl_in_scopes(referred_name);
86  if (decl && decl->is_type_decl()) {
87  ret = decl->get_type();
88  } else {
89  error(ErrorType::TYPE_ERROR, node, fmt::format("Unknown type {}", referred_name));
90  }
91 
92  TAN_ASSERT(ret);
93  return ret;
94 }
95 
96 Type *TypeCheck::resolve_type(Type *p, ASTBase *node) {
97  TAN_ASSERT(p);
98 
99  Type *ret = p;
100  if (p->is_ref()) {
101  ret = resolve_type_ref(p, node);
102  } else if (p->is_pointer()) {
103  auto *pointee = pcast<PointerType>(p)->get_pointee();
104 
105  TAN_ASSERT(pointee);
106  if (pointee->is_ref()) {
107  pointee = resolve_type_ref(pointee, node);
108  ret = Type::GetPointerType(pointee);
109  }
110  }
111 
112  TAN_ASSERT(ret);
113  return ret;
114 }
115 
116 void TypeCheck::analyze_func_decl_prototype(ASTBase *_p) {
117  auto *p = pcast<FunctionDecl>(_p);
118 
119  push_scope(p);
120 
121  /// update return type
122  auto *func_type = pcast<FunctionType>(p->get_type());
123  auto *ret_type = resolve_type(func_type->get_return_type(), p);
124  func_type->set_return_type(ret_type);
125 
126  /// type_check_ast args
127  size_t n = p->get_n_args();
128  const auto &arg_decls = p->get_arg_decls();
129  vector<Type *> arg_types(n, nullptr);
130  for (size_t i = 0; i < n; ++i) {
131  visit(arg_decls[i]); /// args will be added to the scope here
132  arg_types[i] = arg_decls[i]->get_type();
133  TAN_ASSERT(arg_types[i]->is_canonical());
134  }
135  func_type->set_arg_types(arg_types); /// update arg types
136 
137  pop_scope();
138 }
139 
140 void TypeCheck::analyze_func_body(ASTBase *_p) {
141  auto *p = pcast<FunctionDecl>(_p);
142 
143  push_scope(p);
144 
145  if (!p->is_external()) {
146  visit(p->get_body());
147  }
148 
149  pop_scope();
150 }
151 
152 void TypeCheck::analyze_function_call(FunctionCall *p, bool include_intrinsics) {
153  for (const auto &a : p->_args) {
154  visit(a);
155  }
156 
157  FunctionDecl *callee = search_function_callee(p);
158  if (include_intrinsics || !callee->is_intrinsic()) {
159  p->_callee = callee;
160  } else {
161  error(ErrorType::UNKNOWN_SYMBOL,
162  p,
163  fmt::format("Unknown function call. Maybe use @{} if you want to call this intrinsic?", p->get_name()));
164  }
165 
166  auto *func_type = pcast<FunctionType>(callee->get_type());
167  p->set_type(func_type->get_return_type());
168 }
169 
170 void TypeCheck::analyze_intrinsic_func_call(Intrinsic *p, FunctionCall *func_call) {
171  auto *void_type = Type::GetVoidType();
172  switch (p->get_intrinsic_type()) {
173  case IntrinsicType::STACK_TRACE:
174  func_call->set_name(Intrinsic::STACK_TRACE_FUNCTION_REAL_NAME);
175  [[fallthrough]];
176  case IntrinsicType::ABORT:
177  analyze_function_call(func_call, true);
178  p->set_type(void_type);
179  break;
180  case IntrinsicType::GET_DECL: {
181  if (func_call->get_n_args() != 1) {
182  error(ErrorType::SEMANTIC_ERROR, func_call, "Expect the number of args to be 1");
183  }
184 
185  auto *target = func_call->get_arg(0);
186  if (target->get_node_type() != ASTNodeType::ID) {
187  error(ErrorType::TYPE_ERROR,
188  target,
189  fmt::format("Expect an identifier as the operand, but got {}",
190  ASTBase::ASTTypeNames[target->get_node_type()]));
191  }
192 
193  visit(target);
194 
195  auto *id = pcast<Identifier>(target);
196  if (id->get_id_type() != IdentifierType::ID_VAR_REF) {
197  error(ErrorType::TYPE_ERROR, id, fmt::format("Expect a value but got type"));
198  }
199 
200  auto *decl = id->get_var_ref()->get_referred();
201  auto *source_str = Literal::CreateStringLiteral(p->src(), decl->src()->get_source_code(decl->start(), decl->end()));
202 
203  // FEATURE: Return AST?
204  p->set_sub(source_str);
205  p->set_type(source_str->get_type());
206  break;
207  }
208  case IntrinsicType::COMP_PRINT: {
209  p->set_type(void_type);
210 
211  // FEATURE: print with var args
212  auto args = func_call->_args;
213  if (args.size() != 1 || args[0]->get_node_type() != ASTNodeType::STRING_LITERAL) {
214  error(ErrorType::TYPE_ERROR, p, "Invalid call to compprint, one argument with type 'str' required");
215  }
216 
217  str msg = pcast<StringLiteral>(args[0])->get_value();
218  std::cout << fmt::format("Message ({}): {}\n", p->src()->get_src_location_str(p->start()), msg);
219  break;
220  }
221  default:
222  TAN_ASSERT(false);
223  }
224 }
225 
226 // ASSUMES lhs has been already analyzed, while rhs has not
227 void TypeCheck::analyze_member_func_call(MemberAccess *p, Expr *lhs, FunctionCall *rhs) {
228  if (!lhs->is_lvalue() && !lhs->get_type()->is_pointer()) {
229  error(ErrorType::TYPE_ERROR, p, "Invalid member function call");
230  }
231 
232  // insert the address of the struct instance as the first parameter
233  if (lhs->is_lvalue() && !lhs->get_type()->is_pointer()) {
234  Expr *tmp = UnaryOperator::Create(UnaryOpKind::ADDRESS_OF, lhs->src(), lhs);
235  visit(tmp);
236  rhs->_args.insert(rhs->_args.begin(), tmp);
237  } else {
238  rhs->_args.insert(rhs->_args.begin(), lhs);
239  }
240 
241  visit(rhs);
242  p->set_type(rhs->get_type());
243 }
244 
245 // ASSUMES lhs has been already analyzed, while rhs has not
246 void TypeCheck::analyze_bracket_access(MemberAccess *p, Expr *lhs, Expr *rhs) {
247  visit(rhs);
248 
249  if (!lhs->is_lvalue()) {
250  error(ErrorType::TYPE_ERROR, p, "Expect lhs to be an lvalue");
251  }
252 
253  auto *lhs_type = lhs->get_type();
254  if (!(lhs_type->is_pointer() || lhs_type->is_array() || lhs_type->is_string())) {
255  error(ErrorType::TYPE_ERROR, p, "Expect a type that supports bracket access");
256  }
257  if (!rhs->get_type()->is_int()) {
258  error(ErrorType::TYPE_ERROR, rhs, "Expect an integer");
259  }
260 
261  Type *sub_type = nullptr;
262  if (lhs_type->is_pointer()) {
263  sub_type = pcast<PointerType>(lhs_type)->get_pointee();
264  } else if (lhs_type->is_array()) {
265  auto *array_type = pcast<ArrayType>(lhs_type);
266  sub_type = array_type->get_element_type();
267  /// check if array index is out-of-bound
268  if (rhs->get_node_type() == ASTNodeType::INTEGER_LITERAL) {
269  uint64_t size = pcast<IntegerLiteral>(rhs)->get_value();
270  if (lhs->get_type()->is_array() && (int)size >= array_type->array_size()) {
271  error(ErrorType::TYPE_ERROR,
272  p,
273  fmt::format("Index {} out of bound, the array size is {}",
274  std::to_string(size),
275  std::to_string(array_type->array_size())));
276  }
277  }
278  } else if (lhs_type->is_string()) {
279  sub_type = Type::GetCharType();
280  }
281 
282  p->set_type(sub_type);
283 }
284 
285 // ASSUMES lhs has been already analyzed, while rhs has not
286 void TypeCheck::analyze_member_access_member_variable(MemberAccess *p, Expr *lhs, Expr *rhs) {
287  str m_name = pcast<Identifier>(rhs)->get_name();
288  Type *t = nullptr;
289 
290  // auto dereference pointers
291  if (lhs->get_type()->is_pointer()) {
292  t = pcast<PointerType>(lhs->get_type())->get_pointee();
293  } else {
294  t = lhs->get_type();
295  }
296 
297  // resole struct type
298  t = resolve_type(t, lhs);
299  if (!t->is_struct()) {
300  error(ErrorType::TYPE_ERROR, lhs, "Expect a struct type");
301  }
302 
303  // resolve the member being accessed
304  auto *struct_type = pcast<StructType>(t);
305  auto *struct_decl = struct_type->get_decl();
306  TAN_ASSERT(struct_decl);
307 
308  p->_access_idx = struct_decl->get_struct_member_index(m_name);
309  if (p->_access_idx == -1) {
310  error(ErrorType::UNKNOWN_SYMBOL,
311  p,
312  fmt::format("Cannot find member variable '{}' of struct '{}'", m_name, struct_decl->get_name()));
313  }
314  auto *mem_type = struct_decl->get_struct_member_ty(p->_access_idx);
315  p->set_type(resolve_type(mem_type, p));
316 }
317 
318 DEFINE_AST_VISITOR_IMPL(TypeCheck, Identifier) {
319  auto *referred = search_decl_in_scopes(p->get_name());
320  if (referred) {
321  if (referred->is_type_decl()) { /// refers to a type
322  auto *ty = resolve_type_ref(referred->get_type(), p);
323  p->set_type_ref(ty);
324  } else { /// refers to a variable
325  p->set_var_ref(VarRef::Create(p->src(), p->get_name(), referred));
326  p->set_type(resolve_type(referred->get_type(), p));
327  }
328  } else {
329  error(ErrorType::UNKNOWN_SYMBOL, p, "Unknown identifier");
330  }
331 }
332 
333 DEFINE_AST_VISITOR_IMPL(TypeCheck, Parenthesis) {
334  visit(p->get_sub());
335  p->set_type(p->get_sub()->get_type());
336 }
337 
338 DEFINE_AST_VISITOR_IMPL(TypeCheck, If) {
339  size_t n = p->get_num_branches();
340  for (size_t i = 0; i < n; ++i) {
341  auto *cond = p->get_predicate(i);
342  if (cond) { /// can be nullptr, meaning an "else" branch
343  visit(cond);
344  p->set_predicate(i, create_implicit_conversion(cond, Type::GetBoolType()));
345  }
346 
347  visit(p->get_branch(i));
348  }
349 }
350 
351 DEFINE_AST_VISITOR_IMPL(TypeCheck, VarDecl) {
352  Type *ty = p->get_type();
353 
354  // assume the type is always non-null
355  // type_check_assignment is responsible for setting the deduced type if necessary
356  if (!ty) {
357  error(ErrorType::TYPE_ERROR, p, "Cannot deduce the type of variable declaration");
358  }
359  p->set_type(resolve_type(ty, p));
360 }
361 
362 DEFINE_AST_VISITOR_IMPL(TypeCheck, ArgDecl) { p->set_type(resolve_type(p->get_type(), p)); }
363 
364 DEFINE_AST_VISITOR_IMPL(TypeCheck, Return) {
365  FunctionDecl *func = search_node_in_parent_scopes<FunctionDecl, ASTNodeType::FUNC_DECL>();
366  if (!func) {
367  error(ErrorType::TYPE_ERROR, p, "Return statement must be inside a function definition");
368  }
369 
370  auto *rhs = p->get_rhs();
371  Type *ret_type = Type::GetVoidType();
372  if (rhs) {
373  visit(rhs);
374  ret_type = rhs->get_type();
375  }
376  // check if return type is the same as the function return type
377  if (!CanImplicitlyConvert(ret_type, pcast<FunctionType>(func->get_type())->get_return_type())) {
378  error(ErrorType::TYPE_ERROR, p, "Returned type cannot be coerced to function return type");
379  }
380 }
381 
382 DEFINE_AST_VISITOR_IMPL(TypeCheck, CompoundStmt) {
383  push_scope(p);
384 
385  for (const auto &c : p->get_children()) {
386  visit(c);
387  }
388 
389  pop_scope();
390 }
391 
392 DEFINE_AST_VISITOR_IMPL(TypeCheck, BinaryOrUnary) { visit(p->get_expr_ptr()); }
393 
394 DEFINE_AST_VISITOR_IMPL(TypeCheck, BinaryOperator) {
395  Expr *lhs = p->get_lhs();
396  Expr *rhs = p->get_rhs();
397 
398  if (p->get_op() == BinaryOpKind::MEMBER_ACCESS) {
399  CALL_AST_VISITOR(MemberAccess, p);
400  return;
401  }
402 
403  visit(lhs);
404  visit(rhs);
405 
406  switch (p->get_op()) {
407  case BinaryOpKind::SUM:
408  case BinaryOpKind::SUBTRACT:
409  case BinaryOpKind::MULTIPLY:
410  case BinaryOpKind::DIVIDE:
411  case BinaryOpKind::BAND:
412  case BinaryOpKind::BOR:
413  case BinaryOpKind::MOD: {
414  p->set_type(auto_promote_bop_operand_types(p));
415  break;
416  }
417  case BinaryOpKind::LAND:
418  case BinaryOpKind::LOR:
419  case BinaryOpKind::XOR: {
420  // check if both operators are bool
421  auto *bool_type = PrimitiveType::GetBoolType();
422  p->set_lhs(create_implicit_conversion(lhs, bool_type));
423  p->set_rhs(create_implicit_conversion(rhs, bool_type));
424  p->set_type(bool_type);
425  break;
426  }
427  case BinaryOpKind::GT:
428  case BinaryOpKind::GE:
429  case BinaryOpKind::LT:
430  case BinaryOpKind::LE:
431  case BinaryOpKind::EQ:
432  case BinaryOpKind::NE:
433  auto_promote_bop_operand_types(p);
434  p->set_type(PrimitiveType::GetBoolType());
435  break;
436  default:
437  TAN_ASSERT(false);
438  }
439 }
440 
441 DEFINE_AST_VISITOR_IMPL(TypeCheck, UnaryOperator) {
442  auto *rhs = p->get_rhs();
443  visit(rhs);
444 
445  auto *rhs_type = rhs->get_type();
446  switch (p->get_op()) {
447  case UnaryOpKind::LNOT:
448  rhs = create_implicit_conversion(rhs, Type::GetBoolType());
449  p->set_rhs(rhs);
450  p->set_type(PrimitiveType::GetBoolType());
451  break;
452  case UnaryOpKind::BNOT:
453  if (!rhs_type->is_int()) {
454  error(ErrorType::TYPE_ERROR, rhs, "Expect an integer type");
455  }
456  p->set_type(rhs_type);
457  break;
458  case UnaryOpKind::ADDRESS_OF:
459  p->set_type(Type::GetPointerType(rhs_type));
460  break;
461  case UnaryOpKind::PTR_DEREF:
462  if (!rhs_type->is_pointer()) {
463  error(ErrorType::TYPE_ERROR, rhs, "Expect a pointer type");
464  }
465  TAN_ASSERT(rhs->is_lvalue());
466  p->set_lvalue(true);
467  p->set_type(pcast<PointerType>(rhs_type)->get_pointee());
468  break;
469  case UnaryOpKind::PLUS:
470  case UnaryOpKind::MINUS: /// unary plus/minus
471  if (!rhs_type->is_num()) {
472  error(ErrorType::TYPE_ERROR, rhs, "Expect a numerical type");
473  }
474  p->set_type(rhs_type);
475  break;
476  default:
477  TAN_ASSERT(false);
478  }
479 }
480 
481 DEFINE_AST_VISITOR_IMPL(TypeCheck, Cast) {
482  Expr *lhs = p->get_lhs();
483  visit(lhs);
484  p->set_type(resolve_type(p->get_type(), p));
485 }
486 
487 DEFINE_AST_VISITOR_IMPL(TypeCheck, Assignment) {
488  Expr *rhs = p->get_rhs();
489  visit(rhs);
490 
491  auto *lhs = p->get_lhs();
492  Type *lhs_type = nullptr;
493  if (lhs->get_node_type() == ASTNodeType::VAR_DECL) {
494  auto *var_decl = pcast<VarDecl>(lhs);
495 
496  // deduce type of variable declaration
497  if (!var_decl->get_type()) {
498  var_decl->set_type(rhs->get_type());
499  }
500 
501  visit(lhs);
502  lhs_type = var_decl->get_type();
503  } else {
504  visit(lhs);
505 
506  switch (lhs->get_node_type()) {
507  case ASTNodeType::ID: {
508  auto *id = pcast<Identifier>(lhs);
509  if (id->get_id_type() != IdentifierType::ID_VAR_REF) {
510  error(ErrorType::TYPE_ERROR, lhs, "Can only assign value to a variable");
511  }
512  lhs_type = id->get_type();
513  break;
514  }
515  case ASTNodeType::ARG_DECL:
516  case ASTNodeType::BOP_OR_UOP:
517  case ASTNodeType::UOP:
518  case ASTNodeType::BOP:
519  lhs_type = pcast<Expr>(lhs)->get_type();
520  break;
521  default:
522  error(ErrorType::TYPE_ERROR, lhs, "Invalid left-hand operand");
523  }
524  }
525  p->set_type(lhs_type);
526 
527  rhs = create_implicit_conversion(rhs, lhs_type);
528  p->set_rhs(rhs);
529  p->set_lvalue(true);
530 }
531 
532 DEFINE_AST_VISITOR_IMPL(TypeCheck, FunctionCall) {
533  analyze_function_call(p, false); // intrinsic function call is handled elsewhere
534 }
535 
536 DEFINE_AST_VISITOR_IMPL(TypeCheck, FunctionDecl) {
537  analyze_func_decl_prototype(p);
538  analyze_func_body(p);
539 }
540 
541 DEFINE_AST_VISITOR_IMPL(TypeCheck, Import) {
542  for (TypeDecl *t : p->_imported_types) {
543  visit(t);
544  }
545 }
546 
547 DEFINE_AST_VISITOR_IMPL(TypeCheck, Intrinsic) {
548  switch (p->get_intrinsic_type()) {
549  case IntrinsicType::LINENO: {
550  auto sub = IntegerLiteral::Create(p->src(), p->src()->get_line(p->start()), true);
551  auto type = PrimitiveType::GetIntegerType(32, true);
552  sub->set_type(type);
553  p->set_type(type);
554  p->set_sub(sub);
555  break;
556  }
557  case IntrinsicType::FILENAME: {
558  auto sub = StringLiteral::Create(p->src(), p->src()->get_filename());
559  auto type = Type::GetStringType();
560  sub->set_type(type);
561  p->set_type(type);
562  p->set_sub(sub);
563  break;
564  }
565  case IntrinsicType::TEST_COMP_ERROR: {
566  auto *tce = pcast<TestCompError>(p->get_sub());
567  if (tce->_caught)
568  return;
569 
570  push_scope(p);
571 
572  try {
573  visit(tce);
574  } catch (const CompileException &e) {
575  std::cerr << fmt::format("Caught expected compile error: {}\nContinue compilation...\n", e.what());
576  tce->_caught = true;
577  }
578 
579  if (!tce->_caught)
580  error(ErrorType::TYPE_ERROR, p, "Expect a compile error");
581 
582  pop_scope();
583  break;
584  }
585  case IntrinsicType::NOOP:
586  break;
587  case IntrinsicType::INVALID:
588  TAN_ASSERT(false);
589  break;
590  default: {
591  auto *c = p->get_sub();
592  if (c->get_node_type() == ASTNodeType::FUNC_CALL) {
593  analyze_intrinsic_func_call(p, pcast<FunctionCall>(c));
594  return;
595  }
596 
597  break;
598  }
599  }
600 }
601 
602 DEFINE_AST_VISITOR_IMPL(TypeCheck, StringLiteral) {
603  TAN_ASSERT(!p->get_value().empty());
604  p->set_type(Type::GetStringType());
605 }
606 
607 DEFINE_AST_VISITOR_IMPL(TypeCheck, CharLiteral) { p->set_type(Type::GetCharType()); }
608 
609 DEFINE_AST_VISITOR_IMPL(TypeCheck, IntegerLiteral) {
610  Type *ty;
611  if (p->is_unsigned()) {
612  ty = Type::GetIntegerType(32, true);
613  } else {
614  ty = Type::GetIntegerType(32, false);
615  }
616  p->set_type(ty);
617 }
618 
619 DEFINE_AST_VISITOR_IMPL(TypeCheck, BoolLiteral) { p->set_type(Type::GetBoolType()); }
620 
621 DEFINE_AST_VISITOR_IMPL(TypeCheck, FloatLiteral) { p->set_type(Type::GetFloatType(32)); }
622 
623 DEFINE_AST_VISITOR_IMPL(TypeCheck, ArrayLiteral) {
624  // TODO IMPORTANT: find the type that all elements can implicitly convert to
625  // for example: [1, 2.2, 3u] has element type float
626  auto elements = p->get_elements();
627  Type *element_type = nullptr;
628  for (auto *e : elements) {
629  visit(e);
630  if (!element_type) {
631  element_type = e->get_type();
632  }
633  create_implicit_conversion(e, element_type);
634  }
635 
636  TAN_ASSERT(element_type);
637  p->set_type(Type::GetArrayType(element_type, (int)elements.size()));
638 }
639 
640 DEFINE_AST_VISITOR_IMPL(TypeCheck, MemberAccess) {
641  Expr *lhs = p->get_lhs();
642  visit(lhs);
643 
644  Expr *rhs = p->get_rhs();
645 
646  if (rhs->get_node_type() == ASTNodeType::FUNC_CALL) { /// method call
647  p->_access_type = MemberAccess::MemberAccessMemberFunction;
648  auto func_call = pcast<FunctionCall>(rhs);
649  analyze_member_func_call(p, lhs, func_call);
650  } else if (p->_access_type == MemberAccess::MemberAccessBracket) {
651  analyze_bracket_access(p, lhs, rhs);
652  } else if (rhs->get_node_type() == ASTNodeType::ID) { /// member variable
653  p->_access_type = MemberAccess::MemberAccessMemberVariable;
654  analyze_member_access_member_variable(p, lhs, rhs);
655  } else {
656  error(ErrorType::UNKNOWN_SYMBOL, p, "Invalid right-hand operand");
657  }
658 }
659 
660 DEFINE_AST_VISITOR_IMPL(TypeCheck, StructDecl) {
661  str struct_name = p->get_name();
662  auto *ty = pcast<StructType>(p->get_type());
663  TAN_ASSERT(ty && ty->is_struct());
664  auto members = p->get_member_decls();
665 
666  push_scope(p);
667 
668  size_t n = members.size();
669  for (size_t i = 0; i < n; ++i) {
670  Expr *m = members[i];
671 
672  if (m->get_node_type() == ASTNodeType::VAR_DECL) { // member variable without initial value
673  (*ty)[i] = resolve_type((*ty)[i], m);
674 
675  } else if (m->get_node_type() == ASTNodeType::ASSIGN) { // member variable with an initial value
676  auto init_val = pcast<Assignment>(m)->get_rhs();
677  visit(init_val);
678 
679  (*ty)[i] = resolve_type((*ty)[i], m);
680 
681  if (!init_val->is_comptime_known()) {
682  error(ErrorType::TYPE_ERROR, p, "Initial value of a member variable must be compile-time known");
683  }
684 
685  } else if (m->get_node_type() == ASTNodeType::FUNC_DECL) { // TODO: member functions
686  auto f = pcast<FunctionDecl>(m);
687  (*ty)[i] = f->get_type();
688 
689  } else {
690  error(ErrorType::TYPE_ERROR, p, "Invalid struct member");
691  }
692  }
693 
694  pop_scope();
695 }
696 
697 DEFINE_AST_VISITOR_IMPL(TypeCheck, Loop) {
698  push_scope(p);
699 
700  if (p->_loop_type == ASTLoopType::FOR) {
701  visit(p->_initialization);
702  }
703 
704  visit(p->_predicate);
705 
706  if (p->_loop_type == ASTLoopType::FOR) {
707  visit(p->_iteration);
708  }
709 
710  visit(p->_body);
711 
712  pop_scope();
713 }
714 
715 DEFINE_AST_VISITOR_IMPL(TypeCheck, BreakContinue) {
716  Loop *loop = search_node_in_parent_scopes<Loop, ASTNodeType::LOOP>();
717  if (!loop) {
718  error(ErrorType::SEMANTIC_ERROR, p, "Break or continue must be inside a loop");
719  }
720  p->set_parent_loop(pcast<Loop>(loop));
721 }
722 
723 } // namespace tanlang
static umap< ASTNodeType, str > ASTTypeNames
string representation of ASTNodeType
Definition: ast_base.h:16
FunctionDecl * get_func_decl(const str &name) const
Search for a function declaration by name.
Definition: context.cpp:32
static VarRef * Create(TokenizedSourceFile *src, const str &name, Decl *referred)
Definition: expr.cpp:90