diff --git a/libraries/targets/wgsl.mtlx b/libraries/targets/wgsl.mtlx new file mode 100644 index 0000000000..58c1405dc3 --- /dev/null +++ b/libraries/targets/wgsl.mtlx @@ -0,0 +1,14 @@ + + + + + + + + + + + diff --git a/source/MaterialXGenGlsl/WgslResourceBindingContext.cpp b/source/MaterialXGenGlsl/WgslResourceBindingContext.cpp index b9ceb270a4..98b4d31a04 100644 --- a/source/MaterialXGenGlsl/WgslResourceBindingContext.cpp +++ b/source/MaterialXGenGlsl/WgslResourceBindingContext.cpp @@ -46,25 +46,26 @@ void WgslResourceBindingContext::emitResourceBindings(GenContext& context, const { if (uniform->getType() != Type::FILENAME) { - if ( uniform->getType() == Type::BOOLEAN ) + if (uniform->getType() == Type::BOOLEAN) { - // Cannot have boolean uniforms in WGSL - std::cerr << "Warning: WGSL does not allow boolean types to be stored in uniform or storage address spaces." << std::endl; + // WGSL does not support boolean uniforms; emit as integer + // with the resolved variable name so that replaceTokens + // (which wraps bool-uniform tokens in bool()) won't corrupt + // the declaration. Cast to bool at use sites is handled by + // WgslShaderGenerator::emitInput and replaceTokens. + uniform->setType(Type::INTEGER); + + string tokenName = uniform->getVariable(); + const auto& subs = generator.getTokenSubstitutions(); + auto it = subs.find(tokenName); + if (it != subs.end()) + uniform->setVariable(it->second); - // Set uniform type to integer - uniform->setType( Type::INTEGER ); - - // Write declaration as normal generator.emitLineBegin(stage); generator.emitVariableDeclaration(uniform, EMPTY_STRING, context, stage, false); generator.emitString(Syntax::SEMICOLON, stage); generator.emitLineEnd(stage, false); - - // Add macro to treat any follow usages of this variable as a boolean - // eg. u_myUniformBool -> bool(u_myUniformBool) - generator.emitString("#define " + uniform->getVariable() + " bool(" + uniform->getVariable() + ")", stage); - generator.emitLineBreak(stage); - } + } else { generator.emitLineBegin(stage); diff --git a/source/MaterialXGenGlsl/WgslShaderGenerator.cpp b/source/MaterialXGenGlsl/WgslShaderGenerator.cpp index d39a2498f4..85553d04ab 100644 --- a/source/MaterialXGenGlsl/WgslShaderGenerator.cpp +++ b/source/MaterialXGenGlsl/WgslShaderGenerator.cpp @@ -7,9 +7,12 @@ #include #include +#include +#include MATERIALX_NAMESPACE_BEGIN +const string WgslShaderGenerator::TARGET = "wgsl"; const string WgslShaderGenerator::LIGHTDATA_TYPEVAR_STRING = "light_type"; WgslShaderGenerator::WgslShaderGenerator(TypeSystemPtr typeSystem) : @@ -51,17 +54,75 @@ void WgslShaderGenerator::emitFunctionDefinitionParameter(const ShaderPort* shad } } -// Called by SourceCodeNode::emitFunctionCall() +// Called by SourceCodeNode::emitFunctionCall() and CompoundNode::emitFunctionCall() void WgslShaderGenerator::emitInput(const ShaderInput* input, GenContext& context, ShaderStage& stage) const { if (input->getType() == Type::FILENAME) { emitString(getUpstreamResult(input, context)+"_texture, "+getUpstreamResult(input, context)+"_sampler", stage); } + else if (input->getType() == Type::BOOLEAN) + { + const string result = getUpstreamResult(input, context); + emitString("bool(" + result + ")", stage); + } else { VkShaderGenerator::emitInput(input, context, stage); } } +void WgslShaderGenerator::replaceTokens(const StringMap& substitutions, ShaderStage& stage) const +{ + // Bool-as-int uniform tokens. Add new entries when introducing more bool uniforms. + // Local static avoids static initialization order issues with extern HW:: constants. + static const vector> boolUniformTokens = { + { HW::T_REFRACTION_TWO_SIDED, HW::REFRACTION_TWO_SIDED }, + }; + + // Source code: bool-as-int uniform tokens get wrapped in bool() so that + // uses like "if ($refractionTwoSided)" become "if (bool(u_refractionTwoSided))". + const StringMap codeSubstitutions = [&]() { + StringMap subs = substitutions; + for (const auto& entry : boolUniformTokens) + subs[entry.first] = "bool(" + entry.second + ")"; + return subs; + }(); + + string code = stage.getSourceCode(); + tokenSubstitution(codeSubstitutions, code); + stage.setSourceCode(code); + + // Interface ports: bool-as-int uniform tokens stay as plain names so that + // uniform declarations and application-side binding remain correct. + const StringMap portSubstitutions = [&]() { + StringMap subs = substitutions; + for (const auto& entry : boolUniformTokens) + subs[entry.first] = entry.second; + return subs; + }(); + + auto replacePorts = [&portSubstitutions](VariableBlock& block) + { + for (size_t i = 0; i < block.size(); ++i) + { + ShaderPort* port = block[i]; + string name = port->getName(); + tokenSubstitution(portSubstitutions, name); + port->setName(name); + string variable = port->getVariable(); + tokenSubstitution(portSubstitutions, variable); + port->setVariable(variable); + } + }; + + replacePorts(stage.getConstantBlock()); + for (const auto& it : stage.getUniformBlocks()) + replacePorts(*it.second); + for (const auto& it : stage.getInputBlocks()) + replacePorts(*it.second); + for (const auto& it : stage.getOutputBlocks()) + replacePorts(*it.second); +} + MATERIALX_NAMESPACE_END diff --git a/source/MaterialXGenGlsl/WgslShaderGenerator.h b/source/MaterialXGenGlsl/WgslShaderGenerator.h index af9141f440..d72f4e3805 100644 --- a/source/MaterialXGenGlsl/WgslShaderGenerator.h +++ b/source/MaterialXGenGlsl/WgslShaderGenerator.h @@ -35,6 +35,11 @@ class MX_GENGLSL_API WgslShaderGenerator : public VkShaderGenerator return std::make_shared(typeSystem ? typeSystem : TypeSystem::create()); } + const string& getTarget() const override { return TARGET; } + + /// Unique identifier for this generator target + static const string TARGET; + void emitDirectives(GenContext& context, ShaderStage& stage) const override; const string& getLightDataTypevarString() const override { return LIGHTDATA_TYPEVAR_STRING; } @@ -43,6 +48,8 @@ class MX_GENGLSL_API WgslShaderGenerator : public VkShaderGenerator void emitInput(const ShaderInput* input, GenContext& context, ShaderStage& stage) const override; + void replaceTokens(const StringMap& substitutions, ShaderStage& stage) const override; + protected: static const string LIGHTDATA_TYPEVAR_STRING; }; diff --git a/source/MaterialXGenShader/ShaderGenerator.h b/source/MaterialXGenShader/ShaderGenerator.h index ce6c3cf996..512498c162 100644 --- a/source/MaterialXGenShader/ShaderGenerator.h +++ b/source/MaterialXGenShader/ShaderGenerator.h @@ -245,7 +245,7 @@ class MX_GENSHADER_API ShaderGenerator } /// Replace tokens with identifiers according to the given substitutions map. - void replaceTokens(const StringMap& substitutions, ShaderStage& stage) const; + virtual void replaceTokens(const StringMap& substitutions, ShaderStage& stage) const; /// Create shader variables (e.g. uniforms, inputs and outputs) for /// nodes that require input data from the application. diff --git a/source/MaterialXGenShader/ShaderGraph.cpp b/source/MaterialXGenShader/ShaderGraph.cpp index bfb15e9281..522e13f4dd 100644 --- a/source/MaterialXGenShader/ShaderGraph.cpp +++ b/source/MaterialXGenShader/ShaderGraph.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include MATERIALX_NAMESPACE_BEGIN @@ -916,6 +916,39 @@ void ShaderGraph::finalize(GenContext& context) if (context.getOptions().shaderInterfaceType == SHADER_INTERFACE_COMPLETE) { + // Track shared sockets so that different nodes with the same input name + // get separate sockets deterministically. + std::unordered_map sharedSockets; + + // Helper lambda function to resolve the name of the input socket + auto resolveInputSocketName = [this, &sharedSockets, &context]( + const ShaderNode* node, ShaderInput* input, bool useGenericName) + -> std::pair + { + string name = useGenericName ? input->getName() : input->getFullName(); + ShaderGraphInputSocket* socket = getInputSocket(name); + if (socket && socket->getType() != input->getType()) + { + name = input->getFullName(); + socket = getInputSocket(name); + } + if (socket) + { + auto it = sharedSockets.find(name); + if (it != sharedSockets.end() && it->second != node) + { + string sanitized = node->getUniqueId(); + context.getShaderGenerator().getSyntax().makeValidName(sanitized); + if (!sanitized.empty() && sanitized[0] == '_') + sanitized.erase(0, 1); + string baseName = useGenericName ? input->getName() : input->getFullName(); + name = baseName + "_" + sanitized; + socket = getInputSocket(name); + } + } + return { name, socket }; + }; + // Publish all node inputs that has not been connected already. for (const ShaderNode* node : getNodes()) { @@ -927,12 +960,14 @@ void ShaderGraph::finalize(GenContext& context) // publish the input as an editable uniform. if (!input->getType().isClosure() && node->isEditable(*input)) { - // Use a consistent naming convention: _ - // so application side can figure out what uniforms to set - // when node inputs change on application side. - const string interfaceName = node->getName() + "_" + input->getName(); + // Create simpler names for generic nodes if possible + // so application side can employ techniques to easily switch between materials + // that are similar and only differ in unifrom values. + const bool useGenericName = (node->getClassification() & (ShaderNode::Classification::SHADER | + ShaderNode::Classification::CLOSURE | + ShaderNode::Classification::MATERIAL)) != 0; + auto [interfaceName, inputSocket] = resolveInputSocketName(node, input, useGenericName); - ShaderGraphInputSocket* inputSocket = getInputSocket(interfaceName); if (!inputSocket) { inputSocket = addInputSocket(interfaceName, input->getType()); @@ -944,6 +979,7 @@ void ShaderGraph::finalize(GenContext& context) { inputSocket->setUniform(); } + sharedSockets[interfaceName] = node; } inputSocket->makeConnection(input); inputSocket->setMetadata(input->getMetadata()); @@ -1121,10 +1157,18 @@ void ShaderGraph::topologicalSort() // Calculate a topological order of the children, using Kahn's algorithm // to avoid recursion. // - // Running time: O(numNodes + numEdges). - - // Calculate in-degrees for all nodes, and enqueue those with degree 0. + // Running time: O((numNodes + numEdges) + numNodes * log(numNodes)). + // + // The BFS traversal runs in O(numNodes + numEdges). A final stable sort + // over the result, keyed by topological depth then by name, ensures + // deterministic ordering of nodes at the same depth. This guarantees that + // materials with the same set of functions always emit them in the same + // order, regardless of which inputs happen to be connected. + + // Calculate in-degrees and topological depth for all nodes, + // and enqueue those with degree 0. std::unordered_map inDegree(_nodeMap.size()); + std::unordered_map depth(_nodeMap.size()); std::deque nodeQueue; for (ShaderNode* node : _nodeOrder) { @@ -1138,6 +1182,7 @@ void ShaderGraph::topologicalSort() } inDegree[node] = connectionCount; + depth[node] = 0; if (connectionCount == 0) { @@ -1156,7 +1201,7 @@ void ShaderGraph::topologicalSort() _nodeOrder[count++] = node; // Find connected nodes and decrease their in-degree, - // adding node to the queue if in-degrees becomes 0. + // adding node to the queue if in-degree becomes 0. for (const ShaderOutput* output : node->getOutputs()) { for (const ShaderInput* input : output->getConnections()) @@ -1164,6 +1209,7 @@ void ShaderGraph::topologicalSort() ShaderNode* downstreamNode = const_cast(input->getNode()); if (downstreamNode != this) { + depth[downstreamNode] = std::max(depth[downstreamNode], depth[node] + 1); if (--inDegree[downstreamNode] <= 0) { nodeQueue.push_back(downstreamNode); @@ -1172,6 +1218,17 @@ void ShaderGraph::topologicalSort() } } } + + // Stable sort by (depth, name, uniqueId) for deterministic output + // while preserving topological correctness. + std::stable_sort(_nodeOrder.begin(), _nodeOrder.begin() + count, + [&depth](ShaderNode* a, ShaderNode* b) { + if (depth[a] != depth[b]) + return depth[a] < depth[b]; + if (a->getName() != b->getName()) + return a->getName() < b->getName(); + return a->getUniqueId() < b->getUniqueId(); + }); } void ShaderGraph::setVariableNames(GenContext& context) @@ -1181,6 +1238,14 @@ void ShaderGraph::setVariableNames(GenContext& context) const Syntax& syntax = context.getShaderGenerator().getSyntax(); + // Use generic base names for material and surfaceshader so multiple outputs get + // consistent names (surfaceshader_out, surfaceshader_out1, ...) via getVariableName. + auto variableBaseName = [](const TypeDesc& type, const string& defaultName) -> string { + if (type == Type::MATERIAL) return "material_out"; + if (type == Type::SURFACESHADER) return "surfaceshader_out"; + return defaultName; + }; + for (ShaderGraphInputSocket* inputSocket : getInputSockets()) { const string variable = syntax.getVariableName(inputSocket->getName(), inputSocket->getType(), _identifiers); @@ -1188,7 +1253,8 @@ void ShaderGraph::setVariableNames(GenContext& context) } for (ShaderGraphOutputSocket* outputSocket : getOutputSockets()) { - const string variable = syntax.getVariableName(outputSocket->getName(), outputSocket->getType(), _identifiers); + const string baseName = variableBaseName(outputSocket->getType(), outputSocket->getName()); + const string variable = syntax.getVariableName(baseName, outputSocket->getType(), _identifiers); outputSocket->setVariable(variable); } for (ShaderNode* node : getNodes()) @@ -1201,8 +1267,8 @@ void ShaderGraph::setVariableNames(GenContext& context) } for (ShaderOutput* output : node->getOutputs()) { - string variable = output->getFullName(); - variable = syntax.getVariableName(variable, output->getType(), _identifiers); + const string baseName = variableBaseName(output->getType(), output->getFullName()); + const string variable = syntax.getVariableName(baseName, output->getType(), _identifiers); output->setVariable(variable); } } diff --git a/source/MaterialXView/Viewer.cpp b/source/MaterialXView/Viewer.cpp index 9fc69cfa2f..3ea6bc3512 100644 --- a/source/MaterialXView/Viewer.cpp +++ b/source/MaterialXView/Viewer.cpp @@ -36,6 +36,7 @@ #include #endif #include +#include #include #include @@ -202,6 +203,7 @@ Viewer::Viewer(const std::string& materialFilename, #ifndef MATERIALXVIEW_METAL_BACKEND _genContext(mx::GlslShaderGenerator::create(_typeSystem)), _genContextEssl(mx::EsslShaderGenerator::create(_typeSystem)), + _genContextWgsl(mx::WgslShaderGenerator::create(_typeSystem)), #else _genContext(mx::MslShaderGenerator::create(_typeSystem)), #endif @@ -267,6 +269,10 @@ Viewer::Viewer(const std::string& materialFilename, _genContextEssl.getOptions().targetColorSpaceOverride = "lin_rec709"; _genContextEssl.getOptions().fileTextureVerticalFlip = false; _genContextEssl.getOptions().hwMaxActiveLightSources = 1; + + // Set Wgsl generator options + _genContextWgsl.getOptions().targetColorSpaceOverride = "lin_rec709"; + _genContextWgsl.getOptions().fileTextureVerticalFlip = false; #endif #if MATERIALX_BUILD_GEN_OSL // Set OSL generator options. @@ -504,6 +510,7 @@ void Viewer::applyDirectLights(mx::DocumentPtr doc) _lightHandler->registerLights(doc, lights, _genContext); #ifndef MATERIALXVIEW_METAL_BACKEND _lightHandler->registerLights(doc, lights, _genContextEssl); + //_lightHandler->registerLights(doc, lights, _genContextWgsl); #endif _lightHandler->setLightSources(lights); } @@ -815,6 +822,7 @@ void Viewer::createAdvancedSettings(ng::ref parent) _genContext.getOptions().hwSpecularEnvironmentMethod = enable ? mx::SPECULAR_ENVIRONMENT_FIS : mx::SPECULAR_ENVIRONMENT_PREFILTER; #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.getOptions().hwSpecularEnvironmentMethod = _genContext.getOptions().hwSpecularEnvironmentMethod; + _genContextWgsl.getOptions().hwSpecularEnvironmentMethod = _genContext.getOptions().hwSpecularEnvironmentMethod; #endif _lightHandler->setUsePrefilteredMap(!enable); reloadShaders(); @@ -827,6 +835,7 @@ void Viewer::createAdvancedSettings(ng::ref parent) _genContext.getOptions().hwTransmissionRenderMethod = enable ? mx::TRANSMISSION_REFRACTION : mx::TRANSMISSION_OPACITY; #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.getOptions().hwTransmissionRenderMethod = _genContext.getOptions().hwTransmissionRenderMethod; + _genContextWgsl.getOptions().hwTransmissionRenderMethod = _genContext.getOptions().hwTransmissionRenderMethod; #endif reloadShaders(); }); @@ -911,6 +920,7 @@ void Viewer::createAdvancedSettings(ng::ref parent) _genContext.getOptions().hwAiryFresnelIterations = MIN_AIRY_FRESNEL_ITERATIONS * (int)std::pow(2, index); #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.getOptions().hwAiryFresnelIterations = _genContext.getOptions().hwAiryFresnelIterations; + _genContextWgsl.getOptions().hwAiryFresnelIterations = _genContext.getOptions().hwAiryFresnelIterations; #endif reloadShaders(); }); @@ -995,6 +1005,7 @@ void Viewer::createAdvancedSettings(ng::ref parent) _genContext.getOptions().targetDistanceUnit = _distanceUnitOptions[index]; #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.getOptions().targetDistanceUnit = _distanceUnitOptions[index]; + _genContextWgsl.getOptions().targetDistanceUnit = _distanceUnitOptions[index]; #endif #if MATERIALX_BUILD_GEN_OSL _genContextOsl.getOptions().targetDistanceUnit = _distanceUnitOptions[index]; @@ -1313,6 +1324,7 @@ void Viewer::loadDocument(const mx::FilePath& filename, mx::DocumentPtr librarie _genContext.clearUserData(); #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.clearUserData(); + _genContextWgsl.clearUserData(); #endif // Clear materials if merging is not requested. @@ -1443,6 +1455,7 @@ void Viewer::loadDocument(const mx::FilePath& filename, mx::DocumentPtr librarie _genContext.clearNodeImplementations(); #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.clearNodeImplementations(); + _genContextWgsl.clearNodeImplementations(); #endif // Add new materials to the global vector. @@ -1644,6 +1657,16 @@ void Viewer::saveShaderSource(mx::GenContext& context) new ng::MessageDialog(this, ng::MessageDialog::Type::Information, "Saved ESSL source: ", sourceFilename.asString() + "_essl_*.glsl"); } + else if (context.getShaderGenerator().getTarget() == mx::WgslShaderGenerator::TARGET) + { + mx::ShaderPtr shader = createShader(elem->getNamePath(), context, elem); + const std::string& pixelShader = shader->getSourceCode(mx::Stage::PIXEL); + const std::string& vertexShader = shader->getSourceCode(mx::Stage::VERTEX); + writeTextFile(vertexShader, sourceFilename.asString() + "_wgsl_vs.wgsl"); + writeTextFile(pixelShader, sourceFilename.asString() + "_wgsl_ps.wgsl"); + new ng::MessageDialog(this, ng::MessageDialog::Type::Information, "Saved WGSL source: ", + sourceFilename.asString() + "_wgsl_*.wgsl"); + } #else if (context.getShaderGenerator().getTarget() == mx::MslShaderGenerator::TARGET) { @@ -1859,6 +1882,7 @@ void Viewer::loadStandardLibraries() initContext(_genContext); #ifndef MATERIALXVIEW_METAL_BACKEND initContext(_genContextEssl); + initContext(_genContextWgsl); #endif #if MATERIALX_BUILD_GEN_OSL initContext(_genContextOsl); @@ -1937,6 +1961,15 @@ bool Viewer::keyboard_event(int key, int scancode, int action, int modifiers) return true; } + // Save WGSL shader source to file. + if (key == GLFW_KEY_W && action == GLFW_PRESS) + { +#ifndef MATERIALXVIEW_METAL_BACKEND + saveShaderSource(_genContextWgsl); +#endif + return true; + } + // Load GLSL shader source from file. Editing the source files before // loading provides a way to debug and experiment with shader source code. if (key == GLFW_KEY_L && action == GLFW_PRESS) @@ -2604,6 +2637,7 @@ void Viewer::setShaderInterfaceType(mx::ShaderInterfaceType interfaceType) _genContext.getOptions().shaderInterfaceType = interfaceType; #ifndef MATERIALXVIEW_METAL_BACKEND _genContextEssl.getOptions().shaderInterfaceType = interfaceType; + _genContextWgsl.getOptions().shaderInterfaceType = interfaceType; #endif #if MATERIALX_BUILD_GEN_OSL _genContextOsl.getOptions().shaderInterfaceType = interfaceType; diff --git a/source/MaterialXView/Viewer.h b/source/MaterialXView/Viewer.h index 515b8a75d8..04e5fd9518 100644 --- a/source/MaterialXView/Viewer.h +++ b/source/MaterialXView/Viewer.h @@ -423,6 +423,7 @@ class Viewer : public ng::Screen mx::GenContext _genContext; #ifndef MATERIALXVIEW_METAL_BACKEND mx::GenContext _genContextEssl; + mx::GenContext _genContextWgsl; #endif #if MATERIALX_BUILD_GEN_OSL mx::GenContext _genContextOsl;