diff --git a/src/AMSlib/include/AMSError.hpp b/src/AMSlib/include/AMSError.hpp index 421d9bd4..9447b6ba 100644 --- a/src/AMSlib/include/AMSError.hpp +++ b/src/AMSlib/include/AMSError.hpp @@ -16,6 +16,7 @@ enum class AMSErrorType { FileDoesNotExist, ///< Path to file or directory does not exist TorchInternal, ///< An internal error that happens to the torch library InvalidModel, ///< A torchscripted model that has not been serialized through AMS + InvalidShapes, ///< Some Data shape is not the proper|expected shape }; /// \brief Strongly-typed error object used across AMS. diff --git a/src/AMSlib/wf/index_map.hpp b/src/AMSlib/wf/index_map.hpp new file mode 100644 index 00000000..2e96321d --- /dev/null +++ b/src/AMSlib/wf/index_map.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include +#include + +namespace ams +{ + +/// Field-to-column mapping for layout transformations. +struct IndexMap { + struct FieldInfo { + std::string Name; + + enum class Kind { Input, InOut, Output }; + Kind EKind; + + int64_t Offset; ///< Starting column in the concatenated tensor + int64_t Cols; ///< Number of columns this field covers + }; + + std::vector Fields; +}; + +} // namespace ams diff --git a/src/AMSlib/wf/layout_transform.hpp b/src/AMSlib/wf/layout_transform.hpp index 16ed851b..b7806fee 100644 --- a/src/AMSlib/wf/layout_transform.hpp +++ b/src/AMSlib/wf/layout_transform.hpp @@ -5,6 +5,8 @@ #include +#include "AMSError.hpp" +#include "wf/index_map.hpp" #include "wf/tensor_bundle.hpp" namespace ams @@ -24,26 +26,15 @@ class LayoutTransform public: virtual ~LayoutTransform() = default; - /// Pack the application-level Inputs and Inouts into a single tensor suitable - /// for feeding into the ML model. - virtual at::Tensor pack(const TensorBundle& Inputs, - const TensorBundle& Inouts) = 0; - - /// Unpack the model's output (an IValue that may be a tensor or a tuple of - /// tensors) into: - /// - Outputs - /// - Inouts - /// - Uncertainties (optional) - /// - /// Concrete layouts determine how the returned IValue maps back to domain - /// tensors. Only LayoutTransform knows the correct indexing and shapes. - virtual void unpack(const torch::jit::IValue& ModelOutput, - TensorBundle& Outputs, - TensorBundle& Inouts, - std::optional& Uncertainties) = 0; - - /// Descriptive name used for debugging, logging, and introspection. - /// Must be implemented by all subclasses. + virtual AMSExpected pack(const TensorBundle& Inputs, + const TensorBundle& InOuts, + at::Tensor& ModelInput) = 0; + + virtual AMSStatus unpack(const torch::jit::IValue& ModelOutput, + TensorBundle& Outs, + TensorBundle& InOuts, + std::optional& Uncertainties) = 0; + virtual const char* name() const noexcept = 0; }; diff --git a/src/AMSlib/wf/pointwise_layout_transform.hpp b/src/AMSlib/wf/pointwise_layout_transform.hpp new file mode 100644 index 00000000..e3b5a2e0 --- /dev/null +++ b/src/AMSlib/wf/pointwise_layout_transform.hpp @@ -0,0 +1,167 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include "wf/index_map.hpp" +#include "wf/layout_transform.hpp" +#include "wf/tensor_bundle.hpp" + +namespace ams +{ + +/// PointwiseConcatTransform: +/// +/// Converts Inputs + InOuts into a single matrix [N, SUM(K_i)] where: +/// - N = batch size (outer dim) +/// - K_i = flattened size of each tensor field except the batch dimension +/// +/// Supports: +/// ✔ Scalar fields (shape [N]) +/// ✔ Multi-channel fields (shape [N, K]) +/// ✔ Arbitrary shapes [N, ...] → flattened to [N, M] +/// ✔ Prediction-only models +/// ✔ Uncertainty-aware models returning (pred, uncertainty) +/// +/// Produces IndexMap for both pack() and unpack(). +class PointwiseConcatTransform : public LayoutTransform +{ +public: + const char* name() const noexcept override + { + return "PointwiseConcatTransform"; + } + + // ------------------------------------------------------------------ + // PACK + // ------------------------------------------------------------------ + AMSExpected pack(const TensorBundle& Inputs, + const TensorBundle& InOuts, + at::Tensor& ModelInput) override + { + IndexMap map; + std::vector cols; + int total_cols{0}; + + if (auto st = process( + Inputs, IndexMap::FieldInfo::Kind::Input, map, cols, total_cols); + !st) + return tl::unexpected(st.error()); + if (auto st = process( + InOuts, IndexMap::FieldInfo::Kind::InOut, map, cols, total_cols); + !st) + return tl::unexpected(st.error()); + + if (total_cols <= 0) { + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, + fmt::format("PointwiseConcatTransform expected at " + "least a single dimension in pack")); + } + // Concatenate horizontally + ModelInput = at::cat(cols, /*dim=*/1); + return map; + } + + // ------------------------------------------------------------------ + // UNPACK + // ------------------------------------------------------------------ + AMSStatus unpack(const torch::jit::IValue& ModelOutput, + TensorBundle& Outs, + TensorBundle& InOuts, + std::optional& Uncertainties) override + { + at::Tensor ModelOut; + at::Tensor Uncertainty; + bool has_uncertainty = false; + + // -------------------------------------------- + // Case 1: Single tensor prediction + // -------------------------------------------- + if (ModelOutput.isTensor()) { + ModelOut = ModelOutput.toTensor(); + } + // -------------------------------------------- + // Case 2: Tuple(pred, uncertainty) + // -------------------------------------------- + else if (ModelOutput.isTuple()) { + auto tup = ModelOutput.toTuple(); + if (tup->elements().size() != 2) + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, + "PointwiseConcatTransform: expected " + "tuple(pred,uncertainty)."); + + ModelOut = tup->elements()[0].toTensor(); + Uncertainty = tup->elements()[1].toTensor(); + has_uncertainty = true; + } else { + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, + "PointwiseConcatTransform: ModelOutput must be " + "tensor or " + "tuple."); + } + + // Uncertainties + if (has_uncertainty) { + Uncertainties = Uncertainty; + } else { + Uncertainties.reset(); + } + + if (ModelOut.size(1) != Outs.size() + InOuts.size()) + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, + "Expected the output size to match the Application " + "output dimensions"); + + int k = 0; + for (; k < Outs.size(); ++k) { + Outs[k].tensor = + ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze(); + } + + for (int i = 0; i < InOuts.size(); ++k, ++i) { + InOuts[i].tensor = + ModelOut.narrow(/*dim=*/1, /*start=*/k, /*length=*/1).squeeze(); + } + + return {}; + } + +private: + AMSStatus process(const TensorBundle& tb, + IndexMap::FieldInfo::Kind kind, + IndexMap& map, + std::vector& cols, + int& total_cols) + { + for (size_t i = 0; i < tb.size(); i++) { + const auto& item = tb.items[i]; + at::Tensor t = item.tensor; + + if (t.dim() < 1) + return AMS_MAKE_ERROR(AMSErrorType::InvalidShapes, + fmt::format("PointwiseConcatTransform for " + "field {} must have at least 1 " + "dimension", + item.name)); + int64_t N = t.size(0); + + // Flatten everything except outer dimension. + at::Tensor flat = t.reshape({N, -1}); + int64_t M = flat.size(1); + + int64_t offset = total_cols; + total_cols += M; + + map.Fields.push_back({item.name, kind, offset, M}); + + cols.push_back(flat); + } + return {}; + } +}; + +} // namespace ams diff --git a/tests/AMSlib/wf/CMakeLists.txt b/tests/AMSlib/wf/CMakeLists.txt index 1a2bf78d..018b04b2 100644 --- a/tests/AMSlib/wf/CMakeLists.txt +++ b/tests/AMSlib/wf/CMakeLists.txt @@ -51,7 +51,8 @@ ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVALUATE_IN_OUTS evaluate_in_and_outs) BUILD_UNIT_TEST(tensor_bundle tensor_bundle.cpp) ADD_WORKFLOW_UNIT_TEST(WORKFLOW::TENSOR_BUNDLE tensor_bundle) -BUILD_UNIT_TEST(layout_transform layout_transform.cpp) -ADD_WORKFLOW_UNIT_TEST(WORKFLOW::LAYOUT_TRANSFORM layout_transform) BUILD_UNIT_TEST(eval_context eval_context.cpp) ADD_WORKFLOW_UNIT_TEST(WORKFLOW::EVAL_CONTEXT eval_context) + +BUILD_UNIT_TEST(pointwise pointwise_layout_transform.cpp) +ADD_WORKFLOW_UNIT_TEST(WORKFLOW::POINTWISE pointwise) diff --git a/tests/AMSlib/wf/layout_transform.cpp b/tests/AMSlib/wf/layout_transform.cpp deleted file mode 100644 index dbbde361..00000000 --- a/tests/AMSlib/wf/layout_transform.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "wf/layout_transform.hpp" - -#include -#include - -#include -#include - -#include "wf/tensor_bundle.hpp" - -using Catch::Matchers::WithinAbs; - -namespace -{ - -/// Dummy transform that: -/// - pack() returns a constant tensor {9, 9} -/// - unpack() expects an IValue tuple: (prediction, uncertainty) -/// - writes prediction → Outputs -/// - writes uncertainty → Uncertainties -class DummyLayoutTransform : public ams::LayoutTransform -{ -public: - at::Tensor pack(const ams::TensorBundle& Inputs, - const ams::TensorBundle& Inouts) override - { - // A predictable packed tensor for test - return at::full({2}, 9.0f); - } - - void unpack(const torch::jit::IValue& iv, - ams::TensorBundle& Outputs, - ams::TensorBundle& Inouts, - std::optional& Uncertainties) override - { - // Expect a tuple of 2 tensors - auto tup = iv.toTuple(); - auto pred = tup->elements()[0].toTensor(); - auto uncrt = tup->elements()[1].toTensor(); - - // Output bundle receives the prediction - Outputs.add("pred", pred); - - // Uncertainties receives the uncertainty tensor - Uncertainties = uncrt; - } - - const char* name() const noexcept override { return "DummyLayoutTransform"; } -}; - -} // namespace - -// ----------------------------------------------------------------------------- -// TESTS -// ----------------------------------------------------------------------------- - -CATCH_TEST_CASE("LayoutTransform pack() returns model input tensor", "[layout]") -{ - DummyLayoutTransform lt; - - ams::TensorBundle ins; - ams::TensorBundle ios; - - ins.add("a", at::ones({1})); - ios.add("b", at::zeros({1})); - - at::Tensor packed = lt.pack(ins, ios); - - CATCH_REQUIRE(packed.sizes() == at::IntArrayRef({2})); - CATCH_REQUIRE_THAT(packed[0].item(), WithinAbs(9.0f, 1e-6f)); - CATCH_REQUIRE_THAT(packed[1].item(), WithinAbs(9.0f, 1e-6f)); -} - -CATCH_TEST_CASE("LayoutTransform unpack() populates Outputs + Uncertainties", - "[layout]") -{ - DummyLayoutTransform lt; - - // Dummy prediction + uncertainty tensors - at::Tensor pred = at::full({3}, 42.0f); - at::Tensor uncrt = at::full({3}, 0.5f); - - // Construct an IValue tuple: (prediction, uncertainty) - auto tup = c10::ivalue::Tuple::create({pred, uncrt}); - torch::jit::IValue iv(tup); - - ams::TensorBundle outs; - ams::TensorBundle ios; // not modified by dummy - std::optional uncertainties; - - lt.unpack(iv, outs, ios, uncertainties); - - // Output extracted correctly - CATCH_REQUIRE(outs.size() == 1); - CATCH_REQUIRE(outs[0].name == "pred"); - CATCH_REQUIRE(outs[0].tensor.sizes() == at::IntArrayRef({3})); - CATCH_REQUIRE_THAT(outs[0].tensor[0].item(), WithinAbs(42.0f, 1e-6f)); - - // Uncertainty extracted correctly - CATCH_REQUIRE(uncertainties.has_value()); - CATCH_REQUIRE(at::allclose(*uncertainties, at::full({3}, 0.5f))); -} - -CATCH_TEST_CASE("LayoutTransform name() returns identifier", "[layout]") -{ - DummyLayoutTransform lt; - CATCH_REQUIRE(std::string(lt.name()) == "DummyLayoutTransform"); -} diff --git a/tests/AMSlib/wf/pointwise_layout_transform.cpp b/tests/AMSlib/wf/pointwise_layout_transform.cpp new file mode 100644 index 00000000..d64b27d2 --- /dev/null +++ b/tests/AMSlib/wf/pointwise_layout_transform.cpp @@ -0,0 +1,265 @@ +#include "wf/pointwise_layout_transform.hpp" + +#include +#include +#include +using namespace torch::indexing; + +#include +#include +#include + +using Catch::Matchers::ContainsSubstring; +using Catch::Matchers::WithinAbs; + +// ----------------------------------------------------------------------------- +// TEST 1: pack() builds correct mapping and concatenation +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform pack()", "[layout][concat]") +{ + ams::PointwiseConcatTransform lt; + + ams::TensorBundle ins; + ams::TensorBundle ios; + + // Shapes: a -> [4], b -> [4,3], c -> [4] + ins.add("a", at::ones({4})); // 1 column + ins.add("b", at::full({4, 3}, 2.0f)); // 3 columns + ios.add("c", at::full({4}, 3.0f)); // 1 column + + at::Tensor model_input; + auto MapOrErr = lt.pack(ins, ios, model_input); + if (!MapOrErr) std::cout << MapOrErr.error() << "\n"; + CATCH_REQUIRE(MapOrErr); + ams::IndexMap map = std::move(*MapOrErr); + auto A = model_input.index({Slice(), Slice(0, 1)}); + auto B = model_input.index({Slice(), Slice(1, 4)}); + auto C = model_input.index({Slice(), Slice(4, 5)}); + CATCH_REQUIRE(model_input.sizes() == at::IntArrayRef({4, 5})); + + // Expected IndexMap: + // a: offset 0, cols 1 + // b: offset 1, cols 3 + // c: offset 4, cols 1 + CATCH_REQUIRE(map.Fields.size() == 3); + + CATCH_REQUIRE(map.Fields[0].Name == "a"); + CATCH_REQUIRE(map.Fields[0].Offset == 0); + CATCH_REQUIRE(map.Fields[0].Cols == 1); + + CATCH_REQUIRE(map.Fields[1].Name == "b"); + CATCH_REQUIRE(map.Fields[1].Offset == 1); + CATCH_REQUIRE(map.Fields[1].Cols == 3); + + CATCH_REQUIRE(map.Fields[2].Name == "c"); + CATCH_REQUIRE(map.Fields[2].Offset == 4); + CATCH_REQUIRE(map.Fields[2].Cols == 1); + CATCH_REQUIRE(torch::equal(A.squeeze(), ins[0].tensor)); + CATCH_REQUIRE(torch::equal(B, ins[1].tensor)); + CATCH_REQUIRE(torch::equal(C.squeeze(), ios[0].tensor)); +} + +// ----------------------------------------------------------------------------- +// TEST 2: unpack() with only predictions +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform unpack() predictions only", + "[layout][concat]") +{ + + ams::PointwiseConcatTransform lt; + // Prepare pack + at::Tensor pred = at::full({6, 6}, 7.0f); + + ams::TensorBundle outs, inouts; + outs.add("a", torch::Tensor()); + outs.add("b", torch::Tensor()); + outs.add("c", torch::Tensor()); + outs.add("d", torch::Tensor()); + inouts.add("e", at::full({1, 6}, 1.0f)); + inouts.add("f", at::full({1, 6}, 1.0f)); + + std::optional uncrt_out; + + auto res = lt.unpack(pred, outs, inouts, uncrt_out); + CATCH_REQUIRE(res.has_value()); + + at::Tensor corr = at::full({1, 6}, 7.0f); + for (auto& V : outs) { + CATCH_REQUIRE(at::allclose(V.tensor, corr)); + } + + for (auto& V : inouts) { + CATCH_REQUIRE(at::allclose(V.tensor, corr)); + } + + CATCH_REQUIRE_FALSE(uncrt_out.has_value()); +} + +// ----------------------------------------------------------------------------- +// TEST 3: unpack() with uncertainty tuple +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform unpack() with uncertainty", + "[layout][concat][uncertainty]") +{ + + ams::PointwiseConcatTransform lt; + + // Prepare pack + at::Tensor pred = at::full({6, 6}, 7.0f); + at::Tensor uncrt = at::full({6}, 0.25f); + + torch::jit::IValue iv{c10::ivalue::Tuple::create({pred, uncrt})}; + + ams::TensorBundle outs, inouts; + outs.add("a", torch::Tensor()); + outs.add("b", torch::Tensor()); + outs.add("c", torch::Tensor()); + outs.add("d", torch::Tensor()); + inouts.add("e", at::full({1, 6}, 1.0f)); + inouts.add("f", at::full({1, 6}, 1.0f)); + + std::optional uncrt_out; + + auto res = lt.unpack(iv, outs, inouts, uncrt_out); + CATCH_REQUIRE(res.has_value()); + + at::Tensor corr = at::full({1, 6}, 7.0f); + for (auto& V : outs) { + CATCH_REQUIRE(at::allclose(V.tensor, corr)); + } + + for (auto& V : inouts) { + CATCH_REQUIRE(at::allclose(V.tensor, corr)); + } + + CATCH_REQUIRE(uncrt_out.has_value()); + CATCH_REQUIRE(at::allclose(*uncrt_out, uncrt)); +} + +// ----------------------------------------------------------------------------- +// TEST 4: pack() errors on 0-dim tensors (dim < 1) +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform pack() rejects scalar 0-dim tensor", + "[layout][concat][error]") +{ + ams::PointwiseConcatTransform lt; + + ams::TensorBundle ins, ios; + + // 0-dim tensor: dim() == 0 -> should error + ins.add("bad", at::scalar_tensor(1.0f)); + + at::Tensor model_input; + auto res = lt.pack(ins, ios, model_input); + + CATCH_REQUIRE_FALSE(res); + CATCH_REQUIRE(res.error().getType() == ams::AMSErrorType::InvalidShapes); + CATCH_REQUIRE_THAT(res.error().getMessage(), + ContainsSubstring("must have at least 1 dimension")); +} + +// ----------------------------------------------------------------------------- +// TEST 5: pack() errors on empty Inputs + InOuts (no columns) +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform pack() rejects empty bundles", + "[layout][concat][error]") +{ + ams::PointwiseConcatTransform lt; + + ams::TensorBundle ins, ios; + at::Tensor model_input; + + auto res = lt.pack(ins, ios, model_input); + + CATCH_REQUIRE_FALSE(res); + CATCH_REQUIRE(res.error().getType() == ams::AMSErrorType::InvalidShapes); + CATCH_REQUIRE_THAT(res.error().getMessage(), + ContainsSubstring("expected at least a single dimension")); +} + +// ----------------------------------------------------------------------------- +// TEST 6: unpack() errors when ModelOutput is not tensor or tuple +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE( + "PointwiseConcatTransform unpack() rejects non-tensor non-tuple", + "[layout][concat][error]") +{ + ams::PointwiseConcatTransform lt; + + // e.g., int IValue + torch::jit::IValue iv(123); + + ams::TensorBundle outs, inouts; + outs.add("a", torch::Tensor()); + inouts.add("b", torch::Tensor()); + + std::optional uncrt_out; + + auto st = lt.unpack(iv, outs, inouts, uncrt_out); + + CATCH_REQUIRE_FALSE(st); + CATCH_REQUIRE(st.error().getType() == ams::AMSErrorType::InvalidShapes); + CATCH_REQUIRE_THAT(st.error().getMessage(), + ContainsSubstring("ModelOutput must be")); +} + +// ----------------------------------------------------------------------------- +// TEST 7: unpack() errors when tuple size != 2 +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE("PointwiseConcatTransform unpack() rejects tuple size != 2", + "[layout][concat][error]") +{ + ams::PointwiseConcatTransform lt; + + at::Tensor pred = at::full({2, 2}, 7.0f); + at::Tensor uncrt = at::full({2}, 0.25f); + at::Tensor extra = at::zeros({1}); + + torch::jit::IValue iv{ + c10::ivalue::Tuple::create({pred, uncrt, extra}) // size 3 + }; + + ams::TensorBundle outs, inouts; + outs.add("a", torch::Tensor()); + outs.add("b", torch::Tensor()); + inouts.add("c", torch::Tensor()); + inouts.add("d", torch::Tensor()); + + std::optional uncrt_out; + + auto st = lt.unpack(iv, outs, inouts, uncrt_out); + + CATCH_REQUIRE_FALSE(st); + CATCH_REQUIRE(st.error().getType() == ams::AMSErrorType::InvalidShapes); + CATCH_REQUIRE_THAT(st.error().getMessage(), + ContainsSubstring("expected tuple(pred,uncertainty)")); +} + +// ----------------------------------------------------------------------------- +// TEST 8: unpack() errors when output columns don't match Outs+InOuts +// ----------------------------------------------------------------------------- +CATCH_TEST_CASE( + "PointwiseConcatTransform unpack() rejects mismatched output width", + "[layout][concat][error]") +{ + ams::PointwiseConcatTransform lt; + + // ModelOut has 3 columns + at::Tensor pred = at::full({5, 3}, 7.0f); + + // But Outs+InOuts expects 4 fields -> mismatch + ams::TensorBundle outs, inouts; + outs.add("a", torch::Tensor()); + outs.add("b", torch::Tensor()); + inouts.add("c", torch::Tensor()); + inouts.add("d", torch::Tensor()); + + std::optional uncrt_out; + + auto st = lt.unpack(pred, outs, inouts, uncrt_out); + + CATCH_REQUIRE_FALSE(st); + CATCH_REQUIRE(st.error().getType() == ams::AMSErrorType::InvalidShapes); + CATCH_REQUIRE_THAT(st.error().getMessage(), + ContainsSubstring("Expected the output size")); +}