diff --git a/program.go b/program.go index 4b489b8..00af2e9 100644 --- a/program.go +++ b/program.go @@ -1,14 +1,14 @@ package wrapper import ( + "errors" "fmt" "os/exec" ) type Program struct { - executable string - params map[string]interface{} - getCombinedOutputFunc func(cmd *exec.Cmd) ([]byte, error) + executable string + params map[string]interface{} } func NewProgram(executable string) *Program { @@ -16,7 +16,6 @@ func NewProgram(executable string) *Program { executable: executable, params: make(map[string]interface{}), } - program.getCombinedOutputFunc = program.getCombinedOutput return program } @@ -25,7 +24,9 @@ func (p *Program) WithParam(name string, value interface{}) *Program { return p } -func (p *Program) getCombinedOutput(cmd *exec.Cmd) ([]byte, error) { +var getCombinedOutputFunc = getCombinedOutput + +func getCombinedOutput(cmd *exec.Cmd) ([]byte, error) { return cmd.CombinedOutput() } @@ -37,9 +38,22 @@ func (p *Program) Run() (string, error) { } cmd := exec.Command(p.executable, params...) - output, err := p.getCombinedOutputFunc(cmd) + output, err := getCombinedOutputFunc(cmd) if err != nil { return "", err } return string(output), nil } + +func (p *Program) Compile(mainPath string) error { + cmd := exec.Command("go", "build", "-o", p.executable, mainPath) + output, err := getCombinedOutputFunc(cmd) + if err != nil { + return err + } + outputStr := string(output) + if outputStr != "" { + return errors.New("Error compiling: " + outputStr) + } + return nil +} diff --git a/program_test.go b/program_test.go new file mode 100644 index 0000000..fa037f0 --- /dev/null +++ b/program_test.go @@ -0,0 +1,22 @@ +package wrapper + +import ( + "errors" + "os/exec" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_program_cmdError(t *testing.T) { + + cmdErr := errors.New("test error") + + program := NewProgram("") + getCombinedOutputFunc = func(cmd *exec.Cmd) ([]byte, error) { + return nil, cmdErr + } + defer func() { getCombinedOutputFunc = getCombinedOutput }() + _, err := program.Run() + require.Equal(t, cmdErr, err) +} diff --git a/test/program_test.go b/test/program_test.go index 048a379..afdcdc3 100644 --- a/test/program_test.go +++ b/test/program_test.go @@ -29,6 +29,9 @@ func Test_program(t *testing.T) { path := filepath.Join(wd, "../testProgram.exe") program := wrapper.NewProgram(path) + err = program.Compile("../testProgram/main.go") + require.NoError(t, err) + output, err := program.Run() require.NoError(t, err) require.Contains(t, output, "default message")