diff --git a/node.lua b/node.lua index b9cf87b..526195f 100644 --- a/node.lua +++ b/node.lua @@ -137,7 +137,8 @@ function nnNode:label() elseif istable(data) then local tstr = {} for i,v in ipairs(data) do - table.insert(tstr, getstr(v)) + local gsv = getstr(v) -- avoids luajit error for type(v)='string' + table.insert(tstr, gsv) end return '{' .. table.concat(tstr,',') .. '}' else diff --git a/test/test_nngraph.lua b/test/test_nngraph.lua index 0b06f9c..534d131 100644 --- a/test/test_nngraph.lua +++ b/test/test_nngraph.lua @@ -371,6 +371,24 @@ function test.test_gradInputType() checkDotFile(bg_tmpfile) end + function test.test_table_string_label() + local inp = nn.Identity()() + local in1 = nn.SelectTable(1)(inp) + local in2 = nn.SelectTable(2)(inp) + local out = nn.Linear(10,10)(in1) + -- in2 propagates 'as is': it could be, say, a string debug tag + local mod = nn.gModule({inp}, {out,in2}) + + local mod_out = mod:forward{ + torch.Tensor(10), + "nnNode:label() should handle a string without bad argument #2 to 'insert' (number expected, got string) error" + } + --print('mod_out[1] type is '..torch.type(mod_out[1])) + --print('mod_out[2] type is '..torch.type(mod_out[2])) -- string + local dot0 = mod.fg:todot() + --print(dot0) + end + function test.test_splitMore() local nSplits = 2 local in1 = nn.Identity()()