diff --git a/nerfstudio/data/datamanagers/parallel_datamanager.py b/nerfstudio/data/datamanagers/parallel_datamanager.py index b28e530f91..c24bb56ebd 100644 --- a/nerfstudio/data/datamanagers/parallel_datamanager.py +++ b/nerfstudio/data/datamanagers/parallel_datamanager.py @@ -46,6 +46,8 @@ from nerfstudio.utils.misc import get_orig_class from nerfstudio.utils.rich_utils import CONSOLE +# For orthophotos +import math @dataclass class ParallelDataManagerConfig(VanillaDataManagerConfig): @@ -289,6 +291,24 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]: self.train_count += 1 bundle, batch = self.data_queue.get() ray_bundle = bundle.to(self.device) + + # Save the directions of the rays that are in the center of the image + # The average of these are assumed to be the vertical axis in the scene and is used to project the orthophoto + indices = batch["indices"] + center_index = None + # Set the downscale factor, image width and image height to the correct values + # These are used to find the center of the image + downscale_factor = 2 + width = 5280 + height = 3956 + for i, row in enumerate(indices): + if int(math.floor(height / (2 * downscale_factor))) <= row[1] <= int(math.ceil(height / (2 * downscale_factor))) \ + and int(math.floor(width / (2 * downscale_factor))) <= row[2] <= int(math.ceil(width / (2 * downscale_factor))): + center_index = i + direction = ray_bundle.directions[center_index] + with open("directions.txt", "a") as f: + f.write(f"{direction}\n") + return ray_bundle, batch def next_eval(self, step: int) -> Tuple[RayBundle, Dict]: diff --git a/nerfstudio/pipelines/base_pipeline.py b/nerfstudio/pipelines/base_pipeline.py index 651a2c3008..a49a846387 100644 --- a/nerfstudio/pipelines/base_pipeline.py +++ b/nerfstudio/pipelines/base_pipeline.py @@ -41,6 +41,12 @@ from nerfstudio.models.base_model import Model, ModelConfig from nerfstudio.utils import profiler +# For orthophoto generation +from PIL import Image +import numpy as np +import os +from nerfstudio.cameras.cameras import CameraType +from nerfstudio.cameras.cameras import Cameras def module_wrapper(ddp_or_model: Union[DDP, Model]) -> Model: """ @@ -374,6 +380,79 @@ def get_average_eval_image_metrics( task = progress.add_task("[green]Evaluating all eval images...", total=num_images) idx = 0 for camera, batch in self.datamanager.fixed_indices_eval_dataloader: + if not os.path.isfile('outputs/orthophoto-0.jpg'): + def load_tensors_from_file(file_path): + tensors = [] + with open(file_path, 'r') as file: + for line in file: + # Extract the numerical part of the line and convert to list of floats + nums = line.strip().split('[')[1].split(']')[0].split(',') + nums = [float(num.strip()) for num in nums] + # Create a tensor from the list of floats + tensors.append(nums) + return torch.Tensor(tensors) + + def rotation_matrix_from_vectors(vec1, vec2): + # Find the rotation matrix that aligns vec1 with vec2 + a = (vec1 / np.linalg.norm(vec1)).reshape(3) + b = (vec2 / np.linalg.norm(vec2)).reshape(3) + v = np.cross(a, b) + + if any(v): + # If not all zeros then + c = np.dot(a, b) + s = np.linalg.norm(v) + kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) + + return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2)) + else: + # Cross of all zeros only occurs on identical directions + return np.eye(3) + + # Generate 10 orthophotos at 10 different "heights" + for i in range(0, 10): + # directions.txt contain the directions of the middle pixels of images used in training. + # It is assumed that the average of those is the vertical axis. + directions = load_tensors_from_file("directions.txt") + vertical_axis = torch.mean(directions, axis=0) + + R = rotation_matrix_from_vectors(np.array([0, 0, -1]), np.array(vertical_axis)) + R = torch.Tensor(R) + + # Move the camera along the vertical axis to get orthophotos with different amounts of floaters + distance = 0.05 * i + vertical_axis = torch.Tensor(vertical_axis) + T = torch.zeros(3) + distance * vertical_axis + + # Create a camera-to-world matrix + c2w = torch.cat((R, T.view(3, 1)), dim=1) + + # The below settings produce an orthophoto that covers all of the scene. + # Only change the scale to change the resolution. + scale = 1/3 + fx = torch.Tensor([2000*scale]).view(1, 1) + fy = torch.Tensor([2000*scale]).view(1, 1) + cx = torch.Tensor([3000*scale]).view(1, 1) + cy = torch.Tensor([3000*scale]).view(1, 1) + height = torch.Tensor([6000*scale]).to(int).view(1, 1) + width = torch.Tensor([6000*scale]).to(int).view(1, 1) + distortion_params = torch.zeros(1, 6) + camera_type = torch.Tensor([CameraType.ORTHOPHOTO.value]).to(int).view(1, 1) + ortho_camera = Cameras(c2w, fx, fy, cx, cy, height, width, distortion_params, camera_type) + + # Generate orthophoto + outputs = self.model.get_outputs_for_camera(camera=ortho_camera) + + # Convert RGB valus to jpg image + rgbs = outputs["rgb"] + rgb_array = rgbs.cpu().numpy() + + # Scale the values from [0, 1] to [0, 255] + rgb_array = (rgb_array * 255).astype(np.uint8) + image = Image.fromarray(rgb_array) + + image.save(f"outputs/orthophoto-{i}.jpg") + # time this the following line inner_start = time() outputs = self.model.get_outputs_for_camera(camera=camera)