executor: Enable command substitution and environment variables for ssh and subworkflow executor config (#748)
Some checks are pending
CI / Check for spelling errors (push) Waiting to run
CI / Go Linter (push) Waiting to run
CI / Test (push) Waiting to run

This commit is contained in:
Yota Hamada 2024-12-26 22:43:23 +09:00 committed by GitHub
parent 6ca7f2adc1
commit bb2c8ac777
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 1831 additions and 1848 deletions

View File

@ -48,7 +48,7 @@ func (td *testDAG) AssertCurrentStatus(t *testing.T, expected scheduler.Status)
dag, err := digraph.Load(td.Context, td.Config.Paths.BaseConfig, td.Path, "")
require.NoError(t, err)
cli := td.Client()
cli := td.Client
require.Eventually(t, func() bool {
status, err := cli.GetCurrentStatus(td.Context, dag)
require.NoError(t, err)
@ -59,7 +59,7 @@ func (td *testDAG) AssertCurrentStatus(t *testing.T, expected scheduler.Status)
func (th *testDAG) AssertLastStatus(t *testing.T, expected scheduler.Status) {
t.Helper()
hs := th.DataStore().HistoryStore()
hs := th.DataStores.HistoryStore()
require.Eventually(t, func() bool {
status := hs.ReadStatusRecent(th.Context, th.Path, 1)
if len(status) < 1 {

View File

@ -23,7 +23,7 @@ func TestRetryCommand(t *testing.T) {
th.RunCommand(t, startCmd(), cmdTest{args: args})
// Find the request ID.
cli := th.Client()
cli := th.Client
ctx := context.Background()
status, err := cli.GetStatus(ctx, dagFile.Path)
require.NoError(t, err)

View File

@ -8,7 +8,10 @@ import (
)
func TestStartCommand(t *testing.T) {
t.Parallel()
th := testSetup(t)
tests := []cmdTest{
{
name: "StartDAG",

View File

@ -24,7 +24,7 @@ func TestStatusCommand(t *testing.T) {
close(done)
}()
hs := th.DataStore().HistoryStore()
hs := th.DataStores.HistoryStore()
require.Eventually(t, func() bool {
status := hs.ReadStatusRecent(th.Context, dagFile.Path, 1)
if len(status) < 1 {

View File

@ -116,6 +116,7 @@ func (a *Agent) Run(ctx context.Context) error {
// Check if the DAG is already running.
if err := a.checkIsAlreadyRunning(ctx); err != nil {
a.scheduler.Cancel(ctx, a.graph)
return err
}
@ -203,7 +204,7 @@ func (a *Agent) Run(ctx context.Context) error {
// Send the execution report if necessary.
a.lastErr = lastErr
if err := a.reporter.send(a.dag, finishedStatus, lastErr); err != nil {
if err := a.reporter.send(ctx, a.dag, finishedStatus, lastErr); err != nil {
logger.Error(ctx, "Mail notification failed", "err", err)
}
@ -314,14 +315,13 @@ func (a *Agent) setup(ctx context.Context) error {
defer a.lock.Unlock()
a.scheduler = a.newScheduler()
a.reporter = newReporter(
mailer.New(mailer.Config{
Host: a.dag.SMTP.Host,
Port: a.dag.SMTP.Port,
Username: a.dag.SMTP.Username,
Password: a.dag.SMTP.Password,
}),
)
mailer := mailer.New(mailer.Config{
Host: a.dag.SMTP.Host,
Port: a.dag.SMTP.Port,
Username: a.dag.SMTP.Username,
Password: a.dag.SMTP.Password,
})
a.reporter = newReporter(mailer)
return a.setupGraph(ctx)
}
@ -482,7 +482,7 @@ func (a *Agent) checkPreconditions(ctx context.Context) error {
// If one of the conditions does not met, cancel the execution.
if err := digraph.EvalConditions(a.dag.Preconditions); err != nil {
logger.Error(ctx, "Preconditions are not met", "err", err)
a.scheduler.Cancel(a.graph)
a.scheduler.Cancel(ctx, a.graph)
return err
}
return nil

View File

@ -4,18 +4,12 @@
package agent_test
import (
"context"
"net/http"
"net/url"
"path/filepath"
"syscall"
"testing"
"time"
"github.com/dagu-org/dagu/internal/agent"
"github.com/dagu-org/dagu/internal/fileutil"
"github.com/dagu-org/dagu/internal/test"
"github.com/google/uuid"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/digraph/scheduler"
@ -24,236 +18,172 @@ import (
)
func TestAgent_Run(t *testing.T) {
t.Parallel()
t.Run("RunDAG", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "run.yaml")
dagAgent := dag.Agent()
dag := testLoadDAG(t, "run.yaml")
cli := th.Client()
ctx := th.Context
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
latestStatus, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
require.Equal(t, scheduler.StatusNone, latestStatus.Status)
dag.AssertLatestStatus(t, scheduler.StatusNone)
go func() {
err := agt.Run(ctx)
require.NoError(t, err)
dagAgent.RunSuccess(t)
}()
time.Sleep(100 * time.Millisecond)
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusSuccess
}, time.Second*2, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusSuccess)
})
t.Run("DeleteOldHistory", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "delete_old_history.yaml")
dagAgent := dag.Agent()
// Create a history file by running a DAG
dag := testLoadDAG(t, "simple.yaml")
cli := th.Client()
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
dagAgent.RunSuccess(t)
dag.AssertHistoryCount(t, 1)
err := agt.Run(ctx)
require.NoError(t, err)
history := cli.GetRecentHistory(ctx, dag, 2)
require.Equal(t, 1, len(history))
// Set the retention days to 0 and run the DAG again
// Set the retention days to 0 (delete all history files except the latest one)
dag.HistRetentionDays = 0
agt = newAgent(th, genRequestID(), dag, &agent.Options{})
err = agt.Run(ctx)
require.NoError(t, err)
// Run the DAG again
dagAgent = dag.Agent()
dagAgent.RunSuccess(t)
// Check if only the latest history file exists
history = cli.GetRecentHistory(ctx, dag, 2)
require.Equal(t, 1, len(history))
dag.AssertHistoryCount(t, 1)
})
t.Run("AlreadyRunning", func(t *testing.T) {
th := test.Setup(t)
dag := testLoadDAG(t, "is_running.yaml")
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
dag := th.LoadDAGFile(t, "is_running.yaml")
dagAgent := dag.Agent()
go func() {
_ = agt.Run(ctx)
// Run the DAG in the background so that it is running
dagAgent.RunSuccess(t)
}()
time.Sleep(time.Millisecond * 30)
dag.AssertCurrentStatus(t, scheduler.StatusRunning)
curStatus := agt.Status()
require.NotNil(t, curStatus)
require.Equal(t, curStatus.Status, scheduler.StatusRunning)
agt = newAgent(th, genRequestID(), dag, &agent.Options{})
err := agt.Run(ctx)
require.Error(t, err)
require.Contains(t, err.Error(), "is already running")
// Try to run the DAG again while it is running
dagAgent = dag.Agent()
dagAgent.RunCheckErr(t, "is already running")
})
t.Run("PreConditionNotMet", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "multiple_steps.yaml")
dag := testLoadDAG(t, "multiple_steps.yaml")
// Set a precondition that always fails
dag.Preconditions = []digraph.Condition{
{Condition: "`echo 1`", Expected: "0"},
}
// Precondition is not met
dag.Preconditions = []digraph.Condition{{Condition: "`echo 1`", Expected: "0"}}
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
err := agt.Run(ctx)
require.Error(t, err)
dagAgent := dag.Agent()
dagAgent.RunCheckErr(t, "condition was not met")
// Check if all nodes are not executed
status := agt.Status()
require.Equal(t, scheduler.StatusCancel, status.Status)
require.Equal(t, scheduler.NodeStatusNone, status.Nodes[0].Status)
require.Equal(t, scheduler.NodeStatusNone, status.Nodes[1].Status)
status := dagAgent.Status()
require.Equal(t, scheduler.StatusCancel.String(), status.Status.String())
require.Equal(t, scheduler.NodeStatusNone.String(), status.Nodes[0].Status.String())
require.Equal(t, scheduler.NodeStatusNone.String(), status.Nodes[1].Status.String())
})
t.Run("FinishWithError", func(t *testing.T) {
th := test.Setup(t)
// Run a DAG that fails
errDAG := testLoadDAG(t, "error.yaml")
agt := newAgent(th, genRequestID(), errDAG, &agent.Options{})
ctx := th.Context
err := agt.Run(ctx)
require.Error(t, err)
errDAG := th.LoadDAGFile(t, "error.yaml")
dagAgent := errDAG.Agent()
dagAgent.RunError(t)
// Check if the status is saved correctly
require.Equal(t, scheduler.StatusError, agt.Status().Status)
require.Equal(t, scheduler.StatusError, dagAgent.Status().Status)
})
t.Run("FinishWithTimeout", func(t *testing.T) {
th := test.Setup(t)
// Run a DAG that timeout
timeoutDAG := testLoadDAG(t, "timeout.yaml")
agt := newAgent(th, genRequestID(), timeoutDAG, &agent.Options{})
ctx := th.Context
err := agt.Run(ctx)
require.Error(t, err)
timeoutDAG := th.LoadDAGFile(t, "timeout.yaml")
dagAgent := timeoutDAG.Agent()
dagAgent.RunError(t)
// Check if the status is saved correctly
require.Equal(t, scheduler.StatusError, agt.Status().Status)
require.Equal(t, scheduler.StatusError, dagAgent.Status().Status)
})
t.Run("ReceiveSignal", func(t *testing.T) {
th := test.Setup(t)
ctx := th.Context
abortFunc := func(a *agent.Agent) { a.Signal(ctx, syscall.SIGTERM) }
dag := testLoadDAG(t, "sleep.yaml")
cli := th.Client()
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
dag := th.LoadDAGFile(t, "sleep.yaml")
dagAgent := dag.Agent()
go func() {
_ = agt.Run(ctx)
dagAgent.RunCancel(t)
}()
// wait for the DAG to start
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusRunning
}, time.Second*1, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusRunning)
// send a signal to cancel the DAG
abortFunc(agt)
dagAgent.Abort()
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusCancel
}, time.Second*1, time.Millisecond*100)
// wait for the DAG to be canceled
dag.AssertLatestStatus(t, scheduler.StatusCancel)
})
t.Run("ExitHandler", func(t *testing.T) {
th := test.Setup(t)
dag := testLoadDAG(t, "on_exit.yaml")
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
err := agt.Run(ctx)
require.NoError(t, err)
dag := th.LoadDAGFile(t, "on_exit.yaml")
dagAgent := dag.Agent()
dagAgent.RunSuccess(t)
// Check if the DAG is executed successfully
status := agt.Status()
require.Equal(t, scheduler.StatusSuccess, status.Status)
status := dagAgent.Status()
require.Equal(t, scheduler.StatusSuccess.String(), status.Status.String())
for _, s := range status.Nodes {
require.Equal(t, scheduler.NodeStatusSuccess, s.Status)
require.Equal(t, scheduler.NodeStatusSuccess.String(), s.Status.String())
}
// Check if the exit handler is executed
require.Equal(t, scheduler.NodeStatusSuccess, status.OnExit.Status)
require.Equal(t, scheduler.NodeStatusSuccess.String(), status.OnExit.Status.String())
})
}
func TestAgent_DryRun(t *testing.T) {
t.Parallel()
t.Run("DryRun", func(t *testing.T) {
th := test.Setup(t)
dag := testLoadDAG(t, "dry.yaml")
ctx := th.Context
agt := newAgent(th, genRequestID(), dag, &agent.Options{
Dry: true,
})
dag := th.LoadDAGFile(t, "dry.yaml")
dagAgent := dag.Agent(test.WithAgentOptions(&agent.Options{Dry: true}))
err := agt.Run(ctx)
require.NoError(t, err)
dagAgent.RunSuccess(t)
curStatus := agt.Status()
require.NoError(t, err)
curStatus := dagAgent.Status()
require.Equal(t, scheduler.StatusSuccess, curStatus.Status)
// Check if the status is not saved
cli := th.Client()
history := cli.GetRecentHistory(ctx, dag, 1)
require.Equal(t, 0, len(history))
dag.AssertHistoryCount(t, 0)
})
}
func TestAgent_Retry(t *testing.T) {
t.Parallel()
t.Run("RetryDAG", func(t *testing.T) {
th := test.Setup(t)
// retry.yaml has a DAG that fails
dag := testLoadDAG(t, "retry.yaml")
dag := th.LoadDAGFile(t, "retry.yaml")
dagAgent := dag.Agent()
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
err := agt.Run(ctx)
require.Error(t, err)
// Check if the DAG failed
status := agt.Status()
require.Equal(t, scheduler.StatusError, status.Status)
dagAgent.RunError(t)
// Modify the DAG to make it successful
for _, node := range status.Nodes {
node.Step.CmdWithArgs = "true"
status := dagAgent.Status()
for i := range status.Nodes {
status.Nodes[i].Step.CmdWithArgs = "true"
}
// Retry the DAG and check if it is successful
agt = newAgent(th, genRequestID(), dag, &agent.Options{
dagAgent = dag.Agent(test.WithAgentOptions(&agent.Options{
RetryTarget: status,
})
err = agt.Run(ctx)
require.NoError(t, err)
}))
dagAgent.RunSuccess(t)
status = agt.Status()
require.Equal(t, scheduler.StatusSuccess, status.Status)
for _, node := range status.Nodes {
for _, node := range dagAgent.Status().Nodes {
if node.Status != scheduler.NodeStatusSuccess &&
node.Status != scheduler.NodeStatusSkipped {
t.Errorf("invalid status: %s", node.Status.String())
t.Errorf("node %q is not successful: %s", node.Step.Name, node.Status)
}
}
})
@ -265,25 +195,19 @@ func TestAgent_HandleHTTP(t *testing.T) {
th := test.Setup(t)
// Start a long-running DAG
dag := testLoadDAG(t, "handle_http.yaml")
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
dag := th.LoadDAGFile(t, "handle_http_valid.yaml")
dagAgent := dag.Agent()
ctx := th.Context
go func() {
err := agt.Run(ctx)
require.NoError(t, err)
dagAgent.RunCancel(t)
}()
// Wait for the DAG to start
cli := th.Client()
require.Eventually(t, func() bool {
status, _ := cli.GetLatestStatus(ctx, dag)
// require.NoError(t, err)
return status.Status == scheduler.StatusRunning
}, time.Second*2, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusRunning)
// Get the status of the DAG
var mockResponseWriter = mockResponseWriter{}
agt.HandleHTTP(ctx)(&mockResponseWriter, &http.Request{
dagAgent.HandleHTTP(ctx)(&mockResponseWriter, &http.Request{
Method: "GET", URL: &url.URL{Path: "/status"},
})
require.Equal(t, http.StatusOK, mockResponseWriter.status)
@ -294,76 +218,55 @@ func TestAgent_HandleHTTP(t *testing.T) {
require.Equal(t, scheduler.StatusRunning, status.Status)
// Stop the DAG
agt.Signal(ctx, syscall.SIGTERM)
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusCancel
}, time.Second*2, time.Millisecond*100)
dagAgent.Abort()
dag.AssertLatestStatus(t, scheduler.StatusCancel)
})
t.Run("HTTP_InvalidRequest", func(t *testing.T) {
th := test.Setup(t)
// Start a long-running DAG
dag := testLoadDAG(t, "handle_http2.yaml")
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
dag := th.LoadDAGFile(t, "handle_http_invalid.yaml")
dagAgent := dag.Agent()
go func() {
err := agt.Run(ctx)
require.NoError(t, err)
dagAgent.RunCancel(t)
}()
// Wait for the DAG to start
cli := th.Client()
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusRunning
}, time.Second*2, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusRunning)
var mockResponseWriter = mockResponseWriter{}
// Request with an invalid path
agt.HandleHTTP(ctx)(&mockResponseWriter, &http.Request{
dagAgent.HandleHTTP(th.Context)(&mockResponseWriter, &http.Request{
Method: "GET",
URL: &url.URL{Path: "/invalid-path"},
})
require.Equal(t, http.StatusNotFound, mockResponseWriter.status)
// Stop the DAG
agt.Signal(ctx, syscall.SIGTERM)
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusCancel
}, time.Second*2, time.Millisecond*100)
dagAgent.Abort()
dag.AssertLatestStatus(t, scheduler.StatusCancel)
})
t.Run("HTTP_HandleCancel", func(t *testing.T) {
th := test.Setup(t)
// Start a long-running DAG
dag := testLoadDAG(t, "handle_http3.yaml")
agt := newAgent(th, genRequestID(), dag, &agent.Options{})
ctx := th.Context
dag := th.LoadDAGFile(t, "handle_http_cancel.yaml")
dagAgent := dag.Agent()
done := make(chan struct{})
go func() {
err := agt.Run(ctx)
require.NoError(t, err)
dagAgent.RunCancel(t)
close(done)
}()
// Wait for the DAG to start
cli := th.Client()
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusRunning
}, time.Second*2, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusRunning)
// Cancel the DAG
var mockResponseWriter = mockResponseWriter{}
agt.HandleHTTP(ctx)(&mockResponseWriter, &http.Request{
dagAgent.HandleHTTP(th.Context)(&mockResponseWriter, &http.Request{
Method: "POST",
URL: &url.URL{Path: "/stop"},
})
@ -371,11 +274,8 @@ func TestAgent_HandleHTTP(t *testing.T) {
require.Equal(t, "OK", mockResponseWriter.body)
// Wait for the DAG to stop
require.Eventually(t, func() bool {
status, err := cli.GetLatestStatus(ctx, dag)
require.NoError(t, err)
return status.Status == scheduler.StatusCancel
}, time.Second*10, time.Millisecond*100)
<-done
dag.AssertLatestStatus(t, scheduler.StatusCancel)
})
}
@ -403,38 +303,3 @@ func (h *mockResponseWriter) Write(body []byte) (int, error) {
func (h *mockResponseWriter) WriteHeader(statusCode int) {
h.status = statusCode
}
// testLoadDAG load the specified DAG file for testing
// without base config or parameters.
func testLoadDAG(t *testing.T, name string) *digraph.DAG {
filePath := filepath.Join(fileutil.MustGetwd(), "testdata", name)
dag, err := digraph.Load(context.Background(), "", filePath, "")
require.NoError(t, err)
return dag
}
func genRequestID() string {
id, err := uuid.NewRandom()
if err != nil {
panic(err)
}
return id.String()
}
func newAgent(
th test.Helper,
requestID string,
dag *digraph.DAG,
opts *agent.Options,
) *agent.Agent {
logDir, logFile := th.Config.Paths.LogDir, ""
return agent.New(
requestID,
dag,
logDir,
logFile,
th.Client(),
th.DataStore(),
opts,
)
}

View File

@ -18,14 +18,12 @@ import (
// Sender is a mailer interface.
type Sender interface {
Send(from string, to []string, subject, body string, attachments []string) error
Send(ctx context.Context, from string, to []string, subject, body string, attachments []string) error
}
// reporter is responsible for reporting the status of the scheduler
// to the user.
type reporter struct {
sender Sender
}
type reporter struct{ sender Sender }
func newReporter(sender Sender) *reporter {
return &reporter{sender: sender}
@ -40,13 +38,12 @@ func (r *reporter) reportStep(
logger.Info(ctx, "Step execution finished", "step", node.Data().Step.Name, "status", nodeStatus)
}
if nodeStatus == scheduler.NodeStatusError && node.Data().Step.MailOnError {
return r.sender.Send(
dag.ErrorMail.From,
[]string{dag.ErrorMail.To},
fmt.Sprintf("%s %s (%s)", dag.ErrorMail.Prefix, dag.Name, status.Status),
renderHTML(status.Nodes),
addAttachmentList(dag.ErrorMail.AttachLogs, status.Nodes),
)
fromAddress := dag.ErrorMail.From
toAddresses := []string{dag.ErrorMail.To}
subject := fmt.Sprintf("%s %s (%s)", dag.ErrorMail.Prefix, dag.Name, status.Status)
html := renderHTML(status.Nodes)
attachments := addAttachments(dag.ErrorMail.AttachLogs, status.Nodes)
return r.sender.Send(ctx, fromAddress, toAddresses, subject, html, attachments)
}
return nil
}
@ -56,105 +53,101 @@ func (r *reporter) getSummary(_ context.Context, status *model.Status, err error
var buf bytes.Buffer
_, _ = buf.Write([]byte("\n"))
_, _ = buf.Write([]byte("Summary ->\n"))
_, _ = buf.Write([]byte(renderSummary(status, err)))
_, _ = buf.Write([]byte(renderDAGSummary(status, err)))
_, _ = buf.Write([]byte("\n"))
_, _ = buf.Write([]byte("Details ->\n"))
_, _ = buf.Write([]byte(renderTable(status.Nodes)))
_, _ = buf.Write([]byte(renderStepSummary(status.Nodes)))
return buf.String()
}
// send is a function that sends a report mail.
func (r *reporter) send(
dag *digraph.DAG, status *model.Status, err error,
) error {
func (r *reporter) send(ctx context.Context, dag *digraph.DAG, status *model.Status, err error) error {
if err != nil || status.Status == scheduler.StatusError {
if dag.MailOn != nil && dag.MailOn.Failure {
return r.sender.Send(
dag.ErrorMail.From,
[]string{dag.ErrorMail.To},
fmt.Sprintf(
"%s %s (%s)", dag.ErrorMail.Prefix, dag.Name, status.Status,
),
renderHTML(status.Nodes),
addAttachmentList(dag.ErrorMail.AttachLogs, status.Nodes),
)
fromAddress := dag.ErrorMail.From
toAddresses := []string{dag.ErrorMail.To}
subject := fmt.Sprintf("%s %s (%s)", dag.ErrorMail.Prefix, dag.Name, status.Status)
html := renderHTML(status.Nodes)
attachments := addAttachments(dag.ErrorMail.AttachLogs, status.Nodes)
return r.sender.Send(ctx, fromAddress, toAddresses, subject, html, attachments)
}
} else if status.Status == scheduler.StatusSuccess {
if dag.MailOn != nil && dag.MailOn.Success {
_ = r.sender.Send(
dag.InfoMail.From,
[]string{dag.InfoMail.To},
fmt.Sprintf(
"%s %s (%s)", dag.InfoMail.Prefix, dag.Name, status.Status,
),
renderHTML(status.Nodes),
addAttachmentList(dag.InfoMail.AttachLogs, status.Nodes),
)
fromAddress := dag.InfoMail.From
toAddresses := []string{dag.InfoMail.To}
subject := fmt.Sprintf("%s %s (%s)", dag.InfoMail.Prefix, dag.Name, status.Status)
html := renderHTML(status.Nodes)
attachments := addAttachments(dag.InfoMail.AttachLogs, status.Nodes)
_ = r.sender.Send(ctx, fromAddress, toAddresses, subject, html, attachments)
}
}
return nil
}
func renderSummary(status *model.Status, err error) string {
t := table.NewWriter()
var errText string
if err != nil {
errText = err.Error()
}
t.AppendHeader(
table.Row{
"RequestID",
"Name",
"Started At",
"Finished At",
"Status",
"Params",
"Error",
},
)
t.AppendRow(table.Row{
var dagHeader = table.Row{
"RequestID",
"Name",
"Started At",
"Finished At",
"Status",
"Params",
"Error",
}
func renderDAGSummary(status *model.Status, err error) string {
dataRow := table.Row{
status.RequestID,
status.Name,
status.StartedAt,
status.FinishedAt,
status.Status,
status.Params,
errText,
})
return t.Render()
}
if err != nil {
dataRow = append(dataRow, err.Error())
} else {
dataRow = append(dataRow, "")
}
reportTable := table.NewWriter()
reportTable.AppendHeader(dagHeader)
reportTable.AppendRow(dataRow)
return reportTable.Render()
}
func renderTable(nodes []*model.Node) string {
t := table.NewWriter()
t.AppendHeader(
table.Row{
"#",
"Step",
"Started At",
"Finished At",
"Status",
"Command",
"Error",
},
)
var stepHeader = table.Row{
"#",
"Step",
"Started At",
"Finished At",
"Status",
"Command",
"Error",
}
func renderStepSummary(nodes []*model.Node) string {
stepTable := table.NewWriter()
stepTable.AppendHeader(stepHeader)
for i, n := range nodes {
var command = n.Step.Command
if n.Step.Args != nil {
command = strings.Join(
[]string{n.Step.Command, strings.Join(n.Step.Args, " ")}, " ",
)
}
t.AppendRow(table.Row{
fmt.Sprintf("%d", i+1),
number := fmt.Sprintf("%d", i+1)
dataRow := table.Row{
number,
n.Step.Name,
n.StartedAt,
n.FinishedAt,
n.StatusText,
command,
n.Error,
})
}
if n.Step.Args != nil {
dataRow = append(dataRow, strings.Join(n.Step.Args, " "))
} else {
dataRow = append(dataRow, "")
}
dataRow = append(dataRow, n.Error)
stepTable.AppendRow(dataRow)
}
return t.Render()
return stepTable.Render()
}
func renderHTML(nodes []*model.Node) string {
@ -204,7 +197,7 @@ func renderHTML(nodes []*model.Node) string {
return buffer.String()
}
func addAttachmentList(
func addAttachments(
trigger bool, nodes []*model.Node,
) (attachments []string) {
if trigger {

View File

@ -4,6 +4,7 @@
package agent
import (
"context"
"errors"
"fmt"
"testing"
@ -75,7 +76,7 @@ func testErrorMail(t *testing.T, rp *reporter, dag *digraph.DAG, nodes []*model.
dag.MailOn.Failure = true
dag.MailOn.Success = false
_ = rp.send(dag, &model.Status{
_ = rp.send(context.Background(), dag, &model.Status{
Status: scheduler.StatusError,
Nodes: nodes,
}, fmt.Errorf("Error"))
@ -91,7 +92,7 @@ func testNoErrorMail(t *testing.T, rp *reporter, dag *digraph.DAG, nodes []*mode
dag.MailOn.Failure = false
dag.MailOn.Success = true
err := rp.send(dag, &model.Status{
err := rp.send(context.Background(), dag, &model.Status{
Status: scheduler.StatusError,
Nodes: nodes,
}, nil)
@ -106,7 +107,7 @@ func testSuccessMail(t *testing.T, rp *reporter, dag *digraph.DAG, nodes []*mode
dag.MailOn.Failure = true
dag.MailOn.Success = true
err := rp.send(dag, &model.Status{
err := rp.send(context.Background(), dag, &model.Status{
Status: scheduler.StatusSuccess,
Nodes: nodes,
}, nil)
@ -125,13 +126,13 @@ func testRenderSummary(t *testing.T, _ *reporter, dag *digraph.DAG, nodes []*mod
Status: scheduler.StatusError,
Nodes: nodes,
}
summary := renderSummary(status, errors.New("test error"))
summary := renderDAGSummary(status, errors.New("test error"))
require.Contains(t, summary, "test error")
require.Contains(t, summary, dag.Name)
}
func testRenderTable(t *testing.T, _ *reporter, _ *digraph.DAG, nodes []*model.Node) {
summary := renderTable(nodes)
summary := renderStepSummary(nodes)
require.Contains(t, summary, nodes[0].Step.Name)
require.Contains(t, summary, nodes[0].Step.Args[0])
}
@ -144,7 +145,7 @@ type mockSender struct {
count int
}
func (m *mockSender) Send(from string, to []string, subject, body string, _ []string) error {
func (m *mockSender) Send(_ context.Context, from string, to []string, subject, body string, _ []string) error {
m.count += 1
m.from = from
m.to = to

View File

@ -1,3 +1,3 @@
steps:
- name: "1"
command: "sleep 1"
command: "sleep 10"

View File

@ -4,7 +4,6 @@
package client_test
import (
"context"
"fmt"
"net/http"
"path/filepath"
@ -16,29 +15,24 @@ import (
"github.com/dagu-org/dagu/internal/client"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/digraph/scheduler"
"github.com/dagu-org/dagu/internal/fileutil"
"github.com/dagu-org/dagu/internal/persistence/model"
"github.com/dagu-org/dagu/internal/sock"
"github.com/dagu-org/dagu/internal/test"
)
var testdataDir = filepath.Join(fileutil.MustGetwd(), "./testdata")
func TestClient_GetStatus(t *testing.T) {
t.Parallel()
th := test.Setup(t)
t.Run("Valid", func(t *testing.T) {
th := test.Setup(t)
filePath := testDAG("sleep1.yaml")
cli := th.Client()
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, filePath)
require.NoError(t, err)
dag := th.LoadDAGFile(t, "valid.yaml")
ctx := th.Context
socketServer, _ := sock.NewServer(
dagStatus.DAG.SockAddr(),
dag.SockAddr(),
func(w http.ResponseWriter, _ *http.Request) {
status := model.NewStatus(dagStatus.DAG, nil,
status := model.NewStatus(dag.DAG, nil,
scheduler.StatusRunning, 0, nil, nil)
w.WriteHeader(http.StatusOK)
b, _ := status.ToJSON()
@ -51,24 +45,17 @@ func TestClient_GetStatus(t *testing.T) {
_ = socketServer.Shutdown(ctx)
}()
time.Sleep(time.Millisecond * 100)
curStatus, err := cli.GetCurrentStatus(ctx, dagStatus.DAG)
require.NoError(t, err)
require.Equal(t, scheduler.StatusRunning, curStatus.Status)
dag.AssertCurrentStatus(t, scheduler.StatusRunning)
_ = socketServer.Shutdown(ctx)
curStatus, err = cli.GetCurrentStatus(ctx, dagStatus.DAG)
require.NoError(t, err)
require.Equal(t, scheduler.StatusNone, curStatus.Status)
dag.AssertCurrentStatus(t, scheduler.StatusNone)
})
t.Run("InvalidDAGName", func(t *testing.T) {
th := test.Setup(t)
ctx := th.Context
cli := th.Client
cli := th.Client()
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, testDAG("invalid_dag"))
dagStatus, err := cli.GetStatus(ctx, "invalid-dag-name")
require.Error(t, err)
require.NotNil(t, dagStatus)
@ -76,171 +63,137 @@ func TestClient_GetStatus(t *testing.T) {
require.Error(t, dagStatus.Error)
})
t.Run("UpdateStatus", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "update_status.yaml")
var (
file = testDAG("success.yaml")
requestID = "test-update-status"
now = time.Now()
cli = th.Client()
)
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, file)
requestID := "test-update-status"
now := time.Now()
ctx := th.Context
cli := th.Client
// Open the history store and write a status before updating it.
historyStore := th.DataStores.HistoryStore()
err := historyStore.Open(ctx, dag.Location, now, requestID)
require.NoError(t, err)
historyStore := th.DataStore().HistoryStore()
err = historyStore.Open(ctx, dagStatus.DAG.Location, now, requestID)
require.NoError(t, err)
status := testNewStatus(dagStatus.DAG, requestID,
scheduler.StatusSuccess, scheduler.NodeStatusSuccess)
status := testNewStatus(dag.DAG, requestID, scheduler.StatusSuccess, scheduler.NodeStatusSuccess)
err = historyStore.Write(ctx, status)
require.NoError(t, err)
_ = historyStore.Close(ctx)
time.Sleep(time.Millisecond * 100)
status, err = cli.GetStatusByRequestID(ctx, dagStatus.DAG, requestID)
// Get the status and check if it is the same as the one we wrote.
status, err = cli.GetStatusByRequestID(ctx, dag.DAG, requestID)
require.NoError(t, err)
require.Equal(t, scheduler.NodeStatusSuccess, status.Nodes[0].Status)
// Update the status.
newStatus := scheduler.NodeStatusError
status.Nodes[0].Status = newStatus
err = cli.UpdateStatus(ctx, dagStatus.DAG, status)
err = cli.UpdateStatus(ctx, dag.DAG, status)
require.NoError(t, err)
statusByRequestID, err := cli.GetStatusByRequestID(ctx, dagStatus.DAG, requestID)
statusByRequestID, err := cli.GetStatusByRequestID(ctx, dag.DAG, requestID)
require.NoError(t, err)
require.Equal(t, 1, len(status.Nodes))
require.Equal(t, newStatus, statusByRequestID.Nodes[0].Status)
})
t.Run("InvalidUpdateStatusWithInvalidReqID", func(t *testing.T) {
th := test.Setup(t)
var (
cli = th.Client()
file = testDAG("sleep1.yaml")
wrongReqID = "invalid-request-id"
)
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, file)
require.NoError(t, err)
wrongReqID := "invalid-request-id"
dag := th.LoadDAGFile(t, "invalid_reqid.yaml")
ctx := th.Context
cli := th.Client
// update with invalid request id
status := testNewStatus(dagStatus.DAG, wrongReqID, scheduler.StatusError,
status := testNewStatus(dag.DAG, wrongReqID, scheduler.StatusError,
scheduler.NodeStatusError)
// Check if the update fails.
err = cli.UpdateStatus(ctx, dagStatus.DAG, status)
err := cli.UpdateStatus(ctx, dag.DAG, status)
require.Error(t, err)
})
}
func TestClient_RunDAG(t *testing.T) {
th := test.Setup(t)
t.Run("RunDAG", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
filePath := testDAG("success.yaml")
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, filePath)
dag := th.LoadDAGFile(t, "run_dag.yaml")
dagStatus, err := th.Client.GetStatus(th.Context, dag.Location)
require.NoError(t, err)
err = cli.Start(ctx, dagStatus.DAG, client.StartOptions{})
err = th.Client.Start(th.Context, dagStatus.DAG, client.StartOptions{})
require.NoError(t, err)
status, err := cli.GetLatestStatus(ctx, dagStatus.DAG)
status, err := th.Client.GetLatestStatus(th.Context, dagStatus.DAG)
require.NoError(t, err)
require.Equal(t, scheduler.StatusSuccess.String(), status.Status.String())
})
t.Run("Stop", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "stop.yaml")
ctx := th.Context
cli := th.Client()
filePath := testDAG("sleep10.yaml")
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, filePath)
th.Client.StartAsync(ctx, dag.DAG, client.StartOptions{})
dag.AssertLatestStatus(t, scheduler.StatusRunning)
err := th.Client.Stop(ctx, dag.DAG)
require.NoError(t, err)
cli.StartAsync(ctx, dagStatus.DAG, client.StartOptions{})
require.Eventually(t, func() bool {
curStatus, _ := cli.GetCurrentStatus(ctx, dagStatus.DAG)
return curStatus.Status == scheduler.StatusRunning
}, time.Millisecond*1500, time.Millisecond*100)
_ = cli.Stop(ctx, dagStatus.DAG)
require.Eventually(t, func() bool {
latestStatus, _ := cli.GetLatestStatus(ctx, dagStatus.DAG)
return latestStatus.Status == scheduler.StatusCancel
}, time.Millisecond*1500, time.Millisecond*100)
dag.AssertLatestStatus(t, scheduler.StatusCancel)
})
t.Run("Restart", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "restart.yaml")
ctx := th.Context
cli := th.Client()
filePath := testDAG("success.yaml")
ctx := context.Background()
dagStatus, err := cli.GetStatus(ctx, filePath)
err := th.Client.Restart(ctx, dag.DAG, client.RestartOptions{})
require.NoError(t, err)
err = cli.Restart(ctx, dagStatus.DAG, client.RestartOptions{})
require.NoError(t, err)
status, err := cli.GetLatestStatus(ctx, dagStatus.DAG)
require.NoError(t, err)
require.Equal(t, scheduler.StatusSuccess, status.Status)
dag.AssertLatestStatus(t, scheduler.StatusSuccess)
})
t.Run("Retry", func(t *testing.T) {
th := test.Setup(t)
dag := th.LoadDAGFile(t, "retry.yaml")
ctx := th.Context
cli := th.Client
ctx := context.Background()
cli := th.Client()
filePath := testDAG("retry.yaml")
dagStatus, err := cli.GetStatus(ctx, filePath)
err := cli.Start(ctx, dag.DAG, client.StartOptions{Params: "x y z"})
require.NoError(t, err)
err = cli.Start(ctx, dagStatus.DAG, client.StartOptions{Params: "x y z"})
// Wait for the DAG to finish
dag.AssertLatestStatus(t, scheduler.StatusSuccess)
// Retry the DAG with the same params.
status, err := cli.GetLatestStatus(ctx, dag.DAG)
require.NoError(t, err)
status, err := cli.GetLatestStatus(ctx, dagStatus.DAG)
require.NoError(t, err)
require.Equal(t, scheduler.StatusSuccess, status.Status)
previousRequestID := status.RequestID
previousParams := status.Params
requestID := status.RequestID
params := status.Params
err = cli.Retry(ctx, dagStatus.DAG, requestID)
require.NoError(t, err)
status, err = cli.GetLatestStatus(ctx, dagStatus.DAG)
err = cli.Retry(ctx, dag.DAG, previousRequestID)
require.NoError(t, err)
require.Equal(t, scheduler.StatusSuccess, status.Status)
require.Equal(t, params, status.Params)
// Wait for the DAG to finish
dag.AssertLatestStatus(t, scheduler.StatusSuccess)
statusByRequestID, err := cli.GetStatusByRequestID(ctx, dagStatus.DAG, status.RequestID)
status, err = cli.GetLatestStatus(ctx, dag.DAG)
require.NoError(t, err)
require.Equal(t, status, statusByRequestID)
recentStatuses := cli.GetRecentHistory(ctx, dagStatus.DAG, 1)
require.Equal(t, status, recentStatuses[0].Status)
// Check if the params are the same as the previous run.
require.NotEqual(t, previousRequestID, status.RequestID)
require.Equal(t, previousParams, status.Params)
})
}
func TestClient_UpdateDAG(t *testing.T) {
t.Parallel()
t.Run("Update", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
th := test.Setup(t)
t.Run("Update", func(t *testing.T) {
ctx := th.Context
cli := th.Client
// valid DAG
validDAG := `name: test DAG
@ -266,10 +219,8 @@ steps:
require.Equal(t, validDAG, spec)
})
t.Run("Remove", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
ctx := th.Context
cli := th.Client
spec := `name: test DAG
steps:
@ -293,10 +244,8 @@ steps:
require.NoError(t, err)
})
t.Run("Create", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
ctx := th.Context
cli := th.Client
id, err := cli.CreateDAG(ctx, "test-dag")
require.NoError(t, err)
@ -307,10 +256,8 @@ steps:
require.Equal(t, "test-dag", dag.Name)
})
t.Run("Rename", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
ctx := th.Context
cli := th.Client
// Create a DAG to rename.
id, err := cli.CreateDAG(ctx, "old_name")
@ -328,21 +275,21 @@ steps:
}
func TestClient_ReadHistory(t *testing.T) {
t.Parallel()
th := test.Setup(t)
t.Run("TestClient_Empty", func(t *testing.T) {
th := test.Setup(t)
ctx := th.Context
cli := th.Client
dag := th.LoadDAGFile(t, "empty_status.yaml")
cli := th.Client()
filePath := testDAG("success.yaml")
ctx := context.Background()
_, err := cli.GetStatus(ctx, filePath)
_, err := cli.GetStatus(ctx, dag.Location)
require.NoError(t, err)
})
t.Run("TestClient_All", func(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
ctx := th.Context
cli := th.Client
// Create a DAG
_, err := cli.CreateDAG(ctx, "test-dag1")
@ -358,33 +305,21 @@ func TestClient_ReadHistory(t *testing.T) {
})
}
func testDAG(name string) string {
return filepath.Join(testdataDir, name)
}
func testNewStatus(dag *digraph.DAG, requestID string, status scheduler.Status,
nodeStatus scheduler.NodeStatus) *model.Status {
ret := model.NewStatus(
dag,
[]scheduler.NodeData{
{
State: scheduler.NodeState{Status: nodeStatus},
},
},
status,
0,
model.Time(time.Now()),
nil,
)
ret.RequestID = requestID
return ret
func testNewStatus(dag *digraph.DAG, requestID string, status scheduler.Status, nodeStatus scheduler.NodeStatus) *model.Status {
nodeData := scheduler.NodeData{
State: scheduler.NodeState{Status: nodeStatus},
}
startedAt := model.Time(time.Now())
statusModel := model.NewStatus(dag, []scheduler.NodeData{nodeData}, status, 0, startedAt, nil)
statusModel.RequestID = requestID
return statusModel
}
func TestClient_GetTagList(t *testing.T) {
th := test.Setup(t)
cli := th.Client()
ctx := context.Background()
ctx := th.Context
cli := th.Client
// Create DAG List
for i := 0; i < 40; i++ {

3
internal/client/testdata/restart.yaml vendored Normal file
View File

@ -0,0 +1,3 @@
steps:
- name: "1"
command: "true"

3
internal/client/testdata/run_dag.yaml vendored Normal file
View File

@ -0,0 +1,3 @@
steps:
- name: "1"
command: "true"

View File

@ -0,0 +1,3 @@
steps:
- name: "1"
command: "true"

3
internal/client/testdata/valid.yaml vendored Normal file
View File

@ -0,0 +1,3 @@
steps:
- name: "1"
command: "sleep 1"

View File

@ -252,6 +252,16 @@ func TestSubstituteStringFields(t *testing.T) {
}
}
func TestSubstituteStringFields_AnonymousStruct(t *testing.T) {
obj, err := SubstituteStringFields(struct {
Field string
}{
Field: "`echo hello`",
})
require.NoError(t, err)
require.Equal(t, "hello", obj.Field)
}
func TestSubstituteStringFields_NonStruct(t *testing.T) {
_, err := SubstituteStringFields("not a struct")
if err == nil {

View File

@ -30,22 +30,38 @@ type Config struct {
// Legacy fields for backward compatibility - Start
// Note: These fields are used for backward compatibility and should not be used in new code
DAGs string `mapstructure:"dags"`
Executable string `mapstructure:"executable"`
LogDir string `mapstructure:"logDir"`
DataDir string `mapstructure:"dataDir"`
SuspendFlagsDir string `mapstructure:"suspendFlagsDir"`
AdminLogsDir string `mapstructure:"adminLogsDir"`
BaseConfig string `mapstructure:"baseConfig"`
IsBasicAuth bool `mapstructure:"isBasicAuth"`
BasicAuthUsername string `mapstructure:"basicAuthUsername"`
BasicAuthPassword string `mapstructure:"basicAuthPassword"`
IsAuthToken bool `mapstructure:"isAuthToken"`
AuthToken string `mapstructure:"authToken"`
LogEncodingCharset string `mapstructure:"logEncodingCharset"`
NavbarColor string `mapstructure:"navbarColor"`
NavbarTitle string `mapstructure:"navbarTitle"`
MaxDashboardPageLimit int `mapstructure:"maxDashboardPageLimit"`
// Deprecated: Use Auth.Basic.Enabled instead
DAGs string `mapstructure:"dags"`
// Deprecated: Use Paths.Executable instead
Executable string `mapstructure:"executable"`
// Deprecated: Use Paths.LogDir instead
LogDir string `mapstructure:"logDir"`
// Deprecated: Use Paths.DataDir instead
DataDir string `mapstructure:"dataDir"`
// Deprecated: Use Paths.SuspendFlagsDir instead
SuspendFlagsDir string `mapstructure:"suspendFlagsDir"`
// Deprecated: Use Paths.AdminLogsDir instead
AdminLogsDir string `mapstructure:"adminLogsDir"`
// Deprecated: Use Paths.BaseConfig instead
BaseConfig string `mapstructure:"baseConfig"`
// Deprecated: Use Auth.Token.Enabled instead
IsBasicAuth bool `mapstructure:"isBasicAuth"`
// Deprecated: Use Auth.Basic.Username instead
BasicAuthUsername string `mapstructure:"basicAuthUsername"`
// Deprecated: Use Auth.Basic.Password instead
BasicAuthPassword string `mapstructure:"basicAuthPassword"`
// Deprecated: Use Auth.Token.Enabled instead
IsAuthToken bool `mapstructure:"isAuthToken"`
// Deprecated: Use Auth.Token.Value instead
AuthToken string `mapstructure:"authToken"`
// Deprecated: Use UI.LogEncodingCharset instead
LogEncodingCharset string `mapstructure:"logEncodingCharset"`
// Deprecated: Use UI.NavbarColor instead
NavbarColor string `mapstructure:"navbarColor"`
// Deprecated: Use UI.NavbarTitle instead
NavbarTitle string `mapstructure:"navbarTitle"`
// Deprecated: Use UI.MaxDashboardPageLimit instead
MaxDashboardPageLimit int `mapstructure:"maxDashboardPageLimit"`
// Legacy fields for backward compatibility - End
// Other settings

View File

@ -64,7 +64,10 @@ func (l *ConfigLoader) Load() (*Config, error) {
}
func (l *ConfigLoader) setupViper() error {
homeDir := l.getHomeDir()
homeDir, err := l.getHomeDir()
if err != nil {
return err
}
xdgConfig := l.getXDGConfig(homeDir)
resolver := newResolver("DAGU_HOME", filepath.Join(homeDir, ".dagu"), xdgConfig)
@ -75,13 +78,12 @@ func (l *ConfigLoader) setupViper() error {
return l.setExecutableDefault()
}
func (l *ConfigLoader) getHomeDir() string {
func (l *ConfigLoader) getHomeDir() (string, error) {
dir, err := os.UserHomeDir()
if err != nil {
log.Fatalf("could not determine home directory: %v", err)
return ""
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return dir
return dir, nil
}
func (l *ConfigLoader) getXDGConfig(homeDir string) XDGConfig {

View File

@ -84,7 +84,7 @@ message: %s
-----
`
func (e *mail) Run(_ context.Context) error {
func (e *mail) Run(ctx context.Context) error {
_, _ = e.stdout.Write(
[]byte(fmt.Sprintf(
mailLogTemplate,
@ -95,6 +95,7 @@ func (e *mail) Run(_ context.Context) error {
)),
)
err := e.mailer.Send(
ctx,
e.cfg.From,
[]string{e.cfg.To},
e.cfg.Subject,

View File

@ -16,6 +16,7 @@ import (
"github.com/mitchellh/mapstructure"
"golang.org/x/crypto/ssh"
"github.com/dagu-org/dagu/internal/cmdutil"
"github.com/dagu-org/dagu/internal/digraph"
)
@ -83,34 +84,44 @@ func newSSHExec(_ context.Context, step digraph.Step) (Executor, error) {
)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to create decoder: %w", err)
}
if err := md.Decode(step.ExecutorConfig.Config); err != nil {
return nil, err
return nil, fmt.Errorf("failed to decode ssh config: %w", err)
}
cfg := &sshExecConfig{
var port string
switch v := def.Port.(type) {
case int:
port = fmt.Sprintf("%d", v)
case string:
port = v
default:
port = fmt.Sprintf("%v", def.Port)
}
if port == "" {
port = "22"
}
cfg, err := cmdutil.SubstituteStringFields(sshExecConfig{
User: def.User,
IP: def.IP,
Key: def.Key,
Password: def.Password,
Port: port,
})
if err != nil {
return nil, fmt.Errorf("failed to substitute string fields for ssh config: %w", err)
}
// Handle Port as either string or int
port := os.ExpandEnv(fmt.Sprintf("%v", def.Port))
if port == "" {
port = "22"
}
cfg.Port = port
// StrictHostKeyChecking is not supported yet.
if def.StrictHostKeyChecking {
return nil, errStrictHostKey
}
// Select the authentication method.
authMethod, err := selectSSHAuthMethod(cfg)
authMethod, err := selectSSHAuthMethod(&cfg)
if err != nil {
return nil, err
}
@ -126,7 +137,7 @@ func newSSHExec(_ context.Context, step digraph.Step) (Executor, error) {
return &sshExec{
step: step,
config: cfg,
config: &cfg,
sshConfig: sshConfig,
stdout: os.Stdout,
}, nil

View File

@ -13,6 +13,7 @@ import (
"sync"
"syscall"
"github.com/dagu-org/dagu/internal/cmdutil"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/fileutil"
"github.com/google/uuid"
@ -41,10 +42,21 @@ func newSubWorkflow(
return nil, fmt.Errorf("failed to get dag context: %w", err)
}
subDAG, err := dagCtx.Finder.Find(ctx, step.SubWorkflow.Name)
config, err := cmdutil.SubstituteStringFields(struct {
Name string
Params string
}{
Name: step.SubWorkflow.Name,
Params: step.SubWorkflow.Params,
})
if err != nil {
return nil, fmt.Errorf("failed to substitute string fields: %w", err)
}
subDAG, err := dagCtx.Finder.Find(ctx, config.Name)
if err != nil {
return nil, fmt.Errorf(
"failed to find subworkflow %q: %w", step.SubWorkflow.Name, err,
"failed to find subworkflow %q: %w", config.Name, err,
)
}
@ -53,13 +65,11 @@ func newSubWorkflow(
return nil, fmt.Errorf("failed to generate request ID: %w", err)
}
params := os.ExpandEnv(step.SubWorkflow.Params)
args := []string{
"start",
fmt.Sprintf("--requestID=%s", requestID),
"--quiet",
fmt.Sprintf("--params=%q", params),
fmt.Sprintf("--params=%q", config.Params),
subDAG.Location,
}

View File

@ -8,7 +8,7 @@ import (
)
// parseFuncCall parses the function call in the step definition.
// deprecated: use subworkflow instead.
// Deprecated: use subworkflow instead.
func parseFuncCall(step *Step, call *callFuncDef, funcs []*funcDef) error {
if call == nil {
return nil

View File

@ -35,7 +35,7 @@ func NewExecutionGraph(steps ...digraph.Step) (*ExecutionGraph, error) {
}
for _, step := range steps {
node := &Node{data: NodeData{Step: step}}
node.init()
node.Init()
graph.dict[node.id] = node
graph.nodes = append(graph.nodes, node)
}
@ -55,7 +55,7 @@ func CreateRetryExecutionGraph(ctx context.Context, nodes ...*Node) (*ExecutionG
nodes: []*Node{},
}
for _, node := range nodes {
node.init()
node.Init()
graph.dict[node.id] = node
graph.nodes = append(graph.nodes, node)
}
@ -167,7 +167,7 @@ func (g *ExecutionGraph) setupRetry(ctx context.Context) error {
if retry[u] || dict[u] == NodeStatusError ||
dict[u] == NodeStatusCancel {
logger.Info(ctx, "clear node state", "step", g.dict[u].data.Step.Name)
g.dict[u].clearState()
g.dict[u].ClearState()
retry[u] = true
}
for _, v := range g.from[u] {

View File

@ -1,13 +1,14 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package scheduler
package scheduler_test
import (
"context"
"testing"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/digraph/scheduler"
"github.com/stretchr/testify/require"
)
@ -20,7 +21,7 @@ func TestCycleDetection(t *testing.T) {
step2.Name = "2"
step2.Depends = []string{"1"}
_, err := NewExecutionGraph(step1, step2)
_, err := scheduler.NewExecutionGraph(step1, step2)
if err == nil {
t.Fatal("cycle detection should be detected.")
@ -28,81 +29,80 @@ func TestCycleDetection(t *testing.T) {
}
func TestRetryExecution(t *testing.T) {
nodes := []*Node{
{
data: NodeData{
nodes := []*scheduler.Node{
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "1", Command: "true"},
State: NodeState{
Status: NodeStatusSuccess,
State: scheduler.NodeState{
Status: scheduler.NodeStatusSuccess,
},
},
},
{
data: NodeData{
}),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "2", Command: "true", Depends: []string{"1"}},
State: NodeState{
Status: NodeStatusError,
State: scheduler.NodeState{
Status: scheduler.NodeStatusError,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "3", Command: "true", Depends: []string{"2"}},
State: NodeState{
Status: NodeStatusCancel,
State: scheduler.NodeState{
Status: scheduler.NodeStatusCancel,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "4", Command: "true", Depends: []string{}},
State: NodeState{
Status: NodeStatusSkipped,
State: scheduler.NodeState{
Status: scheduler.NodeStatusSkipped,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "5", Command: "true", Depends: []string{"4"}},
State: NodeState{
Status: NodeStatusError,
State: scheduler.NodeState{
Status: scheduler.NodeStatusError,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "6", Command: "true", Depends: []string{"5"}},
State: NodeState{
Status: NodeStatusSuccess,
State: scheduler.NodeState{
Status: scheduler.NodeStatusSuccess,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "7", Command: "true", Depends: []string{"6"}},
State: NodeState{
Status: NodeStatusSkipped,
State: scheduler.NodeState{
Status: scheduler.NodeStatusSkipped,
},
},
},
{
data: NodeData{
),
scheduler.NodeWithData(
scheduler.NodeData{
Step: digraph.Step{Name: "8", Command: "true", Depends: []string{}},
State: NodeState{
Status: NodeStatusSkipped,
State: scheduler.NodeState{
Status: scheduler.NodeStatusSkipped,
},
},
},
),
}
ctx := context.Background()
_, err := CreateRetryExecutionGraph(ctx, nodes...)
_, err := scheduler.CreateRetryExecutionGraph(ctx, nodes...)
require.NoError(t, err)
require.Equal(t, NodeStatusSuccess, nodes[0].State().Status)
require.Equal(t, NodeStatusNone, nodes[1].State().Status)
require.Equal(t, NodeStatusNone, nodes[2].State().Status)
require.Equal(t, NodeStatusSkipped, nodes[3].State().Status)
require.Equal(t, NodeStatusNone, nodes[4].State().Status)
require.Equal(t, NodeStatusNone, nodes[5].State().Status)
require.Equal(t, NodeStatusNone, nodes[6].State().Status)
require.Equal(t, NodeStatusSkipped, nodes[7].State().Status)
require.Equal(t, scheduler.NodeStatusSuccess, nodes[0].State().Status)
require.Equal(t, scheduler.NodeStatusNone, nodes[1].State().Status)
require.Equal(t, scheduler.NodeStatusNone, nodes[2].State().Status)
require.Equal(t, scheduler.NodeStatusSkipped, nodes[3].State().Status)
require.Equal(t, scheduler.NodeStatusNone, nodes[4].State().Status)
require.Equal(t, scheduler.NodeStatusNone, nodes[5].State().Status)
require.Equal(t, scheduler.NodeStatusNone, nodes[6].State().Status)
require.Equal(t, scheduler.NodeStatusSkipped, nodes[7].State().Status)
}

View File

@ -9,7 +9,6 @@ import (
"context"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strings"
@ -96,6 +95,12 @@ func (s NodeStatus) String() string {
}
}
func NodeWithData(data NodeData) *Node {
return &Node{
data: data,
}
}
func NewNode(step digraph.Step, state NodeState) *Node {
return &Node{
data: NodeData{Step: step, State: state},
@ -108,6 +113,33 @@ func (n *Node) Data() NodeData {
return n.data
}
func (n *Node) ScriptFilename() string {
n.mu.RLock()
defer n.mu.RUnlock()
if n.scriptFile != nil {
return n.scriptFile.Name()
}
return ""
}
func (n *Node) CloseLog() error {
n.logLock.Lock()
defer n.logLock.Unlock()
if n.logFile != nil {
return n.logFile.Close()
}
return nil
}
func (n *Node) LogFilename() string {
n.logLock.Lock()
defer n.logLock.Unlock()
if n.logFile != nil {
return n.logFile.Name()
}
return ""
}
func (n *Node) setError(err error) {
n.mu.Lock()
defer n.mu.Unlock()
@ -134,7 +166,7 @@ func (n *Node) Execute(ctx context.Context) error {
})
ctx = digraph.WithDagContext(ctx, dagCtx)
cmd, err := n.setupExec(ctx)
cmd, err := n.SetupExec(ctx)
if err != nil {
return err
}
@ -164,13 +196,13 @@ func (n *Node) Execute(ctx context.Context) error {
return n.data.State.Error
}
func (n *Node) finish() {
func (n *Node) Finish() {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.FinishedAt = time.Now()
}
func (n *Node) setupExec(ctx context.Context) (executor.Executor, error) {
func (n *Node) SetupExec(ctx context.Context) (executor.Executor, error) {
n.mu.Lock()
defer n.mu.Unlock()
@ -233,42 +265,42 @@ func (n *Node) setupExec(ctx context.Context) (executor.Executor, error) {
return cmd, nil
}
func (n *Node) getRetryCount() int {
func (n *Node) GetRetryCount() int {
n.mu.RLock()
defer n.mu.RUnlock()
return n.data.State.RetryCount
}
func (n *Node) setRetriedAt(retriedAt time.Time) {
func (n *Node) SetRetriedAt(retriedAt time.Time) {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.RetriedAt = retriedAt
}
func (n *Node) getDoneCount() int {
func (n *Node) GetDoneCount() int {
n.mu.RLock()
defer n.mu.RUnlock()
return n.data.State.DoneCount
}
func (n *Node) clearState() {
func (n *Node) ClearState() {
n.data.State = NodeState{}
}
func (n *Node) setStatus(status NodeStatus) {
func (n *Node) SetStatus(status NodeStatus) {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.Status = status
}
func (n *Node) markError(err error) {
func (n *Node) MarkError(err error) {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.Error = err
n.data.State.Status = NodeStatusError
}
func (n *Node) signal(ctx context.Context, sig os.Signal, allowOverride bool) {
func (n *Node) Signal(ctx context.Context, sig os.Signal, allowOverride bool) {
n.mu.Lock()
defer n.mu.Unlock()
status := n.data.State.Status
@ -277,9 +309,9 @@ func (n *Node) signal(ctx context.Context, sig os.Signal, allowOverride bool) {
if allowOverride && n.data.Step.SignalOnStop != "" {
sigsig = unix.SignalNum(n.data.Step.SignalOnStop)
}
log.Printf("Sending %s signal to %s", sigsig, n.data.Step.Name)
logger.Info(ctx, "Sending signal", "signal", sigsig, "step", n.data.Step.Name)
if err := n.cmd.Kill(sigsig); err != nil {
logger.Error(ctx, "failed to send signal", "err", err)
logger.Error(ctx, "Failed to send signal", "err", err, "step", n.data.Step.Name)
}
}
if status == NodeStatusRunning {
@ -287,7 +319,7 @@ func (n *Node) signal(ctx context.Context, sig os.Signal, allowOverride bool) {
}
}
func (n *Node) cancel() {
func (n *Node) Cancel(ctx context.Context) {
n.mu.Lock()
defer n.mu.Unlock()
status := n.data.State.Status
@ -295,22 +327,31 @@ func (n *Node) cancel() {
n.data.State.Status = NodeStatusCancel
}
if n.cancelFunc != nil {
log.Printf("canceling node: %s", n.data.Step.Name)
logger.Info(ctx, "canceling node", "step", n.data.Step.Name)
n.cancelFunc()
}
}
func (n *Node) setup(logDir string, requestID string) error {
func (n *Node) Setup(logDir string, requestID string) error {
n.mu.Lock()
defer n.mu.Unlock()
// Set the log file path
n.data.State.StartedAt = time.Now()
n.data.State.Log = filepath.Join(logDir, fmt.Sprintf("%s.%s.%s.log",
fileutil.SafeName(n.data.Step.Name),
n.data.State.StartedAt.Format("20060102.15:04:05.000"),
stringutil.TruncString(requestID, 8),
))
startedAt := time.Now()
safeName := fileutil.SafeName(n.data.Step.Name)
timestamp := startedAt.Format("20060102.15:04:05.000")
postfix := stringutil.TruncString(requestID, 8)
logFilename := fmt.Sprintf("%s.%s.%s.log", safeName, timestamp, postfix)
if !fileutil.FileExists(logDir) {
if err := os.MkdirAll(logDir, 0755); err != nil {
return fmt.Errorf("failed to create log directory %q: %w", logDir, err)
}
}
filePath := filepath.Join(logDir, logFilename)
n.data.State.Log = filePath
n.data.State.StartedAt = startedAt
// Replace the special environment variables in the command
// Why this is necessary:
@ -320,7 +361,7 @@ func (n *Node) setup(logDir string, requestID string) error {
// we need to replace the name differently for each node.
envKeyLogPath := fmt.Sprintf("STEP_%d_DAG_EXECUTION_LOG_PATH", n.id)
if err := os.Setenv(envKeyLogPath, n.data.State.Log); err != nil {
return err
return fmt.Errorf("failed to set environment variable %q: %w", envKeyLogPath, err)
}
// Expand environment variables in the step
@ -335,19 +376,66 @@ func (n *Node) setup(logDir string, requestID string) error {
n.data.Step.Dir = os.ExpandEnv(n.data.Step.Dir)
if err := n.setupLog(); err != nil {
return err
return fmt.Errorf("failed to setup log: %w", err)
}
if err := n.setupStdout(); err != nil {
return err
return fmt.Errorf("failed to setup stdout: %w", err)
}
if err := n.setupStderr(); err != nil {
return err
return fmt.Errorf("failed to setup stderr: %w", err)
}
if err := n.setupRetryPolicy(); err != nil {
return err
return fmt.Errorf("failed to setup retry policy: %w", err)
}
if err := n.setupScript(); err != nil {
return fmt.Errorf("failed to setup script: %w", err)
}
return nil
}
return n.setupScript()
func (n *Node) Teardown() error {
if n.done {
return nil
}
n.logLock.Lock()
n.done = true
var lastErr error
for _, w := range []*bufio.Writer{n.logWriter, n.stdoutWriter} {
if w != nil {
if err := w.Flush(); err != nil {
lastErr = err
}
}
}
for _, f := range []*os.File{n.logFile, n.stdoutFile} {
if f != nil {
if err := f.Sync(); err != nil {
lastErr = err
}
_ = f.Close()
}
}
n.logLock.Unlock()
if n.scriptFile != nil {
_ = os.Remove(n.scriptFile.Name())
}
if lastErr != nil {
n.data.State.Error = lastErr
}
return lastErr
}
func (n *Node) IncRetryCount() {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.RetryCount++
}
func (n *Node) IncDoneCount() {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.DoneCount++
}
var (
@ -420,50 +508,6 @@ func (n *Node) setupLog() error {
n.logWriter = bufio.NewWriter(n.logFile)
return nil
}
func (n *Node) teardown() error {
if n.done {
return nil
}
n.logLock.Lock()
n.done = true
var lastErr error
for _, w := range []*bufio.Writer{n.logWriter, n.stdoutWriter} {
if w != nil {
if err := w.Flush(); err != nil {
lastErr = err
}
}
}
for _, f := range []*os.File{n.logFile, n.stdoutFile} {
if f != nil {
if err := f.Sync(); err != nil {
lastErr = err
}
_ = f.Close()
}
}
n.logLock.Unlock()
if n.scriptFile != nil {
_ = os.Remove(n.scriptFile.Name())
}
if lastErr != nil {
n.data.State.Error = lastErr
}
return lastErr
}
func (n *Node) incRetryCount() {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.RetryCount++
}
func (n *Node) incDoneCount() {
n.mu.Lock()
defer n.mu.Unlock()
n.data.State.DoneCount++
}
var (
nextNodeID = 1
@ -478,7 +522,7 @@ func getNextNodeID() int {
return v
}
func (n *Node) init() {
func (n *Node) Init() {
n.mu.Lock()
defer n.mu.Unlock()
if n.id != 0 {

View File

@ -1,382 +1,251 @@
// Copyright (C) 2024 Yota Hamada
// SPDX-License-Identifier: GPL-3.0-or-later
package scheduler
package scheduler_test
import (
"context"
"fmt"
"math/rand"
"os"
"path/filepath"
"path"
"syscall"
"testing"
"time"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/digraph/scheduler"
"github.com/dagu-org/dagu/internal/test"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)
func nodeTextCtxWithDagContext() context.Context {
return digraph.NewContext(context.Background(), nil, nil, nil, "", "")
type nodeHelper struct {
*scheduler.Node
test.Helper
reqID string
}
func TestExecute(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "true",
OutputVariables: &digraph.SyncMap{},
}}}
require.NoError(t, n.Execute(nodeTextCtxWithDagContext()))
require.Nil(t, n.data.State.Error)
type nodeOption func(*scheduler.NodeData)
func withNodeCmdArgs(cmd string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.CmdWithArgs = cmd
}
}
func TestError(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "false",
OutputVariables: &digraph.SyncMap{},
}}}
err := n.Execute(nodeTextCtxWithDagContext())
require.True(t, err != nil)
require.Equal(t, n.data.State.Error, err)
func withNodeCommand(command string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.Command = command
}
}
func TestSignal(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "sleep",
Args: []string{"100"},
OutputVariables: &digraph.SyncMap{},
}}}
go func() {
time.Sleep(100 * time.Millisecond)
n.signal(context.Background(), syscall.SIGTERM, false)
}()
n.setStatus(NodeStatusRunning)
err := n.Execute(nodeTextCtxWithDagContext())
require.Error(t, err)
require.Equal(t, n.State().Status, NodeStatusCancel)
func withNodeSignalOnStop(signal string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.SignalOnStop = signal
}
}
func TestSignalSpecified(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "sleep",
Args: []string{"100"},
OutputVariables: &digraph.SyncMap{},
SignalOnStop: "SIGINT",
}}}
go func() {
time.Sleep(100 * time.Millisecond)
n.signal(context.Background(), syscall.SIGTERM, true)
}()
n.setStatus(NodeStatusRunning)
err := n.Execute(nodeTextCtxWithDagContext())
require.Error(t, err)
require.Equal(t, n.State().Status, NodeStatusCancel)
func withNodeStdout(stdout string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.Stdout = stdout
}
}
func TestLog(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "echo",
Args: []string{"done"},
Dir: os.Getenv("HOME"),
OutputVariables: &digraph.SyncMap{},
}},
func withNodeStderr(stderr string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.Stderr = stderr
}
}
func withNodeScript(script string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.Script = script
}
}
func withNodeOutput(output string) nodeOption {
return func(data *scheduler.NodeData) {
data.Step.Output = output
}
}
func setupNode(t *testing.T, opts ...nodeOption) nodeHelper {
th := test.Setup(t)
data := scheduler.NodeData{Step: digraph.Step{}}
for _, opt := range opts {
opt(&data)
}
runTestNode(t, n)
node := scheduler.NodeWithData(data)
reqID := uuid.Must(uuid.NewRandom()).String()
dat, _ := os.ReadFile(n.logFile.Name())
require.Equal(t, "done\n", string(dat))
return nodeHelper{node, th, reqID}
}
func TestStdout(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "echo",
Args: []string{"done"},
Dir: os.Getenv("HOME"),
Stdout: "stdout.log",
OutputVariables: &digraph.SyncMap{},
}},
}
func (n nodeHelper) Execute(t *testing.T) {
err := n.Node.Setup(n.Config.Paths.LogDir, n.reqID)
require.NoError(t, err, "failed to setup node")
runTestNode(t, n)
err = n.Node.Execute(n.execContext())
require.NoError(t, err, "failed to execute node")
f := filepath.Join(os.Getenv("HOME"), n.data.Step.Stdout)
dat, _ := os.ReadFile(f)
require.Equal(t, "done\n", string(dat))
err = n.Teardown()
require.NoError(t, err, "failed to teardown node")
}
func TestStderr(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "sh",
Script: `
echo Stdout message >&1
echo Stderr message >&2
`,
Dir: os.Getenv("HOME"),
Stdout: "test-stderr-stdout.log",
Stderr: "test-stderr-stderr.log",
OutputVariables: &digraph.SyncMap{},
}},
}
func (n nodeHelper) ExecuteFail(t *testing.T, expectedErr string) {
err := n.Node.Execute(n.execContext())
require.Error(t, err, "expected error")
require.Contains(t, err.Error(), expectedErr, "unexpected error")
}
runTestNode(t, n)
func (n nodeHelper) AssertLogContains(t *testing.T, expected string) {
dat, err := os.ReadFile(n.Node.LogFilename())
require.NoErrorf(t, err, "failed to read log file %q", n.Node.LogFilename())
require.Contains(t, string(dat), expected, "log file does not contain expected string")
}
f := filepath.Join(os.Getenv("HOME"), n.data.Step.Stderr)
dat, _ := os.ReadFile(f)
require.Equal(t, "Stderr message\n", string(dat))
func (n nodeHelper) AssertOutput(t *testing.T, key, value string) {
require.NotNil(t, n.Node.Data().Step.OutputVariables, "output variables not set")
data, ok := n.Node.Data().Step.OutputVariables.Load(key)
require.True(t, ok, "output variable not found")
require.Equal(t, fmt.Sprintf("%s=%s", key, value), data, "output variable value mismatch")
}
f = filepath.Join(os.Getenv("HOME"), n.data.Step.Stdout)
dat, _ = os.ReadFile(f)
require.Equal(t, "Stdout message\n", string(dat))
func (n nodeHelper) execContext() context.Context {
return digraph.NewContext(n.Context, &digraph.DAG{}, nil, nil, n.reqID, "logFile")
}
func TestNode(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "echo",
Args: []string{"hello"},
OutputVariables: &digraph.SyncMap{},
}},
}
n.incDoneCount()
require.Equal(t, 1, n.getDoneCount())
t.Parallel()
n.incRetryCount()
require.Equal(t, 1, n.getRetryCount())
t.Run("Execute", func(t *testing.T) {
node := setupNode(t, withNodeCommand("true"))
node.Execute(t)
})
t.Run("Error", func(t *testing.T) {
node := setupNode(t, withNodeCommand("false"))
node.ExecuteFail(t, "exit status 1")
})
t.Run("Signal", func(t *testing.T) {
node := setupNode(t, withNodeCommand("sleep 3"))
go func() {
time.Sleep(100 * time.Millisecond)
node.Signal(node.Context, syscall.SIGTERM, false)
}()
n.id = 1
n.init()
require.Nil(t, n.data.Step.Variables)
node.SetStatus(scheduler.NodeStatusRunning)
n.id = 0
n.init()
require.Equal(t, n.data.Step.Variables, []string{})
}
func TestOutput(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
CmdWithArgs: "echo hello",
Output: "OUTPUT_TEST",
OutputVariables: &digraph.SyncMap{},
}},
}
err := n.setup(os.Getenv("HOME"), "test-request-id-output")
require.NoError(t, err)
defer func() {
_ = n.teardown()
}()
runTestNode(t, n)
dat, _ := os.ReadFile(n.logFile.Name())
require.Equal(t, "hello\n", string(dat))
require.Equal(t, "hello", os.ExpandEnv("$OUTPUT_TEST"))
// Use the previous output in the subsequent step
n2 := &Node{data: NodeData{
Step: digraph.Step{
CmdWithArgs: "echo $OUTPUT_TEST",
Output: "OUTPUT_TEST2",
OutputVariables: &digraph.SyncMap{},
}},
}
runTestNode(t, n2)
require.Equal(t, "hello", os.ExpandEnv("$OUTPUT_TEST2"))
// Use the previous output in the subsequent step inside a script
n3 := &Node{data: NodeData{
Step: digraph.Step{
Command: "sh",
Script: "echo $OUTPUT_TEST2",
Output: "OUTPUT_TEST3",
OutputVariables: &digraph.SyncMap{},
}},
}
runTestNode(t, n3)
require.Equal(t, "hello", os.ExpandEnv("$OUTPUT_TEST3"))
}
func TestOutputJson(t *testing.T) {
for i, test := range []struct {
CmdWithArgs string
Want string
WantArgs int
}{
{
CmdWithArgs: `echo {\"key\":\"value\"}`,
Want: `{"key":"value"}`,
WantArgs: 1,
},
{
CmdWithArgs: `echo "{\"key\": \"value\"}"`,
Want: `{"key": "value"}`,
WantArgs: 1,
},
} {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
CmdWithArgs: test.CmdWithArgs,
Output: "OUTPUT_JSON_TEST",
OutputVariables: &digraph.SyncMap{},
}},
}
err := n.setup(os.Getenv("HOME"), fmt.Sprintf("test-output-jsondb-%d", i))
require.NoError(t, err)
defer func() {
_ = n.teardown()
}()
runTestNode(t, n)
require.Equal(t, test.WantArgs, len(n.data.Step.Args))
v, _ := n.data.Step.OutputVariables.Load("OUTPUT_JSON_TEST")
require.Equal(t, fmt.Sprintf("OUTPUT_JSON_TEST=%s", test.Want), v)
require.Equal(t, test.Want, os.ExpandEnv("$OUTPUT_JSON_TEST"))
})
}
}
func TestOutputSpecialChar(t *testing.T) {
for i, test := range []struct {
CmdWithArgs string
Want string
WantArgs int
}{
{
CmdWithArgs: `echo "hello\tworld"`,
Want: `hello\tworld`,
WantArgs: 1,
},
{
CmdWithArgs: `echo hello"\t"world`,
Want: `hello\tworld`,
WantArgs: 1,
},
{
CmdWithArgs: `echo hello\tworld`,
Want: `hello\tworld`,
WantArgs: 1,
},
{
CmdWithArgs: `echo hello\nworld`,
Want: `hello\nworld`,
WantArgs: 1,
},
{
CmdWithArgs: `echo {\"key\":\"value\"}`,
Want: `{"key":"value"}`,
WantArgs: 1,
},
{
CmdWithArgs: `echo "{\"key\":\"value\"}"`,
Want: `{"key":"value"}`,
WantArgs: 1,
},
} {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
CmdWithArgs: test.CmdWithArgs,
Output: "OUTPUT_SPECIALCHAR_TEST",
OutputVariables: &digraph.SyncMap{},
}},
}
err := n.setup(os.Getenv("HOME"), fmt.Sprintf("test-output-specialchar-%d", i))
require.NoError(t, err)
defer func() {
_ = n.teardown()
}()
runTestNode(t, n)
require.Equal(t, test.WantArgs, len(n.data.Step.Args))
v, _ := n.data.Step.OutputVariables.Load("OUTPUT_SPECIALCHAR_TEST")
require.Equal(t, fmt.Sprintf("OUTPUT_SPECIALCHAR_TEST=%s", test.Want), v)
require.Equal(t, test.Want, os.ExpandEnv("$OUTPUT_SPECIALCHAR_TEST"))
})
}
}
func TestRunScript(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: "sh",
Args: []string{},
Script: `
echo hello
`,
Output: "SCRIPT_TEST",
OutputVariables: &digraph.SyncMap{},
}},
}
err := n.setup(os.Getenv("HOME"),
fmt.Sprintf("test-request-id-%d", rand.Int()))
require.NoError(t, err)
require.FileExists(t, n.logFile.Name())
b, _ := os.ReadFile(n.scriptFile.Name())
require.Equal(t, n.data.Step.Script, string(b))
require.NoError(t, err)
err = n.Execute(nodeTextCtxWithDagContext())
require.NoError(t, err)
err = n.teardown()
require.NoError(t, err)
require.Equal(t, "hello", os.Getenv("SCRIPT_TEST"))
require.NoFileExists(t, n.scriptFile.Name())
}
func TestTeardown(t *testing.T) {
n := &Node{data: NodeData{
Step: digraph.Step{
Command: testCommand,
Args: []string{},
OutputVariables: &digraph.SyncMap{},
}},
}
runTestNode(t, n)
// no error since done flag is true
err := n.teardown()
require.NoError(t, err)
require.NoError(t, n.data.State.Error)
// error
n.done = false
err = n.teardown()
require.Error(t, err)
require.Error(t, n.data.State.Error)
}
func runTestNode(t *testing.T, n *Node) {
t.Helper()
err := n.setup(os.Getenv("HOME"),
fmt.Sprintf("test-request-id-%d", rand.Int()))
require.NoError(t, err)
err = n.Execute(nodeTextCtxWithDagContext())
require.NoError(t, err)
err = n.teardown()
require.NoError(t, err)
node.ExecuteFail(t, "signal: terminated")
require.Equal(t, scheduler.NodeStatusCancel.String(), node.State().Status.String())
})
t.Run("SignalOnStop", func(t *testing.T) {
node := setupNode(t, withNodeCommand("sleep 3"), withNodeSignalOnStop("SIGINT"))
go func() {
time.Sleep(100 * time.Millisecond)
node.Signal(node.Context, syscall.SIGTERM, true) // allow override signal
}()
node.SetStatus(scheduler.NodeStatusRunning)
node.ExecuteFail(t, "signal: interrupt")
require.Equal(t, scheduler.NodeStatusCancel.String(), node.State().Status.String())
})
t.Run("LogOutput", func(t *testing.T) {
node := setupNode(t, withNodeCommand("echo hello"))
node.Execute(t)
node.AssertLogContains(t, "hello")
})
t.Run("Stdout", func(t *testing.T) {
random := path.Join(os.TempDir(), uuid.Must(uuid.NewRandom()).String())
defer os.Remove(random)
node := setupNode(t, withNodeCommand("echo hello"), withNodeStdout(random))
node.Execute(t)
file := node.Data().Step.Stdout
dat, _ := os.ReadFile(file)
require.Equalf(t, "hello\n", string(dat), "unexpected stdout content: %s", string(dat))
})
t.Run("Stderr", func(t *testing.T) {
random := path.Join(os.TempDir(), uuid.Must(uuid.NewRandom()).String())
defer os.Remove(random)
node := setupNode(t,
withNodeCommand("sh"),
withNodeStderr(random),
withNodeScript("echo hello >&2"),
)
node.Execute(t)
file := node.Data().Step.Stderr
dat, _ := os.ReadFile(file)
require.Equalf(t, "hello\n", string(dat), "unexpected stderr content: %s", string(dat))
})
t.Run("Output", func(t *testing.T) {
node := setupNode(t, withNodeCmdArgs("echo hello"), withNodeOutput("OUTPUT_TEST"))
node.Execute(t)
node.AssertOutput(t, "OUTPUT_TEST", "hello")
})
t.Run("OutputJSON", func(t *testing.T) {
node := setupNode(t, withNodeCmdArgs(`echo '{"key": "value"}'`), withNodeOutput("OUTPUT_JSON_TEST"))
node.Execute(t)
node.AssertOutput(t, "OUTPUT_JSON_TEST", `{"key": "value"}`)
})
t.Run("OutputJSONUnescaped", func(t *testing.T) {
node := setupNode(t, withNodeCmdArgs(`echo {\"key\":\"value\"}`), withNodeOutput("OUTPUT_JSON_TEST"))
node.Execute(t)
node.AssertOutput(t, "OUTPUT_JSON_TEST", `{"key":"value"}`)
})
t.Run("OutputSpecialChar", func(t *testing.T) {
t.Parallel()
testCases := []struct {
CmdWithArgs string
Want string
}{
{
CmdWithArgs: `echo "hello\tworld"`,
Want: `hello\tworld`,
},
{
CmdWithArgs: `echo hello"\t"world`,
Want: `hello\tworld`,
},
{
CmdWithArgs: `echo hello\tworld`,
Want: `hello\tworld`,
},
{
CmdWithArgs: `echo hello\nworld`,
Want: `hello\nworld`,
},
{
CmdWithArgs: `echo {\"key\":\"value\"}`,
Want: `{"key":"value"}`,
},
{
CmdWithArgs: `echo "{\"key\":\"value\"}"`,
Want: `{"key":"value"}`,
},
}
for i, tc := range testCases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
node := setupNode(t, withNodeCmdArgs(tc.CmdWithArgs), withNodeOutput("OUTPUT_SPECIALCHAR_TEST"))
node.Execute(t)
node.AssertOutput(t, "OUTPUT_SPECIALCHAR_TEST", tc.Want)
})
}
})
t.Run("Script", func(t *testing.T) {
node := setupNode(t, withNodeScript("echo hello"), withNodeOutput("SCRIPT_TEST"))
node.Execute(t)
node.AssertOutput(t, "SCRIPT_TEST", "hello")
// check script file is removed
scriptFilePath := node.ScriptFilename()
require.NotEmpty(t, scriptFilePath)
require.NoFileExists(t, scriptFilePath, "script file not removed")
})
}

View File

@ -129,7 +129,7 @@ func (sc *Scheduler) Schedule(ctx context.Context, graph *ExecutionGraph, done c
logger.Infof(ctx, "Checking pre conditions for \"%s\"", node.data.Step.Name)
if err := digraph.EvalConditions(node.data.Step.Preconditions); err != nil {
logger.Infof(ctx, "Pre conditions failed for \"%s\"", node.data.Step.Name)
node.setStatus(NodeStatusSkipped)
node.SetStatus(NodeStatusSkipped)
node.setError(err)
continue NodesIteration
}
@ -138,20 +138,20 @@ func (sc *Scheduler) Schedule(ctx context.Context, graph *ExecutionGraph, done c
wg.Add(1)
logger.Info(ctx, "Step execution started", "step", node.data.Step.Name)
node.setStatus(NodeStatusRunning)
node.SetStatus(NodeStatusRunning)
go func(node *Node) {
defer func() {
if panicObj := recover(); panicObj != nil {
stack := string(debug.Stack())
err := fmt.Errorf("panic recovered: %v\n%s", panicObj, stack)
logger.Error(ctx, "Panic occurred", "error", err, "step", node.data.Step.Name, "stack", stack)
node.markError(err)
node.MarkError(err)
sc.setLastError(err)
}
}()
defer func() {
node.finish()
node.Finish()
wg.Done()
}()
@ -159,7 +159,7 @@ func (sc *Scheduler) Schedule(ctx context.Context, graph *ExecutionGraph, done c
if err := sc.setupNode(node); err != nil {
setupSucceed = false
sc.setLastError(err)
node.markError(err)
node.MarkError(err)
}
defer func() {
@ -177,37 +177,40 @@ func (sc *Scheduler) Schedule(ctx context.Context, graph *ExecutionGraph, done c
case sc.isTimeout(graph.startedAt):
logger.Info(ctx, "Step execution deadline exceeded", "step", node.data.Step.Name, "error", execErr)
node.setStatus(NodeStatusCancel)
node.SetStatus(NodeStatusCancel)
sc.setLastError(execErr)
case sc.isCanceled():
sc.setLastError(execErr)
case node.retryPolicy.Limit > node.getRetryCount():
case node.retryPolicy.Limit > node.GetRetryCount():
// retry
node.incRetryCount()
logger.Info(ctx, "Step execution failed. Retrying...", "step", node.data.Step.Name, "error", execErr, "retry", node.getRetryCount())
node.IncRetryCount()
logger.Info(ctx, "Step execution failed. Retrying...", "step", node.data.Step.Name, "error", execErr, "retry", node.GetRetryCount())
time.Sleep(node.retryPolicy.Interval)
node.setRetriedAt(time.Now())
node.setStatus(NodeStatusNone)
node.SetRetriedAt(time.Now())
node.SetStatus(NodeStatusNone)
default:
// finish the node
node.setStatus(NodeStatusError)
node.markError(execErr)
node.SetStatus(NodeStatusError)
node.MarkError(execErr)
sc.setLastError(execErr)
}
}
if node.State().Status != NodeStatusCancel {
node.incDoneCount()
node.IncDoneCount()
}
if node.data.Step.RepeatPolicy.Repeat {
if execErr == nil || node.data.Step.ContinueOn.Failure {
if !sc.isCanceled() {
time.Sleep(node.data.Step.RepeatPolicy.Interval)
if done != nil {
done <- node
}
continue ExecRepeat
}
}
@ -223,12 +226,12 @@ func (sc *Scheduler) Schedule(ctx context.Context, graph *ExecutionGraph, done c
// finish the node
if node.State().Status == NodeStatusRunning {
node.setStatus(NodeStatusSuccess)
node.SetStatus(NodeStatusSuccess)
}
if err := sc.teardownNode(node); err != nil {
sc.setLastError(err)
node.setStatus(NodeStatusError)
node.SetStatus(NodeStatusError)
}
if done != nil {
@ -288,14 +291,14 @@ func (sc *Scheduler) setLastError(err error) {
func (sc *Scheduler) setupNode(node *Node) error {
if !sc.dry {
return node.setup(sc.logDir, sc.requestID)
return node.Setup(sc.logDir, sc.requestID)
}
return nil
}
func (sc *Scheduler) teardownNode(node *Node) error {
if !sc.dry {
return node.teardown()
return node.Teardown()
}
return nil
}
@ -383,7 +386,7 @@ func (sc *Scheduler) Signal(
// for a repetitive task, we'll wait for the job to finish
// until time reaches max wait time
if !node.data.Step.RepeatPolicy.Repeat {
node.signal(ctx, sig, allowOverride)
node.Signal(ctx, sig, allowOverride)
}
}
@ -399,10 +402,10 @@ func (sc *Scheduler) Signal(
}
// Cancel sends -1 signal to all nodes.
func (sc *Scheduler) Cancel(g *ExecutionGraph) {
func (sc *Scheduler) Cancel(ctx context.Context, g *ExecutionGraph) {
sc.setCanceled()
for _, node := range g.Nodes() {
node.cancel()
node.Cancel(ctx)
}
}
@ -456,20 +459,20 @@ func isReady(g *ExecutionGraph, node *Node) bool {
case NodeStatusError:
if !n.data.Step.ContinueOn.Failure {
ready = false
node.setStatus(NodeStatusCancel)
node.SetStatus(NodeStatusCancel)
node.setError(errUpstreamFailed)
}
case NodeStatusSkipped:
if !n.data.Step.ContinueOn.Skipped {
ready = false
node.setStatus(NodeStatusSkipped)
node.SetStatus(NodeStatusSkipped)
node.setError(errUpstreamSkipped)
}
case NodeStatusCancel:
ready = false
node.setStatus(NodeStatusCancel)
node.SetStatus(NodeStatusCancel)
case NodeStatusNone, NodeStatusRunning:
ready = false
@ -487,26 +490,27 @@ func (sc *Scheduler) runHandlerNode(ctx context.Context, graph *ExecutionGraph,
node.data.State.FinishedAt = time.Now()
}()
node.setStatus(NodeStatusRunning)
node.SetStatus(NodeStatusRunning)
if !sc.dry {
err := node.setup(sc.logDir, sc.requestID)
err := node.Setup(sc.logDir, sc.requestID)
if err != nil {
node.setStatus(NodeStatusError)
node.SetStatus(NodeStatusError)
return nil
}
defer func() {
_ = node.teardown()
_ = node.Teardown()
}()
ctx = sc.buildStepContextForHandler(ctx, graph)
err = node.Execute(ctx)
if err != nil {
node.setStatus(NodeStatusError)
node.SetStatus(NodeStatusError)
return err
} else {
node.setStatus(NodeStatusSuccess)
node.SetStatus(NodeStatusSuccess)
}
} else {
node.setStatus(NodeStatusSuccess)
node.SetStatus(NodeStatusSuccess)
}
return nil

File diff suppressed because it is too large Load Diff

View File

@ -7,29 +7,55 @@ package digraph
// This struct is used to unmarshal the YAML data.
// The data is then converted to the DAG struct.
type definition struct {
Name string
Group string
Description string
Schedule any
SkipIfSuccessful bool
LogDir string
Env any
HandlerOn handlerOnDef
Functions []*funcDef // deprecated
Steps []stepDef
SMTP smtpConfigDef
MailOn *mailOnDef
ErrorMail mailConfigDef
InfoMail mailConfigDef
TimeoutSec int
DelaySec int
RestartWaitSec int
// Name is the name of the DAG.
Name string
// Group is the group of the DAG for grouping DAGs on the UI.
Group string
// Description is the description of the DAG.
Description string
// Schedule is the cron schedule to run the DAG.
Schedule any
// SkipIfSuccessful is the flag to skip the DAG on schedule when it is
// executed manually before the schedule.
SkipIfSuccessful bool
// LogFile is the file to write the log.
LogDir string
// Env is the environment variables setting.
Env any
// HandlerOn is the handler configuration.
HandlerOn handlerOnDef
// Deprecated: Don't use this field
Functions []*funcDef // deprecated
// Steps is the list of steps to run.
Steps []stepDef
// SMTP is the SMTP configuration.
SMTP smtpConfigDef
// MailOn is the mail configuration.
MailOn *mailOnDef
// ErrorMail is the mail configuration for error.
ErrorMail mailConfigDef
// InfoMail is the mail configuration for information.
InfoMail mailConfigDef
// TimeoutSec is the timeout in seconds to finish the DAG.
TimeoutSec int
// DelaySec is the delay in seconds to start the first node.
DelaySec int
// RestartWaitSec is the wait in seconds to when the DAG is restarted.
RestartWaitSec int
// HistRetentionDays is the retention days of the history.
HistRetentionDays *int
Preconditions []*conditionDef
MaxActiveRuns int
Params string
// Precondition is the condition to run the DAG.
Preconditions []*conditionDef
// MaxActiveRuns is the maximum number of concurrent steps.
MaxActiveRuns int
// Params is the default parameters for the steps.
Params string
// MaxCleanUpTimeSec is the maximum time in seconds to clean up the DAG.
// It is a wait time to kill the processes when it is requested to stop.
// If the time is exceeded, the process is killed.
MaxCleanUpTimeSec *int
Tags any
// Tags is the tags for the DAG.
Tags any
}
type conditionDef struct {
@ -45,27 +71,48 @@ type handlerOnDef struct {
}
type stepDef struct {
Name string
Description string
Dir string
Executor any
Command any
Shell string
Script string
Stdout string
Stderr string
Output string
Depends []string
ContinueOn *continueOnDef
RetryPolicy *retryPolicyDef
RepeatPolicy *repeatPolicyDef
MailOnError bool
// Name is the name of the step.
Name string
// Description is the description of the step.
Description string
// Dir is the working directory of the step.
Dir string
// Executor is the executor configuration.
Executor any
// Command is the command to run (on shell).
Command any
// Shell is the shell to run the command. Default is `$SHELL` or `sh`.
Shell string
// Script is the script to run.
Script string
// Stdout is the file to write the stdout.
Stdout string
// Stderr is the file to write the stderr.
Stderr string
// Output is the variable name to store the output.
Output string
// Depends is the list of steps to depend on.
Depends []string
// ContinueOn is the condition to continue on.
ContinueOn *continueOnDef
// RetryPolicy is the retry policy.
RetryPolicy *retryPolicyDef
// RepeatPolicy is the repeat policy.
RepeatPolicy *repeatPolicyDef
// MailOnError is the flag to send mail on error.
MailOnError bool
// Precondition is the condition to run the step.
Preconditions []*conditionDef
SignalOnStop *string
Env string
Call *callFuncDef // deprecated
Run string // Run is a sub workflow to run
Params string // Params is the parameters for the sub workflow
// SignalOnStop is the signal when the step is requested to stop.
// When it is empty, the same signal as the parent process is sent.
// It can be KILL when the process does not stop over the timeout.
SignalOnStop *string
// Deprecated: Don't use this field
Call *callFuncDef // deprecated
// Run is a sub workflow to run
Run string
// Params is the parameters for the sub workflow
Params string
}
type funcDef struct {

View File

@ -5,7 +5,6 @@ package fileutil
import (
"errors"
"log"
"os"
"path/filepath"
"slices"
@ -99,13 +98,6 @@ func MustTempDir(pattern string) string {
return t
}
// LogErr logs error if it's not nil.
func LogErr(action string, err error) {
if err != nil {
log.Printf("%s failed. %s", action, err)
}
}
// TruncString TurnString returns truncated string.
func TruncString(val string, max int) string {
if len(val) > max {

View File

@ -4,10 +4,7 @@
package fileutil
import (
"bytes"
"errors"
"io"
"log"
"os"
"path/filepath"
"testing"
@ -142,33 +139,6 @@ func Test_MustTempDir(t *testing.T) {
})
}
func Test_LogErr(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
origStdout := os.Stdout
r, w, err := os.Pipe()
require.NoError(t, err)
os.Stdout = w
log.SetOutput(w)
defer func() {
os.Stdout = origStdout
log.SetOutput(origStdout)
}()
LogErr("test action", errors.New("test error"))
os.Stdout = origStdout
_ = w.Close()
var buf bytes.Buffer
_, err = io.Copy(&buf, r)
require.NoError(t, err)
s := buf.String()
require.Contains(t, s, "test action failed")
require.Contains(t, s, "test error")
})
}
func TestTruncString(t *testing.T) {
t.Run("Valid", func(t *testing.T) {
// Test empty string

View File

@ -5,14 +5,16 @@ package mailer
import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"log"
"net/smtp"
"os"
"path/filepath"
"strings"
"github.com/dagu-org/dagu/internal/logger"
)
// Mailer is a mailer that sends emails.
@ -50,16 +52,13 @@ var (
// SendMail sends an email.
func (m *Mailer) Send(
ctx context.Context,
from string,
to []string,
subject, body string,
attachments []string,
) error {
log.Printf(
"Sending an email to %s, subject is \"%s\"",
strings.Join(to, ","),
subject,
)
logger.Info(ctx, "Sending an email", "to", to, "subject", subject)
if m.username == "" && m.password == "" {
return m.sendWithNoAuth(from, to, subject, body, attachments)
}

View File

@ -6,44 +6,101 @@ package test
import (
"bytes"
"context"
"fmt"
"os"
"path/filepath"
"sync"
"syscall"
"testing"
"time"
"github.com/dagu-org/dagu/internal/agent"
"github.com/dagu-org/dagu/internal/client"
"github.com/dagu-org/dagu/internal/config"
"github.com/dagu-org/dagu/internal/digraph"
"github.com/dagu-org/dagu/internal/digraph/scheduler"
"github.com/dagu-org/dagu/internal/fileutil"
"github.com/dagu-org/dagu/internal/logger"
"github.com/dagu-org/dagu/internal/persistence"
dsclient "github.com/dagu-org/dagu/internal/persistence/client"
"github.com/spf13/viper"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var setupLock sync.Mutex
var executablePath string
func init() {
executablePath = filepath.Join(fileutil.MustGetwd(), "../../.local/bin/dagu")
}
// TestHelperOption defines functional options for Helper
type TestHelperOption func(*Helper)
// WithCaptureLoggingOutput creates a logging capture option
func WithCaptureLoggingOutput() TestHelperOption {
return func(h *Helper) {
h.LoggingOutput = &SyncBuffer{buf: new(bytes.Buffer)}
loggerInstance := logger.NewLogger(
logger.WithDebug(),
logger.WithFormat("text"),
logger.WithWriter(h.LoggingOutput),
)
h.Context = logger.WithFixedLogger(h.Context, loggerInstance)
}
}
// Setup creates a new Helper instance for testing
func Setup(t *testing.T, opts ...TestHelperOption) Helper {
setupLock.Lock()
defer setupLock.Unlock()
random := uuid.New().String()
tmpDir := fileutil.MustTempDir(fmt.Sprintf("dagu-test-%s", random))
require.NoError(t, os.Setenv("DAGU_HOME", tmpDir))
cfg, err := config.Load()
require.NoError(t, err)
cfg.Paths.Executable = executablePath
cfg.Paths.LogDir = filepath.Join(tmpDir, "logs")
dataStores := dsclient.NewDataStores(
cfg.Paths.DAGsDir,
cfg.Paths.DataDir,
cfg.Paths.SuspendFlagsDir,
dsclient.DataStoreOptions{
LatestStatusToday: cfg.LatestStatusToday,
},
)
helper := Helper{
Context: createDefaultContext(),
Config: cfg,
Client: client.New(dataStores, cfg.Paths.Executable, cfg.WorkDir),
DataStores: dataStores,
tmpDir: tmpDir,
}
for _, opt := range opts {
opt(&helper)
}
t.Cleanup(helper.Cleanup)
return helper
}
// Helper provides test utilities and configuration
type Helper struct {
Context context.Context
Config *config.Config
LoggingOutput *SyncBuffer
tmpDir string
}
Client client.Client
DataStores persistence.DataStores
// DataStore creates a new DataStores instance
func (h Helper) DataStore() persistence.DataStores {
return dsclient.NewDataStores(
h.Config.Paths.DAGsDir,
h.Config.Paths.DataDir,
h.Config.Paths.SuspendFlagsDir,
dsclient.DataStoreOptions{
LatestStatusToday: h.Config.LatestStatusToday,
},
)
}
// Client creates a new Client instance
func (h Helper) Client() client.Client {
return client.New(h.DataStore(), h.Config.Paths.Executable, h.Config.WorkDir)
tmpDir string
}
// Cleanup removes temporary test directories
@ -51,6 +108,155 @@ func (h Helper) Cleanup() {
_ = os.RemoveAll(h.tmpDir)
}
func (h Helper) LoadDAGFile(t *testing.T, filename string) DAG {
t.Helper()
filePath := filepath.Join(fileutil.MustGetwd(), "testdata", filename)
dag, err := digraph.Load(h.Context, "", filePath, "")
require.NoError(t, err)
return DAG{
Helper: &h,
DAG: dag,
}
}
type DAG struct {
*Helper
*digraph.DAG
}
func (d *DAG) AssertLatestStatus(t *testing.T, expected scheduler.Status) {
t.Helper()
var latestStatusValue scheduler.Status
assert.Eventually(t, func() bool {
latestStatus, err := d.Client.GetLatestStatus(d.Context, d.DAG)
require.NoError(t, err)
latestStatusValue = latestStatus.Status
return latestStatus.Status == expected
}, time.Second*3, time.Millisecond*50, "expected latest status to be %q, got %q", expected, latestStatusValue)
}
func (d *DAG) AssertHistoryCount(t *testing.T, expected int) {
t.Helper()
// the +1 to the limit is needed to ensure that the number of the history
// entries is exactly the expected number
history := d.Client.GetRecentHistory(d.Context, d.DAG, expected+1)
require.Len(t, history, expected)
}
func (d *DAG) AssertCurrentStatus(t *testing.T, expected scheduler.Status) {
t.Helper()
var lastCurrentStatus scheduler.Status
assert.Eventuallyf(t, func() bool {
currentStatus, err := d.Client.GetCurrentStatus(d.Context, d.DAG)
require.NoError(t, err)
lastCurrentStatus = currentStatus.Status
return currentStatus.Status == expected
}, time.Second*2, time.Millisecond*50, "expected current status to be %q, got %q", expected, lastCurrentStatus)
}
type AgentOption func(*Agent)
func WithAgentOptions(options *agent.Options) AgentOption {
return func(a *Agent) {
a.opts = options
}
}
func (d *DAG) Agent(opts ...AgentOption) *Agent {
requestID := genRequestID()
logDir := d.Config.Paths.LogDir
logFile := filepath.Join(d.Config.Paths.LogDir, requestID+".log")
helper := &Agent{
Helper: d.Helper,
DAG: d.DAG,
}
for _, opt := range opts {
opt(helper)
}
if helper.opts == nil {
helper.opts = &agent.Options{}
}
helper.Agent = agent.New(
requestID,
d.DAG,
logDir,
logFile,
d.Client,
d.DataStores,
helper.opts,
)
return helper
}
func genRequestID() string {
id, err := uuid.NewRandom()
if err != nil {
panic(err)
}
return id.String()
}
type Agent struct {
*Helper
*digraph.DAG
*agent.Agent
opts *agent.Options
}
func (a *Agent) RunError(t *testing.T) {
t.Helper()
err := a.Agent.Run(a.Context)
assert.Error(t, err)
status := a.Agent.Status().Status
require.Equal(t, scheduler.StatusError.String(), status.String())
}
func (a *Agent) RunCancel(t *testing.T) {
t.Helper()
err := a.Agent.Run(a.Context)
assert.NoError(t, err)
status := a.Agent.Status().Status
require.Equal(t, scheduler.StatusCancel.String(), status.String())
}
func (a *Agent) RunCheckErr(t *testing.T, expectedErr string) {
t.Helper()
err := a.Agent.Run(a.Context)
require.Error(t, err, "expected error %q, got nil", expectedErr)
require.Contains(t, err.Error(), expectedErr)
status := a.Agent.Status()
require.Equal(t, scheduler.StatusCancel.String(), status.Status.String())
}
func (a *Agent) RunSuccess(t *testing.T) {
t.Helper()
err := a.Agent.Run(a.Context)
assert.NoError(t, err)
status := a.Agent.Status().Status
require.Equal(t, scheduler.StatusSuccess.String(), status.String())
}
func (a *Agent) Abort() {
a.Signal(a.Context, syscall.SIGTERM)
}
// SyncBuffer provides thread-safe buffer operations
type SyncBuffer struct {
buf *bytes.Buffer
@ -69,71 +275,6 @@ func (b *SyncBuffer) String() string {
return b.buf.String()
}
// TestHelperOption defines functional options for Helper
type TestHelperOption func(*Helper)
// WithCaptureLoggingOutput creates a logging capture option
func WithCaptureLoggingOutput() TestHelperOption {
return func(h *Helper) {
h.LoggingOutput = &SyncBuffer{buf: new(bytes.Buffer)}
loggerInstance := logger.NewLogger(
logger.WithDebug(),
logger.WithFormat("text"),
logger.WithWriter(h.LoggingOutput),
)
h.Context = logger.WithFixedLogger(h.Context, loggerInstance)
}
}
var setupLock sync.Mutex
// Setup creates a new Helper instance for testing
func Setup(t *testing.T, opts ...TestHelperOption) Helper {
setupLock.Lock()
defer setupLock.Unlock()
tmpDir := fileutil.MustTempDir("test")
require.NoError(t, os.Setenv("DAGU_HOME", tmpDir))
cfg, err := config.Load()
require.NoError(t, err)
cfg.Paths.Executable = filepath.Join(fileutil.MustGetwd(), "../../.local/bin/dagu")
helper := Helper{
Context: createDefaultContext(),
Config: cfg,
tmpDir: tmpDir,
}
for _, opt := range opts {
opt(&helper)
}
t.Cleanup(helper.Cleanup)
return helper
}
// SetupForDir creates a new Helper instance with a specific configuration directory
func SetupForDir(t *testing.T, dir string) Helper {
setupLock.Lock()
defer setupLock.Unlock()
tmpDir := fileutil.MustTempDir("test")
require.NoError(t, os.Setenv("HOME", tmpDir))
configureViper(dir)
cfg, err := config.Load()
require.NoError(t, err)
return Helper{
Context: createDefaultContext(),
Config: cfg,
tmpDir: tmpDir,
}
}
// createDefaultContext creates a context with default logger settings
func createDefaultContext() context.Context {
ctx := context.Background()
@ -142,10 +283,3 @@ func createDefaultContext() context.Context {
logger.WithFormat("text"),
))
}
// configureViper sets up Viper configuration
func configureViper(dir string) {
viper.AddConfigPath(dir)
viper.SetConfigType("yaml")
viper.SetConfigName("admin")
}