tan  0.0.1
type.cpp
1 #include "ast/type.h"
2 #include "ast/decl.h"
3 #include <unordered_set>
4 #include <queue>
5 #include <fmt/format.h>
6 
7 using namespace tanlang;
8 
9 StringType *Type::STRING_TYPE = new StringType();
10 
11 PrimitiveType *PrimitiveType::Create(PrimitiveType::Kind kind) {
12  auto it = CACHE.find(kind);
13  if (it != CACHE.end()) {
14  return it->second;
15  } else {
16  auto *ret = new PrimitiveType();
17  ret->_kind = kind;
18  ret->_type_name = TYPE_NAMES[kind];
19  CACHE[kind] = ret;
20  return ret;
21  }
22 }
23 
24 PrimitiveType *Type::GetVoidType() { return PrimitiveType::Create(PrimitiveType::VOID); }
25 
26 PrimitiveType *Type::GetBoolType() { return PrimitiveType::Create(PrimitiveType::BOOL); }
27 
28 PrimitiveType *Type::GetCharType() { return PrimitiveType::Create(PrimitiveType::CHAR); }
29 
30 PrimitiveType *Type::GetIntegerType(size_t bit_size, bool is_unsigned) {
31  switch (bit_size) {
32  case 8:
33  if (is_unsigned) {
34  return PrimitiveType::Create(PrimitiveType::U8);
35  } else {
36  return PrimitiveType::Create(PrimitiveType::I8);
37  }
38  case 16:
39  if (is_unsigned) {
40  return PrimitiveType::Create(PrimitiveType::U16);
41  } else {
42  return PrimitiveType::Create(PrimitiveType::I16);
43  }
44  case 32:
45  if (is_unsigned) {
46  return PrimitiveType::Create(PrimitiveType::U32);
47  } else {
48  return PrimitiveType::Create(PrimitiveType::I32);
49  }
50  case 64:
51  if (is_unsigned) {
52  return PrimitiveType::Create(PrimitiveType::U64);
53  } else {
54  return PrimitiveType::Create(PrimitiveType::I64);
55  }
56  default:
57  TAN_ASSERT(false);
58  }
59 }
60 
61 PrimitiveType *Type::GetFloatType(size_t bit_size) {
62  switch (bit_size) {
63  case 32:
64  return PrimitiveType::Create(PrimitiveType::F32);
65  case 64:
66  return PrimitiveType::Create(PrimitiveType::F64);
67  default:
68  TAN_ASSERT(false);
69  }
70 }
71 
72 StringType *Type::GetStringType() { return STRING_TYPE; }
73 
74 PointerType *Type::GetPointerType(Type *pointee) {
75  auto it = POINTER_TYPE_CACHE.find(pointee);
76  if (it != POINTER_TYPE_CACHE.end()) {
77  TAN_ASSERT(it->second->is_pointer() && it->second->get_pointee() == pointee);
78  return it->second;
79  } else {
80  auto *ret = new PointerType(pointee);
81  POINTER_TYPE_CACHE[pointee] = ret;
82  return ret;
83  }
84 }
85 
86 ArrayType *Type::GetArrayType(Type *element_type, int size) {
87  auto it = ARRAY_TYPE_CACHE.find({element_type, size});
88  if (it != ARRAY_TYPE_CACHE.end()) {
89  return it->second;
90  } else {
91  auto *ret = new ArrayType(element_type, size);
92  ARRAY_TYPE_CACHE[{element_type, size}] = ret;
93  return ret;
94  }
95 }
96 
97 // TODO IMPORTANT: cache function types
98 FunctionType *Type::GetFunctionType(Type *ret_type, const vector<Type *> &arg_types) {
99  return new FunctionType(ret_type, arg_types);
100 }
101 
102 StructType *Type::GetStructType(StructDecl *decl) {
103  auto it = NAMED_TYPE_CACHE.find(decl->get_name());
104  if (it != NAMED_TYPE_CACHE.end()) {
105  auto *t = pcast<StructType>(it->second);
106  t->_member_types = decl->get_member_types(); // update forward declaration
107  return t;
108  } else {
109  auto *ret = new StructType(decl);
110  NAMED_TYPE_CACHE[decl->get_name()] = ret;
111  return ret;
112  }
113 }
114 
115 TypeRef *Type::GetTypeRef(const str &name) { return new TypeRef(name); }
116 
117 int Type::get_align_bits() {
118  TAN_ASSERT(false);
119  return 0;
120 }
121 
122 int Type::get_size_bits() {
123  TAN_ASSERT(false);
124  return 0;
125 }
126 
127 vector<Type *> Type::children() const {
128  TAN_ASSERT(false);
129  return {};
130 }
131 
132 int PrimitiveType::get_size_bits() { return SIZE_BITS[_kind]; }
133 
134 int PrimitiveType::get_align_bits() {
135  TAN_ASSERT(_kind != VOID);
136  return SIZE_BITS[_kind]; // the same as their sizes
137 }
138 
139 PointerType::PointerType(Type *pointee_type) : _pointee_type(pointee_type) {
140  _type_name = pointee_type->get_typename() + "*";
141 }
142 
143 vector<Type *> PointerType::children() const { return {_pointee_type}; }
144 
145 // TODO: find out the pointer size from llvm::TargetMachine
146 int PointerType::get_align_bits() { return 64; }
147 int PointerType::get_size_bits() { return 64; }
148 int ArrayType::get_align_bits() { return 64; }
149 int ArrayType::get_size_bits() { return 64; }
150 int StringType::get_align_bits() { return 64; }
151 int StringType::get_size_bits() { return 64; }
152 
153 ArrayType::ArrayType(Type *element_type, int size) : _element_type(element_type), _size(size) {
154  _type_name = element_type->get_typename() + "[" + std::to_string(size) + "]";
155 }
156 
157 vector<Type *> ArrayType::children() const { return {_element_type}; }
158 
159 StringType::StringType() { _type_name = "str"; }
160 
161 StructType::StructType(StructDecl *decl) {
162  _decl = decl;
163  _type_name = decl->get_name();
164  _member_types = decl->get_member_types();
165 }
166 
167 int StructType::get_align_bits() {
168  int ret = 0;
169  for (auto *t : _member_types) {
170  ret = std::max(t->get_align_bits(), ret);
171  }
172  TAN_ASSERT(ret);
173  return ret;
174 }
175 
176 int StructType::get_size_bits() {
177  // TODO: calculate struct size in bits
178  return 8;
179 }
180 
181 void StructType::append_member_type(Type *t) { _member_types.push_back(t); }
182 Type *&StructType::operator[](size_t index) { return _member_types[index]; }
183 Type *StructType::operator[](size_t index) const { return _member_types[index]; }
184 vector<Type *> StructType::children() const { return _member_types; }
185 vector<Type *> StructType::get_member_types() const { return _member_types; }
186 StructDecl *StructType::get_decl() const { return _decl; }
187 
188 TypeRef::TypeRef(const str &name) { _type_name = name; }
189 
190 FunctionType::FunctionType(Type *ret_type, const vector<Type *> &arg_types) {
191  _ret_type = ret_type;
192  _arg_types = arg_types;
193 }
194 
195 Type *FunctionType::get_return_type() const { return _ret_type; }
196 
197 vector<Type *> FunctionType::get_arg_types() const { return _arg_types; }
198 
199 void FunctionType::set_arg_types(const vector<Type *> &arg_types) { _arg_types = arg_types; }
200 
201 void FunctionType::set_return_type(Type *t) { _ret_type = t; }
202 
203 vector<Type *> FunctionType::children() const {
204  vector<Type *> ret{_ret_type};
205  ret.insert(ret.begin(), _arg_types.begin(), _arg_types.end());
206  return ret;
207 }
208 
209 bool Type::IsCanonical(const Type &type) {
210  std::queue<Type const *> q{};
211  std::unordered_set<Type const *> s{}; // avoid infinite recursion
212  q.push(&type);
213 
214  umap<Type *, bool> met_pointer{};
215  while (!q.empty()) {
216  auto *t = (Type *)q.front();
217  s.insert(t);
218  q.pop();
219 
220  if (!t || t->is_ref()) {
221  return false;
222  } else if (t->is_array() || t->is_pointer() || t->is_function()) {
223  if (t->is_pointer())
224  met_pointer[t] = true;
225 
226  auto children = t->children();
227  for (auto *c : children) {
228  met_pointer[c] = met_pointer[t];
229 
230  if (!c) {
231  return false;
232  } else if (!s.contains(c))
233  q.push(c);
234  }
235  } else if (t->is_struct()) {
236  auto children = t->children();
237  for (auto *c : children) {
238  if (!c)
239  return false;
240 
241  met_pointer[c] = met_pointer[t];
242 
243  if (!s.contains(c)) {
244  q.push(c);
245  } else if (c->is_struct() && !met_pointer[t]) {
246  Error err(fmt::format("Recursive type reference to {} without using a pointer", c->get_typename()));
247  err.raise();
248  }
249  }
250  }
251  }
252 
253  return true;
254 }
255 
256 bool Type::is_canonical() const { return Type::IsCanonical(*this); }
257 bool Type::is_primitive() const { return false; }
258 bool Type::is_pointer() const { return false; }
259 bool Type::is_array() const { return false; }
260 bool Type::is_string() const { return false; }
261 bool Type::is_struct() const { return false; }
262 bool Type::is_function() const { return false; }
263 bool Type::is_ref() const { return false; }
264 bool Type::is_float() const { return false; }
265 bool Type::is_int() const { return false; }
266 bool Type::is_num() const { return false; }
267 bool Type::is_unsigned() const { return false; }
268 bool Type::is_bool() const { return false; }
269 bool Type::is_void() const { return false; }
270 bool Type::is_char() const { return false; }
Placeholder during parsing.
Definition: type.h:268
Type is immutable once created. The exception is StructType. Its information is updated in multiple s...
Definition: type.h:22
static bool IsCanonical(const Type &type)
A composite type is canonical only if its subtype(s) are also canonical. A non-composite type is cano...
Definition: type.cpp:209