Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 230 additions & 7 deletions lib/parser/ast/visitors/BytecodeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ void BytecodeVisitor::Visit(Module& node) {
constructor_params_.clear();
type_aliases_.clear();
pending_init_static_.clear();
pending_init_static_names_.clear();
pending_init_static_types_.clear();

for (auto& decl : node.MutableDecls()) {
if (auto* f = dynamic_cast<FunctionDecl*>(decl.get())) {
Expand All @@ -461,6 +463,7 @@ void BytecodeVisitor::Visit(Module& node) {
if (sd->MutableInit() != nullptr) {
pending_init_static_.push_back(sd->MutableInit());
pending_init_static_names_.push_back(sd->Name());
pending_init_static_types_.push_back(sd->Type());
}
}
if (const auto* md = dynamic_cast<MethodDecl*>(m.get())) {
Expand Down Expand Up @@ -505,6 +508,7 @@ void BytecodeVisitor::Visit(Module& node) {
if (gv->MutableInit() != nullptr) {
pending_init_static_.push_back(gv->MutableInit());
pending_init_static_names_.push_back(gv->Name());
pending_init_static_types_.push_back(gv->Type());
}
}

Expand All @@ -519,6 +523,16 @@ void BytecodeVisitor::Visit(Module& node) {
if (!pending_init_static_.empty()) {
for (size_t i = 0; i < pending_init_static_.size(); ++i) {
pending_init_static_[i]->Accept(*this);

// Add CallConstructor if type is a wrapper type (Float, Int, etc.)
if (i < pending_init_static_types_.size()) {
std::string type_name = TypeToMangledName(pending_init_static_types_[i]);
if (IsPrimitiveWrapper(type_name)) {
std::string primitive_type = GetPrimitiveTypeForWrapper(type_name);
EmitWrapConstructorCall(type_name, primitive_type);
}
}

EmitCommandWithInt("SetStatic", static_cast<int64_t>(GetStaticIndex(pending_init_static_names_[i])));
}
}
Expand Down Expand Up @@ -677,11 +691,40 @@ void BytecodeVisitor::Visit(CallDecl& node) {
if (node.MutableBody() != nullptr) {
node.MutableBody()->Accept(*this);

EmitCommandWithInt("LoadLocal", 0);
EmitCommand("Return");
// Check if body ends with return statement
bool has_return = false;
if (const auto& stmts = node.MutableBody()->GetStatements(); !stmts.empty()) {
if (dynamic_cast<ReturnStmt*>(stmts.back().get()) != nullptr) {
has_return = true;
}
}

// Only add implicit return if body doesn't already have one
if (!has_return) {
if (node.ReturnType() != nullptr) {
std::string return_type_name = TypeToMangledName(*node.ReturnType());
if (return_type_name == "void") {
EmitCommand("Return");
} else {
EmitCommandWithInt("LoadLocal", 0);
EmitCommand("Return");
}
} else {
EmitCommand("Return");
}
}
} else {
EmitCommandWithInt("LoadLocal", 0);
EmitCommand("Return");
if (node.ReturnType() != nullptr) {
std::string return_type_name = TypeToMangledName(*node.ReturnType());
if (return_type_name == "void") {
EmitCommand("Return");
} else {
EmitCommandWithInt("LoadLocal", 0);
EmitCommand("Return");
}
} else {
EmitCommand("Return");
}
}
EmitBlockEnd();
output_ << "\n";
Expand Down Expand Up @@ -712,6 +755,8 @@ void BytecodeVisitor::Visit(TypeAliasDecl& node) {

void BytecodeVisitor::Visit(GlobalVarDecl& node) {
(void) GetStaticIndex(node.Name());
std::string type_name = TypeToMangledName(node.Type());
variable_types_[node.Name()] = type_name;
}

void BytecodeVisitor::Visit(FieldDecl&) {
Expand Down Expand Up @@ -1043,7 +1088,6 @@ void BytecodeVisitor::Visit(ForStmt& node) {
if (const auto local_it = local_variables_.find(ident->Name());
var_it != variable_types_.end() && local_it != local_variables_.end()) {
collection_index = local_it->second;
collection_var_name = ident->Name();
collection_type = var_it->second;
} else {
node.MutableIteratorExpr()->Accept(*this);
Expand Down Expand Up @@ -1332,19 +1376,26 @@ void BytecodeVisitor::Visit(Assign& node) {
}

if (auto* ident = dynamic_cast<IdentRef*>(&node.MutableTarget())) {
node.MutableValue().Accept(*this);
// Check if this is a global variable before generating code
bool is_global = static_variables_.contains(ident->Name());

std::string expected_type_name;
if (auto type_it = variable_types_.find(ident->Name()); type_it != variable_types_.end()) {
expected_type_name = type_it->second;
}

node.MutableValue().Accept(*this);

if (!expected_type_name.empty()) {
std::string value_type_name = GetTypeNameForExpr(&node.MutableValue());
EmitTypeConversionIfNeeded(expected_type_name, value_type_name);
}

EmitCommandWithInt("SetLocal", static_cast<int64_t>(GetLocalIndex(ident->Name())));
if (is_global) {
EmitCommandWithInt("SetStatic", static_cast<int64_t>(GetStaticIndex(ident->Name())));
} else {
EmitCommandWithInt("SetLocal", static_cast<int64_t>(GetLocalIndex(ident->Name())));
}
} else if (auto* index_access = dynamic_cast<IndexAccess*>(&node.MutableTarget())) {
node.MutableValue().Accept(*this);

Expand Down Expand Up @@ -1384,12 +1435,61 @@ void BytecodeVisitor::Visit(Call& node) {
return;
}

// Handle sys::ToString for primitive types
if (ns_name == "ToString" && args.size() == 1) {
args[0]->Accept(*this);
std::string arg_type = GetTypeNameForExpr(args[0].get());

// Handle wrapper types by unwrapping first
if (IsPrimitiveWrapper(arg_type)) {
std::string primitive_type = GetPrimitiveTypeForWrapper(arg_type);
EmitCommand("Unwrap");
arg_type = primitive_type;
}

// For primitive types, use special instructions
if (arg_type == "int") {
EmitCommand("IntToString");
return;
}
if (arg_type == "float") {
EmitCommand("FloatToString");
return;
}
if (arg_type == "byte") {
EmitCommand("ByteToString");
return;
}
if (arg_type == "char") {
EmitCommand("CharToString");
return;
}
if (arg_type == "bool") {
EmitCommand("BoolToString");
return;
}
}

// Handle sys::Sqrt for float type
if (ns_name == "Sqrt" && args.size() == 1) {
args[0]->Accept(*this);
std::string arg_type = GetTypeNameForExpr(args[0].get());

// Sqrt only works with float type
if (arg_type == "float") {
EmitCommand("FloatSqrt");
return;
}
}

std::string full_name = "sys::" + ns_name;
if (auto it = function_name_map_.find(full_name); it != function_name_map_.end()) {
EmitArgumentsInReverse(args);
EmitCommandWithStringWithoutBraces("Call", it->second);
return;
}

EmitArgumentsInReverse(args);
EmitCommandWithStringWithoutBraces("Call", full_name);
return;
}
Expand Down Expand Up @@ -1518,6 +1618,44 @@ void BytecodeVisitor::Visit(Call& node) {
return;
}

// Handle ToString for primitive types
if (name == "ToString" && args.size() == 1) {
args[0]->Accept(*this);
std::string arg_type = GetTypeNameForExpr(args[0].get());

// For wrapper types, use their ToString method from kBuiltinMethods
if (kBuiltinTypeNames.contains(arg_type)) {
if (auto methods_it = kBuiltinMethods.find(arg_type); methods_it != kBuiltinMethods.end()) {
if (auto tostring_it = methods_it->second.find("ToString"); tostring_it != methods_it->second.end()) {
EmitCommandWithStringWithoutBraces("Call", tostring_it->second);
return;
}
}
}

// For primitive types, use special instructions
if (arg_type == "int") {
EmitCommand("IntToString");
return;
}
if (arg_type == "float") {
EmitCommand("FloatToString");
return;
}
if (arg_type == "byte") {
EmitCommand("ByteToString");
return;
}
if (arg_type == "char") {
EmitCommand("CharToString");
return;
}
if (arg_type == "bool") {
EmitCommand("BoolToString");
return;
}
}

if (auto it = function_name_map_.find(name); it != function_name_map_.end()) {
for (auto& arg : std::ranges::reverse_view(args)) {
arg->Accept(*this);
Expand Down Expand Up @@ -2229,9 +2367,22 @@ size_t BytecodeVisitor::GetStaticIndex(const std::string& name) {
}

void BytecodeVisitor::ResetLocalVariables() {
// Save global variable types before clearing
std::unordered_map<std::string, std::string> global_types;
for (const auto& [name, type] : variable_types_) {
if (static_variables_.contains(name)) {
global_types[name] = type;
}
}

local_variables_.clear();
variable_types_.clear();
next_local_index_ = 0;

// Restore global variable types
for (const auto& [name, type] : global_types) {
variable_types_[name] = type;
}
}

BytecodeVisitor::OperandType BytecodeVisitor::DetermineOperandType(Expr* expr) {
Expand Down Expand Up @@ -2375,6 +2526,29 @@ BytecodeVisitor::OperandType BytecodeVisitor::DetermineOperandType(Expr* expr) {
return DetermineOperandType(&unary->MutableOperand());
}

if (auto* call = dynamic_cast<Call*>(expr)) {
// Use GetTypeNameForExpr to determine return type
std::string return_type = GetTypeNameForExpr(call);
if (return_type == "int") {
return OperandType::kInt;
}
if (return_type == "float") {
return OperandType::kFloat;
}
if (return_type == "byte") {
return OperandType::kByte;
}
if (return_type == "bool") {
return OperandType::kBool;
}
if (return_type == "char") {
return OperandType::kChar;
}
if (return_type == "String") {
return OperandType::kString;
}
}

return OperandType::kUnknown;
}

Expand Down Expand Up @@ -2454,6 +2628,55 @@ std::string BytecodeVisitor::GetTypeNameForExpr(Expr* expr) {
return "String";
}

if (auto* call = dynamic_cast<Call*>(expr)) {
if (const auto* ident = dynamic_cast<IdentRef*>(&call->MutableCallee())) {
const std::string func_name = ident->Name();
if (const auto it = function_return_types_.find(func_name); it != function_return_types_.end()) {
return it->second;
}
}

// Handle sys::Sqrt and other sys namespace functions
if (auto* ns_ref = dynamic_cast<NamespaceRef*>(&call->MutableCallee())) {
std::string ns_name = ns_ref->Name();
// Check if namespace is "sys" by examining NamespaceExpr
if (const auto* ns_ident = dynamic_cast<IdentRef*>(&ns_ref->MutableNamespaceExpr())) {
if (ns_ident->Name() == "sys") {
// Check for built-in return types
if (const auto it = kBuiltinReturnPrimitives.find(ns_name); it != kBuiltinReturnPrimitives.end()) {
return it->second;
}
if (ns_name == "Sqrt" && call->Args().size() == 1) {
return "float";
}
if (ns_name == "ToString" && call->Args().size() == 1) {
// ToString returns String, but we need to check the argument type
// to determine which *ToString instruction to use
std::string arg_type = GetTypeNameForExpr(call->Args()[0].get());
if (arg_type == "int" || arg_type == "Int") {
return "String";
}
if (arg_type == "float" || arg_type == "Float") {
return "String";
}
// For other types, ToString still returns String
return "String";
}
}
}
}
}

if (const auto* cast = dynamic_cast<CastAs*>(expr)) {
return TypeToMangledName(cast->Type());
}

// For Call expressions, if we couldn't determine the type, return "unknown"
// to avoid infinite recursion (GetOperandTypeName -> DetermineOperandType -> GetTypeNameForExpr)
if (dynamic_cast<Call*>(expr) != nullptr) {
return "unknown";
}

return GetOperandTypeName(expr);
}

Expand Down
1 change: 1 addition & 0 deletions lib/parser/ast/visitors/BytecodeVisitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class BytecodeVisitor : public AstVisitor {

std::vector<Expr*> pending_init_static_;
std::vector<std::string> pending_init_static_names_;
std::vector<TypeReference> pending_init_static_types_;
std::unordered_map<std::string, std::string> method_name_map_;
std::unordered_map<std::string, std::string> method_vtable_map_;
std::unordered_map<std::string, std::string> method_return_types_;
Expand Down
10 changes: 10 additions & 0 deletions lib/parser/ast/visitors/StructuralValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "lib/parser/ast/nodes/class_members/CallDecl.hpp"
#include "lib/parser/ast/nodes/class_members/DestructorDecl.hpp"
#include "lib/parser/ast/nodes/class_members/MethodDecl.hpp"
#include "lib/parser/ast/nodes/class_members/StaticFieldDecl.hpp"

#include "lib/parser/ast/nodes/decls/ClassDecl.hpp"
#include "lib/parser/ast/nodes/decls/FunctionDecl.hpp"
Expand Down Expand Up @@ -115,4 +116,13 @@ void StructuralValidator::Visit(SafeCall& node) {
WalkVisitor::Visit(node);
}

void StructuralValidator::Visit(StaticFieldDecl& node) {
// Static constants (val) cannot be initialized
if (!node.IsVar() && node.MutableInit() != nullptr) {
sink_.Error("E1401", "static constant cannot be initialized");
}

WalkVisitor::Visit(node);
}

} // namespace ovum::compiler::parser
1 change: 1 addition & 0 deletions lib/parser/ast/visitors/StructuralValidator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class StructuralValidator : public WalkVisitor {
void Visit(CallDecl& node) override;
void Visit(MethodDecl& node) override;
void Visit(DestructorDecl& node) override;
void Visit(StaticFieldDecl& node) override;

void Visit(Call& node) override;
void Visit(Binary& node) override;
Expand Down
3 changes: 1 addition & 2 deletions lib/parser/states/StateBlock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ IState::StepResult StateBlock::TryStep(ContextParser& ctx, ITokenStream& ts) con

// Pop call and add to class
auto call_node = ctx.PopNode();
auto* class_decl = ctx.TopNodeAs<ClassDecl>();
if (class_decl != nullptr) {
if (auto* class_decl = ctx.TopNodeAs<ClassDecl>(); class_decl != nullptr) {
class_decl->AddMember(std::unique_ptr<Decl>(dynamic_cast<Decl*>(call_node.release())));
}
}
Expand Down
Loading