diff --git a/rust/src/import_parsing.rs b/rust/src/import_parsing.rs index 5fbd9768..0dd3e1a1 100644 --- a/rust/src/import_parsing.rs +++ b/rust/src/import_parsing.rs @@ -30,57 +30,21 @@ impl ImportedObject { } } -#[derive(Debug, PartialEq, Eq, Clone)] -struct ImportedObjectWithoutLineContents { - pub name: String, - pub line_number: usize, - pub typechecking_only: bool, -} - -impl ImportedObjectWithoutLineContents { - fn new(name: String, line_number: usize, typechecking_only: bool) -> Self { - Self { - name, - line_number, - typechecking_only, - } - } -} - pub fn parse_imports(path: &Path) -> GrimpResult> { let code = fs::read_to_string(path).expect("failed to read file"); parse_imports_from_code(&code) } pub fn parse_imports_from_code(code: &str) -> GrimpResult> { - let imports_without_line_contents = parse_imports_from_code_without_line_contents(code)?; - - let lines: Vec<&str> = code.lines().collect(); - - Ok(imports_without_line_contents - .into_iter() - .map(|i| { - ImportedObject::new( - i.name, - i.line_number, - lines[i.line_number - 1].trim_start().to_string(), - i.typechecking_only, - ) - }) - .collect()) -} - -fn parse_imports_from_code_without_line_contents( - code: &str, -) -> GrimpResult> { let line_index = LineIndex::from_source_text(code); let source_code = SourceCode::new(code, &line_index); let ast = match parse_module(code) { Ok(ast) => ast, Err(e) => { - let line_number = source_code.line_index(e.location.start()).get(); - let text = source_code.slice(e.location); + let location_index = source_code.line_index(e.location.start()); + let line_number = location_index.get(); + let text = source_code.line_text(location_index).trim(); Err(GrimpError::ParseError { line_number, text: text.to_owned(), @@ -98,7 +62,7 @@ fn parse_imports_from_code_without_line_contents( #[derive(Debug)] struct Visitor<'a> { source_code: SourceCode<'a, 'a>, - pub imported_objects: Vec, + pub imported_objects: Vec, pub typechecking_only: bool, } @@ -118,12 +82,12 @@ impl<'a> StatementVisitor<'a> for Visitor<'a> { Stmt::Import(import_stmt) => { let line_number = self.source_code.line_index(import_stmt.range.start()); for name in import_stmt.names.iter() { - self.imported_objects - .push(ImportedObjectWithoutLineContents::new( - name.name.id.clone(), - line_number.get(), - self.typechecking_only, - )) + self.imported_objects.push(ImportedObject::new( + name.name.id.clone(), + line_number.get(), + self.source_code.line_text(line_number).trim().to_string(), + self.typechecking_only, + )) } walk_stmt(self, stmt); } @@ -147,12 +111,12 @@ impl<'a> StatementVisitor<'a> for Visitor<'a> { ) } }; - self.imported_objects - .push(ImportedObjectWithoutLineContents::new( - imported_object_name, - line_number.get(), - self.typechecking_only, - )) + self.imported_objects.push(ImportedObject::new( + imported_object_name, + line_number.get(), + self.source_code.line_text(line_number).trim().to_string(), + self.typechecking_only, + )) } walk_stmt(self, stmt); } diff --git a/tests/functional/test_error_handling.py b/tests/functional/test_error_handling.py index 8cf1c804..45635986 100644 --- a/tests/functional/test_error_handling.py +++ b/tests/functional/test_error_handling.py @@ -15,7 +15,9 @@ def test_syntax_error_includes_module(): with pytest.raises(exceptions.SourceSyntaxError) as excinfo: build_graph("syntaxerrorpackage", cache_dir=None) - expected_exception = exceptions.SourceSyntaxError(filename=filename, lineno=5, text="import") + expected_exception = exceptions.SourceSyntaxError( + filename=filename, lineno=5, text="fromb . import two" + ) assert expected_exception == excinfo.value